Created
November 21, 2025 23:43
-
-
Save msullivan/6840f5c38b30b21b708d8b4ff553a95b to your computer and use it in GitHub Desktop.
jury-rigged static analyzer for rust to find tokio::main misuses
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import json | |
| import sys | |
| """ | |
| Jury-rigged static analyzer to detect when we are spawning new tokio runtimes | |
| with tokio::main from within an async context. | |
| # This uses some weird deps: | |
| # https://github.com/sourcegraph/scip (seems basically plausible) | |
| # https://github.com/Beneficial-AI-Foundation/scip-callgraph (was a pain in my ass but probably saved some time) | |
| # Prep commands | |
| rust-analyzer scip . | |
| ~/tmp/scip print --json index.scip > index_scip.json | |
| ~/tmp/scip-callgraph/target/debug/export_call_graph_d3 -- ~/src/e/edgedb-cli/index_scip.json && mv call_graph_d3.json call_graph.json | |
| git grep -n -A1 'tokio'::main src > tokio_mains.txt | |
| # And then | |
| ./analyze.py index_scip.json call_graph.json tokio_mains.txt | |
| """ | |
| EXEMPT = { | |
| 'portable/windows/_get_wsl_distro().', | |
| } | |
| def trim_symbol(s): | |
| return s.split(' ')[-1] | |
| def main(args): | |
| _, scipfile, jfile, mains_file = args | |
| with open(scipfile) as f: | |
| scip_data = json.load(f) | |
| with open(jfile) as f: | |
| graph = json.load(f) | |
| with open(mains_file) as f: | |
| mains = [s.strip() for s in f] | |
| nodes = graph['nodes'] | |
| edges = graph['links'] | |
| bad_funcs = set() | |
| documents = {doc['relative_path'].replace('\\', '/'): doc for doc in scip_data['documents']} | |
| # Match the line numbers with symbol definitions | |
| for main in mains: | |
| if 'tokio''::main' not in main: | |
| continue | |
| filename, num, rest = main.split(':', 2) | |
| num = int(num) | |
| doc = documents[filename] | |
| for symbol in doc['occurrences']: | |
| if symbol.get('symbol_roles', 0) & 1 and symbol['range'][0] == num: | |
| # print(filename, num, "->", symbol['symbol']) | |
| bad_funcs.add(symbol['symbol']) | |
| break | |
| else: | |
| print('MISSED', filename, num) | |
| async_funcs = set() | |
| # We have to look through the original scip data to get symbols | |
| for doc in scip_data['documents']: | |
| for symbol in doc['symbols']: | |
| sig = symbol.get('signature_documentation') | |
| if sig and 'async fn' in sig.get('text'): | |
| async_funcs.add(symbol['symbol']) | |
| # The tokio::main "bad functions" are async but don't have it in their signature anymore! | |
| async_funcs.update(bad_funcs) | |
| sgraph = dict() | |
| for edge in edges: | |
| sgraph.setdefault(edge["source"], []).append(edge["target"]) | |
| # print(bad_funcs) | |
| # print(async_funcs) | |
| # print(sgraph) | |
| async_called = {} | |
| wl = list(async_funcs) | |
| while wl: | |
| s = wl.pop() | |
| for tgt in sgraph.get(s, ()): | |
| if trim_symbol(tgt) in EXEMPT: | |
| continue | |
| # need to do the checks at the outbound side, not the inbound one, | |
| # because we need to tell if the bad functions are getting *called* | |
| if tgt not in async_called: | |
| async_called[tgt] = [s] | |
| wl.append(tgt) | |
| else: | |
| async_called[tgt].append(s) | |
| danger = async_called.keys() & bad_funcs | |
| print(f'{len(bad_funcs)=}') | |
| print(f'{len(async_funcs)=}') | |
| print(f'{len(async_called)=}') | |
| print(f'{len(async_called.keys() | async_funcs)=}') | |
| print(f'{len(danger)=}') | |
| for bad in sorted(danger): | |
| print() | |
| print('==', trim_symbol(bad)) | |
| for path in set(get_all_paths(async_called, bad, (bad,))): | |
| print(' <- '.join([trim_symbol(s) for s in path])) | |
| # Having all the paths doesn't really help! | |
| # break | |
| def get_all_paths(graph, node, path): | |
| if node not in graph: | |
| yield path | |
| return | |
| for tgt in graph[node]: | |
| if tgt not in path: | |
| yield from get_all_paths(graph, tgt, path + (tgt,)) | |
| if __name__ == '__main__': | |
| main(sys.argv) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment