This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
https://github.com/iree-org/iree-turbine/blob/main/iree/turbine/aot/fx_programs.py | |
Also ai-edge torch exporter |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from ai_edge_torch.odml_torch.export import exported_program_to_mlir | |
import torch | |
class PowModel(torch.nn.Module): | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return x ** 0.5 | |
model = PowModel() | |
print(model(torch.tensor(2))) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from onnx_diagnostic import torch_export_patches | |
from onnxscript.ir.passes.common import clear_metadata_and_docstring | |
from transformers import AttentionInterface, AutoModelForCausalLM, AutoTokenizer | |
from transformers.cache_utils import DynamicCache | |
# Get position_ids from attention_mask | |
def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Owner(s): ["module: onnx"] | |
"""Unit LLM tests for the onnx dynamo exporter.""" | |
from __future__ import annotations | |
from typing import Any | |
import logging | |
import transformers |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import torch | |
from torch_geometric.nn import GAT | |
logger = logging.getLogger(__name__) | |
logging.getLogger('torch.onnx').setLevel(logging.INFO) | |
logger.info("Prepare model") | |
num_features = 23 | |
num_classes = 12 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Display all PyTorch ONNX exporter supported ops. | |
NOTE: This is using internal methods. Do not use it in production code. | |
NOTE: Ops implemented via decomp may not be supported because they may still be decomposed | |
into ops that are without native implementation. They include some backward ops, | |
svd, sq, and some others. | |
""" | |
from torch.onnx._internal.exporter import _decomp, _registration |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# ************************************************************************************** | |
# NOTE: Users can now use https://github.com/microsoft/onnxscript/blob/main/onnxscript/ir/passes/common/constant_manipulation.py | |
# aka. onnxscript.ir.passes.common.constant_manipulation.LiftConstantsToInitializersPass | |
# ************************************************************************************** | |
from onnxscript import ir | |
def convert_constants_to_initizliers(model: ir.Model, size_limit: int = 1024): | |
"""Convert constant nodes to initializers.""" | |
for node in model.graph: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from onnxscript import ir | |
def fold_transpose_initializers(model: ir.Model): | |
for name, initializer in model.graph.initializers.items(): | |
user_nodes = initializer.consumers() | |
if len(user_nodes) == 1 and user_nodes[0].op_type == "Transpose": | |
transpose_node = user_nodes[0] | |
perm = transpose_node.attributes.get("perm") | |
if perm is None: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from onnxscript import ir | |
def _all_values(model: ir.Model): | |
"""Yield all values in a model.""" | |
yield from model.graph.inputs | |
yield from model.graph.initializers.values() | |
for node in ir.traversal.RecursiveGraphIterator(model.graph): | |
yield from node.outputs |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Metadata-Version: 2.2 | |
Name: onnxscript | |
Version: 0.1.0.dev20250113 | |
Summary: Naturally author ONNX functions and models using a subset of Python | |
Author-email: Microsoft Corporation <[email protected]> | |
License: MIT License | |
Copyright (c) Microsoft Corporation | |
Permission is hereby granted, free of charge, to any person obtaining a copy |
NewerOlder