Last active
May 7, 2022 01:33
-
-
Save MilesCranmer/639d6d34bb4d7c91e94f77499a36338f to your computer and use it in GitHub Desktop.
Enable valid Python to be a config.gin file, so code analysis and syntax highlighting works
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def preprocess_config(s: str): | |
| """Remove imports from a string representation of a python file""" | |
| # We assume that imports are not multi-line. | |
| lines = s.splitlines() | |
| out_lines = [] | |
| for line in lines: | |
| # Skip lines with import in them: | |
| if 'import' in line: | |
| continue | |
| # We add "@" to each symbol after the = sign. | |
| # For example: | |
| # `MlpTowerFactory.nonlinearity = softplus`` | |
| # becomes: | |
| # `MlpTowerFactory.nonlinearity = @softplus` | |
| # However, this does not happen if a % symbol is there. | |
| # It also doesn't happen if the symbol (softplus in this example) | |
| # starts with a number. | |
| if "=" in line and "%" not in line and not line.startswith("#"): | |
| tokens = line.replace(" ", "").split("=") | |
| assert len(tokens) == 2 | |
| if tokens[-1][0].isalpha(): | |
| line = f"{tokens[0]} = @{tokens[-1]}" | |
| out_lines.append(line) | |
| return '\n'.join(out_lines) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
e.g.,
config.gin
Now when you hover over
function, Python analysis tools can correctly find the source of it, rather than the@destroying the reference.Simply load a config file, run
preprocess_configon it, then pass it togin.