Created
May 18, 2023 16:00
-
-
Save pashu123/21c6596d19abe132793c3c7bc70f193d to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Lint as: python3 | |
"""SHARK Importer""" | |
import sys | |
import tempfile | |
import os | |
import hashlib | |
def create_hash(file_name): | |
with open(file_name, "rb") as f: | |
file_hash = hashlib.blake2b(digest_size=64) | |
while chunk := f.read(2**10): | |
file_hash.update(chunk) | |
return file_hash.hexdigest() | |
# List of the supported frontends. | |
supported_frontends = { | |
"tensorflow", | |
"tf", | |
"pytorch", | |
"torch", | |
"tf-lite", | |
"tflite", | |
} | |
class SharkImporter: | |
""" | |
SharkImporter converts frontend modules into a | |
mlir_module. The supported frameworks are tensorflow, | |
pytorch, and tf-lite. | |
... | |
Attributes | |
---------- | |
module : | |
torch, tensorflow or tf-lite module. | |
inputs : | |
inputs to the module, may be required for the shape | |
information. | |
frontend: str | |
frontend to which the module belongs. | |
raw_model_file: str | |
temp tflite model path | |
Methods | |
------- | |
import_mlir(is_dynamic, tracing_required, func_name): | |
is_dynamic: input shapes to be totally dynamic (pytorch specific). | |
tracing_required: whether tracing is required (pytorch specific. | |
func_name: The function to be traced out or imported to mlir. | |
import_debug(is_dynamic, tracing_required, func_name): | |
returns the converted (mlir_module,func_name) with inputs and golden | |
outputs. | |
The inputs and outputs are converted into np array. | |
""" | |
def __init__( | |
self, | |
module, | |
inputs: tuple = (), | |
frontend: str = "torch", | |
raw_model_file: str = "", | |
return_str: bool = False, | |
): | |
self.module = module | |
self.inputs = None if len(inputs) == 0 else inputs | |
self.frontend = frontend | |
if not self.frontend in supported_frontends: | |
print( | |
f"The frontend is not in the supported_frontends: {supported_frontends}" | |
) | |
sys.exit(1) | |
self.raw_model_file = raw_model_file | |
self.return_str = return_str | |
# NOTE: The default function for torch is "forward" and tf-lite is "main". | |
def _torch_mlir(self, is_dynamic, tracing_required, mlir_type): | |
from shark.torch_mlir_utils import get_torch_mlir_module | |
return get_torch_mlir_module( | |
self.module, | |
self.inputs, | |
is_dynamic, | |
tracing_required, | |
self.return_str, | |
mlir_type, | |
) | |
def _tf_mlir(self, func_name, save_dir="."): | |
from iree.compiler import tf as tfc | |
return tfc.compile_module( | |
self.module, | |
exported_names=[func_name], | |
import_only=True, | |
output_file=save_dir, | |
) | |
def _tflite_mlir(self, func_name, save_dir="."): | |
from iree.compiler import tflite as tflitec | |
self.mlir_model = tflitec.compile_file( | |
self.raw_model_file, # in tflite, it is a path to .tflite file, not a tflite interpreter | |
input_type="tosa", | |
import_only=True, | |
output_file=save_dir, | |
) | |
return self.mlir_model | |
# Adds the conversion of the frontend with the private function. | |
def import_mlir( | |
self, | |
is_dynamic=False, | |
tracing_required=False, | |
func_name="forward", | |
save_dir="./shark_tmp/", | |
mlir_type="linalg", | |
): | |
if self.frontend in ["torch", "pytorch"]: | |
if self.inputs == None: | |
print( | |
"Please pass in the inputs, the inputs are required to determine the shape of the mlir_module" | |
) | |
sys.exit(1) | |
return ( | |
self._torch_mlir(is_dynamic, tracing_required, mlir_type), | |
func_name, | |
) | |
if self.frontend in ["tf", "tensorflow"]: | |
return self._tf_mlir(func_name, save_dir), func_name | |
if self.frontend in ["tflite", "tf-lite"]: | |
func_name = "main" | |
return self._tflite_mlir(func_name, save_dir), func_name | |
# Converts the frontend specific tensors into np array. | |
def convert_to_numpy(self, array_tuple: tuple): | |
if self.frontend in ["torch", "pytorch"]: | |
return [x.detach().cpu().numpy() for x in array_tuple] | |
if self.frontend in ["tf", "tensorflow"]: | |
return [x.numpy() for x in array_tuple] | |
# Saves `function_name.npy`, `inputs.npz`, `golden_out.npz` and `model_name.mlir` in the directory `dir`. | |
def save_data( | |
self, dir, model_name, mlir_data, func_name, inputs, outputs | |
): | |
import numpy as np | |
inputs_name = "inputs.npz" | |
outputs_name = "golden_out.npz" | |
func_file_name = "function_name" | |
model_name_mlir = model_name + "_" + self.frontend + ".mlir" | |
print(f"saving {model_name_mlir} to {dir}") | |
try: | |
inputs = [x.cpu().detach() for x in inputs] | |
except AttributeError: | |
try: | |
inputs = [x.numpy() for x in inputs] | |
except AttributeError: | |
inputs = [x for x in inputs] | |
np.savez(os.path.join(dir, inputs_name), *inputs) | |
np.savez(os.path.join(dir, outputs_name), *outputs) | |
np.save(os.path.join(dir, func_file_name), np.array(func_name)) | |
if self.frontend == "torch": | |
with open(os.path.join(dir, model_name_mlir), "wb") as mlir_file: | |
mlir_file.write(mlir_data) | |
hash_gen_attempts = 2 | |
for i in range(hash_gen_attempts): | |
try: | |
mlir_hash = create_hash(os.path.join(dir, model_name_mlir)) | |
except FileNotFoundError as err: | |
if i < hash_gen_attempts: | |
continue | |
else: | |
raise err | |
np.save(os.path.join(dir, "hash"), np.array(mlir_hash)) | |
return | |
def import_debug( | |
self, | |
is_dynamic=False, | |
tracing_required=False, | |
func_name="forward", | |
dir=tempfile.gettempdir(), | |
model_name="model", | |
golden_values=None, | |
): | |
if self.inputs == None: | |
print( | |
f"There is no input provided: {self.inputs}, please provide inputs or simply run import_mlir." | |
) | |
sys.exit(1) | |
model_name_mlir = model_name + "_" + self.frontend + ".mlir" | |
artifact_path = os.path.join(dir, model_name_mlir) | |
imported_mlir = self.import_mlir( | |
is_dynamic, | |
tracing_required, | |
func_name, | |
save_dir=artifact_path, | |
) | |
# TODO: Make sure that any generic function name is accepted. Currently takes in the default function names. | |
# TODO: Check for multiple outputs. | |
if self.frontend in ["torch", "pytorch"]: | |
import torch | |
golden_out = None | |
if golden_values is not None: | |
golden_out = golden_values | |
else: | |
golden_out = self.module(*self.inputs) | |
if torch.is_tensor(golden_out): | |
golden_out = tuple( | |
golden_out.detach().cpu().numpy(), | |
) | |
else: | |
golden_out = self.convert_to_numpy(golden_out) | |
# Save the artifacts in the directory dir. | |
self.save_data( | |
dir, | |
model_name, | |
imported_mlir[0], | |
imported_mlir[1], | |
self.inputs, | |
golden_out, | |
) | |
return ( | |
imported_mlir, | |
self.convert_to_numpy(self.inputs), | |
golden_out, | |
) | |
if self.frontend in ["tf", "tensorflow"]: | |
import tensorflow as tf | |
golden_out = self.module.forward(*self.inputs) | |
if tf.is_tensor(golden_out): | |
golden_out = tuple( | |
golden_out.numpy(), | |
) | |
elif golden_out is tuple: | |
golden_out = self.convert_to_numpy(golden_out) | |
elif hasattr(golden_out, "logits"): | |
# from transformers import TFSequenceClassifierOutput | |
golden_out = golden_out.logits | |
else: | |
golden_out = golden_out.last_hidden_state | |
# Save the artifacts in the directory dir. | |
self.save_data( | |
dir, | |
model_name, | |
imported_mlir[0], | |
imported_mlir[1], | |
self.inputs, | |
golden_out, | |
) | |
return ( | |
imported_mlir, | |
self.convert_to_numpy(self.inputs), | |
golden_out, | |
) | |
if self.frontend in ["tflite", "tf-lite"]: | |
# TODO(Chi): Validate it for tflite models. | |
golden_out = self.module.invoke_tflite(self.inputs) | |
self.save_data( | |
dir, | |
model_name, | |
imported_mlir[0], | |
imported_mlir[1], | |
self.inputs, | |
golden_out, | |
) | |
return ( | |
imported_mlir, | |
self.inputs, | |
golden_out, | |
) | |
def get_f16_inputs(inputs, is_f16, f16_input_mask): | |
if is_f16 == False: | |
return inputs | |
if f16_input_mask == None: | |
return tuple([x.half() for x in inputs]) | |
f16_masked_inputs = [] | |
for i in range(len(inputs)): | |
if f16_input_mask[i]: | |
f16_masked_inputs.append(inputs[i].half()) | |
else: | |
f16_masked_inputs.append(inputs[i]) | |
return tuple(f16_masked_inputs) | |
def change_sdpa_op(fx_g): | |
import torch | |
for node in fx_g.graph.nodes: | |
if node.op == "call_function": | |
if node.target in [torch.ops.aten._scaled_dot_product_efficient_attention]: | |
print(node) | |
with fx_g.graph.inserting_after(node): | |
scaled_dot_product_attention = fx_g.graph.call_function( | |
torch.ops.aten.scaled_dot_product_attention, | |
args=(node.args[0], node.args[1], node.args[2]), | |
kwargs={}, | |
) | |
node.replace_all_uses_with(scaled_dot_product_attention) | |
fx_g.graph.erase_node(node) | |
fx_g.graph.lint() | |
for node in fx_g.graph.nodes: | |
if node.op == "call_function": | |
if node.name.startswith("getitem"): | |
with fx_g.graph.inserting_before(node): | |
if node.args[0].target in [torch.ops.aten.scaled_dot_product_attention]: | |
node.replace_all_uses_with(node.args[0]) | |
fx_g.graph.erase_node(node) | |
fx_g.graph.lint() | |
def transform_fx(fx_g): | |
import torch | |
kwargs_dict = { | |
"dtype": torch.float16, | |
"device": torch.device(type="cpu"), | |
"pin_memory": False, | |
} | |
for node in fx_g.graph.nodes: | |
if node.op == "call_function": | |
if node.target in [ | |
torch.ops.aten.arange, | |
torch.ops.aten.empty, | |
torch.ops.aten.zeros, | |
]: | |
node.kwargs = kwargs_dict | |
# Inputs and outputs of aten.var.mean should be upcasted to fp32. | |
if node.target in [torch.ops.aten.var_mean]: | |
with fx_g.graph.inserting_before(node): | |
new_node = fx_g.graph.call_function( | |
torch.ops.prims.convert_element_type, | |
args=(node.args[0], torch.float32), | |
kwargs={}, | |
) | |
node.args = (new_node, node.args[1]) | |
if node.name.startswith("getitem"): | |
with fx_g.graph.inserting_before(node): | |
if node.args[0].target in [torch.ops.aten.var_mean]: | |
new_node = fx_g.graph.call_function( | |
torch.ops.aten._to_copy, | |
args=(node,), | |
kwargs={"dtype": torch.float16}, | |
) | |
node.append(new_node) | |
node.replace_all_uses_with(new_node) | |
new_node.args = (node,) | |
new_node.kwargs = {"dtype": torch.float16} | |
# aten.empty should be filled with zeros. | |
if node.target in [torch.ops.aten.empty]: | |
with fx_g.graph.inserting_after(node): | |
new_node = fx_g.graph.call_function( | |
torch.ops.aten.zero_, | |
args=(node,), | |
) | |
node.append(new_node) | |
node.replace_all_uses_with(new_node) | |
new_node.args = (node,) | |
fx_g.graph.lint() | |
# Doesn't replace the None type. | |
def change_fx_graph_return_to_tuple(fx_g): | |
for node in fx_g.graph.nodes: | |
if node.op == "output": | |
# output nodes always have one argument | |
node_arg = node.args[0] | |
out_nodes = [] | |
if isinstance(node_arg, list): | |
# Don't return NoneType elements. | |
for out_node in node_arg: | |
if not isinstance(out_node, type(None)): | |
out_nodes.append(out_node) | |
# If there is a single tensor/element to be returned don't | |
# a tuple for it. | |
if len(out_nodes) == 1: | |
node.args = out_nodes | |
else: | |
node.args = (tuple(out_nodes),) | |
fx_g.graph.lint() | |
fx_g.recompile() | |
return fx_g | |
def flatten_training_input(inputs): | |
flattened_input = [] | |
for i in inputs: | |
if isinstance(i, dict): | |
for value in i.values(): | |
flattened_input.append(value.detach()) | |
elif isinstance(i, tuple): | |
for value in i: | |
flattened_input.append(value) | |
else: | |
flattened_input.append(i) | |
return tuple(flattened_input) | |
# Applies fx conversion to the model and imports the mlir. | |
def import_with_fx( | |
model, | |
inputs, | |
is_f16=False, | |
f16_input_mask=None, | |
debug=False, | |
training=False, | |
return_str=False, | |
save_dir=tempfile.gettempdir(), | |
model_name="model", | |
): | |
import torch | |
from torch.fx.experimental.proxy_tensor import make_fx | |
from torch._decomp import get_decompositions | |
golden_values = None | |
if debug: | |
try: | |
golden_values = model(*inputs) | |
except: | |
golden_values = None | |
# TODO: Control the decompositions. | |
fx_g = make_fx( | |
model, | |
decomposition_table=get_decompositions( | |
[ | |
torch.ops.aten.embedding_dense_backward, | |
torch.ops.aten.native_layer_norm_backward, | |
torch.ops.aten.slice_backward, | |
torch.ops.aten.select_backward, | |
torch.ops.aten.norm.ScalarOpt_dim, | |
torch.ops.aten.native_group_norm, | |
torch.ops.aten.upsample_bilinear2d.vec, | |
torch.ops.aten.split.Tensor, | |
torch.ops.aten.split_with_sizes, | |
torch.ops.aten.native_layer_norm, | |
] | |
), | |
)(*inputs) | |
fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) | |
fx_g.recompile() | |
def strip_overloads(gm): | |
""" | |
Modifies the target of graph nodes in :attr:`gm` to strip overloads. | |
Args: | |
gm(fx.GraphModule): The input Fx graph module to be modified | |
""" | |
for node in gm.graph.nodes: | |
if isinstance(node.target, torch._ops.OpOverload): | |
node.target = node.target.overloadpacket | |
gm.recompile() | |
strip_overloads(fx_g) | |
if is_f16: | |
fx_g = fx_g.half() | |
transform_fx(fx_g) | |
fx_g.recompile() | |
if training: | |
change_fx_graph_return_to_tuple(fx_g) | |
inputs = flatten_training_input(inputs) | |
change_sdpa_op(fx_g) | |
fx_g.recompile() | |
ts_graph = torch.jit.trace(fx_g, inputs) | |
inputs = get_f16_inputs(inputs, is_f16, f16_input_mask) | |
mlir_importer = SharkImporter( | |
ts_graph, | |
inputs, | |
frontend="torch", | |
return_str=return_str, | |
) | |
if debug: # and not is_f16: | |
(mlir_module, func_name), _, _ = mlir_importer.import_debug( | |
dir=save_dir, model_name=model_name, golden_values=golden_values | |
) | |
return mlir_module, func_name | |
mlir_module, func_name = mlir_importer.import_mlir() | |
return mlir_module, func_name |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment