Last active
January 27, 2023 03:48
-
-
Save MilesCranmer/7cabee412d606f0d5fb341dbc633506b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from argparse import ArgumentParser | |
import time | |
import numpy as np | |
import torch | |
from torch import nn | |
from torch.utils.data import DataLoader, TensorDataset | |
from torch.optim import Adam | |
from torch.nn import functional as F | |
parser = ArgumentParser() | |
parser.add_argument("--pytorch_2", action="store_true") | |
parser.add_argument("--compile", action="store_true") | |
parser.add_argument("--tensorcores", action="store_true") | |
args = parser.parse_args() | |
pytorch_2 = args.pytorch_2 | |
compile = args.compile | |
tensorcores = args.tensorcores | |
# For tensorcore speedup: | |
if tensorcores: | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
else: | |
torch.backends.cuda.matmul.allow_tf32 = False | |
torch.backends.cudnn.allow_tf32 = False | |
if pytorch_2 and tensorcores: | |
torch.set_float32_matmul_precision("high") | |
class MLP(nn.Module): | |
def __init__(self, n_in, n_out, n_hidden=128, n_layers=2, activation=F.relu): | |
super().__init__() | |
self.activation = activation | |
self.layers = nn.ModuleList([nn.Linear(n_in, n_hidden)]) | |
self.layers.extend([nn.Linear(n_hidden, n_hidden) for _ in range(n_layers - 1)]) | |
self.layers.append(nn.Linear(n_hidden, n_out)) | |
def forward(self, x): | |
for layer in self.layers[:-1]: | |
x = self.activation(layer(x)) | |
return self.layers[-1](x) | |
class ResidualConnection(nn.Module): | |
def __init__(self, module): | |
super().__init__() | |
self.module = module | |
def forward(self, x): | |
return x + self.module(x) | |
class DeepMLP(nn.Module): | |
def __init__(self, n_in, n_out, n_hidden=128, blocks=2, activation=F.relu): | |
super().__init__() | |
self.activation = activation | |
self.net = nn.Sequential( | |
MLP(n_in, n_hidden, n_hidden, n_layers=2, activation=activation), | |
*[ | |
ResidualConnection( | |
MLP(n_hidden, n_hidden, n_hidden, n_layers=2, activation=activation) | |
) | |
for _ in range(blocks) | |
], | |
nn.Linear(n_hidden, n_out), | |
) | |
def forward(self, x): | |
return self.net(x) | |
device = torch.device("cuda") | |
# Dataset | |
N = 1_000_000 | |
m = 100 | |
X = torch.rand(N, m, device=device) * 20 - 10 | |
y = torch.cos(X) | |
dataset = TensorDataset(X, y) | |
loader = DataLoader(dataset, batch_size=1024, shuffle=True) | |
# Model | |
n_hidden = 256 | |
model = DeepMLP(m, m, n_hidden=n_hidden, blocks=2, activation=F.relu) | |
model = model.to(device) | |
# Optimizer | |
opt = Adam(model.parameters(), lr=1e-3) | |
def train(model, X_batch, y_batch): | |
opt.zero_grad() | |
y_pred = model(X_batch) | |
# assert y_pred.shape == y_batch.shape | |
loss = F.mse_loss(y_pred, y_batch) | |
loss.backward() | |
opt.step() | |
return loss.item() | |
if compile: | |
train = torch.compile(train, mode="reduce-overhead") | |
losses = [] | |
times = [] | |
# Training: | |
for epoch in range(10): | |
for X_batch, y_batch in loader: | |
start = time.time() | |
loss = train(model, X_batch, y_batch) | |
end = time.time() | |
losses.append(loss) | |
times.append(end - start) | |
print( | |
f"Epoch {epoch}: loss={np.median(losses[-100:]):.3f}, timing={np.median(times[-100:]):.3e}" | |
) | |
# a100, with tensorcores, with compilation: 2.47e-3 | |
# a100, with tensorcores, without compilation: 4.60e-3 | |
# a100, without tensorcores, with compilation: 2.47e-3 | |
# a100, without tensorcores, without compilation: 4.51e-3 | |
# h100, with tensorcores, with compilation: 2.37e-3 | |
# h100, with tensorcores, without compilation: 4.42e-3 | |
# h100, without tensorcores, with compilation: 2.37e-3 | |
# h100, without tensorcores, without compilation: 5.08e-3 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment