-
-
Save justinchuby/797ca805e0a53e0ceec6ef7d76647efd to your computer and use it in GitHub Desktop.
import os | |
from typing import Sequence | |
import torch | |
import torch_onnx | |
import torch_onnx.tools.diff_model | |
from onnxscript import ir | |
import onnxscript | |
import onnxscript.rewriter.pattern as orp | |
def predecessors(node: ir.Node) -> Sequence[ir.Node]: | |
results = [] | |
for inp in node.inputs: | |
if inp is None: | |
continue | |
if inp.producer() is not None: | |
results.append(inp.producer()) | |
return results | |
def successors(node: ir.Node) -> Sequence[ir.Node]: | |
results = [] | |
for out in node.outputs: | |
for user, _ in out.uses(): | |
results.append(user) | |
return results | |
# def find_names_dynamo(model: ir.Model): | |
# names = [] | |
# for node in model.graph: | |
# if node.op_type == "LayerNormalization": | |
# names.append(predecessors(node)[0].name) | |
# names.append(successors(node)[0].name) | |
# return names | |
def find_names(model: ir.Model): | |
names = [] | |
for node in model.graph: | |
if node.op_type == "LayerNormalization": | |
names.append(node.inputs[0].name) | |
names.append(node.outputs[0].name) | |
return names | |
def cast_pow_cast(op, x, y): | |
x = op.Cast(x, to=ir.DataType.FLOAT) | |
pow = op.Pow(x, y) | |
return op.Cast(pow, to=ir.DataType.FLOAT16) | |
# Replacement | |
def fp16_pow(op, x, y): | |
y = op.Cast(y, to=ir.DataType.FLOAT16) | |
return op.Pow(x, y) | |
cast_pow_cast_rule = orp.RewriteRule(cast_pow_cast, fp16_pow) | |
def rewrite(path: str): | |
model = ir.load(path) | |
model = onnxscript.rewriter.rewrite( | |
model, pattern_rewrite_rules=[cast_pow_cast_rule] | |
) | |
dir = os.path.dirname(path) | |
file = os.path.basename(path) | |
rewritten_file = os.path.join(dir, f"rewritten_{file}") | |
ir.save(model, rewritten_file) | |
def main(): | |
# data is 1x512 | |
data = (torch.randint(1, 128, (1, 512)),) | |
# # rewrite("accuracy_investigation/gpt15/model.onnx") | |
# dynamo_model = ir.load("accuracy_investigation/gpt15/model.onnx") | |
# # dynamo_model = ir.load("accuracy_investigation/gpt15/rewritten_model.onnx") | |
# torchscript_model = ir.load("accuracy_investigation/gpt15/torchscript/model.onnx") | |
# dynamo_names = find_names(dynamo_model) | |
# print(dynamo_names) | |
# torchscript_names = find_names(torchscript_model) | |
# print(torchscript_names) | |
# assert len(dynamo_names) == len(torchscript_names) | |
results, _ = torch_onnx.tools.diff_model.diff( | |
"accuracy_investigation/gpt15/model.onnx", | |
"accuracy_investigation/gpt15/torchscript/model.onnx", | |
[ | |
# ("val_0", "model/model/layers.0/self_attn/q_proj/MatMul_output_0"), # 0.0 | |
# ("add", "model/model/layers.0/self_attn/Add_output_0"), # 0.0 | |
# ("cat_2", "model/model/layers.0/self_attn/Concat_5_output_0"), # 0.0 | |
# ("add_1", "model/model/layers.0/self_attn/Add_1_output_0"), # 0.0 | |
# ("cat_3", "key_states"), # 0.0 | |
# ("val_237", "model/model/layers.0/self_attn/Transpose_3_output_0"), # 0.0 | |
# ("val_241", "model/model/layers.0/self_attn/Sqrt_2_output_0") # 0.0 | |
("val_242", "model/model/layers.0/self_attn/Mul_5_output_0") | |
# ("val_239", "model/model/layers.0/self_attn/Mul_4_output_0"), # 0.0009765625 | |
# ("val_243", "model/model/layers.0/self_attn/MatMul_output_0"), # 0.00390625 | |
# ("val_244", "model/model/layers.0/self_attn/Softmax_output_0"), # 0.999 | |
# ("view_10", "model/model/layers.0/self_attn/Reshape_3_output_0"), # 3.49 | |
], | |
# list(zip(dynamo_names, torchscript_names)), | |
data, | |
keep_original_outputs=False, | |
) | |
for i, result in enumerate(results): | |
print(f"Result {i}----------------------------") | |
print(result) | |
if __name__ == "__main__": | |
main() |
['view_14', 'convert_element_type_default_2', 'view_30', 'convert_element_type_default_4', 'view_46', 'convert_element_type_default_6', 'view_62', 'convert_element_type_default_8', 'view_78', 'convert_element_type_default_10', 'view_94', 'convert_element_type_default_12', 'view_110', 'convert_element_type_default_14', 'view_126', 'convert_element_type_default_16', 'view_142', 'convert_element_type_default_18', 'view_158', 'convert_element_type_default_20', 'view_174', 'convert_element_type_default_22', 'view_190', 'convert_element_type_default_24', 'view_206', 'convert_element_type_default_26', 'view_222', 'convert_element_type_default_28', 'view_238', 'convert_element_type_default_30', 'view_254', 'convert_element_type_default_32', 'view_270', 'convert_element_type_default_34', 'view_286', 'convert_element_type_default_36', 'view_302', 'convert_element_type_default_38', 'view_318', 'convert_element_type_default_40', 'view_334', 'convert_element_type_default_42', 'view_350', 'convert_element_type_default_44', 'view_366', 'convert_element_type_default_46', 'view_382', 'convert_element_type_default_48']
['model/model/layers.0/mlp/fc1/Add_output_0', 'model/model/layers.0/mlp/activation_fn/Pow_output_0', 'model/model/layers.1/mlp/fc1/Add_output_0', 'model/model/layers.1/mlp/activation_fn/Pow_output_0', 'model/model/layers.2/mlp/fc1/Add_output_0', 'model/model/layers.2/mlp/activation_fn/Pow_output_0', 'model/model/layers.3/mlp/fc1/Add_output_0', 'model/model/layers.3/mlp/activation_fn/Pow_output_0', 'model/model/layers.4/mlp/fc1/Add_output_0', 'model/model/layers.4/mlp/activation_fn/Pow_output_0', 'model/model/layers.5/mlp/fc1/Add_output_0', 'model/model/layers.5/mlp/activation_fn/Pow_output_0', 'model/model/layers.6/mlp/fc1/Add_output_0', 'model/model/layers.6/mlp/activation_fn/Pow_output_0', 'model/model/layers.7/mlp/fc1/Add_output_0', 'model/model/layers.7/mlp/activation_fn/Pow_output_0', 'model/model/layers.8/mlp/fc1/Add_output_0', 'model/model/layers.8/mlp/activation_fn/Pow_output_0', 'model/model/layers.9/mlp/fc1/Add_output_0', 'model/model/layers.9/mlp/activation_fn/Pow_output_0', 'model/model/layers.10/mlp/fc1/Add_output_0', 'model/model/layers.10/mlp/activation_fn/Pow_output_0', 'model/model/layers.11/mlp/fc1/Add_output_0', 'model/model/layers.11/mlp/activation_fn/Pow_output_0', 'model/model/layers.12/mlp/fc1/Add_output_0', 'model/model/layers.12/mlp/activation_fn/Pow_output_0', 'model/model/layers.13/mlp/fc1/Add_output_0', 'model/model/layers.13/mlp/activation_fn/Pow_output_0', 'model/model/layers.14/mlp/fc1/Add_output_0', 'model/model/layers.14/mlp/activation_fn/Pow_output_0', 'model/model/layers.15/mlp/fc1/Add_output_0', 'model/model/layers.15/mlp/activation_fn/Pow_output_0', 'model/model/layers.16/mlp/fc1/Add_output_0', 'model/model/layers.16/mlp/activation_fn/Pow_output_0', 'model/model/layers.17/mlp/fc1/Add_output_0', 'model/model/layers.17/mlp/activation_fn/Pow_output_0', 'model/model/layers.18/mlp/fc1/Add_output_0', 'model/model/layers.18/mlp/activation_fn/Pow_output_0', 'model/model/layers.19/mlp/fc1/Add_output_0', 'model/model/layers.19/mlp/activation_fn/Pow_output_0', 'model/model/layers.20/mlp/fc1/Add_output_0', 'model/model/layers.20/mlp/activation_fn/Pow_output_0', 'model/model/layers.21/mlp/fc1/Add_output_0', 'model/model/layers.21/mlp/activation_fn/Pow_output_0', 'model/model/layers.22/mlp/fc1/Add_output_0', 'model/model/layers.22/mlp/activation_fn/Pow_output_0', 'model/model/layers.23/mlp/fc1/Add_output_0', 'model/model/layers.23/mlp/activation_fn/Pow_output_0']
Uh oh!
There was an error while loading. Please reload this page.