Skip to content

Instantly share code, notes, and snippets.

@KeAWang
Last active November 20, 2023 17:37
Show Gist options
  • Save KeAWang/e93f93585044814b08937c3d2cc92498 to your computer and use it in GitHub Desktop.
Save KeAWang/e93f93585044814b08937c3d2cc92498 to your computer and use it in GitHub Desktop.
TCN experiment with correct residual connection
# %%
import torch
import numpy as np
def make_adding_dataset(num_seqs, seq_len, num_terms=2, seed=43141):
assert 0 <= num_terms <= seq_len
rng = np.random.default_rng(seed=seed)
numbers = rng.uniform(0, 1, (num_seqs, seq_len)) # B x T
mask = np.zeros_like(numbers) # B x T
non_zero = np.stack([rng.choice(seq_len, num_terms, replace=True) for _ in range(num_seqs)]) # B x 2
mask[np.arange(num_seqs)[:, None], non_zero] = 1 # mask[i, non_zero[i, j]]
X = np.stack([numbers, mask], -1)
Y = (numbers * mask).sum(-1) # B
return X, Y
num_train, num_test = 100_000, 10_000
seq_len = 200
X, Y = make_adding_dataset(num_train + num_test, seq_len)
X, Y = torch.as_tensor(X, dtype=torch.float), torch.as_tensor(Y, dtype=torch.float)
X_train, Y_train = X[:num_train], Y[:num_train]
X_test, Y_test = X[-num_test:], Y[-num_test:]
def mse_loss(Y_pred, Y_true):
"""
Inputs:
Y_pred: (B,) shaped tensor of predicted labels (integer-valued)
Y_true: (B,) shaped tensor of true labels (integer-valued)
"""
assert Y_true.ndim == 1, "Y_true must be (B,)"
assert Y_pred.ndim == 1, "Y_pred must be (B,)"
return (Y_true - Y_pred).pow(2).mean(0) # average over the batch
def update(model, loss_fn, optimizer, X_batch, Y_batch):
"""
Inputs:
model: A PyTorch nn.Module
loss_fn: A loss function callable that takes in Y_pred and Y_true
optimizer: A PyTorch torch.optim optimizer
X_batch: A batch of inputs of shape (B, T, input_size)
Y_batch: A batch of labels of shape (B,)
Output:
loss: A scalar tensor of the loss averaged over this batch.
"""
# 1) Reset gradient for next computation
# 2) Forward pass: compute the predictions given inputs
# 3) Compute loss: difference between the pred and true
# 4) Backward pass: compute the weight
# 5) Clip the gradient norm
# 6) Optimizer: update the weights
optimizer.zero_grad()
Y_pred = model(X_batch)
loss = loss_fn(Y_pred.squeeze(-1), Y_batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
return loss
def train(model, loss_fn, optimizer, X_train, Y_train, X_test, Y_test, num_updates, batch_size, seed=17, device="mps"):
rng = torch.Generator(device=device)
rng.manual_seed(seed)
model = model.to(device)
X_train = X_train.to(device)
Y_train = Y_train.to(device)
X_test = X_test.to(device)
Y_test = Y_test.to(device)
model.train()
train_losses, test_losses = [], []
for i in range(int(num_updates)):
model.train()
batch_idx = torch.randint(len(X_train), (batch_size,), generator=rng, device=device)
X_batch = X_train[batch_idx]
Y_batch = Y_train[batch_idx]
train_loss = update(model, loss_fn, optimizer, X_batch, Y_batch).item()
if i % 1000 == 0:
with torch.no_grad():
model.eval()
test_loss = loss_fn(model(X_test).squeeze(-1), Y_test).item()
print(f"Step {i}, Train Batch Loss: {train_loss}, Test Loss: {test_loss}")
train_losses.append(train_loss)
test_losses.append(test_loss)
return train_losses, test_losses
import torch
import torch.nn as nn
class Chomp1d(nn.Module):
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
return x[:, :, :-self.chomp_size].contiguous()
class TemporalBlock(nn.Module):
def __init__(self, input_size, output_size, kernel_size, stride, dilation, padding, dropout=0.2):
super(TemporalBlock, self).__init__()
self.conv1 = nn.Conv1d(input_size, output_size, kernel_size, stride=stride, padding=padding, dilation=dilation)
self.chomp1 = Chomp1d(padding)
self.relu1 = nn.ReLU()
self.dropout1 = nn.Dropout(dropout)
self.conv2 = nn.Conv1d(output_size, output_size, kernel_size, stride=stride, padding=padding, dilation=dilation)
self.chomp2 = Chomp1d(padding)
self.relu2 = nn.ReLU()
self.dropout2 = nn.Dropout(dropout)
self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
self.conv2, self.chomp2, self.relu2, self.dropout2)
self.downsample = nn.Conv1d(input_size, output_size, 1) if input_size != output_size else None
def forward(self, x):
out = self.net(x)
res = x if self.downsample is None else self.downsample(x)
return out + res
class TemporalConvNet(nn.Module):
def __init__(self, num_inputs, channel_sizes, kernel_size=2, dropout=0.2):
super(TemporalConvNet, self).__init__()
layers = []
num_blocks = len(channel_sizes)
for i in range(num_blocks):
dilation_size = 2 ** i
in_channels = num_inputs if i == 0 else channel_sizes[i-1]
out_channels = channel_sizes[i]
layers += [TemporalBlock(in_channels, out_channels, kernel_size, stride=1, dilation=dilation_size,
padding=(kernel_size-1) * dilation_size, dropout=dropout)]
self.network = nn.Sequential(*layers)
def forward(self, x):
return self.network(x)
class TCN(nn.Module):
def __init__(self, input_size, channel_sizes, output_size, kernel_size, dropout, channel_last=False):
super(TCN, self).__init__()
self.tcn = TemporalConvNet(input_size, channel_sizes, kernel_size=kernel_size, dropout=dropout)
self.linear = nn.Linear(channel_sizes[-1], output_size)
self.channel_last = channel_last
def forward(self, x):
if not self.channel_last:
x = x.transpose(-1, -2)
y1 = self.tcn(x)
return self.linear(y1[:, :, -1])
num_updates = 1e5
batch_size = 64
input_size, kernel_size, output_size = 2, 6, 1
channel_sizes = (8,) * 7
model = TCN(input_size=input_size, channel_sizes=channel_sizes, output_size=output_size, kernel_size=kernel_size, dropout=0.)
print(sum(p.numel() for p in model.parameters()))
lr = 3e-3
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
device = "mps"
train_losses, test_losses = train(model, mse_loss, optimizer, X_train, Y_train, X_test, Y_test, num_updates, batch_size, device=device)
# %%
@KeAWang
Copy link
Author

KeAWang commented Nov 18, 2023

Better to keep two convolutions in a block than double the number of channels per layer

@KeAWang
Copy link
Author

KeAWang commented Nov 18, 2023

Version with a GLU at the end of each TCN block:

class TemporalBlock(nn.Module):
    def __init__(self, input_size, output_size, kernel_size, stride, dilation, padding, dropout=0.2):
        super(TemporalBlock, self).__init__()
        self.conv1 = nn.Conv1d(input_size, output_size, kernel_size, stride=stride, padding=padding, dilation=dilation)
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)

        self.conv2 = nn.Conv1d(output_size, 2 * output_size, kernel_size, stride=stride, padding=padding, dilation=dilation)
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)

        self.glu = nn.GLU(dim=-2)

        self.net = nn.Sequential(self.conv1, self.chomp1, self.relu1, self.dropout1,
                                 self.conv2, self.chomp2, self.relu2, self.dropout2, self.glu)

        self.downsample = nn.Conv1d(input_size, output_size, 1) if input_size != output_size else None

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return out + res

Might be useful for sequence modeling instead of sequence classification.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment