Skip to content

Instantly share code, notes, and snippets.

@taylanbil
Created June 13, 2020 00:27
Show Gist options
  • Save taylanbil/77868bd3f773c40e24bb38f3745da278 to your computer and use it in GitHub Desktop.
Save taylanbil/77868bd3f773c40e24bb38f3745da278 to your computer and use it in GitHub Desktop.
EmbeddingBag backward error repro
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
device = xm.xla_device()
d = nn.EmbeddingBag(10, 10, mode="sum", sparse=False).to(device)
inp = torch.LongTensor([1, 5, 9]).to(device)
x = d(inp, offsets=torch.LongTensor([0]).to(device))
loss = x.sum()
loss.backward()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment