class Steep::Subtyping::VariableVariance

def self.add_params(params, block:, covariants:, contravariants:)

def self.add_params(params, block:, covariants:, contravariants:)
  params.each_type do |type|
    add_type(type, variance: block ? :contravariant : :covariant, covariants: covariants, contravariants: contravariants)
  end
end

def self.add_type(type, variance:, covariants:, contravariants:)

def self.add_type(type, variance:, covariants:, contravariants:)
  case type
  when AST::Types::Var
    case variance
    when :covariant
      covariants << type.name
    when :contravariant
      contravariants << type.name
    when :invariant
      covariants << type.name
      contravariants << type.name
    end
  when AST::Types::Union, AST::Types::Intersection, AST::Types::Tuple
    type.types.each do |ty|
      add_type(ty, variance: variance, covariants: covariants, contravariants: contravariants)
    end
  when AST::Types::Name::Interface, AST::Types::Name::Instance, AST::Types::Name::Alias
    type.args.each do |arg|
      add_type(arg, variance: :invariant, covariants: covariants, contravariants: contravariants)
    end
  end
end

def self.from_method_type(method_type)

def self.from_method_type(method_type)
  covariants = Set.new
  contravariants = Set.new
  add_params(method_type.params, block: false, contravariants: contravariants, covariants: covariants)
  add_type(method_type.return_type, variance: :covariant, covariants: covariants, contravariants: contravariants)
  method_type.block&.type&.yield_self do |proc|
    add_params(proc.params, block: true, contravariants: contravariants, covariants: covariants)
    add_type(proc.return_type, variance: :contravariant, covariants: covariants, contravariants: contravariants)
  end
  new(covariants: covariants, contravariants: contravariants)
end

def contravariant?(var)

def contravariant?(var)
  contravariants.member?(var) && !covariants.member?(var)
end

def covariant?(var)

def covariant?(var)
  covariants.member?(var) && !contravariants.member?(var)
end

def initialize(covariants:, contravariants:)

def initialize(covariants:, contravariants:)
  @covariants = covariants
  @contravariants = contravariants
end

def invariant?(var)

def invariant?(var)
  covariants.member?(var) && contravariants.member?(var)
end