Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created April 28, 2023 17:13
Show Gist options
  • Save pashu123/41c9098e047738963008762a629c1cf2 to your computer and use it in GitHub Desktop.
Save pashu123/41c9098e047738963008762a629c1cf2 to your computer and use it in GitHub Desktop.
import torch
import shark
from shark.shark_importer import import_with_fx
from shark.shark_inference import SharkInference
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
StoppingCriteria,
StoppingCriteriaList,
)
import torch_mlir
from apps.stable_diffusion.src.models.model_wrappers import replace_shape_str
import os
from io import BytesIO
tokenizer = AutoTokenizer.from_pretrained(
"stabilityai/stablelm-tuned-alpha-3b"
)
class StopOnTokens(StoppingCriteria):
def __call__(
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
) -> bool:
stop_ids = [50278, 50279, 50277, 1, 0]
for stop_id in stop_ids:
if input_ids[0][-1] == stop_id:
return True
return False
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
- StableLM will refuse to participate in anything that could harm a human.
"""
prompt = f"{system_prompt}<|USER|>What's your mood today?<|ASSISTANT|>"
inputs = tokenizer(prompt, return_tensors="pt")
inputs_model = (inputs["input_ids"], inputs["attention_mask"])
class SLM(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = AutoModelForCausalLM.from_pretrained(
"stabilityai/stablelm-tuned-alpha-3b"
)
def forward(self, input_ids, attention_mask):
return self.model(input_ids, attention_mask)[0]
slm_model = SLM().cuda().half()
res_pytorch = slm_model(inputs_model[0].cuda(), inputs_model[1].cuda())
# import torch
# from torch.fx.experimental.proxy_tensor import make_fx
# from torch._decomp import get_decompositions
# from typing import List
# fx_g = make_fx(
# slm_model,
# decomposition_table=get_decompositions(
# [
# torch.ops.aten.embedding_dense_backward,
# torch.ops.aten.native_layer_norm_backward,
# torch.ops.aten.slice_backward,
# torch.ops.aten.select_backward,
# torch.ops.aten.norm.ScalarOpt_dim,
# torch.ops.aten.native_group_norm,
# torch.ops.aten.upsample_bilinear2d.vec,
# torch.ops.aten.split.Tensor,
# torch.ops.aten.split_with_sizes,
# torch.ops.aten.embedding_dense_backward,
# torch.ops.aten.native_layer_norm_backward,
# torch.ops.aten.slice_backward,
# torch.ops.aten.select_backward,
# torch.ops.aten.norm.ScalarOpt_dim,
# torch.ops.aten.native_group_norm,
# torch.ops.aten.upsample_bilinear2d.vec,
# torch.ops.aten.split.Tensor,
# torch.ops.aten.split_with_sizes,
# torch.ops.aten.native_layer_norm,
# ]
# ),
# )(inputs["input_ids"].cuda(), inputs["attention_mask"].cuda())
# def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
# removed_indexes = []
# for node in fx_g.graph.nodes:
# if node.op == "output":
# assert (
# len(node.args) == 1
# ), "Output node must have a single argument"
# node_arg = node.args[0]
# if isinstance(node_arg, (list, tuple)):
# node_arg = list(node_arg)
# node_args_len = len(node_arg)
# for i in range(node_args_len):
# curr_index = node_args_len - (i + 1)
# if node_arg[curr_index] is None:
# removed_indexes.append(curr_index)
# node_arg.pop(curr_index)
# node.args = (tuple(node_arg),)
# break
# if len(removed_indexes) > 0:
# fx_g.graph.lint()
# fx_g.graph.eliminate_dead_code()
# fx_g.recompile()
# removed_indexes.sort()
# return removed_indexes
# def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
# """
# Replace tuple with tuple element in functions that return one-element tuples.
# Returns true if an unwrapping took place, and false otherwise.
# """
# unwrapped_tuple = False
# for node in fx_g.graph.nodes:
# if node.op == "output":
# assert (
# len(node.args) == 1
# ), "Output node must have a single argument"
# node_arg = node.args[0]
# if isinstance(node_arg, tuple):
# if len(node_arg) == 1:
# node.args = (node_arg[0],)
# unwrapped_tuple = True
# break
# if unwrapped_tuple:
# fx_g.graph.lint()
# fx_g.recompile()
# return unwrapped_tuple
# def transform_fx(fx_g):
# for node in fx_g.graph.nodes:
# if node.op == "call_function":
# if node.target in [
# torch.ops.aten.empty,
# ]:
# # aten.empty should be filled with zeros.
# if node.target in [torch.ops.aten.empty]:
# with fx_g.graph.inserting_after(node):
# new_node = fx_g.graph.call_function(
# torch.ops.aten.zero_,
# args=(node,),
# )
# node.append(new_node)
# node.replace_all_uses_with(new_node)
# new_node.args = (node,)
# fx_g.graph.lint()
# transform_fx(fx_g)
# fx_g.recompile()
# removed_none_indexes = _remove_nones(fx_g)
# was_unwrapped = _unwrap_single_tuple_return(fx_g)
# fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
# fx_g.recompile()
# def strip_overloads(gm):
# """
# Modifies the target of graph nodes in :attr:`gm` to strip overloads.
# Args:
# gm(fx.GraphModule): The input Fx graph module to be modified
# """
# for node in gm.graph.nodes:
# if isinstance(node.target, torch._ops.OpOverload):
# node.target = node.target.overloadpacket
# gm.recompile()
# strip_overloads(fx_g)
# ts_g = torch.jit.trace(fx_g,(inputs["input_ids"].cuda() ,inputs["attention_mask"].cuda()))
# module = torch_mlir.compile(
# ts_g,
# [inputs["input_ids"].cuda(), inputs["attention_mask"].cuda()],
# torch_mlir.OutputType.LINALG_ON_TENSORS,
# use_tracing=False,
# verbose=False,
# )
# bytecode_stream = BytesIO()
# module.operation.write_bytecode(bytecode_stream)
# bytecode = bytecode_stream.getvalue()
# Either load the .mlir file and compile it.
with open("xyz.mlir","rb") as file:
bytecode = file.read()
shark_module = SharkInference(
mlir_module=bytecode, device="cuda", mlir_dialect="tm_tensor"
)
shark_module.compile()
shark_module.save_module(module_name="stable3b_cuda")
result_shark = shark_module(
"forward", inputs_model
)
print("Result PyTorch")
print(res_pytorch)
print("Result SHARK")
print(result_shark)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment