module Steep
module Interface
class Function
class Params
module Utils
def union(*types, null: false)
types << AST::Builtin.nil_type if null
AST::Types::Union.build(types: types)
end
def intersection(*types)
AST::Types::Intersection.build(types: types)
end
end
class PositionalParams
class Base
attr_reader :type
def initialize(type)
@type = type
end
def ==(other)
other.is_a?(self.class) && other.type == type
end
alias eql? ==
def hash
self.class.hash ^ type.hash
end
def subst(s)
ty = type.subst(s)
if ty == type
self
else
_ = self.class.new(ty)
end
end
def var_type
type
end
def map_type(&block)
if block_given?
_ = self.class.new(yield type)
else
enum_for(:map_type)
end
end
end
class Required < Base; end
class Optional < Base; end
class Rest < Base; end
attr_reader :head
attr_reader :tail
def initialize(head:, tail:)
@head = head
@tail = tail
end
def self.required(type, tail = nil)
PositionalParams.new(head: Required.new(type), tail: tail)
end
def self.optional(type, tail = nil)
PositionalParams.new(head: Optional.new(type), tail: tail)
end
def self.rest(type, tail = nil)
PositionalParams.new(head: Rest.new(type), tail: tail)
end
def to_ary
[head, tail]
end
def map(&block)
hd = yield(head)
tl = tail&.map(&block)
if head == hd && tail == tl
self
else
PositionalParams.new(head: hd, tail: tl)
end
end
def map_type(&block)
if block
map {|param| param.map_type(&block) }
else
enum_for :map_type
end
end
def subst(s)
map_type do |type|
ty = type.subst(s)
if ty == type
type
else
ty
end
end
end
def ==(other)
other.is_a?(PositionalParams) && other.head == head && other.tail == tail
end
alias eql? ==
def hash
self.class.hash ^ head.hash ^ tail.hash
end
def each(&block)
if block
yield head
tail&.each(&block)
else
enum_for(:each)
end
end
def each_type
if block_given?
each do |param|
yield param.type
end
else
enum_for :each_type
end
end
def size
1 + (tail&.size || 0)
end
def self.build(required:, optional:, rest:)
params = rest ? self.rest(rest) : nil
params = optional.reverse_each.inject(params) {|params, type| self.optional(type, params) }
params = required.reverse_each.inject(params) {|params, type| self.required(type, params) }
params
end
extend Utils
# Calculates xs + ys.
# Never fails.
def self.merge_for_overload(xs, ys)
x = xs&.head
y = ys&.head
case
when x.is_a?(Required) && y.is_a?(Required)
xs or raise
ys or raise
required(
union(x.type, y.type),
merge_for_overload(xs.tail, ys.tail)
)
when x.is_a?(Required) && y.is_a?(Optional)
xs or raise
ys or raise
optional(
union(x.type, y.type, null: true),
merge_for_overload(xs.tail, ys.tail)
)
when x.is_a?(Required) && y.is_a?(Rest)
xs or raise
ys or raise
optional(
union(x.type, y.type, null: true),
merge_for_overload(xs.tail, ys)
)
when x.is_a?(Required) && !y
xs or raise
optional(
union(x.type, null: true),
merge_for_overload(xs.tail, nil)
)
when x.is_a?(Optional) && y.is_a?(Required)
xs or raise
ys or raise
optional(
union(x.type, y.type, null: true),
merge_for_overload(xs.tail, ys.tail)
)
when x.is_a?(Optional) && y.is_a?(Optional)
xs or raise
ys or raise
optional(
union(x.type, y.type),
merge_for_overload(xs.tail, ys.tail)
)
when x.is_a?(Optional) && y.is_a?(Rest)
xs or raise
ys or raise
optional(
union(x.type, y.type),
merge_for_overload(xs.tail, ys)
)
when x.is_a?(Optional) && !y
xs or raise
optional(
x.type,
merge_for_overload(xs.tail, nil)
) # == xs
when x.is_a?(Rest) && y.is_a?(Required)
xs or raise
ys or raise
optional(
union(x.type, y.type, null: true),
merge_for_overload(xs, ys.tail)
)
when x.is_a?(Rest) && y.is_a?(Optional)
xs or raise
ys or raise
optional(
union(x.type, y.type),
merge_for_overload(xs, ys.tail)
)
when x.is_a?(Rest) && y.is_a?(Rest)
xs or raise
ys or raise
rest(union(x.type, y.type))
when x.is_a?(Rest) && !y
xs or raise
when !x && y.is_a?(Required)
ys or raise
optional(
union(y.type, null: true),
merge_for_overload(nil, ys.tail)
)
when !x && y.is_a?(Optional)
ys or raise
optional(
y.type,
merge_for_overload(nil, ys.tail)
) # == ys
when !x && y.is_a?(Rest)
ys or raise
when !x && !y
nil
end
end
# xs | ys
def self.merge_for_union(xs, ys)
x = xs&.head
y = ys&.head
case
when x.is_a?(Required) && y.is_a?(Required)
xs or raise
ys or raise
required(
union(x.type, y.type),
merge_for_union(xs.tail, ys.tail)
)
when x.is_a?(Required) && !y
xs or raise
optional(
x.type,
merge_for_union(xs.tail, nil)
)
when x.is_a?(Required) && y.is_a?(Optional)
xs or raise
ys or raise
optional(
union(x.type, y.type),
merge_for_union(xs.tail, ys.tail)
)
when x.is_a?(Required) && y.is_a?(Rest)
xs or raise
ys or raise
optional(
union(x.type, y.type),
merge_for_union(xs.tail, ys)
)
when !x && y.is_a?(Required)
ys or raise
optional(
y.type,
merge_for_union(nil, ys.tail)
)
when !x && !y
nil
when !x && y.is_a?(Optional)
ys or raise
PositionalParams.new(head: y, tail: merge_for_union(nil, ys.tail))
when !x && y.is_a?(Rest)
ys or raise
when x.is_a?(Optional) && y.is_a?(Required)
xs or raise
ys or raise
optional(
union(x.type, y.type),
merge_for_union(xs.tail, ys.tail)
)
when x.is_a?(Optional) && !y
xs or raise
PositionalParams.new(head: x, tail: merge_for_union(xs.tail, nil)) # == xs
when x.is_a?(Optional) && y.is_a?(Optional)
xs or raise
ys or raise
optional(
union(x.type, y.type),
merge_for_union(xs.tail, ys.tail)
)
when x.is_a?(Optional) && y.is_a?(Rest)
xs or raise
ys or raise
optional(
union(x.type, y.type),
merge_for_union(xs.tail, ys.tail)
)
when x.is_a?(Rest) && y.is_a?(Required)
xs or raise
ys or raise
optional(
union(x.type, y.type),
merge_for_union(xs, ys.tail)
)
when x.is_a?(Rest) && !y
xs or raise
when x.is_a?(Rest) && y.is_a?(Optional)
xs or raise
ys or raise
optional(
union(x.type, y.type),
merge_for_union(xs, ys.tail)
)
when x.is_a?(Rest) && y.is_a?(Rest)
xs or raise
ys or raise
rest(
union(x.type, y.type)
)
end
end
# Calculates xs & ys.
# Raises when failed.
#
def self.merge_for_intersection(xs, ys)
x = xs&.head
y = ys&.head
case
when x.is_a?(Required) && y.is_a?(Required)
xs or raise
ys or raise
required(
intersection(x.type, y.type),
merge_for_intersection(xs.tail, ys.tail)
)
when x.is_a?(Required) && !y
raise
when x.is_a?(Required) && y.is_a?(Optional)
xs or raise
ys or raise
required(
intersection(x.type, y.type),
merge_for_intersection(xs.tail, ys.tail)
)
when x.is_a?(Required) && y.is_a?(Rest)
xs or raise
ys or raise
required(
intersection(x.type, y.type),
merge_for_intersection(xs.tail, ys)
)
when !x && y.is_a?(Required)
raise
when !x && !y
nil
when !x && y.is_a?(Optional)
nil
when !x && y.is_a?(Rest)
nil
when x.is_a?(Optional) && y.is_a?(Required)
xs or raise
ys or raise
required(
intersection(x.type, y.type),
merge_for_intersection(xs.tail, ys.tail)
)
when x.is_a?(Optional) && !y
nil
when x.is_a?(Optional) && y.is_a?(Optional)
xs or raise
ys or raise
optional(
intersection(x.type, y.type),
merge_for_intersection(xs.tail, ys.tail)
)
when x.is_a?(Optional) && y.is_a?(Rest)
xs or raise
ys or raise
optional(
intersection(x.type, y.type),
merge_for_intersection(xs.tail, ys)
)
when x.is_a?(Rest) && y.is_a?(Required)
xs or raise
ys or raise
required(
intersection(x.type, y.type),
merge_for_intersection(xs, ys.tail)
)
when x.is_a?(Rest) && !y
nil
when x.is_a?(Rest) && y.is_a?(Optional)
xs or raise
ys or raise
optional(
intersection(x.type, y.type),
merge_for_intersection(xs, ys.tail)
)
when x.is_a?(Rest) && y.is_a?(Rest)
rest(intersection(x.type, y.type))
end
end
end
class KeywordParams
attr_reader :requireds
attr_reader :optionals
attr_reader :rest
def initialize(requireds: {}, optionals: {}, rest: nil)
@requireds = requireds
@optionals = optionals
@rest = rest
end
def ==(other)
other.is_a?(KeywordParams) &&
other.requireds == requireds &&
other.optionals == optionals &&
other.rest == rest
end
alias eql? ==
def hash
self.class.hash ^ requireds.hash ^ optionals.hash ^ rest.hash
end
def update(requireds: self.requireds, optionals: self.optionals, rest: self.rest)
KeywordParams.new(
requireds: requireds,
optionals: optionals,
rest: rest
)
end
def empty?
requireds.empty? && optionals.empty? && rest.nil?
end
def each(&block)
if block
requireds.each(&block)
optionals.each(&block)
if rest
yield [nil, rest]
end
else
enum_for :each
end
end
def each_type
if block_given?
each do |_, type|
yield type
end
else
enum_for :each_type
end
end
def map_type(&block)
if block
rs = requireds.transform_values(&block)
os = optionals.transform_values(&block)
r = rest&.yield_self(&block)
if requireds == rs && optionals == os && rest == r
self
else
update(requireds: rs, optionals: os, rest: r)
end
else
enum_for(:map_type)
end
end
def subst(s)
map_type do |type|
ty = type.subst(s)
if ty == type
type
else
ty
end
end
end
def size
requireds.size + optionals.size + (rest ? 1 : 0)
end
def keywords
Set[] + requireds.keys + optionals.keys
end
include Utils
# For overloading
def +(other)
requireds = {} #: Hash[Symbol, AST::Types::t]
optionals = {} #: Hash[Symbol, AST::Types::t]
all_keys = Set[] + self.requireds.keys + self.optionals.keys + other.requireds.keys + other.optionals.keys
all_keys.each do |key|
case
when t = self.requireds[key]
case
when s = other.requireds[key]
requireds[key] = union(t, s)
when s = other.optionals[key]
optionals[key] = union(t, s, null: true)
when s = other.rest
optionals[key] = union(t, s, null: true)
else
optionals[key] = union(t, null: true)
end
when t = self.optionals[key]
case
when s = other.requireds[key]
optionals[key] = union(t, s, null: true)
when s = other.optionals[key]
optionals[key] = union(t, s)
when s = other.rest
optionals[key] = union(t, s)
else
optionals[key] = t
end
when t = self.rest
case
when s = other.requireds[key]
optionals[key] = union(t, s, null: true)
when s = other.optionals[key]
optionals[key] = union(t, s)
when s = other.rest
# cannot happen
else
# nop
end
else
case
when s = other.requireds[key]
optionals[key] = union(s, null: true)
when s = other.optionals[key]
optionals[key] = s
when s = other.rest
# nop
else
# cannot happen
end
end
end
if self.rest && other.rest
rest = union(self.rest, other.rest)
else
rest = self.rest || other.rest
end
KeywordParams.new(requireds: requireds, optionals: optionals, rest: rest)
end
# For union
def |(other)
requireds = {} #: Hash[Symbol, AST::Types::t]
optionals = {} #: Hash[Symbol, AST::Types::t]
all_keys = Set[] + self.requireds.keys + self.optionals.keys + other.requireds.keys + other.optionals.keys
all_keys.each do |key|
case
when t = self.requireds[key]
case
when s = other.requireds[key]
requireds[key] = union(t, s)
when s = other.optionals[key]
optionals[key] = union(t, s)
when s = other.rest
optionals[key] = union(t, s)
else
optionals[key] = t
end
when t = self.optionals[key]
case
when s = other.requireds[key]
optionals[key] = union(t, s)
when s = other.optionals[key]
optionals[key] = union(t, s)
when s = other.rest
optionals[key] = union(t, s)
else
optionals[key] = t
end
when t = self.rest
case
when s = other.requireds[key]
optionals[key] = union(t, s)
when s = other.optionals[key]
optionals[key] = union(t, s)
when s = other.rest
# cannot happen
else
# nop
end
else
case
when s = other.requireds[key]
optionals[key] = s
when s = other.optionals[key]
optionals[key] = s
when s = other.rest
# nop
else
# cannot happen
end
end
end
rest =
if self.rest && other.rest
union(self.rest, other.rest)
else
self.rest || other.rest
end
KeywordParams.new(requireds: requireds, optionals: optionals, rest: rest)
end
# For intersection
def &(other)
requireds = {} #: Hash[Symbol, AST::Types::t]
optionals = {} #: Hash[Symbol, AST::Types::t]
all_keys = Set[] + self.requireds.keys + self.optionals.keys + other.requireds.keys + other.optionals.keys
all_keys.each do |key|
case
when t = self.requireds[key]
case
when s = other.requireds[key]
requireds[key] = intersection(t, s)
when s = other.optionals[key]
requireds[key] = intersection(t, s)
when s = other.rest
requireds[key] = intersection(t, s)
else
return nil
end
when t = self.optionals[key]
case
when s = other.requireds[key]
requireds[key] = intersection(t, s)
when s = other.optionals[key]
optionals[key] = intersection(t, s)
when s = other.rest
optionals[key] = intersection(t, s)
else
# nop
end
when t = self.rest
case
when s = other.requireds[key]
requireds[key] = intersection(t, s)
when s = other.optionals[key]
optionals[key] = intersection(t, s)
when s = other.rest
# cannot happen
else
# nop
end
else
case
when s = other.requireds[key]
return nil
when s = other.optionals[key]
# nop
when s = other.rest
# nop
else
# cannot happen
end
end
end
rest =
if self.rest && other.rest
intersection(self.rest, other.rest)
else
nil
end
KeywordParams.new(requireds: requireds, optionals: optionals, rest: rest)
end
end
def required
array = [] #: Array[AST::Types::t]
positional_params&.each do |param|
case param
when PositionalParams::Required
array << param.type
else
break
end
end
array
end
def optional
array = [] #: Array[AST::Types::t]
positional_params&.each do |param|
case param
when PositionalParams::Required
# skip
when PositionalParams::Optional
array << param.type
else
break
end
end
array
end
def rest
positional_params&.each do |param|
case param
when PositionalParams::Required, PositionalParams::Optional
# skip
when PositionalParams::Rest
return param.type
end
end
nil
end
attr_reader :positional_params
attr_reader :keyword_params
def self.build(required: [], optional: [], rest: nil, required_keywords: {}, optional_keywords: {}, rest_keywords: nil)
positional_params = PositionalParams.build(required: required, optional: optional, rest: rest)
keyword_params = KeywordParams.new(requireds: required_keywords, optionals: optional_keywords, rest: rest_keywords)
new(positional_params: positional_params, keyword_params: keyword_params)
end
def initialize(positional_params:, keyword_params:)
@positional_params = positional_params
@keyword_params = keyword_params
end
def update(positional_params: self.positional_params, keyword_params: self.keyword_params)
self.class.new(positional_params: positional_params, keyword_params: keyword_params)
end
def first_param
positional_params&.head
end
def with_first_param(param)
update(
positional_params: PositionalParams.new(
head: param,
tail: positional_params
)
)
end
def has_positional?
positional_params ? true : false
end
def self.empty
self.new(positional_params: nil, keyword_params: KeywordParams.new)
end
def ==(other)
other.is_a?(self.class) &&
other.positional_params == positional_params &&
other.keyword_params == keyword_params
end
alias eql? ==
def hash
self.class.hash ^ positional_params.hash ^ keyword_params.hash
end
def flat_unnamed_params
if positional_params
positional_params.each.with_object([]) do |param, types|
case param
when PositionalParams::Required
types << [:required, param.type]
when PositionalParams::Optional
types << [:optional, param.type]
end
end
else
[]
end
end
def flat_keywords
required_keywords.merge(optional_keywords)
end
def required_keywords
keyword_params.requireds
end
def optional_keywords
keyword_params.optionals
end
def rest_keywords
keyword_params.rest
end
def has_keywords?
!keyword_params.empty?
end
def each_positional_param(&block)
if block_given?
if positional_params
positional_params.each(&block)
end
else
enum_for :each_positional_param
end
end
def without_keywords
update(keyword_params: KeywordParams.new)
end
def drop_first
case
when positional_params
update(positional_params: positional_params.tail)
when has_keywords?
without_keywords()
else
raise "Cannot drop from empty params"
end
end
def each_type(&block)
if block
positional_params&.each_type(&block)
keyword_params.each_type(&block)
else
enum_for :each_type
end
end
def free_variables()
@fvs ||= Set.new.tap do |set|
each_type do |type|
set.merge(type.free_variables)
end
end
end
def closed?
each_type.all? { _1.free_variables.empty? }
end
def subst(s)
return self if s.empty?
return self if empty?
return self if each_type.none? {|t| s.apply?(t) }
pp = positional_params
kp = keyword_params
if positional_params && positional_params.each_type.any? {|t| s.apply?(t) }
pp = positional_params.subst(s)
end
if keyword_params && keyword_params.each_type.any? {|t| s.apply?(t) }
kp = keyword_params.subst(s)
end
self.class.new(positional_params: pp, keyword_params: kp)
end
def size
(positional_params&.size || 0) + keyword_params.size
end
def to_s
required = self.required.map {|ty| ty.to_s }
optional = self.optional.map {|ty| "?#{ty}" }
rest = self.rest ? ["*#{self.rest}"] : []
required_keywords = keyword_params.requireds.map {|name, type| "#{name}: #{type}" }
optional_keywords = keyword_params.optionals.map {|name, type| "?#{name}: #{type}"}
rest_keywords = keyword_params.rest ? ["**#{keyword_params.rest}"] : []
"(#{(required + optional + rest + required_keywords + optional_keywords + rest_keywords).join(", ")})"
end
def map_type(&block)
self.class.new(
positional_params: positional_params&.map_type(&block),
keyword_params: keyword_params.map_type(&block)
)
end
def empty?
!has_positional? && !has_keywords?
end
# Returns true if all arguments are non-required.
def optional?
required.empty? && required_keywords.empty?
end
# self + params returns a new params for overloading.
#
def +(other)
pp = PositionalParams.merge_for_overload(positional_params, other.positional_params)
kp = keyword_params + other.keyword_params
Params.new(positional_params: pp, keyword_params: kp)
end
# Returns the intersection between self and other.
# Returns nil if the intersection cannot be computed.
#
# (self & other) <: self
# (self & other) <: other
#
# `self & other` accept `arg` if `arg` is acceptable for both of `self` and `other`.
#
def &(other)
pp = PositionalParams.merge_for_intersection(positional_params, other.positional_params) rescue return
kp = keyword_params & other.keyword_params or return
Params.new(positional_params: pp, keyword_params: kp)
end
# Returns the union between self and other.
#
# self <: (self | other)
# other <: (self | other)
#
# `self | other` accept `arg` if `self` accepts `arg` or `other` accepts `arg`.
#
def |(other)
pp = PositionalParams.merge_for_union(positional_params, other.positional_params) rescue return
kp = keyword_params | other.keyword_params or return
Params.new(positional_params: pp, keyword_params: kp)
end
end
attr_reader :params
attr_reader :return_type
attr_reader :location
def initialize(params:, return_type:, location:)
@params = params
@return_type = return_type
@location = location
end
def ==(other)
other.is_a?(Function) && other.params == params && other.return_type == return_type
end
alias eql? ==
def hash
self.class.hash ^ params.hash ^ return_type.hash
end
def free_variables
@fvs ||= Set[].tap do |fvs|
# @type var fvs: Set[AST::Types::variable]
fvs.merge(params.free_variables) if params
fvs.merge(return_type.free_variables)
end
end
def subst(s)
return self if s.empty?
ps = params.subst(s) if params
ret = return_type.subst(s)
if ps == params && ret == return_type
self
else
Function.new(
params: ps,
return_type: ret,
location: location
)
end
end
def each_type(&block)
if block
params&.each_type(&block)
yield return_type
else
enum_for :each_type
end
end
alias each_child each_type
def map_type(&block)
Function.new(
params: params&.map_type(&block),
return_type: yield(return_type),
location: location
)
end
def with(params: self.params, return_type: self.return_type)
Function.new(
params: params,
return_type: return_type,
location: location
)
end
def to_s
if params
"#{params} -> #{return_type}"
else
"(?) -> #{return_type}"
end
end
def closed?
if params
params.closed? && return_type.free_variables.empty?
else
return_type.free_variables.empty?
end
end
end
end
end