lib/steep/type_inference/case_when.rb



module Steep
  module TypeInference
    class CaseWhen
      class WhenPatterns
        include NodeHelper

        attr_reader :logic, :initial_constr, :unreachable_clause, :pattern_results

        def initialize(logic, initial_constr, unreachable_clause, assignment_node)
          @logic = logic
          @initial_constr = initial_constr
          @unreachable_clause = unreachable_clause
          @assignment_node = assignment_node

          @pattern_results = []
        end

        def add_pattern(pat)
          test_node = pat.updated(:send, [pat, :===, assignment_node])

          latest_constr, unreachable_pattern = latest_result

          type, constr = yield(test_node, latest_constr, unreachable_pattern)
          truthy_result, falsy_result = logic.eval(env: latest_constr.context.type_env, node: test_node)

          pattern_results << [pat, truthy_result, falsy_result]
        end

        def latest_result
          if (_, truthy, falsy = pattern_results.last)
            [
              initial_constr.update_type_env { falsy.env },
              falsy.unreachable
            ]
          else
            [initial_constr, unreachable_clause]
          end
        end

        def body_result
          raise if pattern_results.empty?

          type_envs = pattern_results.map {|_, truthy, _| truthy.env }
          env = initial_constr.context.type_env.join(*type_envs)

          env = yield(env) || env

          [
            initial_constr.update_type_env { env },
            unreachable_clause || pattern_results.all? {|_, truthy, _| truthy.unreachable }
          ]
        end

        def falsy_result
          (_, _, falsy = pattern_results.last) or raise

          [
            initial_constr.update_type_env { falsy.env },
            unreachable_clause || falsy.unreachable
          ]
        end

        def assignment_node()
          clone_node(@assignment_node)
        end
      end

      include NodeHelper
      extend NodeHelper

      def self.type_check(constr, node, logic, hint:, condition:)
        case_when = new(node, logic) do |condition_node|
          constr.synthesize(condition_node)
        end

        case_when.when_clauses() do |when_pats, patterns, body_node, loc|
          patterns.each do |pat|
            when_pats.add_pattern(pat) {|test, constr| constr.synthesize(test) }
          end

          body_constr, body_unreachable = when_pats.body_result() do |env|
            case_when.propagate_value_node_type(env)
          end

          if body_node
            body_constr = body_constr.for_branch(body_node)
            type, body_constr = body_constr.synthesize(body_node, hint: hint, condition: condition)
          else
            type = AST::Builtin.nil_type
          end

          body_result = LogicTypeInterpreter::Result.new(
            type: type,
            env: body_constr.context.type_env,
            unreachable: body_unreachable
          )

          falsy_constr, falsy_unreachable = when_pats.falsy_result
          next_result = LogicTypeInterpreter::Result.new(
            type: AST::Builtin.any_type,    # Unused for falsy pattern
            env: falsy_constr.context.type_env,
            unreachable: falsy_unreachable
          )

          [body_result, next_result]
        end

        case_when.else_clause do |else_node, constr|
          constr.synthesize(else_node, hint: hint, condition: condition)
        end

        case_when.result()
      end

      attr_reader :location, :node, :condition_node, :when_nodes, :else_node
      attr_reader :initial_constr, :logic, :clause_results, :else_result
      attr_reader :assignment_node, :value_node, :var_name

      def initialize(node, logic)
        @node = node

        condition_node, when_nodes, else_node, location = deconstruct_case_node!(node)
        condition_node or raise "CaseWhen works for case-when syntax with condition node"

        @condition_node = condition_node
        @when_nodes = when_nodes
        @else_node = else_node
        @location = location
        @logic = logic
        @clause_results = []

        type, constr = yield(condition_node)

        @var_name = "__case_when:#{SecureRandom.alphanumeric(5)}__".to_sym
        @value_node, @assignment_node = rewrite_condition_node(var_name, condition_node)

        @initial_constr = constr.update_type_env do |env|
          env.merge(local_variable_types: { var_name => [type, nil] })
        end
      end

      def when_clauses()
        when_nodes.each do |when_node|
          clause_constr, unreachable = latest_result

          patterns, body, loc = deconstruct_when_node!(when_node)

          when_pats = WhenPatterns.new(
            logic,
            clause_constr,
            unreachable,
            assignment_node
          )

          body_result, next_result = yield(
            when_pats,
            patterns,
            body,
            loc
          )

          if body_result.unreachable
            if body_result.type.is_a?(AST::Types::Any) || initial_constr.no_subtyping?(sub_type: body_result.type, super_type: AST::Builtin.bottom_type)
              typing.add_error(
                Diagnostic::Ruby::UnreachableValueBranch.new(
                  node: when_node,
                  type: body_result.type,
                  location: loc.keyword
                )
              )
            end
          end

          clause_results << [body_result, next_result]
        end
      end

      def else_clause()
        unless else_loc = has_else_clause?
          return
        end

        constr, unreachable = latest_result

        constr = constr.update_type_env do |env|
          propagate_value_node_type(env) || env
        end

        @else_result =
          if else_node
            yield(else_node, constr)
          else
            TypeConstruction::Pair.new(type: AST::Builtin.nil_type, constr: constr)
          end

        else_result or raise

        if unreachable
          if else_result.type.is_a?(AST::Types::Any) || initial_constr.no_subtyping?(sub_type: else_result.type, super_type: AST::Builtin.bottom_type)
            typing.add_error(
              Diagnostic::Ruby::UnreachableValueBranch.new(
                node: else_node || node,
                type: else_result.type,
                location: else_loc
              )
            )
          end
        end
      end

      def latest_result
        if (_, falsy_result = clause_results.last)
          [
            initial_constr.update_type_env { falsy_result.env },
            falsy_result.unreachable
          ]
        else
          [initial_constr, false]
        end
      end

      def result
        results = clause_results.filter_map do |body, _|
          unless body.unreachable
            body
          end
        end
        next_constr, next_clause_unreachable = latest_result

        unless next_clause_unreachable
          if else_result
            results << LogicTypeInterpreter::Result.new(
              type: else_result.type,
              env: else_result.context.type_env,
              unreachable: false    # Unused
            )
          else
            results << LogicTypeInterpreter::Result.new(
              type: AST::Builtin.nil_type,
              env: next_constr.context.type_env,
              unreachable: false
            )
          end
        end

        types = results.map {|result| result.type }
        envs = results.map {|result| result.env }

        [
          types,
          envs
        ]
      end

      def has_else_clause?
        location.else
      end

      def typing
        logic.typing
      end

      def rewrite_condition_node(var_name, node)
        case node.type
        when :lvasgn
          name, rhs = node.children
          value, rhs = rewrite_condition_node(var_name, rhs)
          [value, node.updated(nil, [name, rhs])]
        when :lvar
          name, = node.children
          [
            nil,
            node.updated(:lvasgn, [name, node.updated(:lvar, [var_name])])
          ]
        when :begin
          *children, last = node.children
          value_node, last = rewrite_condition_node(var_name, last)
          [
            value_node,
            node.updated(nil, children.push(last))
          ]
        else
          [
            node,
            node.updated(:lvar, [var_name])
          ]
        end
      end

      def propagate_value_node_type(env)
        if value_node
          if (call = initial_constr.typing.method_calls[value_node]).is_a?(MethodCall::Typed)
            if env[value_node]
              env.merge(pure_method_calls: { value_node => [call, env[var_name]] })
            end
          end
        end
      end
    end
  end
end