lib/activerecord-import/adapters/postgresql_adapter.rb



# frozen_string_literal: true

module ActiveRecord::Import::PostgreSQLAdapter
  include ActiveRecord::Import::ImportSupport
  include ActiveRecord::Import::OnDuplicateKeyUpdateSupport

  MIN_VERSION_FOR_UPSERT = 90_500

  def insert_many( sql, values, options = {}, *args ) # :nodoc:
    number_of_inserts = 1
    returned_values = {}
    ids = []
    results = []

    base_sql, post_sql = case sql
                         when String
                           [sql, '']
                         when Array
                           [sql.shift, sql.join( ' ' )]
    end

    sql2insert = base_sql + values.join( ',' ) + post_sql

    selections = returning_selections(options)
    if selections.blank? || (options[:no_returning] && !options[:recursive])
      insert( sql2insert, *args )
    else
      returned_values = if selections.size > 1
        # Select composite columns
        db_result = select_all( sql2insert, *args )
        { values: db_result.rows, columns: db_result.columns }
      else
        { values: select_values( sql2insert, *args ) }
      end
      clear_query_cache if query_cache_enabled
    end

    if options[:returning].blank?
      ids = Array(returned_values[:values])
    elsif options[:primary_key].blank?
      options[:returning_columns] ||= returned_values[:columns]
      results = Array(returned_values[:values])
    else
      # split primary key and returning columns
      ids, results, options[:returning_columns] = split_ids_and_results(returned_values, options)
    end

    ActiveRecord::Import::Result.new([], number_of_inserts, ids, results)
  end

  def split_ids_and_results( selections, options )
    ids = []
    returning_values = []

    columns = Array(selections[:columns])
    values = Array(selections[:values])
    id_indexes = Array(options[:primary_key]).map { |key| columns.index(key) }
    returning_columns = columns.reject.with_index { |_, index| id_indexes.include?(index) }
    returning_indexes = returning_columns.map { |column| columns.index(column) }

    values.each do |value|
      value_array = Array(value)
      ids << id_indexes.map { |index| value_array[index] }
      returning_values << returning_indexes.map { |index| value_array[index] }
    end

    ids.map!(&:first) if id_indexes.size == 1
    returning_values.map!(&:first) if returning_columns.size == 1

    [ids, returning_values, returning_columns]
  end

  def next_value_for_sequence(sequence_name)
    %{nextval('#{sequence_name}')}
  end

  def post_sql_statements( table_name, options ) # :nodoc:
    sql = []

    if supports_on_duplicate_key_update?
      # Options :recursive and :on_duplicate_key_ignore are mutually exclusive
      if (options[:ignore] || options[:on_duplicate_key_ignore]) && !options[:on_duplicate_key_update] && !options[:recursive]
        sql << sql_for_on_duplicate_key_ignore( table_name, options[:on_duplicate_key_ignore] )
      end
    elsif logger && options[:on_duplicate_key_ignore] && !options[:on_duplicate_key_update]
      logger.warn "Ignoring on_duplicate_key_ignore because it is not supported by the database."
    end

    sql += super(table_name, options)

    selections = returning_selections(options)
    unless selections.blank? || (options[:no_returning] && !options[:recursive])
      sql << " RETURNING #{selections.join(', ')}"
    end

    sql
  end

  def returning_selections(options)
    selections = []
    column_names = Array(options[:model].column_names)

    selections += Array(options[:primary_key]) if options[:primary_key].present?
    selections += Array(options[:returning]) if options[:returning].present?

    selections.map do |selection|
      column_names.include?(selection.to_s) ? "\"#{selection}\"" : selection
    end
  end

  # Add a column to be updated on duplicate key update
  def add_column_for_on_duplicate_key_update( column, options = {} ) # :nodoc:
    arg = options[:on_duplicate_key_update]
    case arg
    when Hash
      columns = arg.fetch( :columns ) { arg[:columns] = [] }
      case columns
      when Array then columns << column.to_sym unless columns.include?( column.to_sym )
      when Hash then columns[column.to_sym] = column.to_sym
      end
    when Array
      arg << column.to_sym unless arg.include?( column.to_sym )
    end
  end

  # Returns a generated ON CONFLICT DO NOTHING statement given the passed
  # in +args+.
  def sql_for_on_duplicate_key_ignore( table_name, *args ) # :nodoc:
    arg = args.first
    conflict_target = sql_for_conflict_target( arg ) if arg.is_a?( Hash )
    " ON CONFLICT #{conflict_target}DO NOTHING"
  end

  # Returns a generated ON CONFLICT DO UPDATE statement given the passed
  # in +args+.
  def sql_for_on_duplicate_key_update( table_name, *args ) # :nodoc:
    arg, model, primary_key, locking_column = args
    arg = { columns: arg } if arg.is_a?( Array ) || arg.is_a?( String )
    return unless arg.is_a?( Hash )

    sql = ' ON CONFLICT '.dup
    conflict_target = sql_for_conflict_target( arg )

    columns = arg.fetch( :columns, [] )
    condition = arg[:condition]
    if columns.respond_to?( :empty? ) && columns.empty?
      return sql << "#{conflict_target}DO NOTHING"
    end

    conflict_target ||= sql_for_default_conflict_target( table_name, primary_key )
    unless conflict_target
      raise ArgumentError, 'Expected :conflict_target or :constraint_name to be specified'
    end

    sql << "#{conflict_target}DO UPDATE SET "
    case columns
    when Array
      sql << sql_for_on_duplicate_key_update_as_array( table_name, model, locking_column, columns )
    when Hash
      sql << sql_for_on_duplicate_key_update_as_hash( table_name, model, locking_column, columns )
    when String
      sql << columns
    else
      raise ArgumentError, 'Expected :columns to be an Array or Hash'
    end

    sql << " WHERE #{condition}" if condition.present?

    sql
  end

  def sql_for_on_duplicate_key_update_as_array( table_name, model, locking_column, arr ) # :nodoc:
    results = arr.map do |column|
      original_column_name = model.attribute_alias?( column ) ? model.attribute_alias( column ) : column
      qc = quote_column_name( original_column_name )
      "#{qc}=EXCLUDED.#{qc}"
    end
    increment_locking_column!(table_name, results, locking_column)
    results.join( ',' )
  end

  def sql_for_on_duplicate_key_update_as_hash( table_name, model, locking_column, hsh ) # :nodoc:
    results = hsh.map do |column1, column2|
      original_column1_name = model.attribute_alias?( column1 ) ? model.attribute_alias( column1 ) : column1
      qc1 = quote_column_name( original_column1_name )

      original_column2_name = model.attribute_alias?( column2 ) ? model.attribute_alias( column2 ) : column2
      qc2 = quote_column_name( original_column2_name )

      "#{qc1}=EXCLUDED.#{qc2}"
    end
    increment_locking_column!(table_name, results, locking_column)
    results.join( ',' )
  end

  def sql_for_conflict_target( args = {} )
    constraint_name = args[:constraint_name]
    conflict_target = args[:conflict_target]
    index_predicate = args[:index_predicate]
    if constraint_name.present?
      "ON CONSTRAINT #{constraint_name} "
    elsif conflict_target.present?
      sql = "(#{Array( conflict_target ).reject( &:blank? ).join( ', ' )}) "
      sql += "WHERE #{index_predicate} " if index_predicate
      sql
    end
  end

  def sql_for_default_conflict_target( table_name, primary_key )
    conflict_target = Array(primary_key).join(', ')
    "(#{conflict_target}) " if conflict_target.present?
  end

  # Return true if the statement is a duplicate key record error
  def duplicate_key_update_error?(exception) # :nodoc:
    exception.is_a?(ActiveRecord::StatementInvalid) && exception.to_s.include?('duplicate key')
  end

  def supports_on_duplicate_key_update?
    database_version >= MIN_VERSION_FOR_UPSERT
  end

  def supports_setting_primary_key_of_imported_objects?
    true
  end

  def database_version
    defined?(postgresql_version) ? postgresql_version : super
  end
end