Skip to content

Instantly share code, notes, and snippets.

@pashu123
Last active January 24, 2024 19:03
Show Gist options
  • Save pashu123/932ddc07e311808805535cd26d211cc3 to your computer and use it in GitHub Desktop.
Save pashu123/932ddc07e311808805535cd26d211cc3 to your computer and use it in GitHub Desktop.
import torch
import torch_mlir
class ANET(torch.nn.Module):
def __init__(self):
super().__init__()
self.alexnet = torch.hub.load('pytorch/vision:v0.10.0', 'alexnet', pretrained=True)
self.alexnet.eval()
def forward(self, image):
return self.alexnet(image)
alexnet_model = ANET()
# batch, channels, height, width
rnd_inp = torch.randn(1,3,224,224)
module = torch_mlir.compile(alexnet_model, (rnd_inp,), output_type="linalg-on-tensors")
module.dump()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment