lib/steep/type_inference/multiple_assignment.rb



module Steep
  module TypeInference
    class MultipleAssignment
      Assignments = _ = Struct.new(:rhs_type, :optional, :leading_assignments, :trailing_assignments, :splat_assignment, keyword_init: true) do
        # @implements Assignments

        def each(&block)
          if block
            leading_assignments.each(&block)
            if sp = splat_assignment
              yield sp
            end
            trailing_assignments.each(&block)
          else
            enum_for :each
          end
        end
      end

      def expand(mlhs, rhs_type, optional)
        lhss = mlhs.children

        case rhs_type
        when AST::Types::Tuple
          expand_tuple(lhss.dup, rhs_type, rhs_type.types.dup, optional)
        when AST::Types::Name::Instance
          if AST::Builtin::Array.instance_type?(rhs_type)
            expand_array(lhss.dup, rhs_type, optional)
          end
        when AST::Types::Any
          expand_any(lhss, rhs_type, AST::Builtin.any_type, optional)
        end
      end

      def expand_tuple(lhss, rhs_type, tuples, optional)
        # @type var leading_assignments: Array[node_type_pair]
        leading_assignments = []
        # @type var trailing_assignments: Array[node_type_pair]
        trailing_assignments = []
        # @type var splat_assignment: node_type_pair?
        splat_assignment = nil

        while !lhss.empty?
          first = lhss.first or raise

          case
          when first.type == :splat
            break
          else
            leading_assignments << [first, tuples.first || AST::Builtin.nil_type]
            lhss.shift
            tuples.shift
          end
        end

        while !lhss.empty?
          last = lhss.last or raise

          case
          when last.type == :splat
            break
          else
            trailing_assignments << [last, tuples.last || AST::Builtin.nil_type]
            lhss.pop
            tuples.pop
          end
        end

        case lhss.size
        when 0
          # nop
        when 1
          splat_assignment = [lhss.first || raise, AST::Types::Tuple.new(types: tuples)]
        else
          raise
        end

        Assignments.new(
          rhs_type: rhs_type,
          optional: optional,
          leading_assignments: leading_assignments,
          trailing_assignments: trailing_assignments,
          splat_assignment: splat_assignment
        )
      end

      def expand_array(lhss, rhs_type, optional)
        element_type = rhs_type.args[0] or raise

        # @type var leading_assignments: Array[node_type_pair]
        leading_assignments = []
        # @type var trailing_assignments: Array[node_type_pair]
        trailing_assignments = []
        # @type var splat_assignment: node_type_pair?
        splat_assignment = nil

        while !lhss.empty?
          first = lhss.first or raise

          case
          when first.type == :splat
            break
          else
            leading_assignments << [first, AST::Builtin.optional(element_type)]
            lhss.shift
          end
        end

        while !lhss.empty?
          last = lhss.last or raise

          case
          when last.type == :splat
            break
          else
            trailing_assignments << [last, AST::Builtin.optional(element_type)]
            lhss.pop
          end
        end

        case lhss.size
        when 0
          # nop
        when 1
          splat_assignment = [
            lhss.first || raise,
            AST::Builtin::Array.instance_type(element_type)
          ]
        else
          raise
        end

        Assignments.new(
          rhs_type: rhs_type,
          optional: optional,
          leading_assignments: leading_assignments,
          trailing_assignments: trailing_assignments,
          splat_assignment: splat_assignment
        )
      end

      def expand_any(nodes, rhs_type, element_type, optional)
        # @type var leading_assignments: Array[node_type_pair]
        leading_assignments = []
        # @type var trailing_assignments: Array[node_type_pair]
        trailing_assignments = []
        # @type var splat_assignment: node_type_pair?
        splat_assignment = nil

        array = leading_assignments

        nodes.each do |node|
          case node.type
          when :splat
            splat_assignment = [node, AST::Builtin::Array.instance_type(element_type)]
            array = trailing_assignments
          else
            array << [node, element_type]
          end
        end

        Assignments.new(
          rhs_type: rhs_type,
          optional: optional,
          leading_assignments: leading_assignments,
          trailing_assignments: trailing_assignments,
          splat_assignment: splat_assignment
        )
      end

      def hint_for_mlhs(mlhs, env)
        case mlhs.type
        when :mlhs
          types = mlhs.children.map do |node|
            hint_for_mlhs(node, env) or return
          end
          AST::Types::Tuple.new(types: types)
        when :lvasgn, :ivasgn, :gvasgn
          name = mlhs.children[0]
          
          unless TypeConstruction::SPECIAL_LVAR_NAMES.include?(name)
            env[name] || AST::Builtin.any_type
          else
            AST::Builtin.any_type
          end
        when :splat
          return
        else
          return
        end
      end
    end
  end
end