lib/steep/type_inference/send_args.rb



module Steep
  module TypeInference
    class SendArgs
      class PositionalArgs
        class NodeParamPair
          attr_reader :node
          attr_reader :param

          def initialize(node:, param:)
            @node = node
            @param = param
          end

          include Equatable

          def to_ary
            [node, param]
          end
        end

        class NodeTypePair
          attr_reader :node
          attr_reader :type

          def initialize(node:, type:)
            @node = node
            @type = type
          end

          include Equatable

          def node_type
            case node.type
            when :splat
              AST::Builtin::Array.instance_type(type)
            else
              type
            end
          end
        end

        class SplatArg
          attr_reader :node
          attr_accessor :type

          def initialize(node:)
            @node = node
            @type = nil
          end

          include Equatable
        end

        class UnexpectedArg
          attr_reader :node

          def initialize(node:)
            @node = node
          end

          include Equatable
        end

        class MissingArg
          attr_reader :params

          def initialize(params:)
            @params = params
          end

          include Equatable
        end

        attr_reader :args
        attr_reader :index
        attr_reader :positional_params
        attr_reader :uniform

        def initialize(args:, index:, positional_params:, uniform: false)
          @args = args
          @index = index
          @positional_params = positional_params
          @uniform = uniform
        end

        def node
          args[index]
        end

        def following_args
          args[index..] or raise
        end

        def param
          positional_params&.head
        end

        def update(index: self.index, positional_params: self.positional_params, uniform: self.uniform)
          PositionalArgs.new(args: args, index: index, positional_params: positional_params, uniform: uniform)
        end

        def next()
          case
          when node && node.type == :forwarded_args
            # If the node is a `:forwarded_args`, abort
            nil
          when !node && param.is_a?(Interface::Function::Params::PositionalParams::Required)
            [
              MissingArg.new(params: positional_params),
              update(index: index, positional_params: nil)
            ]
          when !node && param.is_a?(Interface::Function::Params::PositionalParams::Optional)
            nil
          when !node && param.is_a?(Interface::Function::Params::PositionalParams::Rest)
            nil
          when !node && !param
            nil
          when node && node.type != :splat && param.is_a?(Interface::Function::Params::PositionalParams::Required)
            [
              NodeParamPair.new(node: node, param: param),
              update(index: index+1, positional_params: positional_params&.tail)
            ]
          when node && node.type != :splat && param.is_a?(Interface::Function::Params::PositionalParams::Optional)
            [
              NodeParamPair.new(node: node, param: param),
              update(index: index+1, positional_params: positional_params&.tail)
            ]
          when node && node.type != :splat && param.is_a?(Interface::Function::Params::PositionalParams::Rest)
            [
              NodeParamPair.new(node: node, param: param),
              update(index: index+1)
            ]
          when node && node.type != :splat && !param
            [
              UnexpectedArg.new(node: node),
              update(index: index + 1)
            ]
          when node && node.type == :splat
            [
              SplatArg.new(node: node),
              self
            ]
          end
        end

        def uniform_type
          return nil unless positional_params
          if positional_params.each.any? {|param| param.is_a?(Interface::Function::Params::PositionalParams::Rest) }
            AST::Types::Intersection.build(types: positional_params.each.map(&:type))
          end
        end

        def consume(n, node:)
          # @type var ps: Array[Interface::Function::Params::PositionalParams::param]
          ps = []
          params = consume0(n, node: node, params: positional_params, ps: ps)
          case params
          when UnexpectedArg
            [
              params,
              update(index: index+1, positional_params: nil)
            ]
          else
            [ps, update(index: index+1, positional_params: params)]
          end
        end

        def consume0(n, node:, params:, ps:)
          case n
          when 0
            params
          else
            head = params&.head
            case head
            when nil
              UnexpectedArg.new(node: node)
            when Interface::Function::Params::PositionalParams::Required, Interface::Function::Params::PositionalParams::Optional
              ps << head
              consume0(n-1, node: node, params: params&.tail, ps: ps)
            when Interface::Function::Params::PositionalParams::Rest
              ps << head
              consume0(n-1, node: node, params: params, ps: ps)
            end
          end
        end
      end

      class KeywordArgs
        class ArgTypePairs
          attr_reader :pairs

          def initialize(pairs:)
            @pairs = pairs
          end

          include Equatable

          def [](index)
            pairs[index]
          end

          def size
            pairs.size
          end
        end

        class SplatArg
          attr_reader :node
          attr_accessor :type

          def initialize(node:)
            @node = node
            @type = nil
          end

          include Equatable
        end

        class UnexpectedKeyword
          attr_reader :keyword
          attr_reader :node

          include Equatable

          def initialize(keyword:, node:)
            @keyword = keyword
            @node = node
          end

          def key_node
            if node.type == :pair
              node.children[0]
            end
          end

          def value_node
            if node.type == :pair
              node.children[1]
            end
          end
        end

        class MissingKeyword
          attr_reader :keywords

          include Equatable

          def initialize(keywords:)
            @keywords = keywords
          end
        end

        attr_reader :kwarg_nodes
        attr_reader :keyword_params
        attr_reader :index
        attr_reader :consumed_keywords

        def initialize(kwarg_nodes:, keyword_params:, index: 0, consumed_keywords: Set[])
          @kwarg_nodes = kwarg_nodes
          @keyword_params = keyword_params
          @index = index
          @consumed_keywords = consumed_keywords
        end

        def update(index: self.index, consumed_keywords: self.consumed_keywords)
          KeywordArgs.new(
            kwarg_nodes: kwarg_nodes,
            keyword_params: keyword_params,
            index: index,
            consumed_keywords: consumed_keywords
          )
        end

        def keyword_pair
          kwarg_nodes[index]
        end

        def required_keywords
          keyword_params.requireds
        end

        def optional_keywords
          keyword_params.optionals
        end

        def rest_type
          keyword_params.rest
        end

        def keyword_type(key)
          required_keywords[key] || optional_keywords[key]
        end

        def all_keys
          keys = Set.new
          keys.merge(required_keywords.each_key)
          keys.merge(optional_keywords.each_key)
          keys.sort_by(&:to_s).to_a
        end

        def all_values
          keys = Set.new
          keys.merge(required_keywords.each_value)
          keys.merge(optional_keywords.each_value)
          keys.sort_by(&:to_s).to_a
        end

        def possible_key_type
          # @type var key_types: Array[AST::Types::t]
          key_types = all_keys.map {|key| AST::Types::Literal.new(value: key) }
          key_types << AST::Builtin::Symbol.instance_type if rest_type

          AST::Types::Union.build(types: key_types)
        end

        def possible_value_type
          value_types = all_values
          value_types << rest_type if rest_type

          AST::Types::Intersection.build(types: value_types)
        end

        def next()
          node = keyword_pair

          if node
            case node.type
            when :pair
              key_node, value_node = node.children

              if key_node.type == :sym
                key = key_node.children[0]

                case
                when value_type = keyword_type(key)
                  [
                    ArgTypePairs.new(
                      pairs: [
                        [key_node, AST::Types::Literal.new(value: key)],
                        [value_node, value_type]
                      ]
                    ),
                    update(
                      index: index+1,
                      consumed_keywords: consumed_keywords + [key]
                    )
                  ]
                when value_type = rest_type
                  [
                    ArgTypePairs.new(
                      pairs: [
                        [key_node, AST::Builtin::Symbol.instance_type],
                        [value_node, value_type]
                      ]
                    ),
                    update(
                      index: index+1,
                      consumed_keywords: consumed_keywords + [key]
                    )
                  ]
                else
                  [
                    UnexpectedKeyword.new(keyword: key, node: node),
                    update(index: index+1)
                  ]
                end
              else
                if !all_keys.empty? || rest_type
                  [
                    ArgTypePairs.new(
                      pairs: [
                        [key_node, possible_key_type],
                        [value_node, possible_value_type]
                      ]
                    ),
                    update(index: index+1)
                  ]
                else
                  [
                    UnexpectedKeyword.new(keyword: nil, node: node),
                    update(index: index+1)
                  ]
                end
              end
            when :kwsplat
              [
                SplatArg.new(node: node),
                self
              ]
            end
          else
            left = Set.new(required_keywords.keys) - consumed_keywords
            unless left.empty?
              [
                MissingKeyword.new(keywords: left),
                update(consumed_keywords: consumed_keywords + left)
              ]
            end
          end
        end

        def consume_keys(keys, node:)
          # @type var consumed_keys: Array[Symbol]
          consumed_keys = []
          # @type var types: Array[AST::Types::t]
          types = []

          # @type var unexpected_keyword: Symbol?
          unexpected_keyword = nil

          keys.each do |key|
            case
            when type = keyword_type(key)
              consumed_keys << key
              types << type
            when type = rest_type()
              types << type
            else
              unexpected_keyword = key
            end
          end

          [
            if unexpected_keyword
              UnexpectedKeyword.new(keyword: unexpected_keyword, node: node)
            else
              types
            end,
            update(index: index + 1, consumed_keywords: consumed_keywords + consumed_keys)
          ]
        end
      end

      class BlockPassArg
        attr_reader :node
        attr_reader :block

        def initialize(node:, block:)
          @node = node
          @block = block
        end

        include Equatable

        def no_block?
          !node && !block
        end

        def compatible?
          if node
            block ? true : false
          else
            !block || block.optional?
          end
        end

        def block_missing?
          !node && block&.required?
        end

        def unexpected_block?
          node && !block
        end

        def pair
          raise unless compatible?

          if node && block
            [
              node,
              block.type
            ]
          end
        end

        def node_type
          raise unless block

          type = AST::Types::Proc.new(type: block.type, block: nil, self_type: block.self_type)

          if block.optional?
            type = AST::Types::Union.build(types: [type, AST::Builtin.nil_type])
          end

          type
        end
      end

      class ForwardedArgs
        attr_reader :node, :params

        def initialize(node:, params:)
          @node = node
          @params = params
        end
      end

      attr_reader :node
      attr_reader :arguments
      attr_reader :type

      def initialize(node:, arguments:, type:)
        @node = node
        @arguments = arguments
        @type = type
      end

      def params
        case type
        when Interface::MethodType
          type.type.params
        when AST::Types::Proc
          type.type.params
        else
          raise
        end
      end

      def block
        case type
        when Interface::MethodType
          type.block
        when AST::Types::Proc
          type.block
        end
      end

      def positional_params
        params.positional_params
      end

      def keyword_params
        params.keyword_params
      end

      def kwargs_node
        unless keyword_params.empty?
          arguments.find {|node| node.type == :kwargs }
        end
      end

      def positional_arg
        args =
          if keyword_params.empty?
            arguments.take_while {|node| node.type != :block_pass }
          else
            arguments.take_while {|node| node.type != :kwargs && node.type != :block_pass }
          end

        PositionalArgs.new(args: args, index: 0, positional_params: positional_params)
      end

      def forwarded_args_node
        arguments.find {|node| node.type == :forwarded_args }
      end

      def keyword_args
        KeywordArgs.new(
          kwarg_nodes: kwargs_node&.children || [],
          keyword_params: keyword_params
        )
      end

      def block_pass_arg
        node = arguments.find {|node| node.type == :block_pass }

        BlockPassArg.new(node: node, block: block)
      end

      def each
        if block_given?
          errors = [] #: Array[PositionalArgs::error_arg | KeywordArgs::error_arg]

          last_positional_args = positional_arg

          positional_arg.tap do |args|
            while (value, args = args.next())
              yield value

              case value
              when PositionalArgs::SplatArg
                type = value.type

                case type
                when nil
                  raise
                when AST::Types::Tuple
                  ts, args = args.consume(type.types.size, node: value.node)

                  case ts
                  when Array
                    ty = AST::Types::Tuple.new(types: ts.map(&:type))
                    yield PositionalArgs::NodeTypePair.new(node: value.node, type: ty)
                  when PositionalArgs::UnexpectedArg
                    errors << ts
                    yield ts
                  end
                else
                  if t = args.uniform_type
                    args.following_args.each do |node|
                      yield PositionalArgs::NodeTypePair.new(node: node, type: t)
                    end
                  else
                    args.following_args.each do |node|
                      arg = PositionalArgs::UnexpectedArg.new(node: node)
                      yield arg
                      errors << arg
                    end
                  end

                  break
                end
              when PositionalArgs::UnexpectedArg, PositionalArgs::MissingArg
                errors << value
              end

              last_positional_args = args
            end
          end

          if fag = forwarded_args_node
            forward_params = Interface::Function::Params.new(
              positional_params: last_positional_args.positional_params,
              keyword_params: keyword_params
            )

            forwarded_args = ForwardedArgs.new(node: fag, params: forward_params)
          else
            keyword_args.tap do |args|
              while (a, args = args.next)
                case a
                when KeywordArgs::MissingKeyword
                  errors << a
                when KeywordArgs::UnexpectedKeyword
                  errors << a
                end

                yield a

                case a
                when KeywordArgs::SplatArg
                  case type = a.type
                  when nil
                    raise
                  when AST::Types::Record
                    # @type var keys: Array[Symbol]
                    keys = _ = type.elements.keys
                    ts, args = args.consume_keys(keys, node: a.node)

                    case ts
                    when KeywordArgs::UnexpectedKeyword
                      yield ts
                      errors << ts
                    when Array
                      pairs = keys.zip(ts) #: Array[[Symbol, AST::Types::t]]
                      record = AST::Types::Record.new(elements: Hash[pairs])
                      yield KeywordArgs::ArgTypePairs.new(pairs: [[a.node, record]])
                    end
                  else
                    args = args.update(index: args.index + 1)

                    if args.rest_type
                      type = AST::Builtin::Hash.instance_type(AST::Builtin::Symbol.instance_type, args.possible_value_type)
                      yield KeywordArgs::ArgTypePairs.new(pairs: [[a.node, type]])
                    else
                      yield KeywordArgs::UnexpectedKeyword.new(keyword: nil, node: a.node)
                    end
                  end
                end
              end
            end
          end

          diagnostics = [] #: Array[Diagnostic::Ruby::Base]

          missing_keywords = [] #: Array[Symbol]
          errors.each do |error|
            case error
            when KeywordArgs::UnexpectedKeyword
              diagnostics << Diagnostic::Ruby::UnexpectedKeywordArgument.new(node: error.node, params: params)
            when KeywordArgs::MissingKeyword
              missing_keywords.push(*error.keywords.to_a)
            when PositionalArgs::UnexpectedArg
              diagnostics << Diagnostic::Ruby::UnexpectedPositionalArgument.new(node: error.node, params: params)
            when PositionalArgs::MissingArg
              diagnostics << Diagnostic::Ruby::InsufficientPositionalArguments.new(node: node, params: params)
            end
          end

          unless missing_keywords.empty?
            diagnostics << Diagnostic::Ruby::InsufficientKeywordArguments.new(node: node, params: params, missing_keywords: missing_keywords)
          end

          [forwarded_args, diagnostics]
        else
          enum_for :each
        end
      end
    end
  end
end