module Steep
module Interface
class Builder
class Config
attr_reader :self_type, :class_type, :instance_type, :variable_bounds
def initialize(self_type:, class_type: nil, instance_type: nil, variable_bounds:)
@self_type = self_type
@class_type = class_type
@instance_type = instance_type
@variable_bounds = variable_bounds
validate
end
def self.empty
new(self_type: nil, variable_bounds: {})
end
def subst
if self_type || class_type || instance_type
Substitution.build([], [], self_type: self_type, module_type: class_type, instance_type: instance_type)
end
end
def validate
validate_fvs(:self_type, self_type)
validate_fvs(:instance_type, instance_type)
validate_fvs(:class_type, class_type)
self
end
def validate_fvs(name, type)
if type
fvs = type.free_variables
if fvs.include?(AST::Types::Self.instance)
raise "#{name} cannot include 'self' type: #{type}"
end
if fvs.include?(AST::Types::Instance.instance)
raise "#{name} cannot include 'instance' type: #{type}"
end
if fvs.include?(AST::Types::Class.instance)
raise "#{name} cannot include 'class' type: #{type}"
end
end
end
def upper_bound(a)
variable_bounds.fetch(a, nil)
end
end
attr_reader :factory, :object_shape_cache, :union_shape_cache, :singleton_shape_cache
def initialize(factory)
@factory = factory
@object_shape_cache = {}
@union_shape_cache = {}
@singleton_shape_cache = {}
end
def shape(type, config)
Steep.logger.tagged "shape(#{type})" do
if shape = raw_shape(type, config)
# Optimization that skips unnecesary substittuion
if type.free_variables.include?(AST::Types::Self.instance)
shape
else
if s = config.subst
shape.subst(s)
else
shape
end
end
end
end
end
def fetch_cache(cache, key)
if cache.key?(key)
return cache.fetch(key)
end
cache[key] = yield
end
def raw_shape(type, config)
case type
when AST::Types::Self
self_type = config.self_type or raise
self_shape(self_type, config)
when AST::Types::Instance
instance_type = config.instance_type or raise
raw_shape(instance_type, config)
when AST::Types::Class
klass_type = config.class_type or raise
raw_shape(klass_type, config)
when AST::Types::Name::Singleton
singleton_shape(type.name).subst(class_subst(type))
when AST::Types::Name::Instance
object_shape(type.name).subst(class_subst(type).merge(app_subst(type)), type: type)
when AST::Types::Name::Interface
object_shape(type.name).subst(interface_subst(type).merge(app_subst(type)), type: type)
when AST::Types::Union
groups = type.types.group_by do |type|
if type.is_a?(AST::Types::Literal)
type.back_type
else
nil
end
end
shapes = [] #: Array[Shape]
groups.each do |name, types|
if name
union = AST::Types::Union.build(types: types)
subst = class_subst(name).update(self_type: union)
shapes << object_shape(name.name).subst(subst, type: union)
else
shapes.concat(types.map {|ty| raw_shape(ty, config) or return })
end
end
fetch_cache(union_shape_cache, type) do
union_shape(type, shapes)
end
when AST::Types::Intersection
shapes = type.types.map do |type|
raw_shape(type, config) or return
end
intersection_shape(type, shapes)
when AST::Types::Name::Alias
expanded = factory.expand_alias(type)
if shape = raw_shape(expanded, config)
shape.update(type: type)
end
when AST::Types::Literal
instance_type = type.back_type
subst = class_subst(instance_type).update(self_type: type)
object_shape(instance_type.name).subst(subst, type: type)
when AST::Types::Boolean
true_shape =
(object_shape(RBS::BuiltinNames::TrueClass.name)).
subst(class_subst(AST::Builtin::TrueClass.instance_type).update(self_type: type))
false_shape =
(object_shape(RBS::BuiltinNames::FalseClass.name)).
subst(class_subst(AST::Builtin::FalseClass.instance_type).update(self_type: type))
union_shape(type, [true_shape, false_shape])
when AST::Types::Proc
shape = object_shape(AST::Builtin::Proc.module_name).subst(class_subst(AST::Builtin::Proc.instance_type).update(self_type: type))
proc_shape(type, shape)
when AST::Types::Tuple
tuple_shape(type) do |array|
object_shape(array.name).subst(
class_subst(array).update(self_type: type).merge(app_subst(array))
)
end
when AST::Types::Record
record_shape(type) do |hash|
object_shape(hash.name).subst(
class_subst(hash).update(self_type: type).merge(app_subst(hash))
)
end
when AST::Types::Var
if bound = config.upper_bound(type.name)
new_config = Config.new(self_type: bound, variable_bounds: config.variable_bounds)
sub = Substitution.build([], self_type: type)
# We have to use `self_shape` insead of `raw_shape` here.
# Keep the `self` types included in the `bound`'s shape, and replace it to the type variable.
self_shape(bound, new_config)&.subst(sub, type: type)
end
when AST::Types::Nil
subst = class_subst(AST::Builtin::NilClass.instance_type).update(self_type: type)
object_shape(AST::Builtin::NilClass.module_name).subst(subst, type: type)
when AST::Types::Logic::Base
true_shape =
(object_shape(RBS::BuiltinNames::TrueClass.name)).
subst(class_subst(AST::Builtin::TrueClass.instance_type).update(self_type: type))
false_shape =
(object_shape(RBS::BuiltinNames::FalseClass.name)).
subst(class_subst(AST::Builtin::FalseClass.instance_type).update(self_type: type))
union_shape(type, [true_shape, false_shape])
else
nil
end
end
def self_shape(type, config)
case type
when AST::Types::Self, AST::Types::Instance, AST::Types::Class
raise
when AST::Types::Name::Singleton
singleton_shape(type.name).subst(class_subst(type).update(self_type: nil))
when AST::Types::Name::Instance
object_shape(type.name)
.subst(
class_subst(type).update(self_type: nil).merge(app_subst(type)),
type: type
)
when AST::Types::Name::Interface
object_shape(type.name).subst(app_subst(type), type: type)
when AST::Types::Literal
instance_type = type.back_type
subst = class_subst(instance_type).update(self_type: nil)
object_shape(instance_type.name).subst(subst, type: type)
when AST::Types::Boolean
true_shape =
(object_shape(RBS::BuiltinNames::TrueClass.name)).
subst(class_subst(AST::Builtin::TrueClass.instance_type).update(self_type: nil))
false_shape =
(object_shape(RBS::BuiltinNames::FalseClass.name)).
subst(class_subst(AST::Builtin::FalseClass.instance_type).update(self_type: nil))
union_shape(type, [true_shape, false_shape])
when AST::Types::Proc
shape = object_shape(AST::Builtin::Proc.module_name).subst(class_subst(AST::Builtin::Proc.instance_type).update(self_type: nil))
proc_shape(type, shape)
when AST::Types::Var
if bound = config.upper_bound(type.name)
self_shape(bound, config)&.update(type: type)
end
else
raw_shape(type, config)
end
end
def app_subst(type)
if type.args.empty?
return Substitution.empty
end
vars =
case type
when AST::Types::Name::Instance
entry = factory.env.normalized_module_class_entry(type.name) or raise
entry.primary.decl.type_params.map { _1.name }
when AST::Types::Name::Interface
entry = factory.env.interface_decls.fetch(type.name)
entry.decl.type_params.map { _1.name }
when AST::Types::Name::Alias
entry = factory.env.type_alias_decls.fetch(type.name)
entry.decl.type_params.map { _1.name }
end
Substitution.build(vars, type.args)
end
def class_subst(type)
case type
when AST::Types::Name::Singleton
self_type = type
singleton_type = type
instance_type = factory.instance_type(type.name)
when AST::Types::Name::Instance
self_type = type
singleton_type = type.to_module
instance_type = factory.instance_type(type.name)
end
Substitution.build([], self_type: self_type, module_type: singleton_type, instance_type: instance_type)
end
def interface_subst(type)
Substitution.build([], self_type: type)
end
def singleton_shape(type_name)
singleton_shape_cache[type_name] ||= begin
shape = Interface::Shape.new(type: AST::Types::Name::Singleton.new(name: type_name), private: true)
definition = factory.definition_builder.build_singleton(type_name)
definition.methods.each do |name, method|
Steep.logger.tagged "method = #{type_name}.#{name}" do
shape.methods[name] = Interface::Shape::Entry.new(
private_method: method.private?,
method_types: method.defs.map do |type_def|
method_name = method_name_for(type_def, name)
decl = TypeInference::MethodCall::MethodDecl.new(method_name: method_name, method_def: type_def)
method_type = factory.method_type(type_def.type, method_decls: Set[decl])
replace_primitive_method(method_name, type_def, method_type)
end
)
end
end
shape
end
end
def object_shape(type_name)
object_shape_cache[type_name] ||= begin
shape = Interface::Shape.new(type: AST::Builtin.bottom_type, private: true)
case
when type_name.class?
definition = factory.definition_builder.build_instance(type_name)
when type_name.interface?
definition = factory.definition_builder.build_interface(type_name)
end
definition or raise
definition.methods.each do |name, method|
Steep.logger.tagged "method = #{type_name}##{name}" do
shape.methods[name] = Interface::Shape::Entry.new(
private_method: method.private?,
method_types: method.defs.map do |type_def|
method_name = method_name_for(type_def, name)
decl = TypeInference::MethodCall::MethodDecl.new(method_name: method_name, method_def: type_def)
method_type = factory.method_type(type_def.type, method_decls: Set[decl])
replace_primitive_method(method_name, type_def, method_type)
end
)
end
end
shape
end
end
def union_shape(shape_type, shapes)
s0, *sx = shapes
s0 or raise
all_common_methods = Set.new(s0.methods.each_name)
sx.each do |shape|
all_common_methods &= shape.methods.each_name
end
shape = Interface::Shape.new(type: shape_type, private: true)
all_common_methods.each do |method_name|
method_typess = [] #: Array[Array[MethodType]]
private_method = false
shapes.each do |shape|
entry = shape.methods[method_name] || raise
method_typess << entry.method_types
private_method ||= entry.private_method?
end
shape.methods[method_name] = Interface::Shape::Entry.new(private_method: private_method) do
method_typess.inject do |types1, types2|
# @type break: nil
if types1 == types2
decl_array1 = types1.map(&:method_decls)
decl_array2 = types2.map(&:method_decls)
if decl_array1 == decl_array2
next types1
end
decls1 = decl_array1.each.with_object(Set[]) {|array, decls| decls.merge(array) } #$ Set[TypeInference::MethodCall::MethodDecl]
decls2 = decl_array2.each.with_object(Set[]) {|array, decls| decls.merge(array) } #$ Set[TypeInference::MethodCall::MethodDecl]
if decls1 == decls2
next types1
end
end
method_types = {} #: Hash[MethodType, bool]
types1.each do |type1|
types2.each do |type2|
if type1 == type2
method_types[type1.with(method_decls: type1.method_decls + type2.method_decls)] = true
else
if type = MethodType.union(type1, type2, subtyping)
method_types[type] = true
end
end
end
end
break nil if method_types.empty?
method_types.keys
end
end
end
shape
end
def intersection_shape(type, shapes)
shape = Interface::Shape.new(type: type, private: true)
shapes.each do |s|
shape.methods.merge!(s.methods) do |name, old_entry, new_entry|
if old_entry.public_method? && new_entry.private_method?
old_entry
else
new_entry
end
end
end
shape
end
def method_name_for(type_def, name)
type_name = type_def.implemented_in || type_def.defined_in
if name == :new && type_def.member.is_a?(RBS::AST::Members::MethodDefinition) && type_def.member.name == :initialize
return SingletonMethodName.new(type_name: type_name, method_name: name)
end
case type_def.member.kind
when :instance
InstanceMethodName.new(type_name: type_name, method_name: name)
when :singleton
SingletonMethodName.new(type_name: type_name, method_name: name)
when :singleton_instance
# Assume it a instance method, because `module_function` methods are typically defined with `def`
InstanceMethodName.new(type_name: type_name, method_name: name)
else
raise
end
end
def subtyping
@subtyping ||= Subtyping::Check.new(builder: self)
end
def tuple_shape(tuple)
element_type = AST::Types::Union.build(types: tuple.types, location: nil)
array_type = AST::Builtin::Array.instance_type(element_type)
array_shape = yield(array_type) or raise
shape = Shape.new(type: tuple, private: true)
shape.methods.merge!(array_shape.methods)
aref_entry = array_shape.methods[:[]].yield_self do |aref|
raise unless aref
Shape::Entry.new(
private_method: false,
method_types: tuple.types.map.with_index {|elem_type, index|
MethodType.new(
type_params: [],
type: Function.new(
params: Function::Params.build(required: [AST::Types::Literal.new(value: index)]),
return_type: elem_type,
location: nil
),
block: nil,
method_decls: Set[]
)
} + aref.method_types
)
end
aref_update_entry = array_shape.methods[:[]=].yield_self do |update|
raise unless update
Shape::Entry.new(
private_method: false,
method_types: tuple.types.map.with_index {|elem_type, index|
MethodType.new(
type_params: [],
type: Function.new(
params: Function::Params.build(required: [AST::Types::Literal.new(value: index), elem_type]),
return_type: elem_type,
location: nil
),
block: nil,
method_decls: Set[]
)
} + update.method_types
)
end
fetch_entry = array_shape.methods[:fetch].yield_self do |fetch|
raise unless fetch
Shape::Entry.new(
private_method: false,
method_types: tuple.types.flat_map.with_index {|elem_type, index|
[
MethodType.new(
type_params: [],
type: Function.new(
params: Function::Params.build(required: [AST::Types::Literal.new(value: index)]),
return_type: elem_type,
location: nil
),
block: nil,
method_decls: Set[]
),
MethodType.new(
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)],
type: Function.new(
params: Function::Params.build(
required: [
AST::Types::Literal.new(value: index),
AST::Types::Var.new(name: :T)
]
),
return_type: AST::Types::Union.build(types: [elem_type, AST::Types::Var.new(name: :T)]),
location: nil
),
block: nil,
method_decls: Set[]
),
MethodType.new(
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)],
type: Function.new(
params: Function::Params.build(required: [AST::Types::Literal.new(value: index)]),
return_type: AST::Types::Union.build(types: [elem_type, AST::Types::Var.new(name: :T)]),
location: nil
),
block: Block.new(
type: Function.new(
params: Function::Params.build(required: [AST::Builtin::Integer.instance_type]),
return_type: AST::Types::Var.new(name: :T),
location: nil
),
optional: false,
self_type: nil
),
method_decls: Set[]
)
]
} + fetch.method_types
)
end
first_entry = array_shape.methods[:first].yield_self do |first|
Shape::Entry.new(
private_method: false,
method_types: [
MethodType.new(
type_params: [],
type: Function.new(
params: Function::Params.empty,
return_type: tuple.types[0] || AST::Builtin.nil_type,
location: nil
),
block: nil,
method_decls: Set[]
)
]
)
end
last_entry = array_shape.methods[:last].yield_self do |last|
Shape::Entry.new(
private_method: false,
method_types: [
MethodType.new(
type_params: [],
type: Function.new(
params: Function::Params.empty,
return_type: tuple.types.last || AST::Builtin.nil_type,
location: nil
),
block: nil,
method_decls: Set[]
)
]
)
end
shape.methods[:[]] = aref_entry
shape.methods[:[]=] = aref_update_entry
shape.methods[:fetch] = fetch_entry
shape.methods[:first] = first_entry
shape.methods[:last] = last_entry
shape
end
def record_shape(record)
all_key_type = AST::Types::Union.build(
types: record.elements.each_key.map {|value| AST::Types::Literal.new(value: value, location: nil) },
location: nil
)
all_value_type = AST::Types::Union.build(types: record.elements.values, location: nil)
hash_type = AST::Builtin::Hash.instance_type(all_key_type, all_value_type)
hash_shape = yield(hash_type) or raise
shape = Shape.new(type: record, private: true)
shape.methods.merge!(hash_shape.methods)
shape.methods[:[]] = hash_shape.methods[:[]].yield_self do |aref|
aref or raise
Shape::Entry.new(
private_method: false,
method_types: record.elements.map do |key_value, value_type|
key_type = AST::Types::Literal.new(value: key_value, location: nil)
MethodType.new(
type_params: [],
type: Function.new(
params: Function::Params.build(required: [key_type]),
return_type: value_type,
location: nil
),
block: nil,
method_decls: Set[]
)
end + aref.method_types
)
end
shape.methods[:[]=] = hash_shape.methods[:[]=].yield_self do |update|
update or raise
Shape::Entry.new(
private_method: false,
method_types: record.elements.map do |key_value, value_type|
key_type = AST::Types::Literal.new(value: key_value, location: nil)
MethodType.new(
type_params: [],
type: Function.new(
params: Function::Params.build(required: [key_type, value_type]),
return_type: value_type,
location: nil),
block: nil,
method_decls: Set[]
)
end + update.method_types
)
end
shape.methods[:fetch] = hash_shape.methods[:fetch].yield_self do |update|
update or raise
Shape::Entry.new(
private_method: false,
method_types: record.elements.flat_map {|key_value, value_type|
key_type = AST::Types::Literal.new(value: key_value, location: nil)
[
MethodType.new(
type_params: [],
type: Function.new(
params: Function::Params.build(required: [key_type]),
return_type: value_type,
location: nil
),
block: nil,
method_decls: Set[]
),
MethodType.new(
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)],
type: Function.new(
params: Function::Params.build(required: [key_type, AST::Types::Var.new(name: :T)]),
return_type: AST::Types::Union.build(types: [value_type, AST::Types::Var.new(name: :T)]),
location: nil
),
block: nil,
method_decls: Set[]
),
MethodType.new(
type_params: [TypeParam.new(name: :T, upper_bound: nil, variance: :invariant, unchecked: false)],
type: Function.new(
params: Function::Params.build(required: [key_type]),
return_type: AST::Types::Union.build(types: [value_type, AST::Types::Var.new(name: :T)]),
location: nil
),
block: Block.new(
type: Function.new(
params: Function::Params.build(required: [all_key_type]),
return_type: AST::Types::Var.new(name: :T),
location: nil
),
optional: false,
self_type: nil
),
method_decls: Set[]
)
]
} + update.method_types
)
end
shape
end
def proc_shape(proc, proc_shape)
shape = Shape.new(type: proc, private: true)
shape.methods.merge!(proc_shape.methods)
shape.methods[:[]] = shape.methods[:call] = Shape::Entry.new(
private_method: false,
method_types: [MethodType.new(type_params: [], type: proc.type, block: proc.block, method_decls: Set[])]
)
shape
end
def replace_primitive_method(method_name, method_def, method_type)
defined_in = method_def.defined_in
member = method_def.member
if member.is_a?(RBS::AST::Members::MethodDefinition)
case method_name.method_name
when :is_a?, :kind_of?, :instance_of?
case
when RBS::BuiltinNames::Object.name,
RBS::BuiltinNames::Kernel.name
if member.instance?
return method_type.with(
type: method_type.type.with(
return_type: AST::Types::Logic::ReceiverIsArg.new(location: method_type.type.return_type.location)
)
)
end
end
when :nil?
case defined_in
when RBS::BuiltinNames::Object.name,
AST::Builtin::NilClass.module_name,
RBS::BuiltinNames::Kernel.name
if member.instance?
return method_type.with(
type: method_type.type.with(
return_type: AST::Types::Logic::ReceiverIsNil.new(location: method_type.type.return_type.location)
)
)
end
end
when :!
case defined_in
when RBS::BuiltinNames::BasicObject.name,
RBS::BuiltinNames::TrueClass.name,
RBS::BuiltinNames::FalseClass.name,
AST::Builtin::NilClass.module_name
return method_type.with(
type: method_type.type.with(
return_type: AST::Types::Logic::Not.new(location: method_type.type.return_type.location)
)
)
end
when :===
case defined_in
when RBS::BuiltinNames::Module.name
return method_type.with(
type: method_type.type.with(
return_type: AST::Types::Logic::ArgIsReceiver.new(location: method_type.type.return_type.location)
)
)
when RBS::BuiltinNames::Object.name,
RBS::BuiltinNames::Kernel.name,
RBS::BuiltinNames::String.name,
RBS::BuiltinNames::Integer.name,
RBS::BuiltinNames::Symbol.name,
RBS::BuiltinNames::TrueClass.name,
RBS::BuiltinNames::FalseClass.name,
TypeName("::NilClass")
# Value based type-case works on literal types which is available for String, Integer, Symbol, TrueClass, FalseClass, and NilClass
return method_type.with(
type: method_type.type.with(
return_type: AST::Types::Logic::ArgEqualsReceiver.new(location: method_type.type.return_type.location)
)
)
end
when :<, :<=
case defined_in
when RBS::BuiltinNames::Module.name
return method_type.with(
type: method_type.type.with(
return_type: AST::Types::Logic::ArgIsAncestor.new(location: method_type.type.return_type.location)
)
)
end
end
end
method_type
end
end
end
end