lib/ollama/documents.rb



require 'numo/narray'
require 'digest'

class Ollama::Documents
end
require 'ollama/documents/memory_cache'
require 'ollama/documents/redis_cache'
module Ollama::Documents::Splitters
end
require 'ollama/documents/splitters/character'
require 'ollama/documents/splitters/semantic'

class Ollama::Documents
  include Ollama::Utils::Math

  class Record < JSON::GenericObject
    def to_s
      my_tags = Ollama::Utils::Tags.new(tags)
      my_tags.empty? or my_tags = " #{my_tags}"
      "#<#{self.class} #{text.inspect}#{my_tags} #{similarity || 'n/a'}>"
    end

    def ==(other)
      text == other.text
    end

    alias inspect to_s
  end

  def initialize(ollama:, model:, model_options: nil, collection: :default, cache: MemoryCache, redis_url: nil)
    @ollama, @model, @model_options, @collection = ollama, model, model_options, collection
    @cache, @redis_url = connect_cache(cache), redis_url
  end

  attr_reader :ollama, :model, :collection

  def collection=(new_collection)
    @collection = new_collection
    @cache.prefix = prefix
  end

  def add(inputs, batch_size: 10, source: nil, tags: [])
    inputs = Array(inputs)
    tags   = Ollama::Utils::Tags.new(tags)
    source and tags.add File.basename(source)
    inputs.map! { |i|
      text = i.respond_to?(:read) ? i.read : i.to_s
      text
    }
    inputs.reject! { |i| exist?(i) }
    inputs.empty? and return self
    batches = inputs.each_slice(batch_size).
      with_infobar(
        label: "Add #{tags}",
        total: inputs.size
      )
    batches.each do |batch|
      embeddings = fetch_embeddings(model:, options: @model_options, input: batch)
      batch.zip(embeddings) do |text, embedding|
        norm       = norm(embedding)
        self[text] = Record[text:, embedding:, norm:, source:, tags: tags.to_a]
      end
      infobar.progress by: batch.size
    end
    infobar.newline
    self
  end
  alias << add

  def [](text)
    @cache[key(text)]
  end

  def []=(text, record)
    @cache[key(text)] = record
  end

  def exist?(text)
    @cache.key?(key(text))
  end

  def delete(text)
    @cache.delete(key(text))
  end

  def size
    @cache.size
  end

  def clear
    @cache.clear
  end

  def find(string, tags: nil, prompt: nil)
    needle      = convert_to_vector(string, prompt:)
    needle_norm = norm(needle)
    records = @cache
    if tags
      tags = Ollama::Utils::Tags.new(tags)
      records = records.select { |_key, record| (tags & record.tags).size >= 1 }
    end
    records = records.sort_by { |key, record|
      record.key        = key
      record.similarity = cosine_similarity(
        a: needle,
        b: record.embedding,
        a_norm: needle_norm,
        b_norm: record.norm,
      )
    }
    records.transpose.last&.reverse.to_a
  end

  def collections
    case @cache
    when MemoryCache
      [ @collection ]
    when RedisCache
      prefix = '%s-' % self.class
      Documents::RedisCache.new(prefix:, url: @redis_url).map { _1[/#{prefix}(.*)-/, 1] }.uniq
    else
      []
    end
  end

  def tags
    @cache.inject(Ollama::Utils::Tags.new) { |t, (_, record)| t.merge(record.tags) }
  end

  private

  def connect_cache(cache_class)
    cache = nil
    if cache_class == RedisCache
      begin
        cache = cache_class.new(prefix:)
        cache.size
      rescue Redis::CannotConnectError
        STDERR.puts(
          "Cannot connect to redis URL #{@redis_url.inspect}, "\
          "falling back to MemoryCache."
        )
      end
    end
  ensure
    cache ||= MemoryCache.new(prefix:)
    return cache
  end

  def convert_to_vector(input, prompt: nil)
    if prompt
      input = prompt % input
    end
    if input.is_a?(String)
      Numo::NArray[*fetch_embeddings(model:, input:).first]
    else
      super(input)
    end
  end

  def fetch_embeddings(model:, input:, options: nil)
    @ollama.embed(model:, input:, options:).embeddings
  end

  def prefix
    '%s-%s-' % [ self.class, @collection ]
  end

  def key(input)
    Digest::SHA256.hexdigest(input)
  end
end