Last active
September 26, 2022 20:03
-
-
Save AmosLewis/20811123b15bb5d818a3e8da545075c8 to your computer and use it in GitHub Desktop.
bloom_fp16.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from transformers import AutoModelForSequenceClassification | |
class HuggingFaceLanguage(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.model = AutoModelForSequenceClassification.from_pretrained( | |
"bigscience/bloom-560m", # The pretrained model. | |
num_labels=2, # The number of output labels--2 for binary classification. | |
output_attentions=False, # Whether the model returns attentions weights. | |
output_hidden_states=False, # Whether the model returns all hidden-states. | |
torchscript=True, | |
) | |
def forward(self, tokens): | |
return self.model.forward(tokens)[0] | |
torch.manual_seed(0) | |
model = HuggingFaceLanguage() | |
test_input = torch.randint(2, (1, 128)) | |
actual_out = model(test_input) | |
from torch.fx.experimental.proxy_tensor import make_fx | |
from torch._decomp import get_decompositions | |
fx_g = make_fx( | |
model, | |
decomposition_table=get_decompositions( | |
[ | |
torch.ops.aten.split.Tensor, | |
torch.ops.aten.split_with_sizes, | |
] | |
), | |
)( | |
test_input | |
) | |
# print(fx_g.graph) | |
fx_g.graph.set_codegen(torch.fx.graph.CodeGen()) | |
fx_g.recompile() | |
def strip_overloads(gm): | |
""" | |
Modifies the target of graph nodes in :attr:`gm` to strip overloads. | |
Args: | |
gm(fx.GraphModule): The input Fx graph module to be modified | |
""" | |
for node in gm.graph.nodes: | |
if isinstance(node.target, torch._ops.OpOverload): | |
node.target = node.target.overloadpacket | |
gm.recompile() | |
strip_overloads(fx_g) | |
# checkout if the mode symbolic tracelable | |
# https://pytorch.org/docs/stable/fx.html#module-torch.fx | |
from torch.fx import symbolic_trace | |
# Symbolic tracing frontend - captures the semantics of the module | |
symbolic_traced : torch.fx.GraphModule = symbolic_trace(fx_g) | |
# High-level intermediate representation (IR) - Graph representation | |
print(symbolic_traced.graph) # print graph successfully, yes it is symbolic traceable | |
# QUANTIZATION SETUP | |
model = fx_g | |
# # test1: static quantization SUCCESS | |
# # golden_out VS shark_out | |
# # tensor([[ 7.2041, -17.0263]], grad_fn=<IndexBackward0>) tensor([[ 7.2043, -17.0265]]) | |
# model.eval() | |
# model.qconfig = torch.quantization.qconfig.float16_static_qconfig | |
# | |
# # Prepare the model for static quantization. This inserts observers in | |
# # the model that will observe activation tensors during calibration. | |
# model_fp32_prepared = torch.quantization.prepare(model) | |
# | |
# # calibrate the prepared model to determine quantization parameters for activations | |
# # in a real world setting, the calibration would be done with a representative dataset | |
# model_fp32_prepared(test_input) | |
# | |
# # Convert the observed model to a quantized model. This does several things: | |
# # quantizes the weights, computes and stores the scale and bias value to be | |
# # used with each activation tensor, and replaces key operators with quantized | |
# # implementations. | |
# model_fp16 = torch.quantization.convert(model_fp32_prepared) | |
# | |
# # run the model, relevant calculations will happen in fp16 | |
# res = model_fp16(test_input) | |
# print("res: ", res) | |
# fx_g = model_fp16 | |
# print("model_fp16: ", model_fp16.graph) | |
# # test2: dynamic quantization SUCCESS | |
# # golden_out VS shark_out | |
# # tensor([[ 7.2041, -17.0263]], grad_fn=<IndexBackward0>) tensor([[ 7.2043, -17.0265]]) | |
# model.eval() | |
# model_fp16 = torch.quantization.quantize_dynamic( | |
# model, # the original model | |
# {torch.nn.Linear}, # a set of layers to dynamically quantize | |
# dtype=torch.float16) # the target dtype for quantized weights | |
# # res: tensor([[ 7.2050, -17.0270]]) | |
# # run the model, relevant calculations will happen in fp16 | |
# res = model_fp16(test_input) | |
# print("res: ", res) | |
# fx_g = model_fp16 | |
# test3:quantize_fx SUCCESS, But shark_out vary in different run | |
# golden_out VS shark_out | |
# tensor([[ 7.2041, -17.0263]], grad_fn=<IndexBackward0>) tensor([[-5.0346, -3.0476]]) | |
# tensor([[ 7.2041, -17.0263]], grad_fn=<IndexBackward0>) tensor([[-8.9322, -2.2289]]) | |
# tensor([[ 7.2041, -17.0263]], grad_fn=<IndexBackward0>) tensor([[-7.9967, -6.8403]]) | |
import torch.quantization.quantize_fx as quantize_fx | |
# import copy | |
# | |
# model_to_quantize = copy.deepcopy(model) | |
# code from https://github.com/pytorch/pytorch/blob/master/test/quantization/fx/test_quantize_fx.py#:~:text=qconfig_mapping%20%3D%20get_default_qconfig_mapping().set_global(float16_static_qconfig) | |
qconfig_dict = {"": torch.quantization.qconfig.float16_static_qconfig} | |
qconfig_mapping_ = torch.ao.quantization.get_default_qconfig_mapping().set_global(torch.quantization.qconfig.float16_static_qconfig) | |
model.eval() | |
# prepare | |
test_input.to(dtype=torch.float16) | |
model_prepared = quantize_fx.prepare_fx(model, qconfig_mapping=qconfig_mapping_, example_inputs=(test_input,)) | |
# calibrate (not shown) | |
# quantize | |
model_fp16 = quantize_fx.convert_fx(model_prepared) | |
# run the model, relevant calculations will happen in fp16 | |
res = model_fp16(test_input) | |
print("res.dtype: ", res.dtype) # res.dtype: torch.float32 | |
fx_g = model_fp16 | |
# print("model_fp16: ", model_fp16.graph) | |
# print("########model_fp16: ", model_fp16) | |
# model_fp16.print_readable() | |
# print("model_fp16: ", model_fp16.code) | |
# Maybe try to print the weight with fp16_repr() ???? | |
# https://discuss.pytorch.org/t/how-to-print-the-weight-after-torch-quantization-convert/161801 | |
# model_fp16.layer[0].weight().int_repr().data | |
ts_g = torch.jit.script(fx_g) | |
# print("ts_g: ", ts_g.graph) | |
# import torch_mlir | |
# module = torch_mlir.compile( | |
# ts_g, [test_input], torch_mlir.OutputType.LINALG_ON_TENSORS, use_tracing=True, verbose=True | |
# ) | |
# # import pdb | |
# # pdb.set_trace() | |
# # module.dump() | |
# from shark.shark_inference import SharkInference | |
# mlir_model = module | |
# func_name = "forward" | |
# shark_module = SharkInference( | |
# mlir_model, func_name, device="cpu", mlir_dialect="tm_tensor" | |
# ) | |
# shark_module.compile() | |
# def shark_result(x): | |
# x_ny = x.detach().numpy() | |
# inputs = (x_ny,) | |
# result = shark_module.forward(inputs) | |
# return torch.from_numpy(result) | |
# observed_out = shark_result(test_input) | |
# print(actual_out, observed_out) | |
# tensor([[ 7.2041, -17.0263]], grad_fn=<IndexBackward0>) tensor([[3.5580e-33, 0.0000e+00]]) | |
# QUANTIZATION | |
# test1: static quantization: | |
# tensor([[ 7.2041, -17.0263]], grad_fn=<IndexBackward0>) tensor([[ 7.2043, -17.0265]]) | |
# test2: dynamic quantization: | |
# tensor([[ 7.2041, -17.0263]], grad_fn=<IndexBackward0>) tensor([[ 7.2043, -17.0265]]) | |
# test3:quantize_fx : | |
# tensor([[ 7.2041, -17.0263]], grad_fn=<IndexBackward0>) tensor([[-5.0346, -3.0476]]) | |
# tensor([[ 7.2041, -17.0263]], grad_fn=<IndexBackward0>) tensor([[-8.9322, -2.2289]]) | |
# tensor([[ 7.2041, -17.0263]], grad_fn=<IndexBackward0>) tensor([[-7.9967, -6.8403]]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment