lib/pg_conn/schema_methods.rb



module PgConn
  # Schema methods
  class SchemaMethods
    attr_reader :conn

    def initialize(conn)
      @conn = conn
    end

    # Return true if schema exists
    def exist?(schema)
      conn.exist? %(
          select 1
          from information_schema.schemata
          where schema_name = '#{schema}'
      )
    end

    # Create a new schema. The authorization option can be used to set the
    # owner of the schema
    def create(schema, authorization: nil)
      authorization_clause = authorization ? "authorization #{authorization}" : nil
      stmt = ["create schema", schema, authorization_clause].compact.join(" ")
      conn.exec stmt
    end

    # Drop schema
    def drop(schema, cascade: false)
      if cascade
        conn.exec "drop schema if exists #{schema} cascade"
      else
        conn.exec "drop schema if exists #{schema}"
      end
      true
    end

    # Hollow out a schema by dropping all tables and views (but still not
    # functions and procedures TODO)
    def empty!(schema, exclude: [])
      self.list_tables(schema, exclude: exclude).each { |table|
        conn.exec "drop table if exists #{schema}.#{table} cascade"
      }
      self.list_views(schema, exclude: exclude).each { |view|
        conn.exec "drop view if exists #{schema}.#{view} cascade"
      }
    end

    # Empty all tables in the given schema
    def clean!(schema, exclude: [])
      conn.session.triggers(false) {
        self.list_tables(schema, exclude: exclude).each { |table|
          conn.exec "delete from #{schema}.#{table}"
        }
      }
    end

    # List schemas. Built-in schemas are not listed unless the :all option is
    # true. The :exclude option can be used to exclude named schemas
    def list(all: false, exclude: [])
      conn.values(%(
          select schema_name
          from information_schema.schemata
      )).select { |schema|
        !exclude.include?(schema) && (all ? true : schema !~ /^pg_/ && schema != "information_schema")
      }
    end

    # Returns true if relation (table or view) exists
    def exist_relation?(schema, relation)
      conn.exist? relation_exist_query(schema, relation)
    end

    # Return true if table exists
    def exist_table?(schema, table)
      conn.exist? relation_exist_query(schema, table, kind: %w(r f))
    end

    # Return true if view exists
    def exist_view?(schema, view)
      conn.exist? relation_exist_query(schema, view, kind: %w(v m))
    end

    # Return true if the column exists
    def exist_column?(schema, relation, column)
      conn.exist? column_exist_query(schema, relation, column)
    end

    # TODO
#   def exist_index?(schema, relation, FIXME)

    # Return list of relations in the schema
    def list_relations(schema, exclude: [])
      conn.values relation_list_query(schema, exclude: exclude)
    end

    # Return list of tables in the schema
    def list_tables(schema, exclude: [])
      conn.values relation_list_query(schema, exclude: exclude, kind: %w(r f))
    end

    # Return list of view in the schema
    def list_views(schema, exclude: [])
      conn.values relation_list_query(schema, exclude: exclude, kind: %w(v m))
    end

    # Return a list of columns. If +relation+ is defined, only columns from that
    # relation are listed. Columns are returned as fully qualified names (eg.
    # "schema.relation.column")
    def list_columns(schema, relation = nil)
      conn.values column_list_query(schema, relation)
    end

    # Like #list_columns but returns a tuple of column UID and column type
    def list_column_types(schema, relation = nil)
      conn.tuples column_list_type_query(schema, relation)
    end

    def exist_function(schema, function, signature)
      raise NotImplementedError
    end

    def list_functions(schema, function = nil)
      raise NotImplementedError
    end

    # Return name of the table's sequence (if any)
    def sequence(schema, table)
      conn.value "select pg_get_serial_sequence('#{schema}.#{table}', 'id')"
    end

    # Get the current serial value for the table. Returns nil if the serial has
    # not been used. If :next is true, the next value will be returned
    def get_serial(schema, table, next: false)
      uid = "#{schema}.#{table}"
      next_option = binding.local_variable_get(:next) # because 'next' is a keyword

      seq = sequence(schema, table) or raise ArgumentError, "Table #{uid} does not have a sequence"
      value = conn.value %(
        select
          case is_called
            when true then last_value
            else null
          end as "value"
        from
          #{seq}
      )
      if next_option
        value&.+(1) || 1
      else
        value
      end
    end

    # Set the serial value for the table
    def set_serial(schema, table, value)
      uid = "#{schema}.#{table}"
      seq = sequence(schema, table) or raise ArgumentError, "Table #{uid} does not have a sequence"
      if value
        conn.exec "select setval('#{seq}', #{value})"
      else
        conn.exec "select setval('#{seq}', 1, false)"
      end
    end

  private
    def relation_exist_query(schema, relation, kind: nil)
      kind_sql_list = "'" + (kind.nil? ? %w(r f v m) : Array(kind).flatten).join("', '") + "'"
      %(
          select  1
          from    pg_class
          where   relnamespace::regnamespace::text = '#{schema}'
          and     relname = '#{relation}'
          and     relkind in (#{kind_sql_list})
      )
    end

    def relation_list_query(schema, exclude: nil, kind: nil)
      kind_list = "'" + (kind.nil? ? %w(r f v m) : Array(kind).flatten).join("', '") + "'"
      kind_expr = "relkind in (#{kind_list})"
      exclude = Array(exclude || []).flatten
      exclude_list = "'#{exclude.flatten.join("', '")}'" if !exclude.empty?
      exclude_expr = exclude.empty? ? "true = true" : "not relname in (#{exclude_list})"
      %(
          select  relname
          from    pg_class
          where   relnamespace::regnamespace::text = '#{schema}'
          and     #{kind_expr}
          and     #{exclude_expr}
      )
    end

    def column_exist_query(schema, relation, column)
      %(
        select  1
        from    pg_class c
        join    pg_attribute a on a.attrelid = c.oid
        where   c.relnamespace::regnamespace::text = '#{schema}'
        and     c.relname = '#{relation}'
        and     a.attname = '#{column}'
        and     a.attnum > 0
      )
    end

    def column_list_query(schema, relation)
      relation_clause = relation ? "relname = '#{relation}'" : nil
      [
          %(
            select  '#{schema}' || '.' || c.relname || '.' || a.attname
            from    pg_class c
            join    pg_attribute a on a.attrelid = c.oid
            where   relnamespace::regnamespace::text = '#{schema}'
            and     a.attnum > 0
          ),
          relation_clause
      ].compact.join(" and ")
    end

    def column_list_type_query(schema, relation)
      relation_clause = relation ? "relname = '#{relation}'" : nil
      [
          %(
            select  '#{schema}' || '.' || c.relname || '.' || a.attname as "column",
                    a.atttypid::regtype::text as "type"
            from    pg_class c
            join    pg_attribute a on a.attrelid = c.oid
            where   relnamespace::regnamespace::text = '#{schema}'
            and     a.attnum > 0
          ),
          relation_clause
      ].compact.join(" and ")
    end
  end
end