Created
April 28, 2023 17:13
-
-
Save pashu123/41c9098e047738963008762a629c1cf2 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import shark | |
from shark.shark_importer import import_with_fx | |
from shark.shark_inference import SharkInference | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
StoppingCriteria, | |
StoppingCriteriaList, | |
) | |
import torch_mlir | |
from apps.stable_diffusion.src.models.model_wrappers import replace_shape_str | |
import os | |
from io import BytesIO | |
tokenizer = AutoTokenizer.from_pretrained( | |
"stabilityai/stablelm-tuned-alpha-3b" | |
) | |
class StopOnTokens(StoppingCriteria): | |
def __call__( | |
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs | |
) -> bool: | |
stop_ids = [50278, 50279, 50277, 1, 0] | |
for stop_id in stop_ids: | |
if input_ids[0][-1] == stop_id: | |
return True | |
return False | |
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version) | |
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. | |
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. | |
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. | |
- StableLM will refuse to participate in anything that could harm a human. | |
""" | |
prompt = f"{system_prompt}<|USER|>What's your mood today?<|ASSISTANT|>" | |
inputs = tokenizer(prompt, return_tensors="pt") | |
inputs_model = (inputs["input_ids"], inputs["attention_mask"]) | |
class SLM(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model = AutoModelForCausalLM.from_pretrained( | |
"stabilityai/stablelm-tuned-alpha-3b" | |
) | |
def forward(self, input_ids, attention_mask): | |
return self.model(input_ids, attention_mask)[0] | |
slm_model = SLM().cuda().half() | |
res_pytorch = slm_model(inputs_model[0].cuda(), inputs_model[1].cuda()) | |
# import torch | |
# from torch.fx.experimental.proxy_tensor import make_fx | |
# from torch._decomp import get_decompositions | |
# from typing import List | |
# fx_g = make_fx( | |
# slm_model, | |
# decomposition_table=get_decompositions( | |
# [ | |
# torch.ops.aten.embedding_dense_backward, | |
# torch.ops.aten.native_layer_norm_backward, | |
# torch.ops.aten.slice_backward, | |
# torch.ops.aten.select_backward, | |
# torch.ops.aten.norm.ScalarOpt_dim, | |
# torch.ops.aten.native_group_norm, | |
# torch.ops.aten.upsample_bilinear2d.vec, | |
# torch.ops.aten.split.Tensor, | |
# torch.ops.aten.split_with_sizes, | |
# torch.ops.aten.embedding_dense_backward, | |
# torch.ops.aten.native_layer_norm_backward, | |
# torch.ops.aten.slice_backward, | |
# torch.ops.aten.select_backward, | |
# torch.ops.aten.norm.ScalarOpt_dim, | |
# torch.ops.aten.native_group_norm, | |
# torch.ops.aten.upsample_bilinear2d.vec, | |
# torch.ops.aten.split.Tensor, | |
# torch.ops.aten.split_with_sizes, | |
# torch.ops.aten.native_layer_norm, | |
# ] | |
# ), | |
# )(inputs["input_ids"].cuda(), inputs["attention_mask"].cuda()) | |
# def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]: | |
# removed_indexes = [] | |
# for node in fx_g.graph.nodes: | |
# if node.op == "output": | |
# assert ( | |
# len(node.args) == 1 | |
# ), "Output node must have a single argument" | |
# node_arg = node.args[0] | |
# if isinstance(node_arg, (list, tuple)): | |
# node_arg = list(node_arg) | |
# node_args_len = len(node_arg) | |
# for i in range(node_args_len): | |
# curr_index = node_args_len - (i + 1) | |
# if node_arg[curr_index] is None: | |
# removed_indexes.append(curr_index) | |
# node_arg.pop(curr_index) | |
# node.args = (tuple(node_arg),) | |
# break | |
# if len(removed_indexes) > 0: | |
# fx_g.graph.lint() | |
# fx_g.graph.eliminate_dead_code() | |
# fx_g.recompile() | |
# removed_indexes.sort() | |
# return removed_indexes | |
# def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool: | |
# """ | |
# Replace tuple with tuple element in functions that return one-element tuples. | |
# Returns true if an unwrapping took place, and false otherwise. | |
# """ | |
# unwrapped_tuple = False | |
# for node in fx_g.graph.nodes: | |
# if node.op == "output": | |
# assert ( | |
# len(node.args) == 1 | |
# ), "Output node must have a single argument" | |
# node_arg = node.args[0] | |
# if isinstance(node_arg, tuple): | |
# if len(node_arg) == 1: | |
# node.args = (node_arg[0],) | |
# unwrapped_tuple = True | |
# break | |
# if unwrapped_tuple: | |
# fx_g.graph.lint() | |
# fx_g.recompile() | |
# return unwrapped_tuple | |
# def transform_fx(fx_g): | |
# for node in fx_g.graph.nodes: | |
# if node.op == "call_function": | |
# if node.target in [ | |
# torch.ops.aten.empty, | |
# ]: | |
# # aten.empty should be filled with zeros. | |
# if node.target in [torch.ops.aten.empty]: | |
# with fx_g.graph.inserting_after(node): | |
# new_node = fx_g.graph.call_function( | |
# torch.ops.aten.zero_, | |
# args=(node,), | |
# ) | |
# node.append(new_node) | |
# node.replace_all_uses_with(new_node) | |
# new_node.args = (node,) | |
# fx_g.graph.lint() | |
# transform_fx(fx_g) | |
# fx_g.recompile() | |
# removed_none_indexes = _remove_nones(fx_g) | |
# was_unwrapped = _unwrap_single_tuple_return(fx_g) | |
# fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) | |
# fx_g.recompile() | |
# def strip_overloads(gm): | |
# """ | |
# Modifies the target of graph nodes in :attr:`gm` to strip overloads. | |
# Args: | |
# gm(fx.GraphModule): The input Fx graph module to be modified | |
# """ | |
# for node in gm.graph.nodes: | |
# if isinstance(node.target, torch._ops.OpOverload): | |
# node.target = node.target.overloadpacket | |
# gm.recompile() | |
# strip_overloads(fx_g) | |
# ts_g = torch.jit.trace(fx_g,(inputs["input_ids"].cuda() ,inputs["attention_mask"].cuda())) | |
# module = torch_mlir.compile( | |
# ts_g, | |
# [inputs["input_ids"].cuda(), inputs["attention_mask"].cuda()], | |
# torch_mlir.OutputType.LINALG_ON_TENSORS, | |
# use_tracing=False, | |
# verbose=False, | |
# ) | |
# bytecode_stream = BytesIO() | |
# module.operation.write_bytecode(bytecode_stream) | |
# bytecode = bytecode_stream.getvalue() | |
# Either load the .mlir file and compile it. | |
with open("xyz.mlir","rb") as file: | |
bytecode = file.read() | |
shark_module = SharkInference( | |
mlir_module=bytecode, device="cuda", mlir_dialect="tm_tensor" | |
) | |
shark_module.compile() | |
shark_module.save_module(module_name="stable3b_cuda") | |
result_shark = shark_module( | |
"forward", inputs_model | |
) | |
print("Result PyTorch") | |
print(res_pytorch) | |
print("Result SHARK") | |
print(result_shark) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment