module Steep
module TypeInference
class LogicTypeInterpreter
class Result < Struct.new(:env, :type, :unreachable, keyword_init: true)
def update_env
env = yield()
Result.new(type: type, env: env, unreachable: unreachable)
end
def update_type
Result.new(type: yield, env: env, unreachable: unreachable)
end
def unreachable!
self.unreachable = true
self
end
end
attr_reader :subtyping
attr_reader :typing
attr_reader :config
def initialize(subtyping:, typing:, config:)
@subtyping = subtyping
@typing = typing
@config = config
end
def factory
subtyping.factory
end
def guess_type_from_method(node)
if node.type == :send
method = node.children[1]
case method
when :is_a?, :kind_of?, :instance_of?
AST::Types::Logic::ReceiverIsArg.instance
when :nil?
AST::Types::Logic::ReceiverIsNil.instance
when :!
AST::Types::Logic::Not.instance
when :===
AST::Types::Logic::ArgIsReceiver.instance
end
end
end
TRUE = AST::Types::Literal.new(value: true)
FALSE = AST::Types::Literal.new(value: false)
BOOL = AST::Types::Boolean.instance
BOT = AST::Types::Bot.instance
UNTYPED = AST::Types::Any.instance
def eval(env:, node:)
evaluate_node(env: env, node: node)
end
def evaluate_node(env:, node:, type: typing.type_of(node: node))
if type.is_a?(AST::Types::Logic::Env)
truthy_env = type.truthy
falsy_env = type.falsy
truthy_type, falsy_type = factory.partition_union(type.type)
return [
Result.new(env: truthy_env, type: truthy_type || TRUE, unreachable: !truthy_type),
Result.new(env: falsy_env, type: falsy_type || FALSE, unreachable: !falsy_type)
]
end
if type.is_a?(AST::Types::Bot)
return [
Result.new(env: env, type: type, unreachable: true),
Result.new(env: env, type: type, unreachable: true),
]
end
if type.is_a?(AST::Types::Var)
type = config.upper_bound(type.name) || type
end
case node.type
when :lvar
name = node.children[0]
truthy_type, falsy_type = factory.partition_union(type)
truthy_result =
if truthy_type
Result.new(type: truthy_type, env: env.refine_types(local_variable_types: { name => truthy_type }), unreachable: false)
else
Result.new(type: type, env: env, unreachable: true)
end
falsy_result =
if falsy_type
Result.new(type: falsy_type, env: env.refine_types(local_variable_types: { name => falsy_type }), unreachable: false)
else
Result.new(type: type, env: env, unreachable: true)
end
return [truthy_result, falsy_result]
when :lvasgn
name, rhs = node.children
if TypeConstruction::SPECIAL_LVAR_NAMES.include?(name)
return [
Result.new(type: type, env: env, unreachable: false),
Result.new(type: type, env: env, unreachable: false)
]
end
truthy_result, falsy_result = evaluate_node(env: env, node: rhs)
return [
truthy_result.update_env { evaluate_assignment(node, truthy_result.env, truthy_result.type) },
falsy_result.update_env { evaluate_assignment(node, falsy_result.env, falsy_result.type) }
]
when :masgn
_, rhs = node.children
truthy_result, falsy_result = evaluate_node(env: env, node: rhs)
return [
truthy_result.update_env { evaluate_assignment(node, truthy_result.env, truthy_result.type) },
falsy_result.update_env { evaluate_assignment(node, falsy_result.env, falsy_result.type) }
]
when :begin
last_node = node.children.last or raise
return evaluate_node(env: env, node: last_node)
when :csend
if type.is_a?(AST::Types::Any)
type = guess_type_from_method(node) || type
end
receiver, _, *arguments = node.children
receiver_type = typing.type_of(node: receiver)
truthy_receiver, falsy_receiver = evaluate_node(env: env, node: receiver)
truthy_type, _ = factory.partition_union(type)
truthy_result, falsy_result = evaluate_node(
env: truthy_receiver.env,
node: node.updated(:send),
type: truthy_type || type
)
truthy_result.unreachable! if truthy_receiver.unreachable
falsy_result = Result.new(
env: env.join(falsy_receiver.env, falsy_result.env),
unreachable: falsy_result.unreachable && falsy_receiver.unreachable,
type: falsy_result.type
)
return [truthy_result, falsy_result]
when :send
if type.is_a?(AST::Types::Any)
type = guess_type_from_method(node) || type
end
case type
when AST::Types::Logic::Base
receiver, _, *arguments = node.children
if (truthy_result, falsy_result = evaluate_method_call(env: env, type: type, receiver: receiver, arguments: arguments))
return [truthy_result, falsy_result]
end
else
receiver, *_ = node.children
receiver_type = typing.type_of(node: receiver) if receiver
if env[receiver] && receiver_type.is_a?(AST::Types::Union)
result = evaluate_union_method_call(node: node, type: type, env: env, receiver: receiver, receiver_type: receiver_type)
if result
truthy_result = result[0] unless result[0].unreachable
falsy_result = result[1] unless result[1].unreachable
end
end
truthy_result ||= Result.new(type: type, env: env, unreachable: false)
falsy_result ||= Result.new(type: type, env: env, unreachable: false)
truthy_type, falsy_type = factory.partition_union(type)
if truthy_type
truthy_result = truthy_result.update_type { truthy_type }
else
truthy_result = truthy_result.update_type { BOT }.unreachable!
end
if falsy_type
falsy_result = falsy_result.update_type { falsy_type }
else
falsy_result = falsy_result.update_type { BOT }.unreachable!
end
if truthy_result.env[node] && falsy_result.env[node]
if truthy_type
truthy_result = Result.new(type: truthy_type, env: truthy_result.env.refine_types(pure_call_types: { node => truthy_type }), unreachable: false)
end
if falsy_type
falsy_result = Result.new(type: falsy_type, env: falsy_result.env.refine_types(pure_call_types: { node => falsy_type }), unreachable: false)
end
end
return [truthy_result, falsy_result]
end
end
truthy_type, falsy_type = factory.partition_union(type)
return [
Result.new(type: truthy_type || BOT, env: env, unreachable: truthy_type.nil?),
Result.new(type: falsy_type || BOT, env: env, unreachable: falsy_type.nil?)
]
end
def evaluate_assignment(assignment_node, env, rhs_type)
case assignment_node.type
when :lvasgn
name, _ = assignment_node.children
if TypeConstruction::SPECIAL_LVAR_NAMES.include?(name)
env
else
env.refine_types(local_variable_types: { name => rhs_type })
end
when :masgn
lhs, _ = assignment_node.children
masgn = MultipleAssignment.new()
assignments = masgn.expand(lhs, rhs_type, false)
unless assignments
rhs_type_converted = try_convert(rhs_type, :to_ary)
rhs_type_converted ||= try_convert(rhs_type, :to_a)
rhs_type_converted ||= AST::Types::Tuple.new(types: [rhs_type])
assignments = masgn.expand(lhs, rhs_type_converted, false)
end
unless assignments
raise "Multiple assignment rhs doesn't look correct: #{rhs_type.to_s} (#{assignment_node.location.expression&.source_line})"
end
assignments.each do |pair|
node, type = pair
env = evaluate_assignment(node, env, type)
end
env
else
env
end
end
def refine_node_type(env:, node:, truthy_type:, falsy_type:)
case node.type
when :lvar
name = node.children[0]
if TypeConstruction::SPECIAL_LVAR_NAMES.include?(name)
[env, env]
else
[
env.refine_types(local_variable_types: { name => truthy_type }),
env.refine_types(local_variable_types: { name => falsy_type })
]
end
when :lvasgn
name, rhs = node.children
truthy_env, falsy_env = refine_node_type(env: env, node: rhs, truthy_type: truthy_type, falsy_type: falsy_type)
if TypeConstruction::SPECIAL_LVAR_NAMES.include?(name)
[truthy_env, falsy_env]
else
[
truthy_env.refine_types(local_variable_types: { name => truthy_type }),
falsy_env.refine_types(local_variable_types: { name => falsy_type })
]
end
when :send
if env[node]
[
env.refine_types(pure_call_types: { node => truthy_type }),
env.refine_types(pure_call_types: { node => falsy_type })
]
else
[env, env]
end
when :begin
last_node = node.children.last or raise
refine_node_type(env: env, node: last_node, truthy_type: truthy_type, falsy_type: falsy_type)
else
[env, env]
end
end
def evaluate_method_call(env:, type:, receiver:, arguments:)
case type
when AST::Types::Logic::ReceiverIsNil
if receiver && arguments.size.zero?
receiver_type = typing.type_of(node: receiver)
unwrap = factory.unwrap_optional(receiver_type)
truthy_receiver = AST::Builtin.nil_type
falsy_receiver = unwrap || receiver_type
truthy_env, falsy_env = refine_node_type(
env: env,
node: receiver,
truthy_type: truthy_receiver,
falsy_type: falsy_receiver
)
truthy_result = Result.new(type: TRUE, env: truthy_env, unreachable: false)
truthy_result.unreachable! if no_subtyping?(sub_type: AST::Builtin.nil_type, super_type: receiver_type)
falsy_result = Result.new(type: FALSE, env: falsy_env, unreachable: false)
falsy_result.unreachable! unless unwrap
[truthy_result, falsy_result]
end
when AST::Types::Logic::ReceiverIsArg
if receiver && (arg = arguments[0])
receiver_type = typing.type_of(node: receiver)
arg_type = factory.deep_expand_alias(typing.type_of(node: arg))
if arg_type.is_a?(AST::Types::Name::Singleton)
truthy_type, falsy_type = type_case_select(receiver_type, arg_type.name)
truthy_env, falsy_env = refine_node_type(
env: env,
node: receiver,
truthy_type: truthy_type || factory.instance_type(arg_type.name),
falsy_type: falsy_type || UNTYPED
)
truthy_result = Result.new(type: TRUE, env: truthy_env, unreachable: false)
truthy_result.unreachable! unless truthy_type
falsy_result = Result.new(type: FALSE, env: falsy_env, unreachable: false)
falsy_result.unreachable! unless falsy_type
[truthy_result, falsy_result]
end
end
when AST::Types::Logic::ArgIsReceiver
if receiver && (arg = arguments[0])
receiver_type = factory.deep_expand_alias(typing.type_of(node: receiver))
arg_type = typing.type_of(node: arg)
if receiver_type.is_a?(AST::Types::Name::Singleton)
truthy_type, falsy_type = type_case_select(arg_type, receiver_type.name)
truthy_env, falsy_env = refine_node_type(
env: env,
node: arg,
truthy_type: truthy_type || factory.instance_type(receiver_type.name),
falsy_type: falsy_type || UNTYPED
)
truthy_result = Result.new(type: TRUE, env: truthy_env, unreachable: false)
truthy_result.unreachable! unless truthy_type
falsy_result = Result.new(type: FALSE, env: falsy_env, unreachable: false)
falsy_result.unreachable! unless falsy_type
[truthy_result, falsy_result]
end
end
when AST::Types::Logic::ArgEqualsReceiver
if receiver && (arg = arguments[0])
arg_type = factory.expand_alias(typing.type_of(node: arg))
if (truthy_types, falsy_types = literal_var_type_case_select(receiver, arg_type))
truthy_env, falsy_env = refine_node_type(
env: env,
node: arg,
truthy_type: truthy_types.empty? ? BOT : AST::Types::Union.build(types: truthy_types),
falsy_type: falsy_types.empty? ? BOT : AST::Types::Union.build(types: falsy_types)
)
truthy_result = Result.new(type: TRUE, env: truthy_env, unreachable: false)
truthy_result.unreachable! if truthy_types.empty?
falsy_result = Result.new(type: FALSE, env: falsy_env, unreachable: false)
falsy_result.unreachable! if falsy_types.empty?
[truthy_result, falsy_result]
end
end
when AST::Types::Logic::ArgIsAncestor
if receiver && (arg = arguments[0])
receiver_type = typing.type_of(node: receiver)
arg_type = factory.deep_expand_alias(typing.type_of(node: arg))
if arg_type.is_a?(AST::Types::Name::Singleton)
truthy_type = arg_type
falsy_type = receiver_type
truthy_env, falsy_env = refine_node_type(
env: env,
node: receiver,
truthy_type: truthy_type,
falsy_type: falsy_type
)
truthy_result = Result.new(type: TRUE, env: truthy_env, unreachable: false)
truthy_result.unreachable! unless truthy_type
falsy_result = Result.new(type: FALSE, env: falsy_env, unreachable: false)
falsy_result.unreachable! unless falsy_type
[truthy_result, falsy_result]
end
end
when AST::Types::Logic::Not
if receiver
truthy_result, falsy_result = evaluate_node(env: env, node: receiver)
[
falsy_result.update_type { TRUE },
truthy_result.update_type { FALSE }
]
end
end
end
def evaluate_union_method_call(node:, type:, env:, receiver:, receiver_type:)
call_type = typing.call_of(node: node) rescue nil
return unless call_type.is_a?(Steep::TypeInference::MethodCall::Typed)
truthy_types = [] #: Array[AST::Types::t]
falsy_types = [] #: Array[AST::Types::t]
receiver_type.types.each do |type|
if shape = subtyping.builder.shape(type, config)
method = shape.methods[call_type.method_name] or raise
method_type = method.method_types.find do |method_type|
call_type.method_decls.any? {|decl| factory.method_type(decl.method_type) == method_type }
end
if method_type
return_type = method_type.type.return_type
truthy, falsy = factory.partition_union(return_type)
truthy_types << type if truthy
falsy_types << type if falsy
next
end
end
truthy_types << type
falsy_types << type
end
truthy_type = truthy_types.empty? ? BOT : AST::Types::Union.build(types: truthy_types)
falsy_type = falsy_types.empty? ? BOT : AST::Types::Union.build(types: falsy_types)
truthy_env, falsy_env = refine_node_type(
env: env,
node: receiver,
truthy_type: truthy_type,
falsy_type: falsy_type
)
return [
Result.new(type: type, env: truthy_env, unreachable: truthy_type.nil?),
Result.new(type: type, env: falsy_env, unreachable: falsy_type.nil?)
]
end
def decompose_value(node)
case node.type
when :lvar
[node, Set[node.children[0]]]
when :masgn
_, rhs = node.children
decompose_value(rhs)
when :lvasgn
var, rhs = node.children
val, vars = decompose_value(rhs)
[val, vars + [var]]
when :begin
decompose_value(node.children.last)
when :and
left, right = node.children
_, left_vars = decompose_value(left)
val, right_vars = decompose_value(right)
[val, left_vars + right_vars]
else
[node, Set[]]
end
end
def literal_var_type_case_select(value_node, arg_type)
case arg_type
when AST::Types::Union
# @type var truthy_types: Array[AST::Types::t]
truthy_types = []
# @type var falsy_types: Array[AST::Types::t]
falsy_types = []
arg_type.types.each do |type|
if (ts, fs = literal_var_type_case_select(value_node, type))
truthy_types.push(*ts)
falsy_types.push(*fs)
else
return
end
end
[truthy_types, falsy_types]
when AST::Types::Boolean
[[arg_type], [arg_type]]
when AST::Types::Top, AST::Types::Any
[[arg_type], [arg_type]]
else
types = [arg_type]
case value_node.type
when :nil
types.partition do |type|
type.is_a?(AST::Types::Nil) || AST::Builtin::NilClass.instance_type?(type)
end
when :true
types.partition do |type|
AST::Builtin::TrueClass.instance_type?(type) ||
(type.is_a?(AST::Types::Literal) && type.value == true)
end
when :false
types.partition do |type|
AST::Builtin::FalseClass.instance_type?(type) ||
(type.is_a?(AST::Types::Literal) && type.value == false)
end
when :int, :str, :sym
# @type var pairs: [Array[AST::Types::t], Array[AST::Types::t]]
pairs = [[], []]
types.each_with_object(pairs) do |type, pair|
true_types, false_types = pair
case
when type.is_a?(AST::Types::Literal)
if type.value == value_node.children[0]
true_types << type
else
false_types << type
end
else
true_types << AST::Types::Literal.new(value: value_node.children[0])
false_types << type
end
end
end
end
end
def type_case_select(type, klass)
truth_types, false_types = type_case_select0(type, klass)
[
truth_types.empty? ? nil : AST::Types::Union.build(types: truth_types),
false_types.empty? ? nil : AST::Types::Union.build(types: false_types)
]
end
def type_case_select0(type, klass)
instance_type = factory.instance_type(klass)
case type
when AST::Types::Union
truthy_types = [] # :Array[AST::Types::t]
falsy_types = [] #: Array[AST::Types::t]
type.types.each do |ty|
truths, falses = type_case_select0(ty, klass)
if truths.empty?
falsy_types.push(ty)
else
truthy_types.push(*truths)
falsy_types.push(*falses)
end
end
[truthy_types, falsy_types]
when AST::Types::Name::Alias
ty = factory.expand_alias(type)
if ty == type
[[type], [type]]
else
type_case_select0(ty, klass)
end
when AST::Types::Any, AST::Types::Top, AST::Types::Var
[
[instance_type],
[type]
]
when AST::Types::Name::Interface
[
[instance_type],
[type]
]
else
# There are four possible relations between `type` and `instance_type`
#
# ```ruby
# case object # object: T
# when K # K: singleton(K)
# when ...
# end
# ````
#
# 1. T <: K && K <: T (T == K, T = Integer, K = Numeric)
# 2. T <: K (example: T = Integer, K = Numeric)
# 3. K <: T (example: T = Numeric, K = Integer)
# 4. none of the above (example: T = String, K = Integer)
if subtyping?(sub_type: type, super_type: instance_type)
# 1 or 2. Satisfies the condition, no narrowing because `type` is already more specific than/equals to `instance_type`
[
[type],
[]
]
else
if subtyping?(sub_type: instance_type, super_type: type)
# 3. Satisfied the condition, narrows to `instance_type`, but cannot remove it from *falsy* list
[
[instance_type],
[type]
]
else
# 4
[
[],
[type]
]
end
end
end
end
def no_subtyping?(sub_type:, super_type:)
relation = Subtyping::Relation.new(sub_type: sub_type, super_type: super_type)
result = subtyping.check(relation, constraints: Subtyping::Constraints.empty, self_type: AST::Types::Self.instance, instance_type: AST::Types::Instance.instance, class_type: AST::Types::Class.instance)
if result.failure?
result
end
end
def subtyping?(sub_type:, super_type:)
!no_subtyping?(sub_type: sub_type, super_type: super_type)
end
def try_convert(type, method)
if shape = subtyping.builder.shape(type, config)
if entry = shape.methods[method]
method_type = entry.method_types.find do |method_type|
method_type.type.params.nil? ||
method_type.type.params.optional?
end
method_type.type.return_type if method_type
end
end
end
end
end
end