class Steep::TypeInference::LogicTypeInterpreter

def decompose_value(node)

def decompose_value(node)
  case node.type
  when :lvar
    [node, Set[node.children[0]]]
  when :masgn
    _, rhs = node.children
    decompose_value(rhs)
  when :lvasgn
    var, rhs = node.children
    val, vars = decompose_value(rhs)
    [val, vars + [var]]
  when :begin
    decompose_value(node.children.last)
  when :and
    left, right = node.children
    _, left_vars = decompose_value(left)
    val, right_vars = decompose_value(right)
    [val, left_vars + right_vars]
  else
    [node, Set[]]
  end
end

def eval(env:, node:)

def eval(env:, node:)
  evaluate_node(env: env, node: node)
end

def evaluate_assignment(assignment_node, env, rhs_type)

def evaluate_assignment(assignment_node, env, rhs_type)
  case assignment_node.type
  when :lvasgn
    name, _ = assignment_node.children
    if TypeConstruction::SPECIAL_LVAR_NAMES.include?(name)
      env
    else
      env.refine_types(local_variable_types: { name => rhs_type })
    end
  when :masgn
    lhs, _ = assignment_node.children
    masgn = MultipleAssignment.new()
    assignments = masgn.expand(lhs, rhs_type, false)
    unless assignments
      rhs_type_converted = try_convert(rhs_type, :to_ary)
      rhs_type_converted ||= try_convert(rhs_type, :to_a)
      rhs_type_converted ||= AST::Types::Tuple.new(types: [rhs_type])
      assignments = masgn.expand(lhs, rhs_type_converted, false)
    end
    unless assignments
      raise "Multiple assignment rhs doesn't look correct: #{rhs_type.to_s} (#{assignment_node.location.expression&.source_line})"
    end
    assignments.each do |pair|
      node, type = pair
      env = evaluate_assignment(node, env, type)
    end
    env
  else
    env
  end
end

def evaluate_method_call(env:, type:, receiver:, arguments:)

def evaluate_method_call(env:, type:, receiver:, arguments:)
  case type
  when AST::Types::Logic::ReceiverIsNil
    if receiver && arguments.size.zero?
      receiver_type = typing.type_of(node: receiver)
      unwrap = factory.unwrap_optional(receiver_type)
      truthy_receiver = AST::Builtin.nil_type
      falsy_receiver = unwrap || receiver_type
      truthy_env, falsy_env = refine_node_type(
        env: env,
        node: receiver,
        truthy_type: truthy_receiver,
        falsy_type: falsy_receiver
      )
      truthy_result = Result.new(type: TRUE, env: truthy_env, unreachable: false)
      truthy_result.unreachable! if no_subtyping?(sub_type: AST::Builtin.nil_type, super_type: receiver_type)
      falsy_result = Result.new(type: FALSE, env: falsy_env, unreachable: false)
      falsy_result.unreachable! unless unwrap
      [truthy_result, falsy_result]
    end
  when AST::Types::Logic::ReceiverIsArg
    if receiver && (arg = arguments[0])
      receiver_type = typing.type_of(node: receiver)
      arg_type = factory.deep_expand_alias(typing.type_of(node: arg))
      if arg_type.is_a?(AST::Types::Name::Singleton)
        truthy_type, falsy_type = type_case_select(receiver_type, arg_type.name)
        truthy_env, falsy_env = refine_node_type(
          env: env,
          node: receiver,
          truthy_type: truthy_type || factory.instance_type(arg_type.name),
          falsy_type: falsy_type || UNTYPED
        )
        truthy_result = Result.new(type: TRUE, env: truthy_env, unreachable: false)
        truthy_result.unreachable! unless truthy_type
        falsy_result = Result.new(type: FALSE, env: falsy_env, unreachable: false)
        falsy_result.unreachable! unless falsy_type
        [truthy_result, falsy_result]
      end
    end
  when AST::Types::Logic::ArgIsReceiver
    if receiver && (arg = arguments[0])
      receiver_type = factory.deep_expand_alias(typing.type_of(node: receiver))
      arg_type = typing.type_of(node: arg)
      if receiver_type.is_a?(AST::Types::Name::Singleton)
        truthy_type, falsy_type = type_case_select(arg_type, receiver_type.name)
        truthy_env, falsy_env = refine_node_type(
          env: env,
          node: arg,
          truthy_type: truthy_type || factory.instance_type(receiver_type.name),
          falsy_type: falsy_type || UNTYPED
        )
        truthy_result = Result.new(type: TRUE, env: truthy_env, unreachable: false)
        truthy_result.unreachable! unless truthy_type
        falsy_result = Result.new(type: FALSE, env: falsy_env, unreachable: false)
        falsy_result.unreachable! unless falsy_type
        [truthy_result, falsy_result]
      end
    end
  when AST::Types::Logic::ArgEqualsReceiver
    if receiver && (arg = arguments[0])
      arg_type = factory.expand_alias(typing.type_of(node: arg))
      if (truthy_types, falsy_types = literal_var_type_case_select(receiver, arg_type))
        truthy_env, falsy_env = refine_node_type(
          env: env,
          node: arg,
          truthy_type: truthy_types.empty? ? BOT : AST::Types::Union.build(types: truthy_types),
          falsy_type: falsy_types.empty? ? BOT : AST::Types::Union.build(types: falsy_types)
        )
        truthy_result = Result.new(type: TRUE, env: truthy_env, unreachable: false)
        truthy_result.unreachable! if truthy_types.empty?
        falsy_result = Result.new(type: FALSE, env: falsy_env, unreachable: false)
        falsy_result.unreachable! if falsy_types.empty?
        [truthy_result, falsy_result]
      end
    end
  when AST::Types::Logic::ArgIsAncestor
    if receiver && (arg = arguments[0])
      receiver_type = typing.type_of(node: receiver)
      arg_type = factory.deep_expand_alias(typing.type_of(node: arg))
      if arg_type.is_a?(AST::Types::Name::Singleton)
        truthy_type = arg_type
        falsy_type = receiver_type
        truthy_env, falsy_env = refine_node_type(
          env: env,
          node: receiver,
          truthy_type: truthy_type,
          falsy_type: falsy_type
        )
        truthy_result = Result.new(type: TRUE, env: truthy_env, unreachable: false)
        truthy_result.unreachable! unless truthy_type
        falsy_result = Result.new(type: FALSE, env: falsy_env, unreachable: false)
        falsy_result.unreachable! unless falsy_type
        [truthy_result, falsy_result]
      end
    end
  when AST::Types::Logic::Not
    if receiver
      truthy_result, falsy_result = evaluate_node(env: env, node: receiver)
      [
        falsy_result.update_type { TRUE },
        truthy_result.update_type { FALSE }
      ]
    end
  end
end

def evaluate_node(env:, node:, type: typing.type_of(node: node))

def evaluate_node(env:, node:, type: typing.type_of(node: node))
  if type.is_a?(AST::Types::Logic::Env)
    truthy_env = type.truthy
    falsy_env = type.falsy
    truthy_type, falsy_type = factory.partition_union(type.type)
    return [
      Result.new(env: truthy_env, type: truthy_type || TRUE, unreachable: !truthy_type),
      Result.new(env: falsy_env, type: falsy_type || FALSE, unreachable: !falsy_type)
    ]
  end
  if type.is_a?(AST::Types::Bot)
    return [
      Result.new(env: env, type: type, unreachable: true),
      Result.new(env: env, type: type, unreachable: true),
    ]
  end
  if type.is_a?(AST::Types::Var)
    type = config.upper_bound(type.name) || type
  end
  case node.type
  when :lvar
    name = node.children[0]
    truthy_type, falsy_type = factory.partition_union(type)
    truthy_result =
      if truthy_type
        Result.new(type: truthy_type, env: env.refine_types(local_variable_types: { name => truthy_type }), unreachable: false)
      else
        Result.new(type: type, env: env, unreachable: true)
      end
    falsy_result =
      if falsy_type
        Result.new(type: falsy_type, env: env.refine_types(local_variable_types: { name => falsy_type }), unreachable: false)
      else
        Result.new(type: type, env: env, unreachable: true)
      end
    return [truthy_result, falsy_result]
  when :lvasgn
    name, rhs = node.children
    if TypeConstruction::SPECIAL_LVAR_NAMES.include?(name)
      return [
        Result.new(type: type, env: env, unreachable: false),
        Result.new(type: type, env: env, unreachable: false)
      ]
    end
    truthy_result, falsy_result = evaluate_node(env: env, node: rhs)
    return [
      truthy_result.update_env { evaluate_assignment(node, truthy_result.env, truthy_result.type) },
      falsy_result.update_env { evaluate_assignment(node, falsy_result.env, falsy_result.type) }
    ]
  when :masgn
    _, rhs = node.children
    truthy_result, falsy_result = evaluate_node(env: env, node: rhs)
    return [
      truthy_result.update_env { evaluate_assignment(node, truthy_result.env, truthy_result.type) },
      falsy_result.update_env { evaluate_assignment(node, falsy_result.env, falsy_result.type) }
    ]
  when :begin
    last_node = node.children.last or raise
    return evaluate_node(env: env, node: last_node)
  when :csend
    if type.is_a?(AST::Types::Any)
      type = guess_type_from_method(node) || type
    end
    receiver, _, *arguments = node.children
    receiver_type = typing.type_of(node: receiver)
    truthy_receiver, falsy_receiver = evaluate_node(env: env, node: receiver)
    truthy_type, _ = factory.partition_union(type)
    truthy_result, falsy_result = evaluate_node(
      env: truthy_receiver.env,
      node: node.updated(:send),
      type: truthy_type || type
    )
    truthy_result.unreachable! if truthy_receiver.unreachable
    falsy_result = Result.new(
      env: env.join(falsy_receiver.env, falsy_result.env),
      unreachable: falsy_result.unreachable && falsy_receiver.unreachable,
      type: falsy_result.type
    )
    return [truthy_result, falsy_result]
  when :send
    if type.is_a?(AST::Types::Any)
      type = guess_type_from_method(node) || type
    end
    case type
    when AST::Types::Logic::Base
      receiver, _, *arguments = node.children
      if (truthy_result, falsy_result = evaluate_method_call(env: env, type: type, receiver: receiver, arguments: arguments))
        return [truthy_result, falsy_result]
      end
    else
      receiver, *_ = node.children
      receiver_type = typing.type_of(node: receiver) if receiver
      if env[receiver] && receiver_type.is_a?(AST::Types::Union)
        result = evaluate_union_method_call(node: node, type: type, env: env, receiver: receiver, receiver_type: receiver_type)
        if result
          truthy_result = result[0] unless result[0].unreachable
          falsy_result = result[1] unless result[1].unreachable
        end
      end
      truthy_result ||= Result.new(type: type, env: env, unreachable: false)
      falsy_result ||= Result.new(type: type, env: env, unreachable: false)
      truthy_type, falsy_type = factory.partition_union(type)
      if truthy_type
        truthy_result = truthy_result.update_type { truthy_type }
      else
        truthy_result = truthy_result.update_type { BOT }.unreachable!
      end
      if falsy_type
        falsy_result = falsy_result.update_type { falsy_type }
      else
        falsy_result = falsy_result.update_type { BOT }.unreachable!
      end
      if truthy_result.env[node] && falsy_result.env[node]
        if truthy_type
          truthy_result = Result.new(type: truthy_type, env: truthy_result.env.refine_types(pure_call_types: { node => truthy_type }), unreachable: false)
        end
        if falsy_type
          falsy_result = Result.new(type: falsy_type, env: falsy_result.env.refine_types(pure_call_types: { node => falsy_type }), unreachable: false)
        end
      end
      return [truthy_result, falsy_result]
    end
  end
  truthy_type, falsy_type = factory.partition_union(type)
  return [
    Result.new(type: truthy_type || BOT, env: env, unreachable: truthy_type.nil?),
    Result.new(type: falsy_type || BOT, env: env, unreachable: falsy_type.nil?)
  ]
end

def evaluate_union_method_call(node:, type:, env:, receiver:, receiver_type:)

def evaluate_union_method_call(node:, type:, env:, receiver:, receiver_type:)
  call_type = typing.call_of(node: node) rescue nil
  return unless call_type.is_a?(Steep::TypeInference::MethodCall::Typed)
  truthy_types = [] #: Array[AST::Types::t]
  falsy_types = [] #: Array[AST::Types::t]
  receiver_type.types.each do |type|
    if shape = subtyping.builder.shape(type, config)
      method = shape.methods[call_type.method_name] or raise
      method_type = method.method_types.find do |method_type|
        call_type.method_decls.any? {|decl| factory.method_type(decl.method_type) == method_type }
      end
      if method_type
        return_type = method_type.type.return_type
        truthy, falsy = factory.partition_union(return_type)
        truthy_types << type if truthy
        falsy_types << type if falsy
        next
      end
    end
    truthy_types << type
    falsy_types << type
  end
  truthy_type = truthy_types.empty? ? BOT : AST::Types::Union.build(types: truthy_types)
  falsy_type = falsy_types.empty? ? BOT : AST::Types::Union.build(types: falsy_types)
  truthy_env, falsy_env = refine_node_type(
    env: env,
    node: receiver,
    truthy_type: truthy_type,
    falsy_type: falsy_type
  )
  return [
    Result.new(type: type, env: truthy_env, unreachable: truthy_type.nil?),
    Result.new(type: type, env: falsy_env, unreachable: falsy_type.nil?)
  ]
end

def factory

def factory
  subtyping.factory
end

def guess_type_from_method(node)

def guess_type_from_method(node)
  if node.type == :send
    method = node.children[1]
    case method
    when :is_a?, :kind_of?, :instance_of?
      AST::Types::Logic::ReceiverIsArg.instance
    when :nil?
      AST::Types::Logic::ReceiverIsNil.instance
    when :!
      AST::Types::Logic::Not.instance
    when :===
      AST::Types::Logic::ArgIsReceiver.instance
    end
  end
end

def initialize(subtyping:, typing:, config:)

def initialize(subtyping:, typing:, config:)
  @subtyping = subtyping
  @typing = typing
  @config = config
end

def literal_var_type_case_select(value_node, arg_type)

def literal_var_type_case_select(value_node, arg_type)
  case arg_type
  when AST::Types::Union
    # @type var truthy_types: Array[AST::Types::t]
    truthy_types = []
    # @type var falsy_types: Array[AST::Types::t]
    falsy_types = []
    arg_type.types.each do |type|
      if (ts, fs = literal_var_type_case_select(value_node, type))
        truthy_types.push(*ts)
        falsy_types.push(*fs)
      else
        return
      end
    end
    [truthy_types, falsy_types]
  when AST::Types::Boolean
    [[arg_type], [arg_type]]
  when AST::Types::Top, AST::Types::Any
    [[arg_type], [arg_type]]
  else
    types = [arg_type]
    case value_node.type
    when :nil
      types.partition do |type|
        type.is_a?(AST::Types::Nil) || AST::Builtin::NilClass.instance_type?(type)
      end
    when :true
      types.partition do |type|
        AST::Builtin::TrueClass.instance_type?(type) ||
          (type.is_a?(AST::Types::Literal) && type.value == true)
      end
    when :false
      types.partition do |type|
        AST::Builtin::FalseClass.instance_type?(type) ||
          (type.is_a?(AST::Types::Literal) && type.value == false)
      end
    when :int, :str, :sym
      # @type var pairs: [Array[AST::Types::t], Array[AST::Types::t]]
      pairs = [[], []]
      types.each_with_object(pairs) do |type, pair|
        true_types, false_types = pair
        case
        when type.is_a?(AST::Types::Literal)
          if type.value == value_node.children[0]
            true_types << type
          else
            false_types << type
          end
        else
          true_types << AST::Types::Literal.new(value: value_node.children[0])
          false_types << type
        end
      end
    end
  end
end

def no_subtyping?(sub_type:, super_type:)

def no_subtyping?(sub_type:, super_type:)
  relation = Subtyping::Relation.new(sub_type: sub_type, super_type: super_type)
  result = subtyping.check(relation, constraints: Subtyping::Constraints.empty, self_type: AST::Types::Self.instance, instance_type: AST::Types::Instance.instance, class_type: AST::Types::Class.instance)
  if result.failure?
    result
  end
end

def refine_node_type(env:, node:, truthy_type:, falsy_type:)

def refine_node_type(env:, node:, truthy_type:, falsy_type:)
  case node.type
  when :lvar
    name = node.children[0]
    if TypeConstruction::SPECIAL_LVAR_NAMES.include?(name)
      [env, env]
    else
      [
        env.refine_types(local_variable_types: { name => truthy_type }),
        env.refine_types(local_variable_types: { name => falsy_type })
      ]
    end
  when :lvasgn
    name, rhs = node.children
    truthy_env, falsy_env = refine_node_type(env: env, node: rhs, truthy_type: truthy_type, falsy_type: falsy_type)
    if TypeConstruction::SPECIAL_LVAR_NAMES.include?(name)
      [truthy_env, falsy_env]
    else
      [
        truthy_env.refine_types(local_variable_types: { name => truthy_type }),
        falsy_env.refine_types(local_variable_types: { name => falsy_type })
      ]
    end
  when :send
    if env[node]
      [
        env.refine_types(pure_call_types: { node => truthy_type }),
        env.refine_types(pure_call_types: { node => falsy_type })
      ]
    else
      [env, env]
    end
  when :begin
    last_node = node.children.last or raise
    refine_node_type(env: env, node: last_node, truthy_type: truthy_type, falsy_type: falsy_type)
  else
    [env, env]
  end
end

def subtyping?(sub_type:, super_type:)

def subtyping?(sub_type:, super_type:)
  !no_subtyping?(sub_type: sub_type, super_type: super_type)
end

def try_convert(type, method)

def try_convert(type, method)
  if shape = subtyping.builder.shape(type, config)
    if entry = shape.methods[method]
      method_type = entry.method_types.find do |method_type|
        method_type.type.params.nil? ||
          method_type.type.params.optional?
      end
      method_type.type.return_type if method_type
    end
  end
end

def type_case_select(type, klass)

def type_case_select(type, klass)
  truth_types, false_types = type_case_select0(type, klass)
  [
    truth_types.empty? ? nil : AST::Types::Union.build(types: truth_types),
    false_types.empty? ? nil : AST::Types::Union.build(types: false_types)
  ]
end

def type_case_select0(type, klass)

def type_case_select0(type, klass)
  instance_type = factory.instance_type(klass)
  case type
  when AST::Types::Union
    truthy_types = [] # :Array[AST::Types::t]
    falsy_types = [] #: Array[AST::Types::t]
    type.types.each do |ty|
      truths, falses = type_case_select0(ty, klass)
      if truths.empty?
        falsy_types.push(ty)
      else
        truthy_types.push(*truths)
        falsy_types.push(*falses)
      end
    end
    [truthy_types, falsy_types]
  when AST::Types::Name::Alias
    ty = factory.expand_alias(type)
    if ty == type
      [[type], [type]]
    else
      type_case_select0(ty, klass)
    end
  when AST::Types::Any, AST::Types::Top, AST::Types::Var
    [
      [instance_type],
      [type]
    ]
  when AST::Types::Name::Interface
    [
      [instance_type],
      [type]
    ]
  else
    # There are four possible relations between `type` and `instance_type`
    #
    # ```ruby
    # case object      # object: T
    # when K           # K: singleton(K)
    # when ...
    # end
    # ````
    #
    # 1. T <: K && K <: T (T == K, T = Integer, K = Numeric)
    # 2. T <: K           (example: T = Integer, K = Numeric)
    # 3. K <: T           (example: T = Numeric, K = Integer)
    # 4. none of the above (example: T = String, K = Integer)
    if subtyping?(sub_type: type, super_type: instance_type)
      # 1 or 2. Satisfies the condition, no narrowing because `type` is already more specific than/equals to `instance_type`
      [
        [type],
        []
      ]
    else
      if subtyping?(sub_type: instance_type, super_type: type)
        # 3. Satisfied the condition, narrows to `instance_type`, but cannot remove it from *falsy* list
        [
          [instance_type],
          [type]
        ]
      else
        # 4
        [
          [],
          [type]
        ]
      end
    end
  end
end