module Neighbor::Model

def self.neighbor_attributes

def self.neighbor_attributes
  parent_attributes =
    if superclass.respond_to?(:neighbor_attributes)
      superclass.neighbor_attributes
    else
      {}
    end
  parent_attributes.merge(@neighbor_attributes || {})
end

def has_neighbors(*attribute_names, dimensions: nil, normalize: nil, type: nil)

def has_neighbors(*attribute_names, dimensions: nil, normalize: nil, type: nil)
  if attribute_names.empty?
    raise ArgumentError, "has_neighbors requires an attribute name"
  end
  attribute_names.map!(&:to_sym)
  class_eval do
    @neighbor_attributes ||= {}
    if @neighbor_attributes.empty?
      def self.neighbor_attributes
        parent_attributes =
          if superclass.respond_to?(:neighbor_attributes)
            superclass.neighbor_attributes
          else
            {}
          end
        parent_attributes.merge(@neighbor_attributes || {})
      end
    end
    attribute_names.each do |attribute_name|
      raise Error, "has_neighbors already called for #{attribute_name.inspect}" if neighbor_attributes[attribute_name]
      @neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize, type: type&.to_sym}
    end
    if ActiveRecord::VERSION::STRING.to_f >= 7.2
      decorate_attributes(attribute_names) do |name, cast_type|
        Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: name)
      end
    else
      attribute_names.each do |attribute_name|
        attribute attribute_name do |cast_type|
          Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: attribute_name)
        end
      end
    end
    if normalize
      if ActiveRecord::VERSION::STRING.to_f >= 7.1
        attribute_names.each do |attribute_name|
          normalizes attribute_name, with: ->(v) { Neighbor::Utils.normalize(v, column_info: columns_hash[attribute_name.to_s]) }
        end
      else
        attribute_names.each do |attribute_name|
          attribute attribute_name do |cast_type|
            Neighbor::NormalizedAttribute.new(cast_type: cast_type, model: self, attribute_name: attribute_name)
          end
        end
      end
    end
    return if @neighbor_attributes.size != attribute_names.size
    validate do
      adapter = Utils.adapter(self.class)
      self.class.neighbor_attributes.each do |k, v|
        value = read_attribute(k)
        next if value.nil?
        column_info = self.class.columns_hash[k.to_s]
        dimensions = v[:dimensions]
        dimensions ||= column_info&.limit unless column_info&.type == :binary
        type = v[:type] || Utils.type(adapter, column_info&.type)
        if !Neighbor::Utils.validate_dimensions(value, type, dimensions, adapter).nil?
          errors.add(k, "must have #{dimensions} dimensions")
        end
        if !Neighbor::Utils.validate_finite(value, type)
          errors.add(k, "must have finite values")
        end
      end
    end
    scope :nearest_neighbors, ->(attribute_name, vector, distance:, precision: nil) {
      attribute_name = attribute_name.to_sym
      options = neighbor_attributes[attribute_name]
      raise ArgumentError, "Invalid attribute" unless options
      normalize = options[:normalize]
      dimensions = options[:dimensions]
      type = options[:type]
      return none if vector.nil?
      distance = distance.to_s
      column_info = columns_hash[attribute_name.to_s]
      column_type = column_info&.type
      adapter = Neighbor::Utils.adapter(klass)
      if type && adapter != :sqlite
        raise ArgumentError, "type only works with SQLite"
      end
      operator = Neighbor::Utils.operator(adapter, column_type, distance)
      raise ArgumentError, "Invalid distance: #{distance}" unless operator
      # ensure normalize set (can be true or false)
      normalize_required = Utils.normalize_required?(adapter, column_type)
      if distance == "cosine" && normalize_required && normalize.nil?
        raise Neighbor::Error, "Set normalize for cosine distance with cube"
      end
      column_attribute = klass.type_for_attribute(attribute_name)
      vector = column_attribute.cast(vector)
      dimensions ||= column_info&.limit unless column_info&.type == :binary
      Neighbor::Utils.validate(vector, dimensions: dimensions, type: type || Utils.type(adapter, column_info&.type), adapter: adapter)
      vector = Neighbor::Utils.normalize(vector, column_info: column_info) if normalize
      quoted_attribute = nil
      query = nil
      connection_pool.with_connection do |c|
        quoted_attribute = "#{c.quote_table_name(table_name)}.#{c.quote_column_name(attribute_name)}"
        query = c.quote(column_attribute.serialize(vector))
      end
      if !precision.nil?
        if adapter != :postgresql || column_type != :vector
          raise ArgumentError, "Precision not supported for this type"
        end
        case precision.to_s
        when "half"
          cast_dimensions = dimensions || column_info&.limit
          raise ArgumentError, "Unknown dimensions" unless cast_dimensions
          quoted_attribute += "::halfvec(#{connection_pool.with_connection { |c| c.quote(cast_dimensions.to_i) }})"
        else
          raise ArgumentError, "Invalid precision"
        end
      end
      order = Utils.order(adapter, type, operator, quoted_attribute, query)
      # https://stats.stackexchange.com/questions/146221/is-cosine-similarity-identical-to-l2-normalized-euclidean-distance
      # with normalized vectors:
      # cosine similarity = 1 - (euclidean distance)**2 / 2
      # cosine distance = 1 - cosine similarity
      # this transformation doesn't change the order, so only needed for select
      neighbor_distance =
        if distance == "cosine" && normalize_required
          "POWER(#{order}, 2) / 2.0"
        elsif [:vector, :halfvec, :sparsevec].include?(column_type) && distance == "inner_product"
          "(#{order}) * -1"
        else
          order
        end
      # for select, use column_names instead of * to account for ignored columns
      select_columns = select_values.any? ? [] : column_names
      select(*select_columns, "#{neighbor_distance} AS neighbor_distance")
        .where.not(attribute_name => nil)
        .reorder(Arel.sql(order))
    }
    def nearest_neighbors(attribute_name, **options)
      attribute_name = attribute_name.to_sym
      # important! check if neighbor attribute before accessing
      raise ArgumentError, "Invalid attribute" unless self.class.neighbor_attributes[attribute_name]
      self.class
        .where.not(Array(self.class.primary_key).to_h { |k| [k, self[k]] })
        .nearest_neighbors(attribute_name, self[attribute_name], **options)
    end
  end
end

def nearest_neighbors(attribute_name, **options)

def nearest_neighbors(attribute_name, **options)
  attribute_name = attribute_name.to_sym
  # important! check if neighbor attribute before accessing
  raise ArgumentError, "Invalid attribute" unless self.class.neighbor_attributes[attribute_name]
  self.class
    .where.not(Array(self.class.primary_key).to_h { |k| [k, self[k]] })
    .nearest_neighbors(attribute_name, self[attribute_name], **options)
end