class Steep::TypeInference::TypeEnv

def [](name)

def [](name)
  case name
  when Symbol
    case
    when local_variable_name?(name)
      local_variable_types[name]&.[](0)
    when instance_variable_name?(name)
      instance_variable_types[name]
    when global_name?(name)
      global_types[name]
    else
      raise "Unexpected variable name: #{name}"
    end
  when Parser::AST::Node
    case name.type
    when :lvar
      self[name.children[0]]
    when :send
      if (call, type = pure_method_calls[name])
        type || call.return_type
      end
    end
  end
end

def add_pure_call(node, call, type)

def add_pure_call(node, call, type)
  if (c, _ = pure_method_calls[node]) && c == call
    return self
  end
  update =
    pure_node_invalidation(invalidated_pure_nodes(node))
      .merge!({ node => [call, type] })
  merge(pure_method_calls: update)
end

def annotated_constant(name)

def annotated_constant(name)
  constant_types[name]
end

def assign_local_variable(name, var_type, enforced_type)

def assign_local_variable(name, var_type, enforced_type)
  local_variable_name!(name)
  merge(
    local_variable_types: { name => [enforced_type || var_type, enforced_type] },
    pure_method_calls: pure_node_invalidation(invalidated_pure_nodes(::Parser::AST::Node.new(:lvar, [name])))
  )
end

def assign_local_variables(assignments)

def assign_local_variables(assignments)
  local_variable_types = {}
  invalidated_nodes = Set[]
  assignments.each do |name, new_type|
    local_variable_name!(name)
    local_variable_types[name] = [new_type, enforced_type(name)]
    invalidated_nodes.merge(invalidated_pure_nodes(::Parser::AST::Node.new(:lvar, [name])))
  end
  invalidation = pure_node_invalidation(invalidated_nodes)
  merge(
    local_variable_types: local_variable_types,
    pure_method_calls: invalidation
  )
end

def constant(arg1, arg2)

def constant(arg1, arg2)
  if arg1.is_a?(RBS::TypeName) && arg2.is_a?(Symbol)
    constant_env.resolve_child(arg1, arg2)
  elsif arg1.is_a?(Symbol)
    if arg2
      constant_env.toplevel(arg1)
    else
      constant_env.resolve(arg1)
    end
  end
end

def enforced_type(name)

def enforced_type(name)
  local_variable_types[name]&.[](1)
end

def global_name?(name)

def global_name?(name)
  name.start_with?('$')
end

def initialize(constant_env, local_variable_types: {}, instance_variable_types: {}, global_types: {}, constant_types: {}, pure_method_calls: {})

def initialize(constant_env, local_variable_types: {}, instance_variable_types: {}, global_types: {}, constant_types: {}, pure_method_calls: {})
  @constant_env = constant_env
  @local_variable_types = local_variable_types
  @instance_variable_types = instance_variable_types
  @global_types = global_types
  @constant_types = constant_types
  @pure_method_calls = pure_method_calls
  @pure_node_descendants = {}
end

def inspect

def inspect
  s = "#<%s:%#018x " % [self.class, object_id]
  s << instance_variables.map(&:to_s).sort.map {|name| "#{name}=..." }.join(", ")
  s + ">"
end

def instance_variable_name?(name)

def instance_variable_name?(name)
  name.start_with?(/@[^@]/)
end

def invalidate_pure_node(node)

def invalidate_pure_node(node)
  merge(pure_method_calls: pure_node_invalidation(invalidated_pure_nodes(node)))
end

def invalidated_pure_nodes(invalidated_node)

def invalidated_pure_nodes(invalidated_node)
  invalidated_nodes = Set[]
  pure_method_calls.each_key do |pure_node|
    descendants = @pure_node_descendants[pure_node] ||= each_descendant_node(pure_node).to_set
    if descendants.member?(invalidated_node)
      invalidated_nodes << pure_node
    end
  end
  invalidated_nodes
end

def join(*envs)

def join(*envs)
  # @type var all_lvar_types: Hash[Symbol, Array[AST::Types::t]]
  all_lvar_types = envs.each_with_object({}) do |env, hash|
    env.local_variable_types.each_key do |name|
      hash[name] = []
    end
  end
  envs.each do |env|
    all_lvar_types.each_key do |name|
      all_lvar_types[name] << (env[name] || AST::Builtin.nil_type)
    end
  end
  assignments =
    all_lvar_types
      .transform_values {|types| AST::Types::Union.build(types: types) }
      .reject {|var, type| self[var] == type }
  common_pure_nodes = envs
    .map {|env| Set.new(env.pure_method_calls.each_key) }
    .inject {|s1, s2| s1.intersection(s2) } || Set[]
  pure_call_updates = common_pure_nodes.each_with_object({}) do |node, hash|
    pairs = envs.map {|env| env.pure_method_calls[node] }
    refined_type = AST::Types::Union.build(types: pairs.map {|call, type| type || call.return_type })
    # Any *pure_method_call* can be used because it's *pure*
    (call, _ = envs[0].pure_method_calls[node]) or raise
    hash[node] = [call, refined_type]
  end
  assign_local_variables(assignments).merge(pure_method_calls: pure_call_updates)
end

def local_variable_name!(name)

def local_variable_name!(name)
  local_variable_name?(name) || raise("#{name} is not a local variable")
end

def local_variable_name?(name)

def local_variable_name?(name)
  # Ruby constants start with Uppercase_Letter or Titlecase_Letter in the unicode property.
  # If name start with `@`, it is instance variable or class instance variable.
  # If name start with `$`, it is global variable.
  return false if name.start_with?(/[\p{Uppercase_Letter}\p{Titlecase_Letter}@$]/)
  return false if TypeConstruction::SPECIAL_LVAR_NAMES.include?(name)
  true
end

def merge(local_variable_types: {}, instance_variable_types: {}, global_types: {}, constant_types: {}, pure_method_calls: {})

def merge(local_variable_types: {}, instance_variable_types: {}, global_types: {}, constant_types: {}, pure_method_calls: {})
  local_variable_types = self.local_variable_types.merge(local_variable_types)
  instance_variable_types = self.instance_variable_types.merge(instance_variable_types)
  global_types = self.global_types.merge(global_types)
  constant_types = self.constant_types.merge(constant_types)
  pure_method_calls = self.pure_method_calls.merge(pure_method_calls)
  TypeEnv.new(
    constant_env,
    local_variable_types: local_variable_types,
    instance_variable_types: instance_variable_types,
    global_types:  global_types,
    constant_types: constant_types,
    pure_method_calls: pure_method_calls
  )
end

def pin_local_variables(names)

def pin_local_variables(names)
  names = Set.new(names) if names
  local_variable_types.each.with_object({}) do |pair, hash|
    name, entry = pair
    local_variable_name!(name)
    if names.nil? || names.include?(name)
      type, enforced_type = entry
      unless enforced_type
        hash[name] = [type, type]
      end
    end
  end
end

def pure_node_invalidation(invalidated_nodes)

def pure_node_invalidation(invalidated_nodes)
  # @type var invalidation: Hash[Parser::AST::Node, [MethodCall::Typed, AST::Types::t?]]
  invalidation = {}
  invalidated_nodes.each do |node|
    if (call, _ = pure_method_calls[node])
      invalidation[node] = [call, nil]
    end
  end
  invalidation
end

def refine_types(local_variable_types: {}, pure_call_types: {})

def refine_types(local_variable_types: {}, pure_call_types: {})
  local_variable_updates = {}
  local_variable_types.each do |name, type|
    local_variable_name!(name)
    local_variable_updates[name] = [type, enforced_type(name)]
  end
  invalidated_nodes = Set.new(pure_call_types.each_key)
  local_variable_types.each_key do |name|
    invalidated_nodes.merge(invalidated_pure_nodes(Parser::AST::Node.new(:lvar, [name])))
  end
  pure_call_updates = pure_node_invalidation(invalidated_nodes)
  pure_call_types.each do |node, type|
    call, _ = pure_call_updates[node]
    pure_call_updates[node] = [call, type]
  end
  merge(local_variable_types: local_variable_updates, pure_method_calls: pure_call_updates)
end

def replace_pure_call_type(node, type)

def replace_pure_call_type(node, type)
  if (call, _ = pure_method_calls[node])
    calls = pure_method_calls.dup
    calls[node] = [call, type]
    update(pure_method_calls: calls)
  else
    raise
  end
end

def subst(s)

def subst(s)
  update(
    local_variable_types: local_variable_types.transform_values do |entry|
      # @type block: local_variable_entry
      type, enforced_type = entry
      [
        type.subst(s),
        enforced_type&.yield_self {|ty| ty.subst(s) }
      ]
    end
  )
end

def to_s

def to_s
  array = []
  local_variable_types.each do |name, entry|
    if enforced_type = entry[1]
      array << "#{name}: #{entry[0].to_s} <#{enforced_type.to_s}>"
    else
      array << "#{name}: #{entry[0].to_s}"
    end
  end
  instance_variable_types.each do |name, type|
    array << "#{name}: #{type.to_s}"
  end
  global_types.each do |name, type|
    array << "#{name}: #{type.to_s}"
  end
  constant_types.each do |name, type|
    array << "#{name}: #{type.to_s}"
  end
  pure_method_calls.each do |node, pair|
    call, type = pair
    array << "`#{node.loc.expression.source.lines[0]}`: #{type || call.return_type}"
  end
  "{ #{array.join(", ")} }"
end

def unpin_local_variables(names)

def unpin_local_variables(names)
  names = Set.new(names) if names
  local_var_types = local_variable_types.each.with_object({}) do |pair, hash|
    name, entry = pair
    local_variable_name!(name)
    if names.nil? || names.include?(name)
      type, _ = entry
      hash[name] = [type, nil]
    end
  end
  merge(local_variable_types: local_var_types)
end

def update(local_variable_types: self.local_variable_types, instance_variable_types: self.instance_variable_types, global_types: self.global_types, constant_types: self.constant_types, pure_method_calls: self.pure_method_calls)

def update(local_variable_types: self.local_variable_types, instance_variable_types: self.instance_variable_types, global_types: self.global_types, constant_types: self.constant_types, pure_method_calls: self.pure_method_calls)
  TypeEnv.new(
    constant_env,
    local_variable_types: local_variable_types,
    instance_variable_types: instance_variable_types,
    global_types: global_types,
    constant_types: constant_types,
    pure_method_calls: pure_method_calls
  )
end