Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created June 8, 2023 17:13
Show Gist options
  • Save pashu123/c94119736ab0584b509d79ebaee79841 to your computer and use it in GitHub Desktop.
Save pashu123/c94119736ab0584b509d79ebaee79841 to your computer and use it in GitHub Desktop.
import sys
import torch
import torch_mlir
from shark.shark_importer import import_with_fx
import torchvision.models as models
import copy
import io
import numpy as np
# Custom shark backend.
def shark_backend(fx_g: torch.fx.GraphModule, inputs, device:str = "cpu"):
mlir_module = torch_mlir.compile(fx_g, inputs, output_type="linalg-on-tensors")
bytecode_stream = io.BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor",
)
shark_module.compile(extra_args=[])
output = shark_module("forward", inputs)
return output
# Counts the total no. of callable nodes in the graph.
def count_total_nodes(fx_g: torch.fx.GraphModule):
count:int = 0
for node in fx_g.graph.nodes:
if node.op == "call_function":
count += 1
return count
# Breaks the graph at the required position.
def break_at_pos(fx_g: torch.fx.GraphModule, pos: int):
count:int = 0
output_node = None
# First capture the output node since we have to add the new output node before the previous output_node.
for node in fx_g.graph.nodes:
if node.op == "output":
output_node = node
break
# Break at the required position given by the search.
for node in fx_g.graph.nodes:
if node.op == "call_function":
# TODO: Check here that the node is not of the form of empty tensor etc.
if count == pos:
with fx_g.graph.inserting_before(output_node):
fx_g.graph.output(node)
break
count += 1
fx_g.graph.lint()
fx_g.recompile()
return fx_g
def check_output(orig_out, comp_out):
if type(orig_out) == tuple:
for i,j in zip(orig_out, comp_out):
get_val = np.allclose(i.cpu().detach().numpy(), j, rtol=1e-2, atol=1e-3)
if (get_val == False):
return get_val
else:
get_val = np.allclose(orig_out.cpu().detach().numpy(), comp_out, rtol=1e-2, atol=1e-3)
return get_val
def binary_search_faulty_graph(fx_g: torch.fx.GraphModule, inputs, backend):
orig_out = fx_g(*inputs)
out = backend(fx_g, inputs)
print(check_output(orig_out, out))
resnet18 = models.resnet18(pretrained=True)
resnet18.train(False)
input = (torch.randn(1,3,224,224),)
fx_graph = import_with_fx(resnet18, input)
print(f" The total nodes in the graph is: {count_total_nodes(fx_graph)}")
# Break the graph at the position.
fx_g = break_at_pos(copy.deepcopy(fx_graph), 3)
binary_search_faulty_graph(fx_g, input, shark_backend)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment