Last active
April 27, 2021 15:58
-
-
Save kddnewton/935f761951863228223524bca4e64fb5 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# frozen_string_literal: true | |
require 'ripper' | |
# A Ripper parser that will replace usage of Sorbet patterns with whitespace so | |
# that location information is maintained but Sorbet methods aren't called. | |
class Eraser < Ripper | |
# Represents a line in the source. If this class is being used, it means that | |
# every character in the string is 1 byte in length, so we can just return the | |
# start of the line + the index. | |
class SingleByteString | |
def initialize(start) | |
@start = start | |
end | |
def [](byteindex) | |
@start + byteindex | |
end | |
end | |
# Represents a line in the source. If this class is being used, it means that | |
# there are characters in the string that are multi-byte, so we will build up | |
# an array of indices, such that array[byteindex] will be equal to the index | |
# of the character within the string. | |
class MultiByteString | |
def initialize(start, line) | |
@indices = [] | |
line | |
.each_char | |
.with_index(start) do |char, index| | |
char.bytesize.times { @indices << index } | |
end | |
end | |
def [](byteindex) | |
@indices[byteindex] | |
end | |
end | |
# Represents a node in the AST. Keeps track of the event that generated it, | |
# any child nodes that descend from it, and the location in the source. | |
class Node | |
attr_reader :event, :body, :range | |
def initialize(event, body, range) | |
@event = event | |
@body = body | |
@range = range | |
end | |
def match?(pattern) | |
to_s.match?(pattern) | |
end | |
def to_s | |
@to_s ||= "<#{event} #{body.map(&:to_s).join(' ')}>" | |
end | |
end | |
# A specialized node that knows what has been included/extended | |
class StmtsNode < Node | |
def initialize(event, body, range) | |
super(event, body, range) | |
@sigs = | |
body.any? do |child| | |
child.match?('<command <@ident extend> <args_add_block <args <const_path_ref <var_ref <@const T>> <@const Sig>>> false>>') | |
end | |
end | |
def sigs? | |
@sigs | |
end | |
end | |
# A pattern in code that represents a call to a special Sorbet method. | |
class Pattern | |
attr_reader :range | |
def initialize(range) | |
@range = range | |
end | |
def erase(source) | |
original = source[range] | |
replaced = replace(original) | |
# puts "Replacing #{original} (len=#{original.length}) " \ | |
# "with #{replaced} (len=#{replaced.length})" | |
source[range] = replaced | |
source | |
end | |
def replace(segment) | |
segment | |
end | |
end | |
attr_reader :line_counts, :patterns | |
def initialize(source) | |
super(source) | |
@line_counts = [] | |
last_index = 0 | |
source.lines.each do |line| | |
if line.size == line.bytesize | |
@line_counts << SingleByteString.new(last_index) | |
else | |
@line_counts << MultiByteString.new(last_index, line) | |
end | |
last_index += line.size | |
end | |
@patterns = [] | |
end | |
def self.erase(source) | |
parser = new(source) | |
if parser.parse.nil? || parser.error? | |
raise 'Invalid source' | |
else | |
parser.patterns.inject(source) do |current, pattern| | |
pattern.erase(current) | |
end | |
end | |
end | |
private | |
def loc | |
line_counts[lineno - 1][column] | |
end | |
# Loop through all of the scanner events and define a basic method that wraps | |
# everything into a node class. | |
SCANNER_EVENTS.each do |event| | |
define_method(:"on_#{event}") do |value| | |
range = loc.then { |start| start..(start + (value&.size || 0)) } | |
Node.new(:"@#{event}", [value], range) | |
end | |
end | |
# Loop through the parser events and generate a method for each event. If it's | |
# one of the _new methods, then use arrays like SexpBuilderPP. If it's an _add | |
# method then just append to the array. If it's a normal method, then create a | |
# new node and determine its bounds. | |
PARSER_EVENT_TABLE.each do |event, arity| | |
case event | |
when :stmts_new | |
define_method(:on_stmts_new) do | |
StmtsNode.new(:stmts, [], loc.then { |start| start..start }) | |
end | |
when /\A(.+)_new\z/ | |
prefix = $1.to_sym | |
define_method(:"on_#{event}") do | |
Node.new(prefix, [], loc.then { |start| start..start }) | |
end | |
when /_add\z/ | |
define_method(:"on_#{event}") do |node, value| | |
range = | |
node.body.empty? ? value.range : (node.range.begin..value.range.end) | |
node.class.new(node.event, node.body + [value], range) | |
end | |
else | |
define_method(:"on_#{event}") do |*args| | |
first, *, last = args.grep(Node).map(&:range) | |
first ||= loc.then { |start| start..start } | |
last ||= first | |
Node.new(event, args, first.begin..[last.end, loc].max) | |
end | |
end | |
end | |
module Patterns | |
# T.must(foo) | |
class TMustParensPattern < Pattern | |
def replace(segment) | |
segment.gsub(/(T\s*\.must\(\s*)(.+)(\s*\))(.*)/) do | |
"#{' ' * $1.length}#{$2}#{' ' * $3.length}#{$4}" | |
end | |
end | |
end | |
# T.let(foo, bar) | |
class TLetParensPattern < Pattern | |
def replace(segment) | |
segment.gsub(/(T\s*\.let\(\s*)(.+)(\s*,.+\))(.*)/) do | |
"#{' ' * $1.length}#{$2}#{' ' * $3.length}#{$4}" | |
end | |
end | |
end | |
def on_method_add_arg(call, arg_paren) | |
# T.must(foo) | |
if call.match?('<call <var_ref <@const T>> <@period .> <@ident must>>') && | |
arg_paren.match?(/<arg_paren <args_add_block <args .+> false>>/) | |
patterns << TMustParensPattern.new(call.range.begin..arg_paren.range.end) | |
end | |
# T.let(foo, bar) | |
if call.match?('<call <var_ref <@const T>> <@period .> <@ident let>>') && | |
arg_paren.match?(/<arg_paren <args_add_block <args .+> false>>/) | |
patterns << TLetParensPattern.new(call.range.begin..arg_paren.range.end) | |
end | |
super | |
end | |
# extend T::Sig | |
class ExtendTSigPattern < Pattern | |
def replace(segment) | |
segment.gsub(/(extend\s+T::Sig)(.*)/) do | |
"#{' ' * $1.length}#{$2}" | |
end | |
end | |
end | |
def on_command(ident, args_add_block) | |
# extend T::Sig | |
if ident.match?('<@ident extend>') && | |
args_add_block.match?('<args_add_block <args <const_path_ref <var_ref <@const T>> <@const Sig>>> false>') | |
patterns << ExtendTSigPattern.new(ident.range.begin..args_add_block.range.end) | |
end | |
super | |
end | |
# T.must foo | |
class TMustNoParensPattern < Pattern | |
def replace(segment) | |
segment.gsub(/(T\s*\.must\s*)(.+)/) do | |
"#{' ' * $1.length}#{$2}" | |
end | |
end | |
end | |
def on_command_call(var_ref, period, ident, args_add_block) | |
if var_ref.match?('<var_ref <@const T>>') && period.match?('<@period .>') | |
# T.must foo | |
if ident.match?('<@ident must>') && | |
args_add_block.match?(/<args_add_block <args <.+>> false>/) && | |
args_add_block.body[0].body.length == 1 | |
patterns << TMustNoParensPattern.new(var_ref.range.begin..args_add_block.range.end) | |
end | |
end | |
super | |
end | |
# sig { foo } | |
class SigBracesPattern < Pattern | |
def replace(segment) | |
segment.gsub(/(sig\s*\{.+\})(.*)/) do | |
"#{' ' * $1.length}#{$2}" | |
end | |
end | |
end | |
def on_stmts_add(node, value) | |
# sig { foo } | |
if node.sigs? && value.match?(/<method_add_block <method_add_arg <fcall <@ident sig>> <args >> <brace_block <stmts .+>>>/) | |
patterns << SigBracesPattern.new(value.range) | |
end | |
super | |
end | |
end | |
prepend Patterns | |
end | |
# Hook into bootsnap so that before the source is compiled through RubyVM::ISeq | |
# it gets erased through the eraser. | |
if RubyVM::InstructionSequence.method_defined?(:load_iseq) | |
load_iseq, = RubyVM::InstructionSequence.method(:load_iseq).source_location | |
if load_iseq.include?('/bootsnap/') | |
module Patch | |
def input_to_storage(contents, filepath) | |
erased = Eraser.erase(contents) | |
RubyVM::InstructionSequence.compile(erased, filepath, filepath).to_binary | |
rescue SyntaxError | |
raise ::Bootsnap::CompileCache::Uncompilable, 'syntax error' | |
end | |
end | |
Bootsnap::CompileCache::ISeq.singleton_class.prepend(Patch) | |
end | |
end | |
if $0 == __FILE__ | |
raise if Eraser.erase(DATA.read) != <<~RUBY | |
class Foo | |
def foo | |
'foo' | |
end | |
def bar | |
foo | |
end | |
def baz | |
foo | |
end | |
end | |
RUBY | |
end | |
__END__ | |
class Foo | |
extend T::Sig | |
sig { returns(T.nilable(String)) } | |
def foo | |
'foo' | |
end | |
sig { returns(String) } | |
def bar | |
T.must(foo) | |
end | |
sig { returns(String) } | |
def baz | |
T.must foo | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment