Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save leslie-fang-intel/69c3f1b693620dc9b0c8f964018d0b75 to your computer and use it in GitHub Desktop.
Save leslie-fang-intel/69c3f1b693620dc9b0c8f964018d0b75 to your computer and use it in GitHub Desktop.
import requests
import torch
print(torch.__version__, flush=True)
import torch.nn as nn
import os, pickle
import numpy as np
import torch._inductor.config as config
import torch._dynamo.config as dynamo_config
import gc
import time
import psutil
import refcycle
import torchao
from torchao import autoquant
from torchao.quantization import ALL_AUTOQUANT_CLASS_LIST
config.freezing = True
config.max_autotune = True
output_channels = 1024
dtype = torch.bfloat16
class M(torch.nn.Module):
def __init__(self, output_channels, dtype):
super().__init__()
self.lin = torch.nn.Linear(1024, output_channels, bias=False).to(dtype)
def forward(self, attn_weights):
attn_weights = self.lin(attn_weights)
return attn_weights
if __name__ == "__main__":
with torch.no_grad():
model = M(output_channels, dtype).eval()
## Optional: invoke torch.compile
model = torch.compile(model)
model = torchao.autoquant(model, manual=True)
x = torch.randn(2, 1024).to(dtype)
model(x)
model.finalize_autoquant()
# Do we need to invoke model = torch.compile(model)
model(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment