lib/steep/type_inference/type_env.rb



module Steep
  module TypeInference
    class TypeEnv
      attr_reader :subtyping
      attr_reader :lvar_types
      attr_reader :const_types
      attr_reader :gvar_types
      attr_reader :ivar_types
      attr_reader :const_env

      def initialize(subtyping:, const_env:)
        @subtyping = subtyping
        @lvar_types = {}
        @const_types = {}
        @gvar_types = {}
        @ivar_types = {}
        @const_env = const_env
      end

      def initialize_copy(other)
        @subtyping = other.subtyping
        @lvar_types = other.lvar_types.dup
        @const_types = other.const_types.dup
        @gvar_types = other.gvar_types.dup
        @ivar_types = other.ivar_types.dup
        @const_env = other.const_env
      end

      def self.build(annotations:, signatures:, subtyping:, const_env:)
        new(subtyping: subtyping, const_env: const_env).tap do |env|
          annotations.lvar_types.each do |name, type|
            env.set(lvar: name, type: type)
          end
          annotations.ivar_types.each do |name, type|
            env.set(ivar: name, type: type)
          end
          annotations.const_types.each do |name, type|
            env.set(const: name, type: type)
          end
          signatures.globals.each do |name, annot|
            type = subtyping.builder.absolute_type(annot.type, current: AST::Namespace.root)
            env.set(gvar: name, type: type)
          end
        end
      end

      def with_annotations(lvar_types: {}, ivar_types: {}, const_types: {}, gvar_types: {}, &block)
        dup.tap do |env|
          merge!(original_env: env.lvar_types, override_env: lvar_types, &block)
          merge!(original_env: env.ivar_types, override_env: ivar_types, &block)
          merge!(original_env: env.gvar_types, override_env: gvar_types, &block)

          const_types.each do |name, annotated_type|
            original_type = self.const_types[name] || const_env.lookup(name)
            if original_type
              assert_annotation name, original_type: original_type, annotated_type: annotated_type, &block
            end
            env.const_types[name] = annotated_type
          end
        end
      end

      def join!(envs)
        lvars = {}

        common_vars = envs.map {|env| Set.new(env.lvar_types.keys) }.inject {|a, b| a & b }

        envs.each do |env|
          env.lvar_types.each do |name, type|
            unless lvar_types.key?(name)
              lvars[name] = [] unless lvars[name]
              lvars[name] << type
            end
          end
        end

        lvars.each do |name, types|
          if lvar_types.key?(name) || common_vars.member?(name)
            set(lvar: name, type: AST::Types::Union.build(types: types))
          else
            set(lvar: name, type: AST::Types::Union.build(types: types + [AST::Types::Nil.new]))
          end
        end
      end

      # @type method assert: (const: Names::Module) { () -> void } -> AST::Type
      #                    | (gvar: Symbol) { () -> void } -> AST::Type
      #                    | (ivar: Symbol) { () -> void } -> AST::Type
      #                    | (lvar: Symbol) { () -> AST::Type | nil } -> AST::Type
      def get(lvar: nil, const: nil, gvar: nil, ivar: nil)
        case
        when lvar
          lvar_name(lvar).yield_self do |name|
            if lvar_types.key?(name)
              lvar_types[name]
            else
              ty = yield
              lvar_types[name] = ty || AST::Types::Any.new
            end
          end
        when const
          if const_types.key?(const)
            const_types[const]
          else
            const_env.lookup(const).yield_self do |type|
              if type
                type
              else
                yield
                AST::Types::Any.new
              end
            end
          end
        else
          lookup_dictionary(ivar: ivar, gvar: gvar) do |var_name, dictionary|
            if dictionary.key?(var_name)
              dictionary[var_name]
            else
              yield
              AST::Types::Any.new
            end
          end
        end
      end

      def set(lvar: nil, const: nil, gvar: nil, ivar: nil, type:)
        case
        when lvar
          lvar_name(lvar).yield_self do |name|
            lvar_types[name] = type
          end
        when const
          const_types[const] = type
        else
          lookup_dictionary(ivar: ivar, gvar: gvar) do |var_name, dictionary|
            dictionary[var_name] = type
          end
        end
      end

      # @type method assign: (const: Names::Module, type: AST::Type) { (Subtyping::Result::Failure | nil) -> void } -> AST::Type
      #                    | (gvar: Symbol, type: AST::Type) { (Subtyping::Result::Failure | nil) -> void } -> AST::Type
      #                    | (ivar: Symbol, type: AST::Type) { (Subtyping::Result::Failure | nil) -> void } -> AST::Type
      #                    | (lvar: Symbol | LabeledName, type: AST::Type) { (Subtyping::Result::Failure) -> void } -> AST::Type
      def assign(lvar: nil, const: nil, gvar: nil, ivar: nil, type:, &block)
        case
        when lvar
          yield_self do
            name = lvar_name(lvar)
            var_type = lvar_types[name]
            if var_type
              assert_assign(var_type: var_type, lhs_type: type, &block)
            else
              lvar_types[name] = type
            end
          end
        when const
          yield_self do
            const_type = const_types[const] || const_env.lookup(const)
            if const_type
              assert_assign(var_type: const_type, lhs_type: type, &block)
            else
              yield nil
              AST::Types::Any.new
            end
          end
        else
          lookup_dictionary(ivar: ivar, gvar: gvar) do |var_name, dictionary|
            if dictionary.key?(var_name)
              assert_assign(var_type: dictionary[var_name], lhs_type: type, &block)
            else
              yield nil
              AST::Types::Any.new
            end
          end
        end
      end

      def lookup_dictionary(ivar:, gvar:)
        case
        when ivar
          yield ivar, ivar_types
        when gvar
          yield gvar, gvar_types
        end
      end

      def lvar_name(lvar)
        case lvar
        when Symbol
          lvar
        when ASTUtils::Labeling::LabeledName
          lvar.name
        end
      end

      def assert_assign(var_type:, lhs_type:)
        var_type = subtyping.expand_alias(var_type)
        lhs_type = subtyping.expand_alias(lhs_type)

        relation = Subtyping::Relation.new(sub_type: lhs_type, super_type: var_type)
        constraints = Subtyping::Constraints.new(unknowns: Set.new)

        subtyping.check(relation, constraints: constraints).else do |result|
          yield result
        end

        var_type
      end

      def merge!(original_env:, override_env:, &block)
        original_env.merge!(override_env) do |name, original_type, override_type|
          assert_annotation name, annotated_type: override_type, original_type: original_type, &block
        end
      end

      def assert_annotation(name, annotated_type:, original_type:)
        annotated_type = subtyping.expand_alias(annotated_type)
        original_type = subtyping.expand_alias(original_type)

        relation = Subtyping::Relation.new(sub_type: annotated_type, super_type: original_type)
        constraints = Subtyping::Constraints.new(unknowns: Set.new)

        subtyping.check(relation, constraints: constraints).else do |result|
          yield name, relation, result
        end

        annotated_type
      end
    end
  end
end