Last active
August 28, 2024 19:24
-
-
Save justinchuby/797ca805e0a53e0ceec6ef7d76647efd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from typing import Sequence | |
import torch | |
import torch_onnx | |
import torch_onnx.tools.diff_model | |
from onnxscript import ir | |
import onnxscript | |
import onnxscript.rewriter.pattern as orp | |
def predecessors(node: ir.Node) -> Sequence[ir.Node]: | |
results = [] | |
for inp in node.inputs: | |
if inp is None: | |
continue | |
if inp.producer() is not None: | |
results.append(inp.producer()) | |
return results | |
def successors(node: ir.Node) -> Sequence[ir.Node]: | |
results = [] | |
for out in node.outputs: | |
for user, _ in out.uses(): | |
results.append(user) | |
return results | |
# def find_names_dynamo(model: ir.Model): | |
# names = [] | |
# for node in model.graph: | |
# if node.op_type == "LayerNormalization": | |
# names.append(predecessors(node)[0].name) | |
# names.append(successors(node)[0].name) | |
# return names | |
def find_names(model: ir.Model): | |
names = [] | |
for node in model.graph: | |
if node.op_type == "LayerNormalization": | |
names.append(node.inputs[0].name) | |
names.append(node.outputs[0].name) | |
return names | |
def cast_pow_cast(op, x, y): | |
x = op.Cast(x, to=ir.DataType.FLOAT) | |
pow = op.Pow(x, y) | |
return op.Cast(pow, to=ir.DataType.FLOAT16) | |
# Replacement | |
def fp16_pow(op, x, y): | |
y = op.Cast(y, to=ir.DataType.FLOAT16) | |
return op.Pow(x, y) | |
cast_pow_cast_rule = orp.RewriteRule(cast_pow_cast, fp16_pow) | |
def rewrite(path: str): | |
model = ir.load(path) | |
model = onnxscript.rewriter.rewrite( | |
model, pattern_rewrite_rules=[cast_pow_cast_rule] | |
) | |
dir = os.path.dirname(path) | |
file = os.path.basename(path) | |
rewritten_file = os.path.join(dir, f"rewritten_{file}") | |
ir.save(model, rewritten_file) | |
def main(): | |
# data is 1x512 | |
data = (torch.randint(1, 128, (1, 512)),) | |
# # rewrite("accuracy_investigation/gpt15/model.onnx") | |
# dynamo_model = ir.load("accuracy_investigation/gpt15/model.onnx") | |
# # dynamo_model = ir.load("accuracy_investigation/gpt15/rewritten_model.onnx") | |
# torchscript_model = ir.load("accuracy_investigation/gpt15/torchscript/model.onnx") | |
# dynamo_names = find_names(dynamo_model) | |
# print(dynamo_names) | |
# torchscript_names = find_names(torchscript_model) | |
# print(torchscript_names) | |
# assert len(dynamo_names) == len(torchscript_names) | |
results, _ = torch_onnx.tools.diff_model.diff( | |
"accuracy_investigation/gpt15/model.onnx", | |
"accuracy_investigation/gpt15/torchscript/model.onnx", | |
[ | |
# ("val_0", "model/model/layers.0/self_attn/q_proj/MatMul_output_0"), # 0.0 | |
# ("add", "model/model/layers.0/self_attn/Add_output_0"), # 0.0 | |
# ("cat_2", "model/model/layers.0/self_attn/Concat_5_output_0"), # 0.0 | |
# ("add_1", "model/model/layers.0/self_attn/Add_1_output_0"), # 0.0 | |
# ("cat_3", "key_states"), # 0.0 | |
# ("val_237", "model/model/layers.0/self_attn/Transpose_3_output_0"), # 0.0 | |
# ("val_241", "model/model/layers.0/self_attn/Sqrt_2_output_0") # 0.0 | |
("val_242", "model/model/layers.0/self_attn/Mul_5_output_0") | |
# ("val_239", "model/model/layers.0/self_attn/Mul_4_output_0"), # 0.0009765625 | |
# ("val_243", "model/model/layers.0/self_attn/MatMul_output_0"), # 0.00390625 | |
# ("val_244", "model/model/layers.0/self_attn/Softmax_output_0"), # 0.999 | |
# ("view_10", "model/model/layers.0/self_attn/Reshape_3_output_0"), # 3.49 | |
], | |
# list(zip(dynamo_names, torchscript_names)), | |
data, | |
keep_original_outputs=False, | |
) | |
for i, result in enumerate(results): | |
print(f"Result {i}----------------------------") | |
print(result) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
['view_14', 'convert_element_type_default_2', 'view_30', 'convert_element_type_default_4', 'view_46', 'convert_element_type_default_6', 'view_62', 'convert_element_type_default_8', 'view_78', 'convert_element_type_default_10', 'view_94', 'convert_element_type_default_12', 'view_110', 'convert_element_type_default_14', 'view_126', 'convert_element_type_default_16', 'view_142', 'convert_element_type_default_18', 'view_158', 'convert_element_type_default_20', 'view_174', 'convert_element_type_default_22', 'view_190', 'convert_element_type_default_24', 'view_206', 'convert_element_type_default_26', 'view_222', 'convert_element_type_default_28', 'view_238', 'convert_element_type_default_30', 'view_254', 'convert_element_type_default_32', 'view_270', 'convert_element_type_default_34', 'view_286', 'convert_element_type_default_36', 'view_302', 'convert_element_type_default_38', 'view_318', 'convert_element_type_default_40', 'view_334', 'convert_element_type_default_42', 'view_350', 'convert_element_type_default_44', 'view_366', 'convert_element_type_default_46', 'view_382', 'convert_element_type_default_48']
['model/model/layers.0/mlp/fc1/Add_output_0', 'model/model/layers.0/mlp/activation_fn/Pow_output_0', 'model/model/layers.1/mlp/fc1/Add_output_0', 'model/model/layers.1/mlp/activation_fn/Pow_output_0', 'model/model/layers.2/mlp/fc1/Add_output_0', 'model/model/layers.2/mlp/activation_fn/Pow_output_0', 'model/model/layers.3/mlp/fc1/Add_output_0', 'model/model/layers.3/mlp/activation_fn/Pow_output_0', 'model/model/layers.4/mlp/fc1/Add_output_0', 'model/model/layers.4/mlp/activation_fn/Pow_output_0', 'model/model/layers.5/mlp/fc1/Add_output_0', 'model/model/layers.5/mlp/activation_fn/Pow_output_0', 'model/model/layers.6/mlp/fc1/Add_output_0', 'model/model/layers.6/mlp/activation_fn/Pow_output_0', 'model/model/layers.7/mlp/fc1/Add_output_0', 'model/model/layers.7/mlp/activation_fn/Pow_output_0', 'model/model/layers.8/mlp/fc1/Add_output_0', 'model/model/layers.8/mlp/activation_fn/Pow_output_0', 'model/model/layers.9/mlp/fc1/Add_output_0', 'model/model/layers.9/mlp/activation_fn/Pow_output_0', 'model/model/layers.10/mlp/fc1/Add_output_0', 'model/model/layers.10/mlp/activation_fn/Pow_output_0', 'model/model/layers.11/mlp/fc1/Add_output_0', 'model/model/layers.11/mlp/activation_fn/Pow_output_0', 'model/model/layers.12/mlp/fc1/Add_output_0', 'model/model/layers.12/mlp/activation_fn/Pow_output_0', 'model/model/layers.13/mlp/fc1/Add_output_0', 'model/model/layers.13/mlp/activation_fn/Pow_output_0', 'model/model/layers.14/mlp/fc1/Add_output_0', 'model/model/layers.14/mlp/activation_fn/Pow_output_0', 'model/model/layers.15/mlp/fc1/Add_output_0', 'model/model/layers.15/mlp/activation_fn/Pow_output_0', 'model/model/layers.16/mlp/fc1/Add_output_0', 'model/model/layers.16/mlp/activation_fn/Pow_output_0', 'model/model/layers.17/mlp/fc1/Add_output_0', 'model/model/layers.17/mlp/activation_fn/Pow_output_0', 'model/model/layers.18/mlp/fc1/Add_output_0', 'model/model/layers.18/mlp/activation_fn/Pow_output_0', 'model/model/layers.19/mlp/fc1/Add_output_0', 'model/model/layers.19/mlp/activation_fn/Pow_output_0', 'model/model/layers.20/mlp/fc1/Add_output_0', 'model/model/layers.20/mlp/activation_fn/Pow_output_0', 'model/model/layers.21/mlp/fc1/Add_output_0', 'model/model/layers.21/mlp/activation_fn/Pow_output_0', 'model/model/layers.22/mlp/fc1/Add_output_0', 'model/model/layers.22/mlp/activation_fn/Pow_output_0', 'model/model/layers.23/mlp/fc1/Add_output_0', 'model/model/layers.23/mlp/activation_fn/Pow_output_0']