Last active
November 20, 2023 17:37
-
-
Save KeAWang/e93f93585044814b08937c3d2cc92498 to your computer and use it in GitHub Desktop.
TCN experiment with correct residual connection
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
import torch | |
import numpy as np | |
def make_adding_dataset(num_seqs, seq_len, num_terms=2, seed=43141): | |
assert 0 <= num_terms <= seq_len | |
rng = np.random.default_rng(seed=seed) | |
numbers = rng.uniform(0, 1, (num_seqs, seq_len)) # B x T | |
mask = np.zeros_like(numbers) # B x T | |
non_zero = np.stack([rng.choice(seq_len, num_terms, replace=True) for _ in range(num_seqs)]) # B x 2 | |
mask[np.arange(num_seqs)[:, None], non_zero] = 1 # mask[i, non_zero[i, j]] | |
X = np.stack([numbers, mask], -1) | |
Y = (numbers * mask).sum(-1) # B | |
return X, Y | |
num_train, num_test = 100_000, 10_000 | |
seq_len = 200 | |
X, Y = make_adding_dataset(num_train + num_test, seq_len) | |
X, Y = torch.as_tensor(X, dtype=torch.float), torch.as_tensor(Y, dtype=torch.float) | |
X_train, Y_train = X[:num_train], Y[:num_train] | |
X_test, Y_test = X[-num_test:], Y[-num_test:] | |
def mse_loss(Y_pred, Y_true): | |
""" | |
Inputs: | |
Y_pred: (B,) shaped tensor of predicted labels (integer-valued) | |
Y_true: (B,) shaped tensor of true labels (integer-valued) | |
""" | |
assert Y_true.ndim == 1, "Y_true must be (B,)" | |
assert Y_pred.ndim == 1, "Y_pred must be (B,)" | |
return (Y_true - Y_pred).pow(2).mean(0) # average over the batch | |
def update(model, loss_fn, optimizer, X_batch, Y_batch): | |
""" | |
Inputs: | |
model: A PyTorch nn.Module | |
loss_fn: A loss function callable that takes in Y_pred and Y_true | |
optimizer: A PyTorch torch.optim optimizer | |
X_batch: A batch of inputs of shape (B, T, input_size) | |
Y_batch: A batch of labels of shape (B,) | |
Output: | |
loss: A scalar tensor of the loss averaged over this batch. | |
""" | |
# 1) Reset gradient for next computation | |
# 2) Forward pass: compute the predictions given inputs | |
# 3) Compute loss: difference between the pred and true | |
# 4) Backward pass: compute the weight | |
# 5) Clip the gradient norm | |
# 6) Optimizer: update the weights | |
optimizer.zero_grad() | |
Y_pred = model(X_batch) | |
loss = loss_fn(Y_pred.squeeze(-1), Y_batch) | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
optimizer.step() | |
return loss | |
def train(model, loss_fn, optimizer, X_train, Y_train, X_test, Y_test, num_updates, batch_size, seed=17, device="mps"): | |
rng = torch.Generator(device=device) | |
rng.manual_seed(seed) | |
model = model.to(device) | |
X_train = X_train.to(device) | |
Y_train = Y_train.to(device) | |
X_test = X_test.to(device) | |
Y_test = Y_test.to(device) | |
model.train() | |
train_losses, test_losses = [], [] | |
for i in range(int(num_updates)): | |
model.train() | |
batch_idx = torch.randint(len(X_train), (batch_size,), generator=rng, device=device) | |
X_batch = X_train[batch_idx] | |
Y_batch = Y_train[batch_idx] | |
train_loss = update(model, loss_fn, optimizer, X_batch, Y_batch).item() | |
if i % 1000 == 0: | |
with torch.no_grad(): | |
model.eval() | |
test_loss = loss_fn(model(X_test).squeeze(-1), Y_test).item() | |
print(f"Step {i}, Train Batch Loss: {train_loss}, Test Loss: {test_loss}") | |
train_losses.append(train_loss) | |
test_losses.append(test_loss) | |
return train_losses, test_losses | |
import torch | |
import torch.nn as nn | |
class Chomp1d(nn.Module): | |
def __init__(self, chomp_size): | |
super(Chomp1d, self).__init__() | |
self.chomp_size = chomp_size | |
def forward(self, x): | |
return x[:, :, :-self.chomp_size].contiguous() | |
class TemporalBlock(nn.Module): | |
def __init__(self, input_size, output_size, kernel_size, stride, dilation, padding, dropout=0.2): | |
super(TemporalBlock, self).__init__() | |
self.conv1 = nn.Conv1d(input_size, output_size, kernel_size, stride=stride, padding=padding, dilation=dilation) | |
self.chomp1 = Chomp1d(padding) | |
self.relu1 = nn.ReLU() | |
self.dropout1 = nn.Dropout(dropout) | |
self.conv2 = nn.Conv1d(output_size, output_size, kernel_size, stride=stride, padding=padding, dilation=dilation) | |
self.chomp2 = Chomp1d(padding) | |
self.relu2 = nn.ReLU() | |
self.dropout2 = nn.Dropout(dropout) | |
self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1, | |
self.conv2, self.chomp2, self.relu2, self.dropout2) | |
self.downsample = nn.Conv1d(input_size, output_size, 1) if input_size != output_size else None | |
def forward(self, x): | |
out = self.net(x) | |
res = x if self.downsample is None else self.downsample(x) | |
return out + res | |
class TemporalConvNet(nn.Module): | |
def __init__(self, num_inputs, channel_sizes, kernel_size=2, dropout=0.2): | |
super(TemporalConvNet, self).__init__() | |
layers = [] | |
num_blocks = len(channel_sizes) | |
for i in range(num_blocks): | |
dilation_size = 2 ** i | |
in_channels = num_inputs if i == 0 else channel_sizes[i-1] | |
out_channels = channel_sizes[i] | |
layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size, | |
padding=(kernel_size-1) * dilation_size, dropout=dropout)] | |
self.network = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.network(x) | |
class TCN(nn.Module): | |
def __init__(self, input_size, channel_sizes, output_size, kernel_size, dropout, channel_last=False): | |
super(TCN, self).__init__() | |
self.tcn = TemporalConvNet(input_size, channel_sizes, kernel_size=kernel_size, dropout=dropout) | |
self.linear = nn.Linear(channel_sizes[-1], output_size) | |
self.channel_last = channel_last | |
def forward(self, x): | |
if not self.channel_last: | |
x = x.transpose(-1, -2) | |
y1 = self.tcn(x) | |
return self.linear(y1[:, :, -1]) | |
num_updates = 1e5 | |
batch_size = 64 | |
input_size, kernel_size, output_size = 2, 6, 1 | |
channel_sizes = (8,) * 7 | |
model = TCN(input_size=input_size, channel_sizes=channel_sizes, output_size=output_size, kernel_size=kernel_size, dropout=0.) | |
print(sum(p.numel() for p in model.parameters())) | |
lr = 3e-3 | |
optimizer = torch.optim.AdamW(model.parameters(), lr=lr) | |
device = "mps" | |
train_losses, test_losses = train(model, mse_loss, optimizer, X_train, Y_train, X_test, Y_test, num_updates, batch_size, device=device) | |
# %% |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Version with a GLU at the end of each TCN block:
Might be useful for sequence modeling instead of sequence classification.