class RuboCop::AST::NodePattern::Compiler

Builds Ruby code which implements a pattern
@private

def self.tokens(pattern)

def self.tokens(pattern)
  pattern.scan(TOKEN).reject { |token| token =~ /\A#{SEPARATORS}\Z/ }
end

def access_unify(name)

def access_unify(name)
  var = @unify[name]
  if var == :forbidden_unification
    fail_due_to "Wildcard #{name} was first seen in a subset of a" \
                " union and can't be used outside that union"
  end
  var
end

def auto_use_temp_node?(code)

def auto_use_temp_node?(code)
  code.scan(CUR_PLACEHOLDER).count > 1
end

def compile_any_order(capture_all = nil)

rubocop:disable Metrics/MethodLength
rubocop:disable Metrics/AbcSize
def compile_any_order(capture_all = nil)
  rest = capture_rest = nil
  patterns = []
  with_temp_variables do |child, matched|
    tokens_until('>', 'any child') do
      fail_due_to 'ellipsis must be at the end of <>' if rest
      token = tokens.shift
      case token
      when CAPTURED_REST then rest = capture_rest = next_capture
      when REST          then rest = true
      else patterns << compile_expr(token)
      end
    end
    [rest ? patterns.size..Float::INFINITY : patterns.size,
     ->(range) { ANY_ORDER_TEMPLATE.result(binding) }]
  end
end

def compile_arg(token)

def compile_arg(token)
  case token
  when WILDCARD  then
    name = token[1..-1]
    access_unify(name) || fail_due_to('invalid in arglist: ' + token)
  when LITERAL   then token
  when PARAM     then get_param(token[1..-1])
  when CLOSING   then fail_due_to("#{token} in invalid position")
  when nil       then fail_due_to('pattern ended prematurely')
  else fail_due_to("invalid token in arglist: #{token.inspect}")
  end
end

def compile_args(tokens)

def compile_args(tokens)
  index = tokens.find_index { |token| token == ')' }
  tokens.slice!(0..index).each_with_object([]) do |token, args|
    next if [')', ','].include?(token)
    args << compile_arg(token)
  end
end

def compile_ascend

def compile_ascend
  with_context("#{CUR_NODE} && #{compile_expr}", "#{CUR_NODE}.parent")
end

def compile_capture

def compile_capture
  "(#{next_capture} = #{CUR_ELEMENT}; #{compile_expr})"
end

def compile_captured_ellipsis

def compile_captured_ellipsis
  capture = next_capture
  block = lambda { |range|
    # Consider ($...) like (_ $...):
    range = 0..range.end if range.begin == SEQ_HEAD_INDEX
    "(#{capture} = #{CUR_NODE}.children[#{range}])"
  }
  [0..Float::INFINITY, block]
end

def compile_descend

def compile_descend
  with_temp_variables do |descendant|
    pattern = with_context(compile_expr, descendant,
                           use_temp_node: false)
    [
      "RuboCop::AST::NodePattern.descend(#{CUR_ELEMENT}).",
      "any? do |#{descendant}|",
      "  #{pattern}",
      'end'
    ].join("\n")
  end
end

def compile_ellipsis

def compile_ellipsis
  [0..Float::INFINITY, 'true']
end

def compile_expr(token = tokens.shift)

rubocop:disable Metrics/MethodLength, Metrics/AbcSize
def compile_expr(token = tokens.shift)
  # read a single pattern-matching expression from the token stream,
  # return Ruby code which performs the corresponding matching operation
  #
  # the 'pattern-matching' expression may be a composite which
  # contains an arbitrary number of sub-expressions, but that composite
  # must all have precedence higher or equal to that of `&&`
  #
  # Expressions may use placeholders like:
  #   CUR_NODE: Ruby code that evaluates to an AST node
  #   CUR_ELEMENT: Either the node or the type if in first element of
  #   a sequence (aka seq_head, e.g. "(seq_head first_node_arg ...")
  case token
  when '('       then compile_seq
  when '{'       then compile_union
  when '['       then compile_intersect
  when '!'       then compile_negation
  when '$'       then compile_capture
  when '^'       then compile_ascend
  when '`'       then compile_descend
  when WILDCARD  then compile_wildcard(token[1..-1])
  when FUNCALL   then compile_funcall(token)
  when LITERAL   then compile_literal(token)
  when PREDICATE then compile_predicate(token)
  when NODE      then compile_nodetype(token)
  when PARAM     then compile_param(token[1..-1])
  when CLOSING   then fail_due_to("#{token} in invalid position")
  when nil       then fail_due_to('pattern ended prematurely')
  else                fail_due_to("invalid token #{token.inspect}")
  end
end

def compile_funcall(method)

def compile_funcall(method)
  # call a method in the context which this pattern-matching
  # code is used in. pass target value as an argument
  method = method[1..-1] # drop the leading #
  if method.end_with?('(') # is there an arglist?
    args = compile_args(tokens)
    method = method[0..-2] # drop the trailing (
    "#{method}(#{CUR_ELEMENT},#{args.join(',')})"
  else
    "#{method}(#{CUR_ELEMENT})"
  end
end

def compile_guard_clause

def compile_guard_clause
  "#{CUR_NODE}.is_a?(RuboCop::AST::Node)"
end

def compile_intersect

def compile_intersect
  tokens_until(']', 'intersection')
    .map { compile_expr }
    .join(' && ')
end

def compile_literal(literal)

def compile_literal(literal)
  "#{CUR_ELEMENT} == #{literal}"
end

def compile_negation

def compile_negation
  "!(#{compile_expr})"
end

def compile_nodetype(type)

def compile_nodetype(type)
  "#{compile_guard_clause} && #{CUR_NODE}.#{type.tr('-', '_')}_type?"
end

def compile_param(number)

def compile_param(number)
  "#{CUR_ELEMENT} == #{get_param(number)}"
end

def compile_predicate(predicate)

def compile_predicate(predicate)
  if predicate.end_with?('(') # is there an arglist?
    args = compile_args(tokens)
    predicate = predicate[0..-2] # drop the trailing (
    "#{CUR_ELEMENT}.#{predicate}(#{args.join(',')})"
  else
    "#{CUR_ELEMENT}.#{predicate}"
  end
end

def compile_repeated_expr(token)

def compile_repeated_expr(token)
  before = @captures
  expr = compile_expr(token)
  min, max = parse_repetition_token
  return [1, expr] if min.nil?
  if @captures != before
    captured = "captures[#{before}...#{@captures}]"
    accumulate = next_temp_variable(:accumulate)
  end
  arity = min..max || Float::INFINITY
  [arity, repeated_generator(expr, captured, accumulate)]
end

def compile_seq

def compile_seq
  terms = tokens_until(')', 'sequence').map { variadic_seq_term }
  Sequence.new(self, *terms).compile
end

def compile_union

def compile_union
  # we need to ensure that each branch of the {} contains the same
  # number of captures (since only one branch of the {} can actually
  # match, the same variables are used to hold the captures for each
  # branch)
  enum = tokens_until('}', 'union')
  enum = unify_in_union(enum)
  terms = insure_same_captures(enum, 'branch of {}')
          .map { compile_expr }
  "(#{terms.join(' || ')})"
end

def compile_wildcard(name)

def compile_wildcard(name)
  if name.empty?
    'true'
  elsif @unify.key?(name)
    # we have already seen a wildcard with this name before
    # so the value it matched the first time will already be stored
    # in a temp. check if this value matches the one stored in the temp
    "#{CUR_ELEMENT} == #{access_unify(name)}"
  else
    n = @unify[name] = "unify_#{name.gsub('-', '__')}"
    # double assign to avoid "assigned but unused variable"
    "(#{n} = #{CUR_ELEMENT}; " \
    "#{n} = #{n}; true)"
  end
end

def emit_method_code

def emit_method_code
  <<~RUBY
    return unless #{@match_code}
    block_given? ? #{emit_yield_capture} : (return #{emit_retval})
  RUBY
end

def emit_param_list

def emit_param_list
  (1..@params).map { |n| "param#{n}" }.join(',')
end

def emit_retval

def emit_retval
  if @captures.zero?
    'true'
  elsif @captures == 1
    'captures[0]'
  else
    'captures'
  end
end

def emit_trailing_params

def emit_trailing_params
  params = emit_param_list
  params.empty? ? '' : ",#{params}"
end

def emit_yield_capture(when_no_capture = '')

def emit_yield_capture(when_no_capture = '')
  yield_val = if @captures.zero?
                when_no_capture
              elsif @captures == 1
                'captures[0]' # Circumvent https://github.com/jruby/jruby/issues/5710
              else
                '*captures'
              end
  "yield(#{yield_val})"
end

def fail_due_to(message)

def fail_due_to(message)
  raise Invalid, "Couldn't compile due to #{message}. Pattern: #{@string}"
end

def forbid_unification(*names)

def forbid_unification(*names)
  names.each do |name|
    @unify[name] = :forbidden_unification
  end
end

def get_param(number)

def get_param(number)
  number = number.empty? ? 1 : Integer(number)
  @params = number if number > @params
  number.zero? ? @root : "param#{number}"
end

def initialize(str, node_var = 'node0')

def initialize(str, node_var = 'node0')
  @string   = str
  @root     = node_var
  @temps    = 0  # avoid name clashes between temp variables
  @captures = 0  # number of captures seen
  @unify    = {} # named wildcard -> temp variable
  @params   = 0  # highest % (param) number seen
  run(node_var)
end

def insure_same_captures(enum, what)

def insure_same_captures(enum, what)
  return to_enum __method__, enum, what unless block_given?
  captures_before = captures_after = nil
  enum.each do
    captures_before ||= @captures
    @captures = captures_before
    yield
    captures_after ||= @captures
    fail_due_to("each #{what} must have same # of captures") if captures_after != @captures
  end
end

def next_capture

def next_capture
  index = @captures
  @captures += 1
  "captures[#{index}]"
end

def next_temp_value

def next_temp_value
  @temps += 1
end

def next_temp_variable(name)

def next_temp_variable(name)
  "#{name}#{next_temp_value}"
end

def parse_repetition_token

def parse_repetition_token
  case tokens.first
  when '*' then min = 0
  when '+' then min = 1
  when '?' then min = 0
                max = 1
  else          return
  end
  tokens.shift
  [min, max]
end

def repeated_generator(expr, captured, accumulate)

def repeated_generator(expr, captured, accumulate)
  with_temp_variables do |child|
    lambda do |range|
      fail_due_to 'repeated pattern at beginning of sequence' if range.begin == SEQ_HEAD_INDEX
      REPEATED_TEMPLATE.result(binding)
    end
  end
end

def run(node_var)

def run(node_var)
  @tokens = Compiler.tokens(@string)
  @match_code = with_context(compile_expr, node_var, use_temp_node: false)
  @match_code.prepend("(captures = Array.new(#{@captures})) && ") \
    if @captures.positive?
  fail_due_to('unbalanced pattern') unless tokens.empty?
end

def substitute_cur_node(code, cur_node, first_cur_node: cur_node)

def substitute_cur_node(code, cur_node, first_cur_node: cur_node)
  iter = 0
  code
    .gsub(CUR_ELEMENT, CUR_NODE)
    .gsub(CUR_NODE) do
      iter += 1
      iter == 1 ? first_cur_node : cur_node
    end
    .gsub(SEQ_HEAD_GUARD, '')
end

def tokens_until(stop, what)

def tokens_until(stop, what)
  return to_enum __method__, stop, what unless block_given?
  fail_due_to("empty #{what}") if tokens.first == stop && what
  yield until tokens.first == stop
  tokens.shift
end

def unify_in_union(enum)

rubocop:disable Metrics/MethodLength, Metrics/AbcSize
def unify_in_union(enum)
  # We need to reset @unify before each branch is processed.
  # Moreover we need to keep track of newly encountered wildcards.
  # Var `new_unify_intersection` will hold those that are encountered
  # in all branches; these are not a problem.
  # Var `partial_unify` will hold those encountered in only a subset
  # of the branches; these can't be used outside of the union.
  return to_enum __method__, enum unless block_given?
  new_unify_intersection = nil
  partial_unify = []
  unify_before = @unify.dup
  result = enum.each do |e|
    @unify = unify_before.dup if new_unify_intersection
    yield e
    new_unify = @unify.keys - unify_before.keys
    if new_unify_intersection.nil?
      # First iteration
      new_unify_intersection = new_unify
    else
      union = new_unify_intersection | new_unify
      new_unify_intersection &= new_unify
      partial_unify |= union - new_unify_intersection
    end
  end
  # At this point, all members of `new_unify_intersection` can be used
  # for unification outside of the union, but partial_unify may not
  forbid_unification(*partial_unify)
  result
end

def variadic_seq_term

def variadic_seq_term
  token = tokens.shift
  case token
  when CAPTURED_REST then compile_captured_ellipsis
  when REST          then compile_ellipsis
  when '$<'          then compile_any_order(next_capture)
  when '<'           then compile_any_order
  else                    compile_repeated_expr(token)
  end
end

def with_child_context(code, child_index)

def with_child_context(code, child_index)
  with_context(code, "#{CUR_NODE}.children[#{child_index}]")
end

def with_context(code, cur_node,

def with_context(code, cur_node,
                 use_temp_node: auto_use_temp_node?(code))
  if use_temp_node
    with_temp_node(cur_node) do |init, temp_var|
      substitute_cur_node(code, temp_var, first_cur_node: init)
    end
  else
    substitute_cur_node(code, cur_node)
  end
end

def with_seq_head_context(code)

def with_seq_head_context(code)
  fail_due_to('parentheses at sequence head') if code.include?(SEQ_HEAD_GUARD)
  code.gsub CUR_ELEMENT, "#{CUR_NODE}.type"
end

def with_temp_node(cur_node)

def with_temp_node(cur_node)
  with_temp_variables do |node|
    yield "(#{node} = #{cur_node})", node
  end
    .gsub("\n", "\n  ") # Nicer indent for debugging
end

def with_temp_variables(&block)

def with_temp_variables(&block)
  names = block.parameters.map { |_, name| next_temp_variable(name) }
  yield(*names)
end