module Steep
module TypeInference
class SendArgs
class PositionalArgs
class NodeParamPair
attr_reader :node
attr_reader :param
def initialize(node:, param:)
@node = node
@param = param
end
include Equatable
def to_ary
[node, param]
end
end
class NodeTypePair
attr_reader :node
attr_reader :type
def initialize(node:, type:)
@node = node
@type = type
end
include Equatable
def node_type
case node.type
when :splat
AST::Builtin::Array.instance_type(type)
else
type
end
end
end
class SplatArg
attr_reader :node
attr_accessor :type
def initialize(node:)
@node = node
@type = nil
end
include Equatable
end
class UnexpectedArg
attr_reader :node
def initialize(node:)
@node = node
end
include Equatable
end
class MissingArg
attr_reader :params
def initialize(params:)
@params = params
end
include Equatable
end
attr_reader :args
attr_reader :index
attr_reader :positional_params
attr_reader :uniform
def initialize(args:, index:, positional_params:, uniform: false)
@args = args
@index = index
@positional_params = positional_params
@uniform = uniform
end
def node
args[index]
end
def following_args
args[index..] or raise
end
def param
positional_params&.head
end
def update(index: self.index, positional_params: self.positional_params, uniform: self.uniform)
PositionalArgs.new(args: args, index: index, positional_params: positional_params, uniform: uniform)
end
def next()
case
when node && node.type == :forwarded_args
# If the node is a `:forwarded_args`, abort
nil
when !node && param.is_a?(Interface::Function::Params::PositionalParams::Required)
[
MissingArg.new(params: positional_params),
update(index: index, positional_params: nil)
]
when !node && param.is_a?(Interface::Function::Params::PositionalParams::Optional)
nil
when !node && param.is_a?(Interface::Function::Params::PositionalParams::Rest)
nil
when !node && !param
nil
when node && node.type != :splat && param.is_a?(Interface::Function::Params::PositionalParams::Required)
[
NodeParamPair.new(node: node, param: param),
update(index: index+1, positional_params: positional_params&.tail)
]
when node && node.type != :splat && param.is_a?(Interface::Function::Params::PositionalParams::Optional)
[
NodeParamPair.new(node: node, param: param),
update(index: index+1, positional_params: positional_params&.tail)
]
when node && node.type != :splat && param.is_a?(Interface::Function::Params::PositionalParams::Rest)
[
NodeParamPair.new(node: node, param: param),
update(index: index+1)
]
when node && node.type != :splat && !param
[
UnexpectedArg.new(node: node),
update(index: index + 1)
]
when node && node.type == :splat
[
SplatArg.new(node: node),
self
]
end
end
def uniform_type
return nil unless positional_params
if positional_params.each.any? {|param| param.is_a?(Interface::Function::Params::PositionalParams::Rest) }
AST::Types::Intersection.build(types: positional_params.each.map(&:type))
end
end
def consume(n, node:)
# @type var ps: Array[Interface::Function::Params::PositionalParams::param]
ps = []
params = consume0(n, node: node, params: positional_params, ps: ps)
case params
when UnexpectedArg
[
params,
update(index: index+1, positional_params: nil)
]
else
[ps, update(index: index+1, positional_params: params)]
end
end
def consume0(n, node:, params:, ps:)
case n
when 0
params
else
head = params&.head
case head
when nil
UnexpectedArg.new(node: node)
when Interface::Function::Params::PositionalParams::Required, Interface::Function::Params::PositionalParams::Optional
ps << head
consume0(n-1, node: node, params: params&.tail, ps: ps)
when Interface::Function::Params::PositionalParams::Rest
ps << head
consume0(n-1, node: node, params: params, ps: ps)
end
end
end
end
class KeywordArgs
class ArgTypePairs
attr_reader :pairs
def initialize(pairs:)
@pairs = pairs
end
include Equatable
def [](index)
pairs[index]
end
def size
pairs.size
end
end
class SplatArg
attr_reader :node
attr_accessor :type
def initialize(node:)
@node = node
@type = nil
end
include Equatable
end
class UnexpectedKeyword
attr_reader :keyword
attr_reader :node
include Equatable
def initialize(keyword:, node:)
@keyword = keyword
@node = node
end
def key_node
if node.type == :pair
node.children[0]
end
end
def value_node
if node.type == :pair
node.children[1]
end
end
end
class MissingKeyword
attr_reader :keywords
include Equatable
def initialize(keywords:)
@keywords = keywords
end
end
attr_reader :kwarg_nodes
attr_reader :keyword_params
attr_reader :index
attr_reader :consumed_keywords
def initialize(kwarg_nodes:, keyword_params:, index: 0, consumed_keywords: Set[])
@kwarg_nodes = kwarg_nodes
@keyword_params = keyword_params
@index = index
@consumed_keywords = consumed_keywords
end
def update(index: self.index, consumed_keywords: self.consumed_keywords)
KeywordArgs.new(
kwarg_nodes: kwarg_nodes,
keyword_params: keyword_params,
index: index,
consumed_keywords: consumed_keywords
)
end
def keyword_pair
kwarg_nodes[index]
end
def required_keywords
keyword_params.requireds
end
def optional_keywords
keyword_params.optionals
end
def rest_type
keyword_params.rest
end
def keyword_type(key)
required_keywords[key] || optional_keywords[key]
end
def all_keys
keys = Set.new
keys.merge(required_keywords.each_key)
keys.merge(optional_keywords.each_key)
keys.sort_by(&:to_s).to_a
end
def all_values
keys = Set.new
keys.merge(required_keywords.each_value)
keys.merge(optional_keywords.each_value)
keys.sort_by(&:to_s).to_a
end
def possible_key_type
# @type var key_types: Array[AST::Types::t]
key_types = all_keys.map {|key| AST::Types::Literal.new(value: key) }
key_types << AST::Builtin::Symbol.instance_type if rest_type
AST::Types::Union.build(types: key_types)
end
def possible_value_type
value_types = all_values
value_types << rest_type if rest_type
AST::Types::Intersection.build(types: value_types)
end
def next()
node = keyword_pair
if node
case node.type
when :pair
key_node, value_node = node.children
if key_node.type == :sym
key = key_node.children[0]
case
when value_type = keyword_type(key)
[
ArgTypePairs.new(
pairs: [
[key_node, AST::Types::Literal.new(value: key)],
[value_node, value_type]
]
),
update(
index: index+1,
consumed_keywords: consumed_keywords + [key]
)
]
when value_type = rest_type
[
ArgTypePairs.new(
pairs: [
[key_node, AST::Builtin::Symbol.instance_type],
[value_node, value_type]
]
),
update(
index: index+1,
consumed_keywords: consumed_keywords + [key]
)
]
else
[
UnexpectedKeyword.new(keyword: key, node: node),
update(index: index+1)
]
end
else
if !all_keys.empty? || rest_type
[
ArgTypePairs.new(
pairs: [
[key_node, possible_key_type],
[value_node, possible_value_type]
]
),
update(index: index+1)
]
else
[
UnexpectedKeyword.new(keyword: nil, node: node),
update(index: index+1)
]
end
end
when :kwsplat
[
SplatArg.new(node: node),
self
]
end
else
left = Set.new(required_keywords.keys) - consumed_keywords
unless left.empty?
[
MissingKeyword.new(keywords: left),
update(consumed_keywords: consumed_keywords + left)
]
end
end
end
def consume_keys(keys, node:)
# @type var consumed_keys: Array[Symbol]
consumed_keys = []
# @type var types: Array[AST::Types::t]
types = []
# @type var unexpected_keyword: Symbol?
unexpected_keyword = nil
keys.each do |key|
case
when type = keyword_type(key)
consumed_keys << key
types << type
when type = rest_type()
types << type
else
unexpected_keyword = key
end
end
[
if unexpected_keyword
UnexpectedKeyword.new(keyword: unexpected_keyword, node: node)
else
types
end,
update(index: index + 1, consumed_keywords: consumed_keywords + consumed_keys)
]
end
end
class BlockPassArg
attr_reader :node
attr_reader :block
def initialize(node:, block:)
@node = node
@block = block
end
include Equatable
def no_block?
!node && !block
end
def compatible?
if node
block ? true : false
else
!block || block.optional?
end
end
def block_missing?
!node && block&.required?
end
def unexpected_block?
node && !block
end
def pair
raise unless compatible?
if node && block
[
node,
block.type
]
end
end
def node_type
raise unless block
type = AST::Types::Proc.new(type: block.type, block: nil, self_type: block.self_type)
if block.optional?
type = AST::Types::Union.build(types: [type, AST::Builtin.nil_type])
end
type
end
end
class ForwardedArgs
attr_reader :node, :params
def initialize(node:, params:)
@node = node
@params = params
end
end
attr_reader :node
attr_reader :arguments
attr_reader :type
def initialize(node:, arguments:, type:)
@node = node
@arguments = arguments
@type = type
end
def params
case type
when Interface::MethodType
type.type.params
when AST::Types::Proc
type.type.params
else
raise
end
end
def block
case type
when Interface::MethodType
type.block
when AST::Types::Proc
type.block
end
end
def positional_params
params.positional_params
end
def keyword_params
params.keyword_params
end
def kwargs_node
unless keyword_params.empty?
arguments.find {|node| node.type == :kwargs }
end
end
def positional_arg
args =
if keyword_params.empty?
arguments.take_while {|node| node.type != :block_pass }
else
arguments.take_while {|node| node.type != :kwargs && node.type != :block_pass }
end
PositionalArgs.new(args: args, index: 0, positional_params: positional_params)
end
def forwarded_args_node
arguments.find {|node| node.type == :forwarded_args }
end
def keyword_args
KeywordArgs.new(
kwarg_nodes: kwargs_node&.children || [],
keyword_params: keyword_params
)
end
def block_pass_arg
node = arguments.find {|node| node.type == :block_pass }
BlockPassArg.new(node: node, block: block)
end
def each
if block_given?
errors = [] #: Array[PositionalArgs::error_arg | KeywordArgs::error_arg]
last_positional_args = positional_arg
positional_arg.tap do |args|
while (value, args = args.next())
yield value
case value
when PositionalArgs::SplatArg
type = value.type
case type
when nil
raise
when AST::Types::Tuple
ts, args = args.consume(type.types.size, node: value.node)
case ts
when Array
ty = AST::Types::Tuple.new(types: ts.map(&:type))
yield PositionalArgs::NodeTypePair.new(node: value.node, type: ty)
when PositionalArgs::UnexpectedArg
errors << ts
yield ts
end
else
if t = args.uniform_type
args.following_args.each do |node|
yield PositionalArgs::NodeTypePair.new(node: node, type: t)
end
else
args.following_args.each do |node|
arg = PositionalArgs::UnexpectedArg.new(node: node)
yield arg
errors << arg
end
end
break
end
when PositionalArgs::UnexpectedArg, PositionalArgs::MissingArg
errors << value
end
last_positional_args = args
end
end
if fag = forwarded_args_node
forward_params = Interface::Function::Params.new(
positional_params: last_positional_args.positional_params,
keyword_params: keyword_params
)
forwarded_args = ForwardedArgs.new(node: fag, params: forward_params)
else
keyword_args.tap do |args|
while (a, args = args.next)
case a
when KeywordArgs::MissingKeyword
errors << a
when KeywordArgs::UnexpectedKeyword
errors << a
end
yield a
case a
when KeywordArgs::SplatArg
case type = a.type
when nil
raise
when AST::Types::Record
# @type var keys: Array[Symbol]
keys = _ = type.elements.keys
ts, args = args.consume_keys(keys, node: a.node)
case ts
when KeywordArgs::UnexpectedKeyword
yield ts
errors << ts
when Array
pairs = keys.zip(ts) #: Array[[Symbol, AST::Types::t]]
record = AST::Types::Record.new(elements: Hash[pairs])
yield KeywordArgs::ArgTypePairs.new(pairs: [[a.node, record]])
end
else
args = args.update(index: args.index + 1)
if args.rest_type
type = AST::Builtin::Hash.instance_type(AST::Builtin::Symbol.instance_type, args.possible_value_type)
yield KeywordArgs::ArgTypePairs.new(pairs: [[a.node, type]])
else
yield KeywordArgs::UnexpectedKeyword.new(keyword: nil, node: a.node)
end
end
end
end
end
end
diagnostics = [] #: Array[Diagnostic::Ruby::Base]
missing_keywords = [] #: Array[Symbol]
errors.each do |error|
case error
when KeywordArgs::UnexpectedKeyword
diagnostics << Diagnostic::Ruby::UnexpectedKeywordArgument.new(node: error.node, params: params)
when KeywordArgs::MissingKeyword
missing_keywords.push(*error.keywords.to_a)
when PositionalArgs::UnexpectedArg
diagnostics << Diagnostic::Ruby::UnexpectedPositionalArgument.new(node: error.node, params: params)
when PositionalArgs::MissingArg
diagnostics << Diagnostic::Ruby::InsufficientPositionalArguments.new(node: node, params: params)
end
end
unless missing_keywords.empty?
diagnostics << Diagnostic::Ruby::InsufficientKeywordArguments.new(node: node, params: params, missing_keywords: missing_keywords)
end
[forwarded_args, diagnostics]
else
enum_for :each
end
end
end
end
end