# frozen_string_literal: true
module RBS
module Prototype
class RB
include Helpers
class Context < Struct.new(:module_function, :singleton, :namespace, :in_def, keyword_init: true)
# @implements Context
def self.initial(namespace: Namespace.root)
self.new(module_function: false, singleton: false, namespace: namespace, in_def: false)
end
def method_kind
if singleton
:singleton
elsif module_function
:singleton_instance
else
:instance
end
end
def attribute_kind
if singleton
:singleton
else
:instance
end
end
def enter_namespace(namespace)
Context.initial(namespace: self.namespace + namespace)
end
def update(module_function: self.module_function, singleton: self.singleton, in_def: self.in_def)
Context.new(module_function: module_function, singleton: singleton, namespace: namespace, in_def: in_def)
end
end
attr_reader :source_decls
attr_reader :toplevel_members
def initialize
@source_decls = []
end
def decls
# @type var decls: Array[AST::Declarations::t]
decls = []
# @type var top_decls: Array[AST::Declarations::t]
# @type var top_members: Array[AST::Members::t]
top_decls, top_members = _ = source_decls.partition {|decl| decl.is_a?(AST::Declarations::Base) }
decls.push(*top_decls)
unless top_members.empty?
top = AST::Declarations::Class.new(
name: TypeName.new(name: :Object, namespace: Namespace.empty),
super_class: nil,
members: top_members,
annotations: [],
comment: nil,
location: nil,
type_params: []
)
decls << top
end
decls
end
def parse(string)
# @type var comments: Hash[Integer, AST::Comment]
comments = Ripper.lex(string).yield_self do |tokens|
code_lines = {} #: Hash[Integer, bool]
tokens.each.with_object({}) do |token, hash| #$ Hash[Integer, AST::Comment]
case token[1]
when :on_sp, :on_ignored_nl
# skip
when :on_comment
line = token[0][0]
# skip like `module Foo # :nodoc:`
next if code_lines[line]
body = token[2][2..-1] or raise
body = "\n" if body.empty?
comment = AST::Comment.new(string: body, location: nil)
if prev_comment = hash.delete(line - 1)
hash[line] = AST::Comment.new(string: prev_comment.string + comment.string,
location: nil)
else
hash[line] = comment
end
else
code_lines[token[0][0]] = true
end
end
end
process RubyVM::AbstractSyntaxTree.parse(string), decls: source_decls, comments: comments, context: Context.initial
end
def process(node, decls:, comments:, context:)
case node.type
when :CLASS
class_name, super_class_node, *class_body = node.children
super_class_name = const_to_name(super_class_node, context: context)
super_class =
if super_class_name
AST::Declarations::Class::Super.new(name: super_class_name, args: [], location: nil)
else
# Give up detect super class e.g. `class Foo < Struct.new(:bar)`
nil
end
kls = AST::Declarations::Class.new(
name: const_to_name!(class_name),
super_class: super_class,
type_params: [],
members: [],
annotations: [],
location: nil,
comment: comments[node.first_lineno - 1]
)
decls.push kls
new_ctx = context.enter_namespace(kls.name.to_namespace)
each_node class_body do |child|
process child, decls: kls.members, comments: comments, context: new_ctx
end
remove_unnecessary_accessibility_methods! kls.members
sort_members! kls.members
when :MODULE
module_name, *module_body = node.children
mod = AST::Declarations::Module.new(
name: const_to_name!(module_name),
type_params: [],
self_types: [],
members: [],
annotations: [],
location: nil,
comment: comments[node.first_lineno - 1]
)
decls.push mod
new_ctx = context.enter_namespace(mod.name.to_namespace)
each_node module_body do |child|
process child, decls: mod.members, comments: comments, context: new_ctx
end
remove_unnecessary_accessibility_methods! mod.members
sort_members! mod.members
when :SCLASS
this, body = node.children
if this.type != :SELF
RBS.logger.warn "`class <<` syntax with not-self may be compiled to incorrect code: #{this}"
end
accessibility = current_accessibility(decls)
ctx = Context.initial.tap { |ctx| ctx.singleton = true }
process_children(body, decls: decls, comments: comments, context: ctx)
decls << accessibility
when :DEFN, :DEFS
# @type var kind: Context::method_kind
if node.type == :DEFN
def_name, def_body = node.children
kind = context.method_kind
else
_, def_name, def_body = node.children
kind = :singleton
end
types = [
MethodType.new(
type_params: [],
type: function_type_from_body(def_body, def_name),
block: block_from_body(def_body),
location: nil
)
]
member = AST::Members::MethodDefinition.new(
name: def_name,
location: nil,
annotations: [],
overloads: types.map {|type| AST::Members::MethodDefinition::Overload.new(annotations: [], method_type: type )},
kind: kind,
comment: comments[node.first_lineno - 1],
overloading: false,
visibility: nil
)
decls.push member unless decls.include?(member)
new_ctx = context.update(singleton: kind == :singleton, in_def: true)
each_node def_body.children do |child|
process child, decls: decls, comments: comments, context: new_ctx
end
when :ALIAS
new_name, old_name = node.children.map { |c| literal_to_symbol(c) }
member = AST::Members::Alias.new(
new_name: new_name,
old_name: old_name,
kind: context.singleton ? :singleton : :instance,
annotations: [],
location: nil,
comment: comments[node.first_lineno - 1],
)
decls.push member unless decls.include?(member)
when :FCALL, :VCALL
# Inside method definition cannot reach here.
args = node.children[1]&.children || []
case node.children[0]
when :include
args.each do |arg|
if (name = const_to_name(arg, context: context))
klass = context.singleton ? AST::Members::Extend : AST::Members::Include
decls << klass.new(
name: name,
args: [],
annotations: [],
location: nil,
comment: comments[node.first_lineno - 1]
)
end
end
when :prepend
args.each do |arg|
if (name = const_to_name(arg, context: context))
decls << AST::Members::Prepend.new(
name: name,
args: [],
annotations: [],
location: nil,
comment: comments[node.first_lineno - 1]
)
end
end
when :extend
args.each do |arg|
if (name = const_to_name(arg, context: context))
decls << AST::Members::Extend.new(
name: name,
args: [],
annotations: [],
location: nil,
comment: comments[node.first_lineno - 1]
)
end
end
when :attr_reader
args.each do |arg|
if arg && (name = literal_to_symbol(arg))
decls << AST::Members::AttrReader.new(
name: name,
ivar_name: nil,
type: Types::Bases::Any.new(location: nil),
kind: context.attribute_kind,
location: nil,
comment: comments[node.first_lineno - 1],
annotations: []
)
end
end
when :attr_accessor
args.each do |arg|
if arg && (name = literal_to_symbol(arg))
decls << AST::Members::AttrAccessor.new(
name: name,
ivar_name: nil,
type: Types::Bases::Any.new(location: nil),
kind: context.attribute_kind,
location: nil,
comment: comments[node.first_lineno - 1],
annotations: []
)
end
end
when :attr_writer
args.each do |arg|
if arg && (name = literal_to_symbol(arg))
decls << AST::Members::AttrWriter.new(
name: name,
ivar_name: nil,
type: Types::Bases::Any.new(location: nil),
kind: context.attribute_kind,
location: nil,
comment: comments[node.first_lineno - 1],
annotations: []
)
end
end
when :alias_method
if args[0] && args[1] && (new_name = literal_to_symbol(args[0])) && (old_name = literal_to_symbol(args[1]))
decls << AST::Members::Alias.new(
new_name: new_name,
old_name: old_name,
kind: context.singleton ? :singleton : :instance,
annotations: [],
location: nil,
comment: comments[node.first_lineno - 1],
)
end
when :module_function
if args.empty?
context.module_function = true
else
module_func_context = context.update(module_function: true)
args.each do |arg|
if arg && (name = literal_to_symbol(arg))
if (i, defn = find_def_index_by_name(decls, name))
if defn.is_a?(AST::Members::MethodDefinition)
decls[i] = defn.update(kind: :singleton_instance)
end
end
elsif arg
process arg, decls: decls, comments: comments, context: module_func_context
end
end
end
when :public, :private
accessibility = __send__(node.children[0])
if args.empty?
decls << accessibility
else
args.each do |arg|
if arg && (name = literal_to_symbol(arg))
if (i, _ = find_def_index_by_name(decls, name))
current = current_accessibility(decls, i)
if current != accessibility
decls.insert(i + 1, current)
decls.insert(i, accessibility)
end
end
end
end
# For `private def foo` syntax
current = current_accessibility(decls)
decls << accessibility
process_children(node, decls: decls, comments: comments, context: context)
decls << current
end
else
process_children(node, decls: decls, comments: comments, context: context)
end
when :ITER
# ignore
when :CDECL
const_name = case
when node.children[0].is_a?(Symbol)
TypeName.new(name: node.children[0], namespace: Namespace.empty)
else
const_to_name!(node.children[0], context: context)
end
value_node = node.children.last
type = if value_node.nil?
# Give up type prediction when node is MASGN.
Types::Bases::Any.new(location: nil)
else
literal_to_type(value_node)
end
decls << AST::Declarations::Constant.new(
name: const_name,
type: type,
location: nil,
comment: comments[node.first_lineno - 1],
annotations: []
)
when :IASGN
case [context.singleton, context.in_def]
when [true, true], [false, false]
member = AST::Members::ClassInstanceVariable.new(
name: node.children.first,
type: Types::Bases::Any.new(location: nil),
location: nil,
comment: comments[node.first_lineno - 1]
)
when [false, true]
member = AST::Members::InstanceVariable.new(
name: node.children.first,
type: Types::Bases::Any.new(location: nil),
location: nil,
comment: comments[node.first_lineno - 1]
)
when [true, false]
# The variable is for the singleton class of the class object.
# RBS does not have a way to represent it. So we ignore it.
else
raise 'unreachable'
end
decls.push member if member && !decls.include?(member)
when :CVASGN
member = AST::Members::ClassVariable.new(
name: node.children.first,
type: Types::Bases::Any.new(location: nil),
location: nil,
comment: comments[node.first_lineno - 1]
)
decls.push member unless decls.include?(member)
else
process_children(node, decls: decls, comments: comments, context: context)
end
end
def process_children(node, decls:, comments:, context:)
each_child node do |child|
process child, decls: decls, comments: comments, context: context
end
end
def const_to_name!(node, context: nil)
case node.type
when :CONST
TypeName.new(name: node.children[0], namespace: Namespace.empty)
when :COLON2
if node.children[0]
namespace = const_to_name!(node.children[0], context: context).to_namespace
else
namespace = Namespace.empty
end
TypeName.new(name: node.children[1], namespace: namespace)
when :COLON3
TypeName.new(name: node.children[0], namespace: Namespace.root)
when :SELF
raise if context.nil?
context.namespace.to_type_name
else
raise
end
end
def const_to_name(node, context:)
if node
case node.type
when :SELF
context.namespace.to_type_name
when :CONST, :COLON2, :COLON3
const_to_name!(node) rescue nil
end
end
end
def literal_to_symbol(node)
case node.type
when :SYM
node.children[0]
when :LIT
node.children[0] if node.children[0].is_a?(Symbol)
when :STR
node.children[0].to_sym
end
end
def function_type_from_body(node, def_name)
table_node, args_node, *_ = node.children
pre_num, _pre_init, opt, _first_post, post_num, _post_init, rest, kw, kwrest, _block = args_from_node(args_node)
return_type = if def_name == :initialize
Types::Bases::Void.new(location: nil)
else
function_return_type_from_body(node)
end
fun = Types::Function.empty(return_type)
table_node.take(pre_num).each do |name|
fun.required_positionals << Types::Function::Param.new(name: name, type: untyped)
end
while opt&.type == :OPT_ARG
lvasgn, opt = opt.children
name = lvasgn.children[0]
fun.optional_positionals << Types::Function::Param.new(
name: name,
type: param_type(lvasgn.children[1])
)
end
if rest
rest_name = rest == :* ? nil : rest # `def f(...)` syntax has `*` name
fun = fun.update(rest_positionals: Types::Function::Param.new(name: rest_name, type: untyped))
end
table_node.drop(fun.required_positionals.size + fun.optional_positionals.size + (fun.rest_positionals ? 1 : 0)).take(post_num).each do |name|
fun.trailing_positionals << Types::Function::Param.new(name: name, type: untyped)
end
while kw
lvasgn, kw = kw.children
name, value = lvasgn.children
case value
when nil, :NODE_SPECIAL_REQUIRED_KEYWORD
fun.required_keywords[name] = Types::Function::Param.new(name: nil, type: untyped)
when RubyVM::AbstractSyntaxTree::Node
fun.optional_keywords[name] = Types::Function::Param.new(name: nil, type: param_type(value))
else
raise "Unexpected keyword arg value: #{value}"
end
end
if kwrest && kwrest.children.any?
kwrest_name = kwrest.children[0] #: Symbol?
kwrest_name = nil if kwrest_name == :** # `def f(...)` syntax has `**` name
fun = fun.update(rest_keywords: Types::Function::Param.new(name: kwrest_name, type: untyped))
end
fun
end
def function_return_type_from_body(node)
body = node.children[2]
body_type(body)
end
def body_type(node)
return Types::Bases::Nil.new(location: nil) unless node
case node.type
when :IF, :UNLESS
if_unless_type(node)
when :BLOCK
block_type(node)
else
literal_to_type(node)
end
end
def if_unless_type(node)
raise unless node.type == :IF || node.type == :UNLESS
_exp_node, true_node, false_node = node.children
types_to_union_type([body_type(true_node), body_type(false_node)])
end
def block_type(node)
raise unless node.type == :BLOCK
return_stmts = any_node?(node) do |n|
n.type == :RETURN
end&.map do |return_node|
returned_value = return_node.children[0]
returned_value ? literal_to_type(returned_value) : Types::Bases::Nil.new(location: nil)
end || []
last_node = node.children.last
last_evaluated = last_node ? literal_to_type(last_node) : Types::Bases::Nil.new(location: nil)
types_to_union_type([*return_stmts, last_evaluated])
end
def literal_to_type(node)
case node.type
when :STR
lit = node.children[0]
if lit.ascii_only?
Types::Literal.new(literal: lit, location: nil)
else
BuiltinNames::String.instance_type
end
when :DSTR, :XSTR
BuiltinNames::String.instance_type
when :SYM
lit = node.children[0]
if lit.to_s.ascii_only?
Types::Literal.new(literal: lit, location: nil)
else
BuiltinNames::Symbol.instance_type
end
when :DSYM
BuiltinNames::Symbol.instance_type
when :DREGX, :REGX
BuiltinNames::Regexp.instance_type
when :TRUE
Types::Literal.new(literal: true, location: nil)
when :FALSE
Types::Literal.new(literal: false, location: nil)
when :NIL
Types::Bases::Nil.new(location: nil)
when :INTEGER
Types::Literal.new(literal: node.children[0], location: nil)
when :FLOAT
BuiltinNames::Float.instance_type
when :RATIONAL, :IMAGINARY
lit = node.children[0]
type_name = TypeName.new(name: lit.class.name.to_sym, namespace: Namespace.root)
Types::ClassInstance.new(name: type_name, args: [], location: nil)
when :LIT
lit = node.children[0]
case lit
when Symbol
if lit.to_s.ascii_only?
Types::Literal.new(literal: lit, location: nil)
else
BuiltinNames::Symbol.instance_type
end
when Integer
Types::Literal.new(literal: lit, location: nil)
when String
# For Ruby <=3.3 which generates `LIT` node for string literals inside Hash literal.
# "a" => STR node
# { "a" => nil } => LIT node
Types::Literal.new(literal: lit, location: nil)
else
type_name = TypeName.new(name: lit.class.name.to_sym, namespace: Namespace.root)
Types::ClassInstance.new(name: type_name, args: [], location: nil)
end
when :ZLIST, :ZARRAY
BuiltinNames::Array.instance_type(untyped)
when :LIST, :ARRAY
elem_types = node.children.compact.map { |e| literal_to_type(e) }
t = types_to_union_type(elem_types)
BuiltinNames::Array.instance_type(t)
when :DOT2, :DOT3
types = node.children.map { |c| literal_to_type(c) }
type = range_element_type(types)
BuiltinNames::Range.instance_type(type)
when :HASH
list = node.children[0]
if list
children = list.children
children.pop
else
children = [] #: Array[untyped]
end
key_types = [] #: Array[Types::t]
value_types = [] #: Array[Types::t]
children.each_slice(2) do |k, v|
if k
key_types << literal_to_type(k)
value_types << literal_to_type(v)
else
key_types << untyped
value_types << untyped
end
end
if !key_types.empty? && key_types.all? { |t| t.is_a?(Types::Literal) }
fields = key_types.map {|t|
t.is_a?(Types::Literal) or raise
t.literal
}.zip(value_types).to_h #: Hash[Types::Literal::literal, Types::t]
Types::Record.new(fields: fields, location: nil)
else
key_type = types_to_union_type(key_types)
value_type = types_to_union_type(value_types)
BuiltinNames::Hash.instance_type(key_type, value_type)
end
when :SELF
Types::Bases::Self.new(location: nil)
when :CALL
receiver, method_name, * = node.children
case method_name
when :freeze, :tap, :itself, :dup, :clone, :taint, :untaint, :extend
literal_to_type(receiver)
else
untyped
end
else
untyped
end
end
def types_to_union_type(types)
return untyped if types.empty?
uniq = types.uniq
if uniq.size == 1
return uniq.first || raise
end
Types::Union.new(types: uniq, location: nil)
end
def range_element_type(types)
types = types.reject { |t| t == untyped }
return untyped if types.empty?
types = types.map do |t|
if t.is_a?(Types::Literal)
type_name = TypeName.new(name: t.literal.class.name&.to_sym || raise, namespace: Namespace.root)
Types::ClassInstance.new(name: type_name, args: [], location: nil)
else
t
end
end.uniq
if types.size == 1
types.first or raise
else
untyped
end
end
def param_type(node, default: Types::Bases::Any.new(location: nil))
case node.type
when :INTEGER
BuiltinNames::Integer.instance_type
when :FLOAT
BuiltinNames::Float.instance_type
when :RATIONAL
Types::ClassInstance.new(name: TypeName.parse("::Rational"), args: [], location: nil)
when :IMAGINARY
Types::ClassInstance.new(name: TypeName.parse("::Complex"), args: [], location: nil)
when :LIT
case node.children[0]
when Symbol
BuiltinNames::Symbol.instance_type
when Integer
BuiltinNames::Integer.instance_type
when Float
BuiltinNames::Float.instance_type
else
default
end
when :SYM
BuiltinNames::Symbol.instance_type
when :STR, :DSTR
BuiltinNames::String.instance_type
when :NIL
# This type is technical non-sense, but may help practically.
Types::Optional.new(
type: Types::Bases::Any.new(location: nil),
location: nil
)
when :TRUE, :FALSE
Types::Bases::Bool.new(location: nil)
when :ARRAY, :LIST
BuiltinNames::Array.instance_type(default)
when :HASH
BuiltinNames::Hash.instance_type(default, default)
else
default
end
end
# backward compatible
alias node_type param_type
def private
@private ||= AST::Members::Private.new(location: nil)
end
def public
@public ||= AST::Members::Public.new(location: nil)
end
def current_accessibility(decls, index = decls.size)
slice = decls.slice(0, index) or raise
idx = slice.rindex { |decl| decl == private || decl == public }
if idx
_ = decls[idx]
else
public
end
end
def remove_unnecessary_accessibility_methods!(decls)
# @type var current: decl
current = public
idx = 0
loop do
decl = decls[idx] or break
if current == decl
decls.delete_at(idx)
next
end
if 0 < idx && is_accessibility?(decls[idx - 1]) && is_accessibility?(decl)
decls.delete_at(idx - 1)
idx -= 1
current = current_accessibility(decls, idx)
next
end
current = decl if is_accessibility?(decl)
idx += 1
end
decls.pop while decls.last && is_accessibility?(decls.last || raise)
end
def is_accessibility?(decl)
decl == public || decl == private
end
def find_def_index_by_name(decls, name)
index = decls.find_index do |decl|
case decl
when AST::Members::MethodDefinition, AST::Members::AttrReader
decl.name == name
when AST::Members::AttrWriter
:"#{decl.name}=" == name
end
end
if index
[
index,
_ = decls[index]
]
end
end
def sort_members!(decls)
i = 0
orders = {
AST::Members::ClassVariable => -3,
AST::Members::ClassInstanceVariable => -2,
AST::Members::InstanceVariable => -1,
} #: Hash[Class, Integer]
decls.sort_by! { |decl| [orders.fetch(decl.class, 0), i += 1] }
end
end
end
end