Skip to content

Instantly share code, notes, and snippets.

@justinchuby
Last active August 28, 2024 19:24
Show Gist options
  • Save justinchuby/797ca805e0a53e0ceec6ef7d76647efd to your computer and use it in GitHub Desktop.
Save justinchuby/797ca805e0a53e0ceec6ef7d76647efd to your computer and use it in GitHub Desktop.
import os
from typing import Sequence
import torch
import torch_onnx
import torch_onnx.tools.diff_model
from onnxscript import ir
import onnxscript
import onnxscript.rewriter.pattern as orp
def predecessors(node: ir.Node) -> Sequence[ir.Node]:
results = []
for inp in node.inputs:
if inp is None:
continue
if inp.producer() is not None:
results.append(inp.producer())
return results
def successors(node: ir.Node) -> Sequence[ir.Node]:
results = []
for out in node.outputs:
for user, _ in out.uses():
results.append(user)
return results
# def find_names_dynamo(model: ir.Model):
# names = []
# for node in model.graph:
# if node.op_type == "LayerNormalization":
# names.append(predecessors(node)[0].name)
# names.append(successors(node)[0].name)
# return names
def find_names(model: ir.Model):
names = []
for node in model.graph:
if node.op_type == "LayerNormalization":
names.append(node.inputs[0].name)
names.append(node.outputs[0].name)
return names
def cast_pow_cast(op, x, y):
x = op.Cast(x, to=ir.DataType.FLOAT)
pow = op.Pow(x, y)
return op.Cast(pow, to=ir.DataType.FLOAT16)
# Replacement
def fp16_pow(op, x, y):
y = op.Cast(y, to=ir.DataType.FLOAT16)
return op.Pow(x, y)
cast_pow_cast_rule = orp.RewriteRule(cast_pow_cast, fp16_pow)
def rewrite(path: str):
model = ir.load(path)
model = onnxscript.rewriter.rewrite(
model, pattern_rewrite_rules=[cast_pow_cast_rule]
)
dir = os.path.dirname(path)
file = os.path.basename(path)
rewritten_file = os.path.join(dir, f"rewritten_{file}")
ir.save(model, rewritten_file)
def main():
# data is 1x512
data = (torch.randint(1, 128, (1, 512)),)
# # rewrite("accuracy_investigation/gpt15/model.onnx")
# dynamo_model = ir.load("accuracy_investigation/gpt15/model.onnx")
# # dynamo_model = ir.load("accuracy_investigation/gpt15/rewritten_model.onnx")
# torchscript_model = ir.load("accuracy_investigation/gpt15/torchscript/model.onnx")
# dynamo_names = find_names(dynamo_model)
# print(dynamo_names)
# torchscript_names = find_names(torchscript_model)
# print(torchscript_names)
# assert len(dynamo_names) == len(torchscript_names)
results, _ = torch_onnx.tools.diff_model.diff(
"accuracy_investigation/gpt15/model.onnx",
"accuracy_investigation/gpt15/torchscript/model.onnx",
[
# ("val_0", "model/model/layers.0/self_attn/q_proj/MatMul_output_0"), # 0.0
# ("add", "model/model/layers.0/self_attn/Add_output_0"), # 0.0
# ("cat_2", "model/model/layers.0/self_attn/Concat_5_output_0"), # 0.0
# ("add_1", "model/model/layers.0/self_attn/Add_1_output_0"), # 0.0
# ("cat_3", "key_states"), # 0.0
# ("val_237", "model/model/layers.0/self_attn/Transpose_3_output_0"), # 0.0
# ("val_241", "model/model/layers.0/self_attn/Sqrt_2_output_0") # 0.0
("val_242", "model/model/layers.0/self_attn/Mul_5_output_0")
# ("val_239", "model/model/layers.0/self_attn/Mul_4_output_0"), # 0.0009765625
# ("val_243", "model/model/layers.0/self_attn/MatMul_output_0"), # 0.00390625
# ("val_244", "model/model/layers.0/self_attn/Softmax_output_0"), # 0.999
# ("view_10", "model/model/layers.0/self_attn/Reshape_3_output_0"), # 3.49
],
# list(zip(dynamo_names, torchscript_names)),
data,
keep_original_outputs=False,
)
for i, result in enumerate(results):
print(f"Result {i}----------------------------")
print(result)
if __name__ == "__main__":
main()
@justinchuby
Copy link
Author

['view_14', 'convert_element_type_default_2', 'view_30', 'convert_element_type_default_4', 'view_46', 'convert_element_type_default_6', 'view_62', 'convert_element_type_default_8', 'view_78', 'convert_element_type_default_10', 'view_94', 'convert_element_type_default_12', 'view_110', 'convert_element_type_default_14', 'view_126', 'convert_element_type_default_16', 'view_142', 'convert_element_type_default_18', 'view_158', 'convert_element_type_default_20', 'view_174', 'convert_element_type_default_22', 'view_190', 'convert_element_type_default_24', 'view_206', 'convert_element_type_default_26', 'view_222', 'convert_element_type_default_28', 'view_238', 'convert_element_type_default_30', 'view_254', 'convert_element_type_default_32', 'view_270', 'convert_element_type_default_34', 'view_286', 'convert_element_type_default_36', 'view_302', 'convert_element_type_default_38', 'view_318', 'convert_element_type_default_40', 'view_334', 'convert_element_type_default_42', 'view_350', 'convert_element_type_default_44', 'view_366', 'convert_element_type_default_46', 'view_382', 'convert_element_type_default_48']
['model/model/layers.0/mlp/fc1/Add_output_0', 'model/model/layers.0/mlp/activation_fn/Pow_output_0', 'model/model/layers.1/mlp/fc1/Add_output_0', 'model/model/layers.1/mlp/activation_fn/Pow_output_0', 'model/model/layers.2/mlp/fc1/Add_output_0', 'model/model/layers.2/mlp/activation_fn/Pow_output_0', 'model/model/layers.3/mlp/fc1/Add_output_0', 'model/model/layers.3/mlp/activation_fn/Pow_output_0', 'model/model/layers.4/mlp/fc1/Add_output_0', 'model/model/layers.4/mlp/activation_fn/Pow_output_0', 'model/model/layers.5/mlp/fc1/Add_output_0', 'model/model/layers.5/mlp/activation_fn/Pow_output_0', 'model/model/layers.6/mlp/fc1/Add_output_0', 'model/model/layers.6/mlp/activation_fn/Pow_output_0', 'model/model/layers.7/mlp/fc1/Add_output_0', 'model/model/layers.7/mlp/activation_fn/Pow_output_0', 'model/model/layers.8/mlp/fc1/Add_output_0', 'model/model/layers.8/mlp/activation_fn/Pow_output_0', 'model/model/layers.9/mlp/fc1/Add_output_0', 'model/model/layers.9/mlp/activation_fn/Pow_output_0', 'model/model/layers.10/mlp/fc1/Add_output_0', 'model/model/layers.10/mlp/activation_fn/Pow_output_0', 'model/model/layers.11/mlp/fc1/Add_output_0', 'model/model/layers.11/mlp/activation_fn/Pow_output_0', 'model/model/layers.12/mlp/fc1/Add_output_0', 'model/model/layers.12/mlp/activation_fn/Pow_output_0', 'model/model/layers.13/mlp/fc1/Add_output_0', 'model/model/layers.13/mlp/activation_fn/Pow_output_0', 'model/model/layers.14/mlp/fc1/Add_output_0', 'model/model/layers.14/mlp/activation_fn/Pow_output_0', 'model/model/layers.15/mlp/fc1/Add_output_0', 'model/model/layers.15/mlp/activation_fn/Pow_output_0', 'model/model/layers.16/mlp/fc1/Add_output_0', 'model/model/layers.16/mlp/activation_fn/Pow_output_0', 'model/model/layers.17/mlp/fc1/Add_output_0', 'model/model/layers.17/mlp/activation_fn/Pow_output_0', 'model/model/layers.18/mlp/fc1/Add_output_0', 'model/model/layers.18/mlp/activation_fn/Pow_output_0', 'model/model/layers.19/mlp/fc1/Add_output_0', 'model/model/layers.19/mlp/activation_fn/Pow_output_0', 'model/model/layers.20/mlp/fc1/Add_output_0', 'model/model/layers.20/mlp/activation_fn/Pow_output_0', 'model/model/layers.21/mlp/fc1/Add_output_0', 'model/model/layers.21/mlp/activation_fn/Pow_output_0', 'model/model/layers.22/mlp/fc1/Add_output_0', 'model/model/layers.22/mlp/activation_fn/Pow_output_0', 'model/model/layers.23/mlp/fc1/Add_output_0', 'model/model/layers.23/mlp/activation_fn/Pow_output_0']

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment