lib/steep/ast/types/factory.rb



module Steep
  module AST
    module Types
      class Factory
        attr_reader :definition_builder

        attr_reader :type_name_cache
        attr_reader :type_cache

        attr_reader :type_interface_cache

        def initialize(builder:)
          @definition_builder = builder

          @type_name_cache = {}
          @type_cache = {}
          @type_interface_cache = {}
        end

        def type_name_resolver
          @type_name_resolver ||= RBS::TypeNameResolver.from_env(definition_builder.env)
        end

        def type(type)
          ty = type_cache[type] and return ty

          type_cache[type] = case type
          when RBS::Types::Bases::Any
            Any.new(location: nil)
          when RBS::Types::Bases::Class
            Class.new(location: nil)
          when RBS::Types::Bases::Instance
            Instance.new(location: nil)
          when RBS::Types::Bases::Self
            Self.new(location: nil)
          when RBS::Types::Bases::Top
            Top.new(location: nil)
          when RBS::Types::Bases::Bottom
            Bot.new(location: nil)
          when RBS::Types::Bases::Bool
            Boolean.new(location: nil)
          when RBS::Types::Bases::Void
            Void.new(location: nil)
          when RBS::Types::Bases::Nil
            Nil.new(location: nil)
          when RBS::Types::Variable
            Var.new(name: type.name, location: nil)
          when RBS::Types::ClassSingleton
            type_name = type.name
            Name::Singleton.new(name: type_name, location: nil)
          when RBS::Types::ClassInstance
            type_name = type.name
            args = type.args.map {|arg| type(arg) }
            Name::Instance.new(name: type_name, args: args, location: nil)
          when RBS::Types::Interface
            type_name = type.name
            args = type.args.map {|arg| type(arg) }
            Name::Interface.new(name: type_name, args: args, location: nil)
          when RBS::Types::Alias
            type_name = type.name
            Name::Alias.new(name: type_name, args: [], location: nil)
          when RBS::Types::Union
            Union.build(types: type.types.map {|ty| type(ty) }, location: nil)
          when RBS::Types::Intersection
            Intersection.build(types: type.types.map {|ty| type(ty) }, location: nil)
          when RBS::Types::Optional
            Union.build(types: [type(type.type), Nil.new(location: nil)], location: nil)
          when RBS::Types::Literal
            Literal.new(value: type.literal, location: nil)
          when RBS::Types::Tuple
            Tuple.new(types: type.types.map {|ty| type(ty) }, location: nil)
          when RBS::Types::Record
            elements = type.fields.each.with_object({}) do |(key, value), hash|
              hash[key] = type(value)
            end
            Record.new(elements: elements, location: nil)
          when RBS::Types::Proc
            func = Interface::Function.new(
              params: params(type.type),
              return_type: type(type.type.return_type),
              location: type.location
            )
            block = if type.block
                      Interface::Block.new(
                        type: Interface::Function.new(
                          params: params(type.block.type),
                          return_type: type(type.block.type.return_type),
                          location: type.location
                        ),
                        optional: !type.block.required
                      )
                    end

            Proc.new(type: func, block: block)
          else
            raise "Unexpected type given: #{type}"
          end
        end

        def type_1(type)
          case type
          when Any
            RBS::Types::Bases::Any.new(location: nil)
          when Class
            RBS::Types::Bases::Class.new(location: nil)
          when Instance
            RBS::Types::Bases::Instance.new(location: nil)
          when Self
            RBS::Types::Bases::Self.new(location: nil)
          when Top
            RBS::Types::Bases::Top.new(location: nil)
          when Bot
            RBS::Types::Bases::Bottom.new(location: nil)
          when Boolean
            RBS::Types::Bases::Bool.new(location: nil)
          when Void
            RBS::Types::Bases::Void.new(location: nil)
          when Nil
            RBS::Types::Bases::Nil.new(location: nil)
          when Var
            RBS::Types::Variable.new(name: type.name, location: nil)
          when Name::Singleton
            RBS::Types::ClassSingleton.new(name: type.name, location: nil)
          when Name::Instance
            RBS::Types::ClassInstance.new(
              name: type.name,
              args: type.args.map {|arg| type_1(arg) },
              location: nil
            )
          when Name::Interface
            RBS::Types::Interface.new(
              name: type.name,
              args: type.args.map {|arg| type_1(arg) },
              location: nil
            )
          when Name::Alias
            type.args.empty? or raise "alias type with args is not supported"
            RBS::Types::Alias.new(name: type.name, location: nil)
          when Union
            RBS::Types::Union.new(
              types: type.types.map {|ty| type_1(ty) },
              location: nil
            )
          when Intersection
            RBS::Types::Intersection.new(
              types: type.types.map {|ty| type_1(ty) },
              location: nil
            )
          when Literal
            RBS::Types::Literal.new(literal: type.value, location: nil)
          when Tuple
            RBS::Types::Tuple.new(
              types: type.types.map {|ty| type_1(ty) },
              location: nil
            )
          when Record
            fields = type.elements.each.with_object({}) do |(key, value), hash|
              hash[key] = type_1(value)
            end
            RBS::Types::Record.new(fields: fields, location: nil)
          when Proc
            block = if type.block
                      RBS::Types::Block.new(
                        type: function_1(type.block.type),
                        required: !type.block.optional?
                      )
                    end
            RBS::Types::Proc.new(
              type: function_1(type.type),
              block: block,
              location: nil
            )
          when Logic::Base
            RBS::Types::Bases::Bool.new(location: nil)
          else
            raise "Unexpected type given: #{type} (#{type.class})"
          end
        end

        def function_1(func)
          params = func.params
          return_type = func.return_type

          RBS::Types::Function.new(
            required_positionals: params.required.map {|type| RBS::Types::Function::Param.new(name: nil, type: type_1(type)) },
            optional_positionals: params.optional.map {|type| RBS::Types::Function::Param.new(name: nil, type: type_1(type)) },
            rest_positionals: params.rest&.yield_self {|type| RBS::Types::Function::Param.new(name: nil, type: type_1(type)) },
            trailing_positionals: [],
            required_keywords: params.required_keywords.transform_values {|type| RBS::Types::Function::Param.new(name: nil, type: type_1(type)) },
            optional_keywords: params.optional_keywords.transform_values {|type| RBS::Types::Function::Param.new(name: nil, type: type_1(type)) },
            rest_keywords: params.rest_keywords&.yield_self {|type| RBS::Types::Function::Param.new(name: nil, type: type_1(type)) },
            return_type: type_1(return_type)
          )
        end

        def params(type)
          Interface::Function::Params.build(
            required: type.required_positionals.map {|param| type(param.type) },
            optional: type.optional_positionals.map {|param| type(param.type) },
            rest: type.rest_positionals&.yield_self {|param| type(param.type) },
            required_keywords: type.required_keywords.transform_values {|param| type(param.type) },
            optional_keywords: type.optional_keywords.transform_values {|param| type(param.type) },
            rest_keywords: type.rest_keywords&.yield_self {|param| type(param.type) }
          )
        end

        def method_type(method_type, self_type:, subst2: nil, method_decls:)
          fvs = self_type.free_variables()

          type_params = []
          alpha_vars = []
          alpha_types = []

          method_type.type_params.map do |name|
            if fvs.include?(name)
              type = Types::Var.fresh(name)
              alpha_vars << name
              alpha_types << type
              type_params << type.name
            else
              type_params << name
            end
          end
          subst = Interface::Substitution.build(alpha_vars, alpha_types)
          subst.merge!(subst2, overwrite: true) if subst2

          type = Interface::MethodType.new(
            type_params: type_params,
            type: Interface::Function.new(
              params: params(method_type.type).subst(subst),
              return_type: type(method_type.type.return_type).subst(subst),
              location: method_type.location
            ),
            block: method_type.block&.yield_self do |block|
              Interface::Block.new(
                optional: !block.required,
                type: Interface::Function.new(
                  params: params(block.type).subst(subst),
                  return_type: type(block.type.return_type).subst(subst),
                  location: nil
                )
              )
            end,
            method_decls: method_decls
          )

          if block_given?
            yield type
          else
            type
          end
        end

        def method_type_1(method_type, self_type:)
          fvs = self_type.free_variables()

          type_params = []
          alpha_vars = []
          alpha_types = []

          method_type.type_params.map do |name|
            if fvs.include?(name)
              type = RBS::Types::Variable.new(name: name, location: nil),
              alpha_vars << name
              alpha_types << type
              type_params << type.name
            else
              type_params << name
            end
          end
          subst = Interface::Substitution.build(alpha_vars, alpha_types)

          type = RBS::MethodType.new(
            type_params: type_params,
            type: function_1(method_type.type.subst(subst)),
            block: method_type.block&.yield_self do |block|
              block_type = block.type.subst(subst)

              RBS::Types::Block.new(
                type: function_1(block_type),
                required: !block.optional
              )
            end,
            location: nil
          )

          if block_given?
            yield type
          else
            type
          end
        end

        class InterfaceCalculationError < StandardError
          attr_reader :type

          def initialize(type:, message:)
            @type = type
            super message
          end
        end

        def unfold(type_name)
          type_name.yield_self do |type_name|
            type(definition_builder.expand_alias(type_name))
          end
        end

        def expand_alias(type)
          unfolded = case type
                     when AST::Types::Name::Alias
                       unfold(type.name)
                     else
                       type
                     end

          if block_given?
            yield unfolded
          else
            unfolded
          end
        end

        def deep_expand_alias(type, recursive: Set.new, &block)
          raise "Recursive type definition: #{type}" if recursive.member?(type)

          ty = case type
               when AST::Types::Name::Alias
                 deep_expand_alias(expand_alias(type), recursive: recursive.union([type]))
               when AST::Types::Union
                 AST::Types::Union.build(
                   types: type.types.map {|ty| deep_expand_alias(ty, recursive: recursive, &block) },
                   location: type.location
                 )
               else
                 type
               end

          if block_given?
            yield ty
          else
            ty
          end
        end

        def flatten_union(type, acc = [])
          case type
          when AST::Types::Union
            type.types.each {|ty| flatten_union(ty, acc) }
          else
            acc << type
          end

          acc
        end

        def unwrap_optional(type)
          case type
          when AST::Types::Union
            falsy_types, truthy_types = type.types.partition do |type|
              (type.is_a?(AST::Types::Literal) && type.value == false) ||
                type.is_a?(AST::Types::Nil)
            end

            [
              AST::Types::Union.build(types: truthy_types),
              AST::Types::Union.build(types: falsy_types)
            ]
          when AST::Types::Name::Alias
            unwrap_optional(expand_alias(type))
          when AST::Types::Boolean
            [AST::Builtin.true_type, AST::Builtin.false_type]
          else
            [type, nil]
          end
        end

        NilClassName = TypeName("::NilClass")

        def setup_primitives(method_name, method_def, method_type)
          defined_in = method_def.defined_in
          member = method_def.member

          if member.is_a?(RBS::AST::Members::MethodDefinition)
            case method_name
            when :is_a?, :kind_of?, :instance_of?
              if defined_in == RBS::BuiltinNames::Object.name && member.instance?
                return method_type.with(
                  type: method_type.type.with(
                    return_type: AST::Types::Logic::ReceiverIsArg.new(location: method_type.type.return_type.location)
                  )
                )
              end

            when :nil?
              case defined_in
              when RBS::BuiltinNames::Object.name,
                NilClassName
                return method_type.with(
                  type: method_type.type.with(
                    return_type: AST::Types::Logic::ReceiverIsNil.new(location: method_type.type.return_type.location)
                  )
                )
              end

            when :!
              case defined_in
              when RBS::BuiltinNames::BasicObject.name,
                RBS::BuiltinNames::TrueClass.name,
                RBS::BuiltinNames::FalseClass.name
                return method_type.with(
                  type: method_type.type.with(
                    return_type: AST::Types::Logic::Not.new(location: method_type.type.return_type.location)
                  )
                )
              end

            when :===
              case defined_in
              when RBS::BuiltinNames::Module.name
                return method_type.with(
                  type: method_type.type.with(
                    return_type: AST::Types::Logic::ArgIsReceiver.new(location: method_type.type.return_type.location)
                  )
                )
              when RBS::BuiltinNames::Object.name, RBS::BuiltinNames::String.name, RBS::BuiltinNames::Integer.name, RBS::BuiltinNames::Symbol.name,
                RBS::BuiltinNames::TrueClass.name, RBS::BuiltinNames::FalseClass.name, TypeName("::NilClass")
                # Value based type-case works on literal types which is available for String, Integer, Symbol, TrueClass, FalseClass, and NilClass
                return method_type.with(
                  type: method_type.type.with(
                    return_type: AST::Types::Logic::ArgEqualsReceiver.new(location: method_type.type.return_type.location)
                  )
                )
              end
            end
          end

          method_type
        end

        def interface(type, private:, self_type: type)
          Steep.logger.debug { "Factory#interface: #{type}, private=#{private}, self_type=#{self_type}" }

          cache_key = [type, self_type, private]
          if type_interface_cache.key?(cache_key)
            return type_interface_cache[cache_key]
          end

          case type
          when Name::Alias
            interface(expand_alias(type), private: private, self_type: self_type)

          when Self
            if self_type != type
              interface self_type, private: private, self_type: Self.new
            else
              raise "Unexpected `self` type interface"
            end

          when Name::Instance
            Interface::Interface.new(type: self_type, private: private).tap do |interface|
              definition = definition_builder.build_instance(type.name)

              instance_type = Name::Instance.new(name: type.name,
                                                 args: type.args.map { Any.new(location: nil) },
                                                 location: nil)
              module_type = type.to_module()

              subst = Interface::Substitution.build(
                definition.type_params,
                type.args,
                instance_type: instance_type,
                module_type: module_type,
                self_type: self_type
              )

              definition.methods.each do |name, method|
                Steep.logger.tagged "method = #{name}" do
                  next if method.private? && !private

                  interface.methods[name] = Interface::Interface::Entry.new(
                    method_types: method.defs.map do |type_def|
                      method_name = InstanceMethodName.new(type_name: type_def.implemented_in || type_def.defined_in, method_name: name)
                      decl = TypeInference::MethodCall::MethodDecl.new(method_name: method_name, method_def: type_def)
                      setup_primitives(
                        name,
                        type_def,
                        method_type(type_def.type,
                                    method_decls: Set[decl],
                                    self_type: self_type,
                                    subst2: subst)
                      )
                    end
                  )
                end
              end
            end

          when Name::Interface
            Interface::Interface.new(type: self_type, private: private).tap do |interface|
              type_name = type.name
              definition = definition_builder.build_interface(type_name)

              subst = Interface::Substitution.build(
                definition.type_params,
                type.args,
                self_type: self_type
              )

              definition.methods.each do |name, method|
                interface.methods[name] = Interface::Interface::Entry.new(
                  method_types: method.defs.map do |type_def|
                    decls = Set[TypeInference::MethodCall::MethodDecl.new(
                      method_name: InstanceMethodName.new(type_name: type_def.implemented_in || type_def.defined_in, method_name: name),
                      method_def: type_def
                    )]
                    method_type(type_def.type, method_decls: decls, self_type: self_type, subst2: subst)
                  end
                )
              end
            end

          when Name::Singleton
            Interface::Interface.new(type: self_type, private: private).tap do |interface|
              definition = definition_builder.build_singleton(type.name)

              instance_type = Name::Instance.new(name: type.name,
                                                 args: definition.type_params.map {Any.new(location: nil)},
                                                 location: nil)
              subst = Interface::Substitution.build(
                [],
                instance_type: instance_type,
                module_type: type,
                self_type: self_type
              )

              definition.methods.each do |name, method|
                next if !private && method.private?

                interface.methods[name] = Interface::Interface::Entry.new(
                  method_types: method.defs.map do |type_def|
                    decl = TypeInference::MethodCall::MethodDecl.new(
                      method_name: SingletonMethodName.new(type_name: type_def.implemented_in || type_def.defined_in,
                                                           method_name: name),
                      method_def: type_def
                    )
                    setup_primitives(
                      name,
                      type_def,
                      method_type(type_def.type,
                                  method_decls: Set[decl],
                                  self_type: self_type,
                                  subst2: subst)
                    )
                  end
                )
              end
            end

          when Literal
            interface type.back_type, private: private, self_type: self_type

          when Nil
            interface Builtin::NilClass.instance_type, private: private, self_type: self_type

          when Boolean
            interface(AST::Types::Union.build(types: [Builtin::TrueClass.instance_type, Builtin::FalseClass.instance_type]),
                      private: private,
                      self_type: self_type)

          when Union
            yield_self do
              interfaces = type.types.map {|ty| interface(ty, private: private, self_type: self_type) }
              interfaces.inject do |interface1, interface2|
                Interface::Interface.new(type: self_type, private: private).tap do |interface|
                  common_methods = Set.new(interface1.methods.keys) & Set.new(interface2.methods.keys)
                  common_methods.each do |name|
                    types1 = interface1.methods[name].method_types
                    types2 = interface2.methods[name].method_types

                    if types1 == types2
                      interface.methods[name] = interface1.methods[name]
                    else
                      method_types = {}

                      types1.each do |type1|
                        types2.each do |type2|
                          type = type1 | type2 or next
                          method_types[type] = true
                        end
                      end

                      unless method_types.empty?
                        interface.methods[name] = Interface::Interface::Entry.new(method_types: method_types.keys)
                      end
                    end
                  end
                end
              end
            end

          when Intersection
            yield_self do
              interfaces = type.types.map {|ty| interface(ty, private: private, self_type: self_type) }
              interfaces.inject do |interface1, interface2|
                Interface::Interface.new(type: self_type, private: private).tap do |interface|
                  interface.methods.merge!(interface1.methods)
                  interface.methods.merge!(interface2.methods)
                end
              end
            end

          when Tuple
            yield_self do
              element_type = Union.build(types: type.types, location: nil)
              array_type = Builtin::Array.instance_type(element_type)
              interface(array_type, private: private, self_type: self_type).tap do |array_interface|
                array_interface.methods[:[]] = array_interface.methods[:[]].yield_self do |aref|
                  Interface::Interface::Entry.new(
                    method_types: type.types.map.with_index {|elem_type, index|
                      Interface::MethodType.new(
                        type_params: [],
                        type: Interface::Function.new(
                          params: Interface::Function::Params.build(required: [AST::Types::Literal.new(value: index)]),
                          return_type: elem_type,
                          location: nil
                        ),
                        block: nil,
                        method_decls: Set[]
                      )
                    } + aref.method_types
                  )
                end

                array_interface.methods[:[]=] = array_interface.methods[:[]=].yield_self do |update|
                  Interface::Interface::Entry.new(
                    method_types: type.types.map.with_index {|elem_type, index|
                      Interface::MethodType.new(
                        type_params: [],
                        type: Interface::Function.new(
                          params: Interface::Function::Params.build(required: [AST::Types::Literal.new(value: index), elem_type]),
                          return_type: elem_type,
                          location: nil
                        ),
                        block: nil,
                        method_decls: Set[]
                      )
                    } + update.method_types
                  )
                end

                array_interface.methods[:first] = array_interface.methods[:first].yield_self do |first|
                  Interface::Interface::Entry.new(
                    method_types: [
                      Interface::MethodType.new(
                        type_params: [],
                        type: Interface::Function.new(
                          params: Interface::Function::Params.empty,
                          return_type: type.types[0] || AST::Builtin.nil_type,
                          location: nil
                        ),
                        block: nil,
                        method_decls: Set[]
                      )
                    ]
                  )
                end

                array_interface.methods[:last] = array_interface.methods[:last].yield_self do |last|
                  Interface::Interface::Entry.new(
                    method_types: [
                      Interface::MethodType.new(
                        type_params: [],
                        type: Interface::Function.new(
                          params: Interface::Function::Params.empty,
                          return_type: type.types.last || AST::Builtin.nil_type,
                          location: nil
                        ),
                        block: nil,
                        method_decls: Set[]
                      )
                    ]
                  )
                end
              end
            end

          when Record
            yield_self do
              key_type = type.elements.keys.map {|value| Literal.new(value: value, location: nil) }.yield_self do |types|
                Union.build(types: types, location: nil)
              end
              value_type = Union.build(types: type.elements.values, location: nil)
              hash_type = Builtin::Hash.instance_type(key_type, value_type)

              interface(hash_type, private: private, self_type: self_type).tap do |hash_interface|
                hash_interface.methods[:[]] = hash_interface.methods[:[]].yield_self do |ref|
                  Interface::Interface::Entry.new(
                    method_types: type.elements.map {|key_value, value_type|
                      key_type = Literal.new(value: key_value, location: nil)

                      Interface::MethodType.new(
                        type_params: [],
                        type: Interface::Function.new(
                          params: Interface::Function::Params.build(
                            required: [key_type],
                            optional: [],
                            rest: nil,
                            required_keywords: {},
                            optional_keywords: {},
                            rest_keywords: nil
                          ),
                          return_type: value_type,
                          location: nil
                        ),
                        block: nil,
                        method_decls: Set[]
                      )
                    } + ref.method_types
                  )
                end

                hash_interface.methods[:[]=] = hash_interface.methods[:[]=].yield_self do |update|
                  Interface::Interface::Entry.new(
                    method_types: type.elements.map {|key_value, value_type|
                      key_type = Literal.new(value: key_value, location: nil)
                      Interface::MethodType.new(
                        type_params: [],
                        type: Interface::Function.new(
                          params: Interface::Function::Params.build(
                            required: [key_type, value_type],
                            optional: [],
                            rest: nil,
                            required_keywords: {},
                            optional_keywords: {},
                            rest_keywords: nil
                          ),
                          return_type: value_type,
                          location: nil),
                        block: nil,
                        method_decls: Set[]
                      )
                    } + update.method_types
                  )
                end
              end
            end

          when Proc
            interface(Builtin::Proc.instance_type, private: private, self_type: self_type).tap do |interface|
              method_type = Interface::MethodType.new(
                type_params: [],
                type: type.type,
                block: type.block,
                method_decls: Set[]
              )

              interface.methods[:call] = Interface::Interface::Entry.new(method_types: [method_type])

              if type.block_required?
                interface.methods.delete(:[])
              else
                interface.methods[:[]] = Interface::Interface::Entry.new(method_types: [method_type.with(block: nil)])
              end
            end

          when Logic::Base
            interface(AST::Builtin.bool_type, private: private, self_type: self_type)

          else
            raise "Unexpected type for interface: #{type}"
          end.tap do |interface|
            type_interface_cache[cache_key] = interface
          end
        end

        def module_name?(type_name)
          entry = env.class_decls[type_name] and entry.is_a?(RBS::Environment::ModuleEntry)
        end

        def class_name?(type_name)
          entry = env.class_decls[type_name] and entry.is_a?(RBS::Environment::ClassEntry)
        end

        def env
          @env ||= definition_builder.env
        end

        def absolute_type(type, namespace:)
          absolute_type = type_1(type).map_type_name do |name|
            absolute_type_name(name, namespace: namespace) || name.absolute!
          end
          type(absolute_type)
        end

        def absolute_type_name(type_name, namespace:)
          type_name_resolver.resolve(type_name, context: namespace.ascend)
        end

        def instance_type(type_name, args: nil, location: nil)
          raise unless type_name.class?

          definition = definition_builder.build_singleton(type_name)
          def_args = definition.type_params.map { Any.new(location: nil) }

          if args
            raise if def_args.size != args.size
          else
            args = def_args
          end

          AST::Types::Name::Instance.new(location: location, name: type_name, args: args)
        end
      end
    end
  end
end