Skip to content

Instantly share code, notes, and snippets.

@ThoenigAdrian
Created February 20, 2023 23:59
Show Gist options
  • Save ThoenigAdrian/1af63a11637264a509e7dd73bd2c1a18 to your computer and use it in GitHub Desktop.
Save ThoenigAdrian/1af63a11637264a509e7dd73bd2c1a18 to your computer and use it in GitHub Desktop.
import torch.nn
import timeit
from torch.nn import Linear
shape = [784, 50, 50, 50, 10]
batch_size = 500
x = torch.rand((500, 784), device="cuda")
y = torch.rand((500, 10), device="cuda")
model_1 = torch.nn.Sequential(Linear(784, 50, device="cuda"), torch.nn.Sigmoid(),
Linear(50, 50, device="cuda"), torch.nn.Sigmoid(),
Linear(50, 50, device="cuda"), torch.nn.Sigmoid(),
Linear(50, 10, device="cuda"), torch.nn.Sigmoid())
model_2 = torch.nn.Sequential(Linear(784, 50, device="cuda"), torch.nn.Sigmoid(),
Linear(50, 50, device="cuda"), torch.nn.Sigmoid(),
Linear(50, 50, device="cuda"), torch.nn.Sigmoid(),
Linear(50, 10, device="cuda"), torch.nn.Sigmoid())
adam_original = torch.optim.Adam(model_1.parameters(), foreach=False)
adam_foreach = torch.optim.Adam(model_2.parameters(), foreach=True)
def train(xx, yy, model, iters, adam):
crit = torch.nn.MSELoss()
model.debug = False
for i in range(iters):
out = model.forward(xx)
loss = crit(out, yy)
loss.backward()
adam.step()
time_without_foreach = timeit.timeit("train(x, y, model_1, 8000, adam_original)", globals=globals(), number=1)
time_with_foreach = timeit.timeit("train(x, y, model_2, 8000, adam_foreach)", globals=globals(), number=1)
print(f"{time_without_foreach=:.2f} 100%")
relative_performance_improvement = time_with_foreach / time_without_foreach * 100
print(f"{time_with_foreach=:.2f} {relative_performance_improvement:.2f} %")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment