Created
November 25, 2021 22:11
-
-
Save zhangqiaorjc/b52aeed6c5b1181c18d5fe089e69f485 to your computer and use it in GitHub Desktop.
make_hlo
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| def make_hlo(f, optimize=False, metadata=False, platform=None): | |
| """Utility function for printing JAX-emitted HLO and XLA-compiled HLO. | |
| Args: | |
| f: jax function to return hlo for. | |
| optimize: bool: whether to return platform-specific, XLA-optimized HLO | |
| metadata: bool: whether to include JAX metadata information | |
| platform: Optional[str]: None, 'cpu','gpu','tpu' - platform to compile for, | |
| None uses default. | |
| Returns: | |
| str: HLO in text format. | |
| """ | |
| client = jax.lib.xla_bridge.get_backend(platform) | |
| print_opts = jax.lib.xla_client._xla.HloPrintOptions.short_parsable() | |
| print_opts.print_metadata = metadata | |
| def wrapped_fn(*args, **kwargs): | |
| c = jax.xla_computation(f)(*args, **kwargs) | |
| if optimize: | |
| return client.compile(c).hlo_modules()[0].to_string(print_opts) | |
| else: | |
| return c.as_hlo_module().to_string(print_opts) | |
| return wrapped_fn | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment