Skip to content

Instantly share code, notes, and snippets.

@pashu123
Created April 25, 2022 16:23
Show Gist options
  • Save pashu123/1f18e8789147c5044b4631cd96472ed8 to your computer and use it in GitHub Desktop.
Save pashu123/1f18e8789147c5044b4631cd96472ed8 to your computer and use it in GitHub Desktop.
import torch
from shark.shark_runner import SharkInference
from bert_pytorch import BERT
torch.manual_seed(0)
class BERT_torch(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = BERT(512)
def forward(self, tokens):
return self.model.forward(tokens, tokens)
test_input = torch.randint(2,(1,128)).to(torch.int32)
shark_module = SharkInference(
BERT_torch(), (test_input,), device="cpu", jit_trace=True
)
results = shark_module.forward((test_input,))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment