Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created June 15, 2023 16:26
Show Gist options
  • Save pashu123/eeac84d1091a0ddc2131d040faee583b to your computer and use it in GitHub Desktop.
Save pashu123/eeac84d1091a0ddc2131d040faee583b to your computer and use it in GitHub Desktop.
import sys
import torch
import torch_mlir
from shark.shark_importer import import_with_fx
import torchvision.models as models
import copy
import io
import numpy as np
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)
# Custom shark backend.
def shark_backend(fx_g: torch.fx.GraphModule, inputs, device:str = "cpu"):
mlir_module = torch_mlir.compile(fx_g, inputs, output_type="linalg-on-tensors")
bytecode_stream = io.BytesIO()
mlir_module.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
from shark.shark_inference import SharkInference
shark_module = SharkInference(
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor",
)
shark_module.compile(extra_args=[])
output = shark_module("forward", inputs)
return output
# Counts the total no. of callable nodes in the graph.
def count_total_nodes(fx_g: torch.fx.GraphModule):
count:int = 0
for node in fx_g.graph.nodes:
if node.op == "call_function":
count += 1
return count
# Breaks the graph at the required position.
def break_at_pos(fx_g: torch.fx.GraphModule, pos: int):
count:int = 0
output_node = None
# First capture the output node since we have to add the new output node before the previous output_node.
for node in fx_g.graph.nodes:
if node.op == "output":
output_node = node
break
# Break at the required position given by the search.
for node in fx_g.graph.nodes:
if node.op == "call_function":
# TODO: Check here that the node is not of the form of empty tensor etc.
if count == pos:
with fx_g.graph.inserting_before(output_node):
fx_g.graph.output(node)
break
count += 1
fx_g.graph.lint()
fx_g.recompile()
return fx_g
def check_output(orig_out, comp_out):
if type(orig_out) == tuple:
for i,j in zip(orig_out, comp_out):
get_val = np.allclose(i.cpu().detach().numpy(), j, rtol=1e-2, atol=1e-3)
if (get_val == False):
return get_val
else:
get_val = np.allclose(orig_out.cpu().detach().numpy(), comp_out, rtol=1e-2, atol=1e-3)
return get_val
def transform_cuda(fx_g):
for node in fx_g.graph.nodes:
if node.op == "call_function":
if node.kwargs.get("device") == torch.device(type="cpu"):
new_kwargs = node.kwargs.copy()
new_kwargs["device"] = torch.device(type="cuda")
node.kwargs = new_kwargs
fx_g.graph.lint()
fx_g.recompile()
return fx_g
def binary_search_faulty_graph(fx_g: torch.fx.GraphModule, inputs, backend):
fx_g = transform_cuda(fx_g)
cuda_jit = torch.jit.script(fx_g).cuda()
cuda_inputs = [i.cuda() for i in inputs]
fx_out = cuda_jit(*cuda_inputs)
back_out = backend(fx_g, inputs)
print(fx_out)
print(back_out)
print(check_output(fx_out, back_out))
tokenizer = AutoTokenizer.from_pretrained(
"TheBloke/vicuna-7B-1.1-HF", use_fast=False
)
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
prompt = f"{system_prompt}<|USER|>What's your mood today?<|ASSISTANT|>"
inputs = tokenizer(prompt, return_tensors="pt")
inputs_model = (inputs["input_ids"], inputs["attention_mask"])
class SLM(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = AutoModelForCausalLM.from_pretrained(
"TheBloke/Vicuna-7B-CoT-fp16"
)
def forward(self, input_ids, attention_mask):
return self.model(input_ids, attention_mask)[0]
slm_model = SLM()
inputs = tokenizer(prompt, return_tensors="pt")
inputs_model = (inputs["input_ids"], inputs["attention_mask"])
fx_graph = import_with_fx(
slm_model, inputs_model, is_f16=True, f16_input_mask=[False, False]
)
print(f" The total nodes in the graph is: {count_total_nodes(fx_graph)}")
# Break the graph at the position.
user = 0
while True:
try:
user = int(input("Please enter a number: "))
if(user == -1):
break
fx_g = break_at_pos(copy.deepcopy(fx_graph), user)
binary_search_faulty_graph(fx_g, inputs_model, shark_backend)
except ValueError:
print("Integer not found. Please enter again..")
continue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment