Skip to content

Instantly share code, notes, and snippets.

@takuma104
Created May 17, 2023 14:50
Show Gist options
  • Save takuma104/1263383cdab8f54bb14f389facdbe960 to your computer and use it in GitHub Desktop.
Save takuma104/1263383cdab8f54bb14f389facdbe960 to your computer and use it in GitHub Desktop.
import torch
import torch.nn
def test_monkey_patch():
x = torch.randn((2, 2)).cuda()
target = torch.nn.Linear(2, 2).cuda()
with torch.no_grad():
y = target(x)
assert y.shape == (2, 2)
print(y)
old_forward = target.forward
def new_forward(x):
return old_forward(x) * 2.0
target.forward = new_forward
yy = target(x)
assert torch.allclose(yy, y * 2.0)
print(yy)
def test_monkey_patch_instance_method():
x = torch.randn((2, 2)).cuda()
target = torch.nn.Linear(2, 2).cuda()
with torch.no_grad():
y = target(x)
assert y.shape == (2, 2)
print(y)
target.old_forward = target.forward
def new_forward(self, x):
return self.old_forward(x) * 2.0
target.forward = new_forward.__get__(target)
yy = target(x)
assert torch.allclose(yy, y * 2.0)
print(yy)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment