Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created June 5, 2023 14:11
Show Gist options
  • Save pashu123/954c7e136818af27f24b535c85dc0988 to your computer and use it in GitHub Desktop.
Save pashu123/954c7e136818af27f24b535c85dc0988 to your computer and use it in GitHub Desktop.
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import sys
from PIL import Image
import requests
import torch
import torchvision.models as models
from torchvision import transforms
import torch_mlir
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
class MHA(torch.nn.Module):
def __init__(self):
super().__init__()
self.mha = torch.nn.MultiheadAttention(128, 8)
def forward(self, query, key, value):
return self.mha(query, key, value)
mha_model = MHA()
query = torch.randn(128,128,128)
module = torch_mlir.compile(mha_model, (query, query, query), output_type="linalg-on-tensors")
module.dump()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment