Skip to content

Instantly share code, notes, and snippets.

@remi-or
Created June 2, 2025 16:26
Show Gist options
  • Save remi-or/33359e5435c4de74d8146d85ba50e485 to your computer and use it in GitHub Desktop.
Save remi-or/33359e5435c4de74d8146d85ba50e485 to your computer and use it in GitHub Desktop.
MREx for a compile issue linked to resize
import torch
from torch import nn
from torch.nn import functional as F
from torchvision.transforms.v2 import functional as F
import matplotlib.pyplot as plt
INPUT_SIZE = (3, 41, 70)
class Resizer(nn.Module):
def __init__(self, height: int, width: int):
super().__init__()
self.height = height
self.width = width
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.resize(x, (self.height, self.width), interpolation=F.InterpolationMode.BICUBIC, antialias=True)
if __name__ == "__main__":
x = torch.randn(INPUT_SIZE).mul(256).to(torch.uint8).to("cuda:0")
resizer = Resizer(384, 384)
y = resizer(x)
resizer = torch.compile(resizer)
y_compiled = resizer(x)
print(f"Max difference: {y.float().sub(y_compiled.float()).abs().max().item()}")
# Draw the orignal output, compiled output and the difference
fig, axs = plt.subplots(1, 3)
axs[0].imshow(y.permute(1, 2, 0).numpy(force=True))
axs[1].imshow(y_compiled.permute(1, 2, 0).numpy(force=True))
delta = y - y_compiled
axs[2].imshow(delta.permute(1, 2, 0).numpy(force=True))
plt.savefig("resizer.png")
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment