This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Convenience methods for constructing (and manipulating?) the IR.""" | |
from __future__ import annotations | |
import collections.abc | |
from typing import Any, Mapping, Sequence | |
from onnxrewriter.experimental.intermediate_representation import _ir | |
This file has been truncated, but you can view the full file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
< | |
ir_version=8, | |
opset_imports={'pkg.onnxscript.torch_lib': 1, 'pkg.torch.2.4.0a0+gitd56ab7b': 1, 'pkg.transformers.4.37.2': 1, '': 18, 'pkg.onnxscript.torch_lib.common': 1}, | |
producer_name='pytorch', | |
producer_version='2.4.0', | |
domain=None, | |
model_version=None, | |
> | |
graph( | |
name=main_graph, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from onnxscript import ir | |
import onnx | |
model_proto = onnx.load("model.onnx") | |
# (not const) -> cast to 16 -> cast to 32 -> Op | |
model = ir.serde.deserialize_model(model_proto) | |
def is_cast(node: ir.Node, dtype: ir.DataType): | |
if node.op_type != "Cast": |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ir_version: 8 | |
producer_name: "pytorch" | |
producer_version: "2.3.0" | |
graph { | |
node { | |
output: "_val_2" | |
name: "Constant_0" | |
op_type: "Constant" | |
attribute { | |
name: "value_floats" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Utilities for traversing the IR graph.""" | |
from __future__ import annotations | |
__all__ = [ | |
"RecursiveGraphIterator", | |
] | |
from typing import Callable, Iterator, Reversible |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
value_name = self.name if self.name is not None else "anonymous:" + str(id(self)) | |
if producer is not None: | |
producer_text = producer.name if producer.name is not None else "anonymous:" + str(id(producer)) | |
else: | |
producer_text = "None" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Adapted from | |
# https://github.com/pytorch/pytorch/blob/b505e8647547f029d0f7df408ee5f2968f757f89/test/test_public_bindings.py#L523 | |
# Original code PyTorch license https://github.com/pytorch/pytorch/blob/main/LICENSE | |
# Modifications Copyright (c) Microsoft Corporation. All rights reserved. | |
# Licensed under the MIT License. | |
from __future__ import annotations | |
import importlib | |
import itertools | |
import os |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class M(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.linear = torch.nn.Linear(5, 10) | |
def forward(self, x): | |
return self.linear(x) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
class UpsampleModel(torch.nn.Module): | |
def forward(self, x): | |
return torch.nn.functional.upsample_bilinear(x, scale_factor=2) | |
model = UpsampleModel() | |
ep = torch.export.export(model, (torch.randn(1, 3, 224, 224),)) | |
ep.run_decompositions() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def save_node_data_for_model_explorer(verification_infos: Collection[VerificationInfo], node_names: list[str], model_name: str = "model" | |
): | |
# https://github.com/google-ai-edge/model-explorer/wiki/4.-API-Guide#create-custom-node-data | |
# This API is unstable and may change in the future. | |
from model_explorer import node_data_builder as ndb | |
for field in ("max_abs_diff", "max_rel_diff"): | |
# Populate values for the main graph in a model. | |
main_graph_results: dict[str, ndb.NodeDataResult] = {} | |
for info, node_name in zip(verification_infos, node_names): |