lib/steep/type_inference/block_params.rb
module Steep
module TypeInference
class BlockParams
class Param
attr_reader :var
attr_reader :type
attr_reader :value
attr_reader :node
def initialize(var:, type:, value:, node:)
@var = var
@type = type
@value = value
@node = node
end
def ==(other)
other.is_a?(self.class) && other.var == var && other.type == type && other.value == value && other.node == node
end
alias eql? ==
def hash
self.class.hash ^ var.hash ^ type.hash ^ value.hash ^ node.hash
end
end
attr_reader :leading_params
attr_reader :optional_params
attr_reader :rest_param
attr_reader :trailing_params
def initialize(leading_params:, optional_params:, rest_param:, trailing_params:)
@leading_params = leading_params
@optional_params = optional_params
@rest_param = rest_param
@trailing_params = trailing_params
end
def params
[].tap do |params|
params.push(*leading_params)
params.push(*optional_params)
params.push rest_param if rest_param
params.push(*trailing_params)
end
end
def self.from_node(node, annotations:)
leading_params = []
optional_params = []
rest_param = nil
trailing_params = []
default_params = leading_params
node.children.each do |arg|
var = arg.children.first
type = annotations.var_type(lvar: var.name)
case arg.type
when :arg, :procarg0
default_params << Param.new(var: var, type: type, value: nil, node: arg)
when :optarg
default_params = trailing_params
optional_params << Param.new(var: var, type: type, value: arg.children.last, node: arg)
when :restarg
default_params = trailing_params
rest_param = Param.new(var: var, type: type, value: nil, node: arg)
end
end
new(
leading_params: leading_params,
optional_params: optional_params,
rest_param: rest_param,
trailing_params: trailing_params
)
end
def params_type(hint: nil)
params_type0(hint: hint) or params_type0(hint: nil)
end
def params_type0(hint:)
if hint
case
when leading_params.size == hint.required.size
leadings = leading_params.map.with_index do |param, index|
param.type || hint.required[index]
end
when !hint.rest && hint.optional.empty? && leading_params.size > hint.required.size
leadings = leading_params.take(hint.required.size).map.with_index do |param, index|
param.type || hint.required[index]
end
when !hint.rest && hint.optional.empty? && leading_params.size < hint.required.size
leadings = leading_params.map.with_index do |param, index|
param.type || hint.required[index]
end + hint.required.drop(leading_params.size)
else
return nil
end
case
when optional_params.size == hint.optional.size
optionals = optional_params.map.with_index do |param, index|
param.type || hint.optional[index]
end
when !hint.rest && optional_params.size > hint.optional.size
optionals = optional_params.take(hint.optional.size).map.with_index do |param, index|
param.type || hint.optional[index]
end
when !hint.rest && optional_params.size < hint.optional.size
optionals = optional_params.map.with_index do |param, index|
param.type || hint.optional[index]
end + hint.optional.drop(optional_params.size)
else
return nil
end
if rest_param && hint.rest
rest = rest_param.type&.yield_self {|ty| ty.args&.first } || hint.rest
else
rest = hint.rest
end
else
leadings = leading_params.map {|param| param.type || AST::Types::Any.new }
optionals = optional_params.map {|param| param.type || AST::Types::Any.new }
rest = rest_param&.yield_self {|param| param.type.args[0] }
end
Interface::Params.new(
required: leadings,
optional: optionals,
rest: rest,
required_keywords: {},
optional_keywords: {},
rest_keywords: nil
)
end
def zip(params_type)
if trailing_params.any?
Steep.logger.error "Block definition with trailing required parameters are not supported yet"
end
[].tap do |zip|
if expandable_params?(params_type) && expandable?
type = params_type.required[0]
case
when AST::Builtin::Array.instance_type?(type)
type_arg = type.args[0]
params.each do |param|
unless param == rest_param
zip << [param, AST::Types::Union.build(types: [type_arg, AST::Builtin.nil_type])]
else
zip << [param, AST::Builtin::Array.instance_type(type_arg)]
end
end
when type.is_a?(AST::Types::Tuple)
types = type.types.dup
(leading_params + optional_params).each do |param|
ty = types.shift
if ty
zip << [param, ty]
else
zip << [param, AST::Types::Nil.new]
end
end
if rest_param
if types.any?
union = AST::Types::Union.build(types: types)
zip << [rest_param, AST::Builtin::Array.instance_type(union)]
else
zip << [rest_param, AST::Types::Nil.new]
end
end
end
else
types = params_type.flat_unnamed_params
(leading_params + optional_params).each do |param|
type = types.shift&.last || params_type.rest
if type
zip << [param, type]
else
zip << [param, AST::Builtin.nil_type]
end
end
if rest_param
if types.empty?
array = AST::Builtin::Array.instance_type(params_type.rest || AST::Builtin.any_type)
zip << [rest_param, array]
else
union = AST::Types::Union.build(types: types.map(&:last) + [params_type.rest])
array = AST::Builtin::Array.instance_type(union)
zip << [rest_param, array]
end
end
end
end
end
def expandable_params?(params_type)
if params_type.flat_unnamed_params.size == 1
case (type = params_type.required.first)
when AST::Types::Tuple
true
when AST::Types::Name::Base
AST::Builtin::Array.instance_type?(type)
end
end
end
def expandable?
case
when leading_params.size + trailing_params.size > 1
true
when (leading_params.any? || trailing_params.any?) && rest_param
true
when params.size == 1 && params[0].node.type == :arg
true
end
end
def each(&block)
if block_given?
params.each(&block)
else
enum_for :each
end
end
end
end
end