module RBS
module Prototype
class RB
Context = Struct.new(:module_function, :singleton, :namespace, keyword_init: true) do
def self.initial(namespace: Namespace.root)
self.new(module_function: false, singleton: false, namespace: namespace)
end
def method_kind
if singleton
:singleton
elsif module_function
:singleton_instance
else
:instance
end
end
def attribute_kind
if singleton
:singleton
else
:instance
end
end
end
attr_reader :source_decls
attr_reader :toplevel_members
def initialize
@source_decls = []
end
def decls
decls = []
top_decls, top_members = source_decls.partition {|decl| decl.is_a?(AST::Declarations::Base) }
decls.push(*top_decls)
unless top_members.empty?
top = AST::Declarations::Class.new(
name: TypeName.new(name: :Object, namespace: Namespace.empty),
super_class: nil,
members: top_members,
annotations: [],
comment: nil,
location: nil,
type_params: AST::Declarations::ModuleTypeParams.empty
)
decls << top
end
decls
end
def parse(string)
comments = Ripper.lex(string).yield_self do |tokens|
tokens.each.with_object({}) do |token, hash|
if token[1] == :on_comment
line = token[0][0]
body = token[2][2..-1]
body = "\n" if body.empty?
comment = AST::Comment.new(string: body, location: nil)
if (prev_comment = hash[line - 1])
hash[line - 1] = nil
hash[line] = AST::Comment.new(string: prev_comment.string + comment.string,
location: nil)
else
hash[line] = comment
end
end
end
end
process RubyVM::AbstractSyntaxTree.parse(string), decls: source_decls, comments: comments, context: Context.initial
end
def process(node, decls:, comments:, context:)
case node.type
when :CLASS
class_name, super_class, *class_body = node.children
kls = AST::Declarations::Class.new(
name: const_to_name(class_name),
super_class: super_class && AST::Declarations::Class::Super.new(name: const_to_name(super_class), args: [], location: nil),
type_params: AST::Declarations::ModuleTypeParams.empty,
members: [],
annotations: [],
location: nil,
comment: comments[node.first_lineno - 1]
)
decls.push kls
new_ctx = Context.initial(namespace: context.namespace + kls.name.to_namespace)
each_node class_body do |child|
process child, decls: kls.members, comments: comments, context: new_ctx
end
remove_unnecessary_accessibility_methods! kls.members
when :MODULE
module_name, *module_body = node.children
mod = AST::Declarations::Module.new(
name: const_to_name(module_name),
type_params: AST::Declarations::ModuleTypeParams.empty,
self_types: [],
members: [],
annotations: [],
location: nil,
comment: comments[node.first_lineno - 1]
)
decls.push mod
new_ctx = Context.initial(namespace: context.namespace + mod.name.to_namespace)
each_node module_body do |child|
process child, decls: mod.members, comments: comments, context: new_ctx
end
remove_unnecessary_accessibility_methods! mod.members
when :SCLASS
this, body = node.children
if this.type != :SELF
RBS.logger.warn "`class <<` syntax with not-self may be compiled to incorrect code: #{this}"
end
accessibility = current_accessibility(decls)
ctx = Context.initial.tap { |ctx| ctx.singleton = true }
process_children(body, decls: decls, comments: comments, context: ctx)
decls << accessibility
when :DEFN, :DEFS
if node.type == :DEFN
def_name, def_body = node.children
kind = context.method_kind
else
_, def_name, def_body = node.children
kind = :singleton
end
types = [
MethodType.new(
type_params: [],
type: function_type_from_body(def_body),
block: block_from_body(def_body),
location: nil
)
]
member = AST::Members::MethodDefinition.new(
name: def_name,
location: nil,
annotations: [],
types: types,
kind: kind,
comment: comments[node.first_lineno - 1],
overload: false
)
decls.push member unless decls.include?(member)
when :ALIAS
new_name, old_name = node.children.map { |c| literal_to_symbol(c) }
member = AST::Members::Alias.new(
new_name: new_name,
old_name: old_name,
kind: context.singleton ? :singleton : :instance,
annotations: [],
location: nil,
comment: comments[node.first_lineno - 1],
)
decls.push member unless decls.include?(member)
when :FCALL, :VCALL
# Inside method definition cannot reach here.
args = node.children[1]&.children || []
case node.children[0]
when :include
args.each do |arg|
if (name = const_to_name(arg))
decls << AST::Members::Include.new(
name: name,
args: [],
annotations: [],
location: nil,
comment: comments[node.first_lineno - 1]
)
end
end
when :extend
args.each do |arg|
if (name = const_to_name(arg, context: context))
decls << AST::Members::Extend.new(
name: name,
args: [],
annotations: [],
location: nil,
comment: comments[node.first_lineno - 1]
)
end
end
when :attr_reader
args.each do |arg|
if arg && (name = literal_to_symbol(arg))
decls << AST::Members::AttrReader.new(
name: name,
ivar_name: nil,
type: Types::Bases::Any.new(location: nil),
kind: context.attribute_kind,
location: nil,
comment: comments[node.first_lineno - 1],
annotations: []
)
end
end
when :attr_accessor
args.each do |arg|
if arg && (name = literal_to_symbol(arg))
decls << AST::Members::AttrAccessor.new(
name: name,
ivar_name: nil,
type: Types::Bases::Any.new(location: nil),
kind: context.attribute_kind,
location: nil,
comment: comments[node.first_lineno - 1],
annotations: []
)
end
end
when :attr_writer
args.each do |arg|
if arg && (name = literal_to_symbol(arg))
decls << AST::Members::AttrWriter.new(
name: name,
ivar_name: nil,
type: Types::Bases::Any.new(location: nil),
kind: context.attribute_kind,
location: nil,
comment: comments[node.first_lineno - 1],
annotations: []
)
end
end
when :alias_method
if args[0] && args[1] && (new_name = literal_to_symbol(args[0])) && (old_name = literal_to_symbol(args[1]))
decls << AST::Members::Alias.new(
new_name: new_name,
old_name: old_name,
kind: context.singleton ? :singleton : :instance,
annotations: [],
location: nil,
comment: comments[node.first_lineno - 1],
)
end
when :module_function
if args.empty?
context.module_function = true
else
module_func_context = context.dup.tap { |ctx| ctx.module_function = true }
args.each do |arg|
if arg && (name = literal_to_symbol(arg))
if i = find_def_index_by_name(decls, name)
decls[i] = decls[i].update(kind: :singleton_instance)
end
elsif arg
process arg, decls: decls, comments: comments, context: module_func_context
end
end
end
when :public, :private
accessibility = __send__(node.children[0])
if args.empty?
decls << accessibility
else
args.each do |arg|
if arg && (name = literal_to_symbol(arg))
if i = find_def_index_by_name(decls, name)
current = current_accessibility(decls, i)
if current != accessibility
decls.insert(i + 1, current)
decls.insert(i, accessibility)
end
end
end
end
# For `private def foo` syntax
current = current_accessibility(decls)
decls << accessibility
process_children(node, decls: decls, comments: comments, context: context)
decls << current
end
else
process_children(node, decls: decls, comments: comments, context: context)
end
when :ITER
method_name = node.children.first.children.first
case method_name
when :refine
# ignore
else
process_children(node, decls: decls, comments: comments, context: context)
end
when :CDECL
const_name = case
when node.children[0].is_a?(Symbol)
TypeName.new(name: node.children[0], namespace: Namespace.empty)
else
const_to_name(node.children[0])
end
decls << AST::Declarations::Constant.new(
name: const_name,
type: node_type(node.children.last),
location: nil,
comment: comments[node.first_lineno - 1]
)
else
process_children(node, decls: decls, comments: comments, context: context)
end
end
def process_children(node, decls:, comments:, context:)
each_child node do |child|
process child, decls: decls, comments: comments, context: context
end
end
def const_to_name(node, context: nil)
case node&.type
when :CONST
TypeName.new(name: node.children[0], namespace: Namespace.empty)
when :COLON2
if node.children[0]
namespace = const_to_name(node.children[0]).to_namespace
else
namespace = Namespace.empty
end
TypeName.new(name: node.children[1], namespace: namespace)
when :COLON3
TypeName.new(name: node.children[0], namespace: Namespace.root)
when :SELF
context&.then { |c| c.namespace.to_type_name }
end
end
def literal_to_symbol(node)
case node.type
when :LIT
node.children[0] if node.children[0].is_a?(Symbol)
when :STR
node.children[0].to_sym
end
end
def each_node(nodes)
nodes.each do |child|
if child.is_a?(RubyVM::AbstractSyntaxTree::Node)
yield child
end
end
end
def each_child(node, &block)
each_node node.children, &block
end
def function_type_from_body(node)
table_node, args_node, *_ = node.children
pre_num, _pre_init, opt, _first_post, post_num, _post_init, rest, kw, kwrest, _block = args_from_node(args_node)
return_type = function_return_type_from_body(node)
fun = Types::Function.empty(return_type)
table_node.take(pre_num).each do |name|
fun.required_positionals << Types::Function::Param.new(name: name, type: untyped)
end
while opt&.type == :OPT_ARG
lvasgn, opt = opt.children
name = lvasgn.children[0]
fun.optional_positionals << Types::Function::Param.new(
name: name,
type: node_type(lvasgn.children[1])
)
end
if rest
rest_name = rest == :* ? nil : rest # # For `def f(...) end` syntax
fun = fun.update(rest_positionals: Types::Function::Param.new(name: rest_name, type: untyped))
end
table_node.drop(fun.required_positionals.size + fun.optional_positionals.size + (fun.rest_positionals ? 1 : 0)).take(post_num).each do |name|
fun.trailing_positionals << Types::Function::Param.new(name: name, type: untyped)
end
while kw
lvasgn, kw = kw.children
name, value = lvasgn.children
case value
when nil, :NODE_SPECIAL_REQUIRED_KEYWORD
fun.required_keywords[name] = Types::Function::Param.new(name: name, type: untyped)
when RubyVM::AbstractSyntaxTree::Node
fun.optional_keywords[name] = Types::Function::Param.new(name: name, type: node_type(value))
else
raise "Unexpected keyword arg value: #{value}"
end
end
if kwrest && kwrest.children.any?
fun = fun.update(rest_keywords: Types::Function::Param.new(name: kwrest.children[0], type: untyped))
end
fun
end
def function_return_type_from_body(node)
body = node.children[2]
return Types::Bases::Nil.new(location: nil) unless body
if body.type == :BLOCK
return_stmts = any_node?(body) do |n|
n.type == :RETURN
end&.map do |return_node|
returned_value = return_node.children[0]
returned_value ? literal_to_type(returned_value) : Types::Bases::Nil.new(location: nil)
end || []
last_node = body.children.last
last_evaluated = last_node ? literal_to_type(last_node) : Types::Bases::Nil.new(location: nil)
types_to_union_type([*return_stmts, last_evaluated])
else
literal_to_type(body)
end
end
def literal_to_type(node)
case node.type
when :STR
lit = node.children[0]
if lit.match?(/\A[ -~]+\z/)
Types::Literal.new(literal: lit, location: nil)
else
BuiltinNames::String.instance_type
end
when :DSTR, :XSTR
BuiltinNames::String.instance_type
when :DSYM
BuiltinNames::Symbol.instance_type
when :DREGX
BuiltinNames::Regexp.instance_type
when :TRUE
BuiltinNames::TrueClass.instance_type
when :FALSE
BuiltinNames::FalseClass.instance_type
when :NIL
Types::Bases::Nil.new(location: nil)
when :LIT
lit = node.children[0]
case lit
when Symbol
if lit.match?(/\A[ -~]+\z/)
Types::Literal.new(literal: lit, location: nil)
else
BuiltinNames::Symbol.instance_type
end
when Integer
Types::Literal.new(literal: lit, location: nil)
else
type_name = TypeName.new(name: lit.class.name.to_sym, namespace: Namespace.root)
Types::ClassInstance.new(name: type_name, args: [], location: nil)
end
when :ZLIST, :ZARRAY
BuiltinNames::Array.instance_type([untyped])
when :LIST, :ARRAY
elem_types = node.children.compact.map { |e| literal_to_type(e) }
t = types_to_union_type(elem_types)
BuiltinNames::Array.instance_type([t])
when :DOT2, :DOT3
types = node.children.map { |c| literal_to_type(c) }
type = range_element_type(types)
BuiltinNames::Range.instance_type([type])
when :HASH
list = node.children[0]
if list
children = list.children
children.pop
else
children = []
end
key_types = []
value_types = []
children.each_slice(2) do |k, v|
if k
key_types << literal_to_type(k)
value_types << literal_to_type(v)
else
key_types << untyped
value_types << untyped
end
end
if !key_types.empty? && key_types.all? { |t| t.is_a?(Types::Literal) }
fields = key_types.map { |t| t.literal }.zip(value_types).to_h
Types::Record.new(fields: fields, location: nil)
else
key_type = types_to_union_type(key_types)
value_type = types_to_union_type(value_types)
BuiltinNames::Hash.instance_type([key_type, value_type])
end
else
untyped
end
end
def types_to_union_type(types)
return untyped if types.empty?
uniq = types.uniq
return uniq.first if uniq.size == 1
Types::Union.new(types: uniq, location: nil)
end
def range_element_type(types)
types = types.reject { |t| t == untyped }
return untyped if types.empty?
types = types.map do |t|
if t.is_a?(Types::Literal)
type_name = TypeName.new(name: t.literal.class.name.to_sym, namespace: Namespace.root)
Types::ClassInstance.new(name: type_name, args: [], location: nil)
else
t
end
end.uniq
if types.size == 1
types.first
else
untyped
end
end
def block_from_body(node)
_, args_node, body_node = node.children
_pre_num, _pre_init, _opt, _first_post, _post_num, _post_init, _rest, _kw, _kwrest, block = args_from_node(args_node)
method_block = nil
if block
method_block = Types::Block.new(
# HACK: The `block` is :& on `def m(...)` syntax.
# In this case the block looks optional in most cases, so it marks optional.
# In other cases, we can't determine which is required or optional, so it marks required.
required: block != :&,
type: Types::Function.empty(untyped)
)
end
if body_node
if (yields = any_node?(body_node) {|n| n.type == :YIELD })
method_block = Types::Block.new(
required: true,
type: Types::Function.empty(untyped)
)
yields.each do |yield_node|
array_content = yield_node.children[0]&.children&.compact || []
positionals, keywords = if keyword_hash?(array_content.last)
[array_content.take(array_content.size - 1), array_content.last]
else
[array_content, nil]
end
if (diff = positionals.size - method_block.type.required_positionals.size) > 0
diff.times do
method_block.type.required_positionals << Types::Function::Param.new(
type: untyped,
name: nil
)
end
end
if keywords
keywords.children[0].children.each_slice(2) do |key_node, value_node|
if key_node
key = key_node.children[0]
method_block.type.required_keywords[key] ||=
Types::Function::Param.new(
type: untyped,
name: nil
)
end
end
end
end
end
end
method_block
end
# NOTE: args_node may be a nil by a bug
# https://bugs.ruby-lang.org/issues/17495
def args_from_node(args_node)
args_node&.children || [0, nil, nil, nil, 0, nil, nil, nil, nil, nil]
end
def keyword_hash?(node)
if node
if node.type == :HASH
node.children[0].children.compact.each_slice(2).all? {|key, _|
key.type == :LIT && key.children[0].is_a?(Symbol)
}
end
end
end
def any_node?(node, nodes: [], &block)
if yield(node)
nodes << node
end
each_child node do |child|
any_node? child, nodes: nodes, &block
end
nodes.empty? ? nil : nodes
end
def node_type(node, default: Types::Bases::Any.new(location: nil))
case node.type
when :LIT
case node.children[0]
when Symbol
BuiltinNames::Symbol.instance_type
when Integer
BuiltinNames::Integer.instance_type
when Float
BuiltinNames::Float.instance_type
else
default
end
when :STR, :DSTR
BuiltinNames::String.instance_type
when :NIL
# This type is technical non-sense, but may help practically.
Types::Optional.new(
type: Types::Bases::Any.new(location: nil),
location: nil
)
when :TRUE, :FALSE
Types::Bases::Bool.new(location: nil)
when :ARRAY, :LIST
BuiltinNames::Array.instance_type(default)
when :HASH
BuiltinNames::Hash.instance_type(default, default)
else
default
end
end
def untyped
@untyped ||= Types::Bases::Any.new(location: nil)
end
def private
@private ||= AST::Members::Private.new(location: nil)
end
def public
@public ||= AST::Members::Public.new(location: nil)
end
def current_accessibility(decls, index = decls.size)
idx = decls.slice(0, index).rindex { |decl| decl == private || decl == public }
(idx && decls[idx]) || public
end
def remove_unnecessary_accessibility_methods!(decls)
current = public
idx = 0
loop do
decl = decls[idx] or break
if current == decl
decls.delete_at(idx)
next
end
if 0 < idx && is_accessibility?(decls[idx - 1]) && is_accessibility?(decl)
decls.delete_at(idx - 1)
idx -= 1
current = current_accessibility(decls, idx)
next
end
current = decl if is_accessibility?(decl)
idx += 1
end
decls.pop while decls.last && is_accessibility?(decls.last)
end
def is_accessibility?(decl)
decl == public || decl == private
end
def find_def_index_by_name(decls, name)
decls.find_index do |decl|
case decl
when AST::Members::MethodDefinition, AST::Members::AttrReader
decl.name == name
when AST::Members::AttrWriter
decl.name == :"#{name}="
end
end
end
end
end
end