Created
June 15, 2023 16:26
-
-
Save pashu123/eeac84d1091a0ddc2131d040faee583b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import torch | |
import torch_mlir | |
from shark.shark_importer import import_with_fx | |
import torchvision.models as models | |
import copy | |
import io | |
import numpy as np | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
StoppingCriteria, | |
StoppingCriteriaList, | |
) | |
# Custom shark backend. | |
def shark_backend(fx_g: torch.fx.GraphModule, inputs, device:str = "cpu"): | |
mlir_module = torch_mlir.compile(fx_g, inputs, output_type="linalg-on-tensors") | |
bytecode_stream = io.BytesIO() | |
mlir_module.operation.write_bytecode(bytecode_stream) | |
bytecode = bytecode_stream.getvalue() | |
from shark.shark_inference import SharkInference | |
shark_module = SharkInference( | |
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor", | |
) | |
shark_module.compile(extra_args=[]) | |
output = shark_module("forward", inputs) | |
return output | |
# Counts the total no. of callable nodes in the graph. | |
def count_total_nodes(fx_g: torch.fx.GraphModule): | |
count:int = 0 | |
for node in fx_g.graph.nodes: | |
if node.op == "call_function": | |
count += 1 | |
return count | |
# Breaks the graph at the required position. | |
def break_at_pos(fx_g: torch.fx.GraphModule, pos: int): | |
count:int = 0 | |
output_node = None | |
# First capture the output node since we have to add the new output node before the previous output_node. | |
for node in fx_g.graph.nodes: | |
if node.op == "output": | |
output_node = node | |
break | |
# Break at the required position given by the search. | |
for node in fx_g.graph.nodes: | |
if node.op == "call_function": | |
# TODO: Check here that the node is not of the form of empty tensor etc. | |
if count == pos: | |
with fx_g.graph.inserting_before(output_node): | |
fx_g.graph.output(node) | |
break | |
count += 1 | |
fx_g.graph.lint() | |
fx_g.recompile() | |
return fx_g | |
def check_output(orig_out, comp_out): | |
if type(orig_out) == tuple: | |
for i,j in zip(orig_out, comp_out): | |
get_val = np.allclose(i.cpu().detach().numpy(), j, rtol=1e-2, atol=1e-3) | |
if (get_val == False): | |
return get_val | |
else: | |
get_val = np.allclose(orig_out.cpu().detach().numpy(), comp_out, rtol=1e-2, atol=1e-3) | |
return get_val | |
def transform_cuda(fx_g): | |
for node in fx_g.graph.nodes: | |
if node.op == "call_function": | |
if node.kwargs.get("device") == torch.device(type="cpu"): | |
new_kwargs = node.kwargs.copy() | |
new_kwargs["device"] = torch.device(type="cuda") | |
node.kwargs = new_kwargs | |
fx_g.graph.lint() | |
fx_g.recompile() | |
return fx_g | |
def binary_search_faulty_graph(fx_g: torch.fx.GraphModule, inputs, backend): | |
fx_g = transform_cuda(fx_g) | |
cuda_jit = torch.jit.script(fx_g).cuda() | |
cuda_inputs = [i.cuda() for i in inputs] | |
fx_out = cuda_jit(*cuda_inputs) | |
back_out = backend(fx_g, inputs) | |
print(fx_out) | |
print(back_out) | |
print(check_output(fx_out, back_out)) | |
tokenizer = AutoTokenizer.from_pretrained( | |
"TheBloke/vicuna-7B-1.1-HF", use_fast=False | |
) | |
class StopOnTokens(StoppingCriteria): | |
def __call__( | |
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs | |
) -> bool: | |
stop_ids = [50278, 50279, 50277, 1, 0] | |
for stop_id in stop_ids: | |
if input_ids[0][-1] == stop_id: | |
return True | |
return False | |
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version) | |
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. | |
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. | |
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. | |
- StableLM will refuse to participate in anything that could harm a human. | |
""" | |
prompt = f"{system_prompt}<|USER|>What's your mood today?<|ASSISTANT|>" | |
inputs = tokenizer(prompt, return_tensors="pt") | |
inputs_model = (inputs["input_ids"], inputs["attention_mask"]) | |
class SLM(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model = AutoModelForCausalLM.from_pretrained( | |
"TheBloke/Vicuna-7B-CoT-fp16" | |
) | |
def forward(self, input_ids, attention_mask): | |
return self.model(input_ids, attention_mask)[0] | |
slm_model = SLM() | |
inputs = tokenizer(prompt, return_tensors="pt") | |
inputs_model = (inputs["input_ids"], inputs["attention_mask"]) | |
fx_graph = import_with_fx( | |
slm_model, inputs_model, is_f16=True, f16_input_mask=[False, False] | |
) | |
print(f" The total nodes in the graph is: {count_total_nodes(fx_graph)}") | |
# Break the graph at the position. | |
user = 0 | |
while True: | |
try: | |
user = int(input("Please enter a number: ")) | |
if(user == -1): | |
break | |
fx_g = break_at_pos(copy.deepcopy(fx_graph), user) | |
binary_search_faulty_graph(fx_g, inputs_model, shark_backend) | |
except ValueError: | |
print("Integer not found. Please enter again..") | |
continue | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment