require "rubocop"

def tree_distance(object_a, object_b)
  tree_distance_matrix(object_a, object_b).last.last
end

def tree_distance_matrix(object_a, object_b)
  a_size, b_size = object_a.size, object_b.size

  # Rows and columns ascend from 0 to size + 1
  # as the 0th cell is for an "empty" root cell.
  #
  # That means that the first row ascends up to
  # how many steps it would take to go from nothing
  # using additions of nodes to get to its current
  # state. Same idea with columns.
  #
  # All the other cells? Well that's more complicated,
  # and what the rest of the code is for.
  distance_matrix = Array.new(a_size + 1) do |x|
    Array.new(b_size + 1) do |y|
      case [x, y]
      in [0, _] then y
      in [_, 0] then x
      else 0
      end
    end
  end

  # Given that 0th is an empty root we start from 1 and
  # go up to the size of each tree.
  (1..a_size).each do |x|
    (1..b_size).each do |y|
      # Slight preference, representing going northwest
      # from the current coordinates.
      pre_x, pre_y = x - 1, y - 1
      node_a, node_b = object_a[pre_x], object_b[pre_y]

      # If there's no difference between the score at the coordinates
      # and its northwest node we propogate that score onwards.
      distance_matrix[x][y] = if node_a == node_b
        distance_matrix[pre_x][pre_y]
      # Otherwise we want to know what the cheapest cost is from the cells
      # to the north, northwest, and west directions in the grid.
      #
      # We add one to this score to represent one step away from that least
      # costly action (one of delete, insert, or replace)
      else
        1 + [
          distance_matrix[x][pre_y],     # Deletion
          distance_matrix[pre_x][pre_y], # Insertion
          distance_matrix[pre_x][y],     # Replacement
        ].min
      end
    end
  end

  distance_matrix
end

def tree_difference(object_a, object_b)
  distance_matrix = tree_distance_matrix(object_a, object_b)
  x, y = object_a.size, object_b.size

  # Faster to check for descendant redundancy
  diff = Hash.new { |h, k| h[k] = Set.new }

  # Though I could probably invert this and get that to work,
  # but too much work for the moment.
  while x > 0 && y > 0
    pre_x, pre_y   = x - 1, y - 1
    node_a, node_b = object_a[pre_x], object_b[pre_y]
    desc_a, desc_b = node_a.each_descendant.first, node_b.each_descendant.first
    current_cell   = distance_matrix[x][y]

    # Same node, keep iterating
    if node_a == node_b
      x, y = pre_x, pre_y
    # Replaced
    elsif current_cell == distance_matrix[pre_x][pre_y] + 1
      # If the descendant is already marked as a "change" and we're
      # about to add the parent the descendant becomes redundant
      if diff[:changed].include?({ from: desc_a, to: desc_b })
        diff[:changed].delete({ from: desc_a, to: desc_b })
      end

      diff[:changed].add({ from: node_a, to: node_b })
      puts "CHANGE: #{node_a} to #{node_b}"

      x, y = pre_x, pre_y
    # Deleted
    elsif current_cell == distance_matrix[pre_x][y] + 1
      # If the descendant is already marked for deletion it would
      # be redundant to have it and the parent both included
      diff[:deleted].delete(desc_a) if diff[:deleted].include?(desc_a)
      diff[:deleted].add(node_a)
      puts "DELETE: #{node_a}"

      x = pre_x
    # Addition
    elsif current_cell == distance_matrix[x][pre_y] + 1
      # If the descendant is already marked for addition it would
      # be redundant to have it and the parent both included
      diff[:added].delete(desc_b) if diff[:added].include?(desc_b)
      diff[:added].add(node_b)
      puts "ADD: #{node_b}"

      y = pre_y
    # Failure case, should not happen
    else
      break
    end
  end

  # Back to array for readability
  diff.transform_values(&:to_a)
end

# Ruby 3.1.x
def ast_from(string)
  RuboCop::ProcessedSource.new(string, RUBY_VERSION.to_f).ast
end

def ast_difference(a, b)
  tree_difference(a.descendants, b.descendants)
end

standard_block = "[1, 2, 3].select { |v| v.even? }"
standard_ast   = ast_from(standard_block)

shorthand_block = "[1, 2, 3].select(&:even?)"
shorthand_ast   = ast_from(shorthand_block)

differences = ast_difference(standard_ast, shorthand_ast)
pp differences
# {
#   change: [{
#     from: s(:send, s(:lvar, :v), :even?),
#     to:   s(:block_pass,  s(:sym, :even?))
#   }],
#   delete: [ s(:args, s(:arg, :v)) ]
# }