Created
June 8, 2023 17:13
-
-
Save pashu123/c94119736ab0584b509d79ebaee79841 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import torch | |
import torch_mlir | |
from shark.shark_importer import import_with_fx | |
import torchvision.models as models | |
import copy | |
import io | |
import numpy as np | |
# Custom shark backend. | |
def shark_backend(fx_g: torch.fx.GraphModule, inputs, device:str = "cpu"): | |
mlir_module = torch_mlir.compile(fx_g, inputs, output_type="linalg-on-tensors") | |
bytecode_stream = io.BytesIO() | |
mlir_module.operation.write_bytecode(bytecode_stream) | |
bytecode = bytecode_stream.getvalue() | |
from shark.shark_inference import SharkInference | |
shark_module = SharkInference( | |
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor", | |
) | |
shark_module.compile(extra_args=[]) | |
output = shark_module("forward", inputs) | |
return output | |
# Counts the total no. of callable nodes in the graph. | |
def count_total_nodes(fx_g: torch.fx.GraphModule): | |
count:int = 0 | |
for node in fx_g.graph.nodes: | |
if node.op == "call_function": | |
count += 1 | |
return count | |
# Breaks the graph at the required position. | |
def break_at_pos(fx_g: torch.fx.GraphModule, pos: int): | |
count:int = 0 | |
output_node = None | |
# First capture the output node since we have to add the new output node before the previous output_node. | |
for node in fx_g.graph.nodes: | |
if node.op == "output": | |
output_node = node | |
break | |
# Break at the required position given by the search. | |
for node in fx_g.graph.nodes: | |
if node.op == "call_function": | |
# TODO: Check here that the node is not of the form of empty tensor etc. | |
if count == pos: | |
with fx_g.graph.inserting_before(output_node): | |
fx_g.graph.output(node) | |
break | |
count += 1 | |
fx_g.graph.lint() | |
fx_g.recompile() | |
return fx_g | |
def check_output(orig_out, comp_out): | |
if type(orig_out) == tuple: | |
for i,j in zip(orig_out, comp_out): | |
get_val = np.allclose(i.cpu().detach().numpy(), j, rtol=1e-2, atol=1e-3) | |
if (get_val == False): | |
return get_val | |
else: | |
get_val = np.allclose(orig_out.cpu().detach().numpy(), comp_out, rtol=1e-2, atol=1e-3) | |
return get_val | |
def binary_search_faulty_graph(fx_g: torch.fx.GraphModule, inputs, backend): | |
orig_out = fx_g(*inputs) | |
out = backend(fx_g, inputs) | |
print(check_output(orig_out, out)) | |
resnet18 = models.resnet18(pretrained=True) | |
resnet18.train(False) | |
input = (torch.randn(1,3,224,224),) | |
fx_graph = import_with_fx(resnet18, input) | |
print(f" The total nodes in the graph is: {count_total_nodes(fx_graph)}") | |
# Break the graph at the position. | |
fx_g = break_at_pos(copy.deepcopy(fx_graph), 3) | |
binary_search_faulty_graph(fx_g, input, shark_backend) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment