module Neighbor::Utils

def self.adapter(model)

def self.adapter(model)
  case model.connection_db_config.adapter
  when /sqlite/i
    :sqlite
  when /mysql|trilogy/i
    model.connection_pool.with_connection { |c| c.try(:mariadb?) } ? :mariadb : :mysql
  else
    :postgresql
  end
end

def self.array?(value)

def self.array?(value)
  !value.nil? && value.respond_to?(:to_a)
end

def self.normalize(value, column_info:)

def self.normalize(value, column_info:)
  return nil if value.nil?
  raise Error, "Normalize not supported for type" unless [:cube, :vector, :halfvec].include?(column_info&.type)
  norm = Math.sqrt(value.sum { |v| v * v })
  # store zero vector as all zeros
  # since NaN makes the distance always 0
  # could also throw error
  norm > 0 ? value.map { |v| v / norm } : value
end

def self.normalize_required?(adapter, column_type)

def self.normalize_required?(adapter, column_type)
  case adapter
  when :postgresql
    column_type == :cube
  else
    false
  end
end

def self.operator(adapter, column_type, distance)

def self.operator(adapter, column_type, distance)
  case adapter
  when :sqlite
    case distance
    when "euclidean"
      "vec_distance_L2"
    when "cosine"
      "vec_distance_cosine"
    when "taxicab"
      "vec_distance_L1"
    when "hamming"
      "vec_distance_hamming"
    end
  when :mariadb
    case column_type
    when :vector
      case distance
      when "euclidean"
        "VEC_DISTANCE_EUCLIDEAN"
      when "cosine"
        "VEC_DISTANCE_COSINE"
      end
    when :integer
      case distance
      when "hamming"
        "BIT_COUNT"
      end
    else
      raise ArgumentError, "Unsupported type: #{column_type}"
    end
  when :mysql
    case column_type
    when :vector
      case distance
      when "cosine"
        "COSINE"
      when "euclidean"
        "EUCLIDEAN"
      end
    when :binary
      case distance
      when "hamming"
        "BIT_COUNT"
      end
    else
      raise ArgumentError, "Unsupported type: #{column_type}"
    end
  else
    case column_type
    when :bit
      case distance
      when "hamming"
        "<~>"
      when "jaccard"
        "<%>"
      when "hamming2"
        "#"
      end
    when :vector, :halfvec, :sparsevec
      case distance
      when "inner_product"
        "<#>"
      when "cosine"
        "<=>"
      when "euclidean"
        "<->"
      when "taxicab"
        "<+>"
      end
    when :cube
      case distance
      when "taxicab"
        "<#>"
      when "chebyshev"
        "<=>"
      when "euclidean", "cosine"
        "<->"
      end
    else
      raise ArgumentError, "Unsupported type: #{column_type}"
    end
  end
end

def self.order(adapter, type, operator, quoted_attribute, query)

def self.order(adapter, type, operator, quoted_attribute, query)
  case adapter
  when :sqlite
    case type
    when :int8
      "#{operator}(vec_int8(#{quoted_attribute}), vec_int8(#{query}))"
    when :bit
      "#{operator}(vec_bit(#{quoted_attribute}), vec_bit(#{query}))"
    else
      "#{operator}(#{quoted_attribute}, #{query})"
    end
  when :mariadb
    if operator == "BIT_COUNT"
      "BIT_COUNT(#{quoted_attribute} ^ #{query})"
    else
      "#{operator}(#{quoted_attribute}, #{query})"
    end
  when :mysql
    if operator == "BIT_COUNT"
      "BIT_COUNT(#{quoted_attribute} ^ #{query})"
    elsif operator == "COSINE"
      "DISTANCE(#{quoted_attribute}, #{query}, 'COSINE')"
    else
      "DISTANCE(#{quoted_attribute}, #{query}, 'EUCLIDEAN')"
    end
  else
    if operator == "#"
      "bit_count(#{quoted_attribute} # #{query})"
    else
      "#{quoted_attribute} #{operator} #{query}"
    end
  end
end

def self.type(adapter, column_type)

def self.type(adapter, column_type)
  case adapter
  when :mysql
    if column_type == :binary
      :bit
    else
      column_type
    end
  else
    column_type
  end
end

def self.validate(value, dimensions:, type:, adapter:)

def self.validate(value, dimensions:, type:, adapter:)
  if (message = validate_dimensions(value, type, dimensions, adapter))
    raise Error, message
  end
  if !validate_finite(value, type)
    raise Error, "Values must be finite"
  end
end

def self.validate_dimensions(value, type, expected, adapter)

def self.validate_dimensions(value, type, expected, adapter)
  dimensions = type == :sparsevec ? value.dimensions : value.size
  dimensions *= 8 if type == :bit && [:sqlite, :mysql].include?(adapter)
  if expected && dimensions != expected
    "Expected #{expected} dimensions, not #{dimensions}"
  end
end

def self.validate_finite(value, type)

def self.validate_finite(value, type)
  case type
  when :bit, :integer
    true
  when :sparsevec
    value.values.all?(&:finite?)
  else
    value.all?(&:finite?)
  end
end