Skip to content

Instantly share code, notes, and snippets.

@kddnewton
Last active April 27, 2021 15:58
Show Gist options
  • Save kddnewton/935f761951863228223524bca4e64fb5 to your computer and use it in GitHub Desktop.
Save kddnewton/935f761951863228223524bca4e64fb5 to your computer and use it in GitHub Desktop.
# frozen_string_literal: true
require 'ripper'
# A Ripper parser that will replace usage of Sorbet patterns with whitespace so
# that location information is maintained but Sorbet methods aren't called.
class Eraser < Ripper
# Represents a line in the source. If this class is being used, it means that
# every character in the string is 1 byte in length, so we can just return the
# start of the line + the index.
class SingleByteString
def initialize(start)
@start = start
end
def [](byteindex)
@start + byteindex
end
end
# Represents a line in the source. If this class is being used, it means that
# there are characters in the string that are multi-byte, so we will build up
# an array of indices, such that array[byteindex] will be equal to the index
# of the character within the string.
class MultiByteString
def initialize(start, line)
@indices = []
line
.each_char
.with_index(start) do |char, index|
char.bytesize.times { @indices << index }
end
end
def [](byteindex)
@indices[byteindex]
end
end
# Represents a node in the AST. Keeps track of the event that generated it,
# any child nodes that descend from it, and the location in the source.
class Node
attr_reader :event, :body, :range
def initialize(event, body, range)
@event = event
@body = body
@range = range
end
def match?(pattern)
to_s.match?(pattern)
end
def to_s
@to_s ||= "<#{event} #{body.map(&:to_s).join(' ')}>"
end
end
# A specialized node that knows what has been included/extended
class StmtsNode < Node
def initialize(event, body, range)
super(event, body, range)
@sigs =
body.any? do |child|
child.match?('<command <@ident extend> <args_add_block <args <const_path_ref <var_ref <@const T>> <@const Sig>>> false>>')
end
end
def sigs?
@sigs
end
end
# A pattern in code that represents a call to a special Sorbet method.
class Pattern
attr_reader :range
def initialize(range)
@range = range
end
def erase(source)
original = source[range]
replaced = replace(original)
# puts "Replacing #{original} (len=#{original.length}) " \
# "with #{replaced} (len=#{replaced.length})"
source[range] = replaced
source
end
def replace(segment)
segment
end
end
attr_reader :line_counts, :patterns
def initialize(source)
super(source)
@line_counts = []
last_index = 0
source.lines.each do |line|
if line.size == line.bytesize
@line_counts << SingleByteString.new(last_index)
else
@line_counts << MultiByteString.new(last_index, line)
end
last_index += line.size
end
@patterns = []
end
def self.erase(source)
parser = new(source)
if parser.parse.nil? || parser.error?
raise 'Invalid source'
else
parser.patterns.inject(source) do |current, pattern|
pattern.erase(current)
end
end
end
private
def loc
line_counts[lineno - 1][column]
end
# Loop through all of the scanner events and define a basic method that wraps
# everything into a node class.
SCANNER_EVENTS.each do |event|
define_method(:"on_#{event}") do |value|
range = loc.then { |start| start..(start + (value&.size || 0)) }
Node.new(:"@#{event}", [value], range)
end
end
# Loop through the parser events and generate a method for each event. If it's
# one of the _new methods, then use arrays like SexpBuilderPP. If it's an _add
# method then just append to the array. If it's a normal method, then create a
# new node and determine its bounds.
PARSER_EVENT_TABLE.each do |event, arity|
case event
when :stmts_new
define_method(:on_stmts_new) do
StmtsNode.new(:stmts, [], loc.then { |start| start..start })
end
when /\A(.+)_new\z/
prefix = $1.to_sym
define_method(:"on_#{event}") do
Node.new(prefix, [], loc.then { |start| start..start })
end
when /_add\z/
define_method(:"on_#{event}") do |node, value|
range =
node.body.empty? ? value.range : (node.range.begin..value.range.end)
node.class.new(node.event, node.body + [value], range)
end
else
define_method(:"on_#{event}") do |*args|
first, *, last = args.grep(Node).map(&:range)
first ||= loc.then { |start| start..start }
last ||= first
Node.new(event, args, first.begin..[last.end, loc].max)
end
end
end
module Patterns
# T.must(foo)
class TMustParensPattern < Pattern
def replace(segment)
segment.gsub(/(T\s*\.must\(\s*)(.+)(\s*\))(.*)/) do
"#{' ' * $1.length}#{$2}#{' ' * $3.length}#{$4}"
end
end
end
# T.let(foo, bar)
class TLetParensPattern < Pattern
def replace(segment)
segment.gsub(/(T\s*\.let\(\s*)(.+)(\s*,.+\))(.*)/) do
"#{' ' * $1.length}#{$2}#{' ' * $3.length}#{$4}"
end
end
end
def on_method_add_arg(call, arg_paren)
# T.must(foo)
if call.match?('<call <var_ref <@const T>> <@period .> <@ident must>>') &&
arg_paren.match?(/<arg_paren <args_add_block <args .+> false>>/)
patterns << TMustParensPattern.new(call.range.begin..arg_paren.range.end)
end
# T.let(foo, bar)
if call.match?('<call <var_ref <@const T>> <@period .> <@ident let>>') &&
arg_paren.match?(/<arg_paren <args_add_block <args .+> false>>/)
patterns << TLetParensPattern.new(call.range.begin..arg_paren.range.end)
end
super
end
# extend T::Sig
class ExtendTSigPattern < Pattern
def replace(segment)
segment.gsub(/(extend\s+T::Sig)(.*)/) do
"#{' ' * $1.length}#{$2}"
end
end
end
def on_command(ident, args_add_block)
# extend T::Sig
if ident.match?('<@ident extend>') &&
args_add_block.match?('<args_add_block <args <const_path_ref <var_ref <@const T>> <@const Sig>>> false>')
patterns << ExtendTSigPattern.new(ident.range.begin..args_add_block.range.end)
end
super
end
# T.must foo
class TMustNoParensPattern < Pattern
def replace(segment)
segment.gsub(/(T\s*\.must\s*)(.+)/) do
"#{' ' * $1.length}#{$2}"
end
end
end
def on_command_call(var_ref, period, ident, args_add_block)
if var_ref.match?('<var_ref <@const T>>') && period.match?('<@period .>')
# T.must foo
if ident.match?('<@ident must>') &&
args_add_block.match?(/<args_add_block <args <.+>> false>/) &&
args_add_block.body[0].body.length == 1
patterns << TMustNoParensPattern.new(var_ref.range.begin..args_add_block.range.end)
end
end
super
end
# sig { foo }
class SigBracesPattern < Pattern
def replace(segment)
segment.gsub(/(sig\s*\{.+\})(.*)/) do
"#{' ' * $1.length}#{$2}"
end
end
end
def on_stmts_add(node, value)
# sig { foo }
if node.sigs? && value.match?(/<method_add_block <method_add_arg <fcall <@ident sig>> <args >> <brace_block <stmts .+>>>/)
patterns << SigBracesPattern.new(value.range)
end
super
end
end
prepend Patterns
end
# Hook into bootsnap so that before the source is compiled through RubyVM::ISeq
# it gets erased through the eraser.
if RubyVM::InstructionSequence.method_defined?(:load_iseq)
load_iseq, = RubyVM::InstructionSequence.method(:load_iseq).source_location
if load_iseq.include?('/bootsnap/')
module Patch
def input_to_storage(contents, filepath)
erased = Eraser.erase(contents)
RubyVM::InstructionSequence.compile(erased, filepath, filepath).to_binary
rescue SyntaxError
raise ::Bootsnap::CompileCache::Uncompilable, 'syntax error'
end
end
Bootsnap::CompileCache::ISeq.singleton_class.prepend(Patch)
end
end
if $0 == __FILE__
raise if Eraser.erase(DATA.read) != <<~RUBY
class Foo
def foo
'foo'
end
def bar
foo
end
def baz
foo
end
end
RUBY
end
__END__
class Foo
extend T::Sig
sig { returns(T.nilable(String)) }
def foo
'foo'
end
sig { returns(String) }
def bar
T.must(foo)
end
sig { returns(String) }
def baz
T.must foo
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment