lib/steep/type_inference/type_env.rb



module Steep
  module TypeInference
    class TypeEnv
      attr_reader :subtyping
      attr_reader :const_types
      attr_reader :gvar_types
      attr_reader :ivar_types
      attr_reader :const_env

      def initialize(subtyping:, const_env:)
        @subtyping = subtyping
        @const_types = {}
        @gvar_types = {}
        @ivar_types = {}
        @const_env = const_env
      end

      def initialize_copy(other)
        @subtyping = other.subtyping
        @const_types = other.const_types.dup
        @gvar_types = other.gvar_types.dup
        @ivar_types = other.ivar_types.dup
        @const_env = other.const_env
      end

      def self.build(annotations:, signatures:, subtyping:, const_env:)
        new(subtyping: subtyping, const_env: const_env).tap do |env|
          annotations.ivar_types.each do |name, type|
            env.set(ivar: name, type: type)
          end
          annotations.const_types.each do |name, type|
            env.set(const: name, type: type)
          end
          signatures.name_to_global.each do |name, global|
            type = signatures.absolute_type(global.type, namespace: RBS::Namespace.root) {|ty| ty.name.absolute! }
            env.set(gvar: name, type: subtyping.factory.type(type))
          end
        end
      end

      def with_annotations(ivar_types: {}, const_types: {}, gvar_types: {}, self_type:, &block)
        dup.tap do |env|
          merge!(original_env: env.ivar_types, override_env: ivar_types, self_type: self_type, &block)
          merge!(original_env: env.gvar_types, override_env: gvar_types, self_type: self_type, &block)

          const_types.each do |name, annotated_type|
            original_type = self.const_types[name] || const_env.lookup(name)
            if original_type
              assert_annotation name,
                                original_type: original_type,
                                annotated_type: annotated_type,
                                self_type: self_type,
                                &block
            end
            env.const_types[name] = annotated_type
          end
        end
      end

      # @type method assert: (const: Names::Module) { () -> void } -> AST::Type
      #                    | (gvar: Symbol) { () -> void } -> AST::Type
      #                    | (ivar: Symbol) { () -> void } -> AST::Type
      def get(const: nil, gvar: nil, ivar: nil)
        case
        when const
          if const_types.key?(const)
            const_types[const]
          else
            const_env.lookup(const).yield_self do |type|
              if type
                type
              else
                yield
                AST::Types::Any.new
              end
            end
          end
        else
          lookup_dictionary(ivar: ivar, gvar: gvar) do |var_name, dictionary|
            if dictionary.key?(var_name)
              dictionary[var_name]
            else
              yield
              AST::Types::Any.new
            end
          end
        end
      end

      def set(const: nil, gvar: nil, ivar: nil, type:)
        case
        when const
          const_types[const] = type
        else
          lookup_dictionary(ivar: ivar, gvar: gvar) do |var_name, dictionary|
            dictionary[var_name] = type
          end
        end
      end

      # @type method assign: (const: Names::Module, type: AST::Type) { (Subtyping::Result::Failure | nil) -> void } -> AST::Type
      #                    | (gvar: Symbol, type: AST::Type) { (Subtyping::Result::Failure | nil) -> void } -> AST::Type
      #                    | (ivar: Symbol, type: AST::Type) { (Subtyping::Result::Failure | nil) -> void } -> AST::Type
      def assign(const: nil, gvar: nil, ivar: nil, type:, self_type:, &block)
        case
        when const
          yield_self do
            const_type = const_types[const] || const_env.lookup(const)
            if const_type
              assert_assign(var_type: const_type, lhs_type: type, self_type: self_type, &block)
            else
              yield nil
              AST::Types::Any.new
            end
          end
        else
          lookup_dictionary(ivar: ivar, gvar: gvar) do |var_name, dictionary|
            if dictionary.key?(var_name)
              assert_assign(var_type: dictionary[var_name], lhs_type: type, self_type: self_type, &block)
            else
              yield nil
              AST::Types::Any.new
            end
          end
        end
      end

      def lookup_dictionary(ivar:, gvar:)
        case
        when ivar
          yield ivar, ivar_types
        when gvar
          yield gvar, gvar_types
        end
      end

      def assert_assign(var_type:, lhs_type:, self_type:)
        return var_type if var_type == lhs_type

        var_type = subtyping.expand_alias(var_type)
        lhs_type = subtyping.expand_alias(lhs_type)

        relation = Subtyping::Relation.new(sub_type: lhs_type, super_type: var_type)
        constraints = Subtyping::Constraints.new(unknowns: Set.new)

        subtyping.check(relation, self_type: self_type, constraints: constraints).else do |result|
          yield result
        end

        var_type
      end

      def merge!(original_env:, override_env:, self_type:, &block)
        original_env.merge!(override_env) do |name, original_type, override_type|
          assert_annotation name, annotated_type: override_type, original_type: original_type, self_type: self_type, &block
        end
      end

      def assert_annotation(name, annotated_type:, original_type:, self_type:)
        return annotated_type if annotated_type == original_type

        annotated_type = subtyping.expand_alias(annotated_type)
        original_type = subtyping.expand_alias(original_type)

        relation = Subtyping::Relation.new(sub_type: annotated_type, super_type: original_type)
        constraints = Subtyping::Constraints.new(unknowns: Set.new)

        subtyping.check(relation, constraints: constraints, self_type: self_type).else do |result|
          yield name, relation, result
        end

        annotated_type
      end
    end
  end
end