module Steep
module AST
module Types
class Factory
attr_reader :definition_builder
attr_reader :type_cache
def inspect
s = "#<%s:%#018x " % [self.class, object_id]
s << "@definition_builder=#<%s:%#018x>" % [definition_builder.class, definition_builder.object_id]
s + ">"
end
def initialize(builder:)
@definition_builder = builder
@type_cache = {}
@method_type_cache = {}
@method_type_cache.compare_by_identity
end
def type_name_resolver
@type_name_resolver ||= RBS::Resolver::TypeNameResolver.new(definition_builder.env)
end
def type_opt(type)
if type
type(type)
end
end
def type_1_opt(type)
if type
type_1(type)
end
end
def type(type)
unless type.location
if ty = type_cache[type]
return ty
end
end
type_cache[type] =
case type
when RBS::Types::Bases::Any
Any.new(location: type.location)
when RBS::Types::Bases::Class
Class.new(location: type.location)
when RBS::Types::Bases::Instance
Instance.new(location: type.location)
when RBS::Types::Bases::Self
Self.new(location: type.location)
when RBS::Types::Bases::Top
Top.new(location: type.location)
when RBS::Types::Bases::Bottom
Bot.new(location: type.location)
when RBS::Types::Bases::Bool
Boolean.new(location: type.location)
when RBS::Types::Bases::Void
Void.new(location: type.location)
when RBS::Types::Bases::Nil
Nil.new(location: type.location)
when RBS::Types::Variable
Var.new(name: type.name, location: type.location)
when RBS::Types::ClassSingleton
type_name = type.name
Name::Singleton.new(name: type_name, location: type.location)
when RBS::Types::ClassInstance
type_name = type.name
args = type.args.map {|arg| type(arg) }
Name::Instance.new(name: type_name, args: args, location: type.location)
when RBS::Types::Interface
type_name = type.name
args = type.args.map {|arg| type(arg) }
Name::Interface.new(name: type_name, args: args, location: type.location)
when RBS::Types::Alias
type_name = type.name
args = type.args.map {|arg| type(arg) }
Name::Alias.new(name: type_name, args: args, location: type.location)
when RBS::Types::Union
Union.build(types: type.types.map {|ty| type(ty) }, location: type.location)
when RBS::Types::Intersection
Intersection.build(types: type.types.map {|ty| type(ty) }, location: type.location)
when RBS::Types::Optional
Union.build(types: [type(type.type), Nil.new(location: nil)], location: type.location)
when RBS::Types::Literal
Literal.new(value: type.literal, location: type.location)
when RBS::Types::Tuple
Tuple.new(types: type.types.map {|ty| type(ty) }, location: type.location)
when RBS::Types::Record
elements = type.fields.each.with_object({}) do |(key, value), hash|
hash[key] = type(value)
end
Record.new(elements: elements, location: type.location)
when RBS::Types::Proc
func = Interface::Function.new(
params: params(type.type),
return_type: type(type.type.return_type),
location: type.location
)
block = if type.block
Interface::Block.new(
type: Interface::Function.new(
params: params(type.block.type),
return_type: type(type.block.type.return_type),
location: type.location
),
optional: !type.block.required,
self_type: type_opt(type.block.self_type)
)
end
Proc.new(
type: func,
block: block,
self_type: type_opt(type.self_type)
)
else
raise "Unexpected type given: #{type}"
end
end
def type_1(type)
case type
when Any
RBS::Types::Bases::Any.new(location: type.location)
when Class
RBS::Types::Bases::Class.new(location: type.location)
when Instance
RBS::Types::Bases::Instance.new(location: type.location)
when Self
RBS::Types::Bases::Self.new(location: type.location)
when Top
RBS::Types::Bases::Top.new(location: type.location)
when Bot
RBS::Types::Bases::Bottom.new(location: type.location)
when Boolean
RBS::Types::Bases::Bool.new(location: type.location)
when Void
RBS::Types::Bases::Void.new(location: type.location)
when Nil
RBS::Types::Bases::Nil.new(location: type.location)
when Var
RBS::Types::Variable.new(name: type.name, location: type.location)
when Name::Singleton
RBS::Types::ClassSingleton.new(name: type.name, location: type.location)
when Name::Instance
RBS::Types::ClassInstance.new(
name: type.name,
args: type.args.map {|arg| type_1(arg) },
location: type.location
)
when Name::Interface
RBS::Types::Interface.new(
name: type.name,
args: type.args.map {|arg| type_1(arg) },
location: type.location
)
when Name::Alias
RBS::Types::Alias.new(
name: type.name,
args: type.args.map {|arg| type_1(arg) },
location: type.location
)
when Union
RBS::Types::Union.new(
types: type.types.map {|ty| type_1(ty) },
location: type.location
)
when Intersection
RBS::Types::Intersection.new(
types: type.types.map {|ty| type_1(ty) },
location: type.location
)
when Literal
RBS::Types::Literal.new(literal: type.value, location: type.location)
when Tuple
RBS::Types::Tuple.new(
types: type.types.map {|ty| type_1(ty) },
location: type.location
)
when Record
fields = type.elements.each.with_object({}) do |(key, value), hash|
hash[key] = type_1(value)
end
RBS::Types::Record.new(fields: fields, location: type.location)
when Proc
block = if type.block
RBS::Types::Block.new(
type: function_1(type.block.type),
required: !type.block.optional?,
self_type: type_1_opt(type.block.self_type)
)
end
RBS::Types::Proc.new(
type: function_1(type.type),
self_type: type_1_opt(type.self_type),
block: block,
location: type.location
)
when Logic::Base
RBS::Types::Bases::Bool.new(location: type.location)
else
raise "Unexpected type given: #{type} (#{type.class})"
end
end
def function_1(func)
params = func.params
return_type = func.return_type
RBS::Types::Function.new(
required_positionals: params.required.map {|type| RBS::Types::Function::Param.new(name: nil, type: type_1(type)) },
optional_positionals: params.optional.map {|type| RBS::Types::Function::Param.new(name: nil, type: type_1(type)) },
rest_positionals: params.rest&.yield_self {|type| RBS::Types::Function::Param.new(name: nil, type: type_1(type)) },
trailing_positionals: [],
required_keywords: params.required_keywords.transform_values {|type| RBS::Types::Function::Param.new(name: nil, type: type_1(type)) },
optional_keywords: params.optional_keywords.transform_values {|type| RBS::Types::Function::Param.new(name: nil, type: type_1(type)) },
rest_keywords: params.rest_keywords&.yield_self {|type| RBS::Types::Function::Param.new(name: nil, type: type_1(type)) },
return_type: type_1(return_type)
)
end
def params(type)
Interface::Function::Params.build(
required: type.required_positionals.map {|param| type(param.type) },
optional: type.optional_positionals.map {|param| type(param.type) },
rest: type.rest_positionals&.yield_self {|param| type(param.type) },
required_keywords: type.required_keywords.transform_values {|param| type(param.type) },
optional_keywords: type.optional_keywords.transform_values {|param| type(param.type) },
rest_keywords: type.rest_keywords&.yield_self {|param| type(param.type) }
)
end
def type_param(type_param)
Interface::TypeParam.new(
name: type_param.name,
upper_bound: type_opt(type_param.upper_bound),
variance: type_param.variance,
unchecked: type_param.unchecked?
)
end
def type_param_1(type_param)
RBS::AST::TypeParam.new(
name: type_param.name,
variance: type_param.variance,
upper_bound: type_param.upper_bound&.yield_self {|u|
case u_ = type_1(u)
when RBS::Types::ClassInstance, RBS::Types::ClassSingleton, RBS::Types::Interface
u_
else
raise "`#{u_}` cannot be type parameter upper bound"
end
},
location: type_param.location
).unchecked!(type_param.unchecked)
end
def method_type(method_type, method_decls:)
mt = @method_type_cache[method_type] ||=
Interface::MethodType.new(
type_params: method_type.type_params.map {|param| type_param(param) },
type: Interface::Function.new(
params: params(method_type.type),
return_type: type(method_type.type.return_type),
location: method_type.location
),
block: method_type.block&.yield_self do |block|
Interface::Block.new(
optional: !block.required,
type: Interface::Function.new(
params: params(block.type),
return_type: type(block.type.return_type),
location: nil
),
self_type: type_opt(block.self_type)
)
end,
method_decls: Set[]
)
mt.with(method_decls: method_decls)
end
def method_type_1(method_type)
RBS::MethodType.new(
type_params: method_type.type_params.map {|param| type_param_1(param) },
type: function_1(method_type.type),
block: method_type.block&.yield_self do |block|
RBS::Types::Block.new(
type: function_1(block.type),
required: !block.optional,
self_type: type_1_opt(block.self_type)
)
end,
location: nil
)
end
def unfold(type_name, args)
type(
definition_builder.expand_alias2(
type_name,
args.empty? ? [] : args.map {|t| type_1(t) }
)
)
end
def expand_alias(type)
case type
when AST::Types::Name::Alias
unfold(type.name, type.args)
else
type
end
end
def deep_expand_alias(type, recursive: Set.new)
case type
when AST::Types::Name::Alias
unless recursive.member?(type.name)
unfolded = expand_alias(type)
deep_expand_alias(unfolded, recursive: recursive.union([type.name]))
end
when AST::Types::Union
types = type.types.map {|ty| deep_expand_alias(ty, recursive: recursive) or return }
AST::Types::Union.build(types: types, location: type.location)
when AST::Types::Intersection
types = type.types.map {|ty| deep_expand_alias(ty, recursive: recursive) or return }
AST::Types::Intersection.build(types: types, location: type.location)
else
type
end
end
def flatten_union(type, acc = [])
case type
when AST::Types::Union
type.types.each {|ty| flatten_union(ty, acc) }
else
acc << type
end
acc
end
def partition_union(type)
case type
when AST::Types::Name::Alias
unfold = expand_alias(type)
if unfold == type
[type, type]
else
partition_union(unfold)
end
when AST::Types::Union
truthy_types = [] #: Array[AST::Types::t]
falsy_types = [] #: Array[AST::Types::t]
type.types.each do |type|
truthy, falsy = partition_union(type)
truthy_types << truthy if truthy
falsy_types << falsy if falsy
end
[
truthy_types.empty? ? nil : AST::Types::Union.build(types: truthy_types),
falsy_types.empty? ? nil : AST::Types::Union.build(types: falsy_types)
]
when AST::Types::Any, AST::Types::Boolean, AST::Types::Top, AST::Types::Logic::Base
[type, type]
when AST::Types::Nil
[nil, type]
when AST::Types::Literal
if type.value == false
[nil, type]
else
[type, nil]
end
else
[type, nil]
end
end
def unwrap_optional(type)
case type
when AST::Types::Union
unwrap = type.types.filter_map do |type|
unless type.is_a?(AST::Types::Nil)
type
end
end
unless unwrap.empty?
AST::Types::Union.build(types: unwrap)
end
when AST::Types::Nil
nil
when AST::Types::Name::Alias
type_ = expand_alias(type)
if type_ == type
type_
else
unwrap_optional(type_)
end
else
type
end
end
def module_name?(type_name)
env.module_entry(type_name) ? true : false
end
def class_name?(type_name)
env.class_entry(type_name) ? true : false
end
def env
definition_builder.env
end
def absolute_type(type, context:)
absolute_type = type_1(type).map_type_name do |name|
absolute_type_name(name, context: context) || name.absolute!
end
type(absolute_type)
end
def absolute_type_name(type_name, context:)
type_name_resolver.resolve(type_name, context: context)
end
def instance_type(type_name, args: nil, location: nil)
raise unless type_name.class?
definition = definition_builder.build_singleton(type_name)
def_args = definition.type_params.map { Any.new(location: nil) }
if args
raise if def_args.size != args.size
else
args = def_args
end
AST::Types::Name::Instance.new(location: location, name: type_name, args: args)
end
def try_instance_type(type)
case type
when AST::Types::Name::Instance
instance_type(type.name)
when AST::Types::Name::Singleton
instance_type(type.name)
else
nil
end
end
def try_singleton_type(type)
case type
when AST::Types::Name::Instance, AST::Types::Name::Singleton
AST::Types::Name::Singleton.new(name:type.name)
else
nil
end
end
def normalize_type(type)
case type
when AST::Types::Name::Instance
AST::Types::Name::Instance.new(
name: env.normalize_module_name(type.name),
args: type.args.map {|ty| normalize_type(ty) },
location: type.location
)
when AST::Types::Name::Singleton
AST::Types::Name::Singleton.new(
name: env.normalize_module_name(type.name),
location: type.location
)
when AST::Types::Any, AST::Types::Boolean, AST::Types::Bot, AST::Types::Nil,
AST::Types::Top, AST::Types::Void, AST::Types::Literal, AST::Types::Class, AST::Types::Instance,
AST::Types::Self, AST::Types::Var, AST::Types::Logic::Base
type
when AST::Types::Intersection
AST::Types::Intersection.build(
types: type.types.map {|type| normalize_type(type) },
location: type.location
)
when AST::Types::Union
AST::Types::Union.build(
types: type.types.map {|type| normalize_type(type) },
location: type.location
)
when AST::Types::Record
AST::Types::Record.new(
elements: type.elements.transform_values {|type| normalize_type(type) },
location: type.location
)
when AST::Types::Tuple
AST::Types::Tuple.new(
types: type.types.map {|type| normalize_type(type) },
location: type.location
)
when AST::Types::Proc
type.map_type {|type| normalize_type(type) }
when AST::Types::Name::Alias
AST::Types::Name::Alias.new(
name: type.name,
args: type.args.map {|ty| normalize_type(ty) },
location: type.location
)
when AST::Types::Name::Interface
AST::Types::Name::Interface.new(
name: type.name,
args: type.args.map {|ty| normalize_type(ty) },
location: type.location
)
end
end
end
end
end
end