Skip to content

Instantly share code, notes, and snippets.

@Zac-HD
Created August 10, 2024 23:27
Show Gist options
  • Save Zac-HD/e791386964c1cff6c222f6f15ec9434f to your computer and use it in GitHub Desktop.
Save Zac-HD/e791386964c1cff6c222f6f15ec9434f to your computer and use it in GitHub Desktop.
Speed up Shrinkray by pre-stripping comments, type annotations, and docstrings from code
#!/usr/bin/env python
import sys
from io import StringIO
from tokenize import COMMENT, generate_tokens, untokenize
import libcst as cst
import libcst.matchers as m
from libcst.codemod import VisitorBasedCodemodCommand
def cut(code: str) -> str:
"""Make this source code simpler, without changing too much.
* Remove all comments
* Remove all argument and return-type annotations
* Remove docstrings
The idea is to cut down on volume, and via annotations dependencies,
which slow down shrinkray.
"""
code = untokenize(
t for t in generate_tokens(StringIO(code).readline) if t.type != COMMENT
)
context = cst.codemod.CodemodContext()
mod = cst.parse_module(code)
mod = CutJunk(context).transform_module(mod)
return mod.code
class CutJunk(VisitorBasedCodemodCommand):
DESCRIPTION = "Remove stuff to speed up shrinkray"
def leave_AnnAssign(self, original_node, updated_node):
if updated_node.value is None:
return cst.RemoveFromParent()
return cst.Assign(
targets=[cst.AssignTarget( updated_node.target)], value=updated_node.value
)
def leave_FunctionDef(self, original_node, updated_node):
return updated_node.with_changes(returns=None)
def leave_If(self, original_node, updated_node):
if m.matches(updated_node, m.If(test=m.Name("TYPE_CHECKING"))):
return cst.RemoveFromParent()
return updated_node
def leave_Param(self, original_node, updated_node):
return updated_node.with_changes(annotation=None)
def leave_IndentedBlock(self, original_node, updated_node):
"""Remove docstrings."""
body = updated_node.body
if body and m.matches(
body[0],
m.SimpleStatementLine(
[m.Expr(m.SimpleString())],
),
):
return updated_node.with_changes(body=body[1:])
return updated_node
if __name__ == "__main__":
import pathlib
f = pathlib.Path(sys.argv[1])
f.write_text(src := cut(f.read_text()))
# print(src)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment