Skip to content

Instantly share code, notes, and snippets.

@saliksyed
Created October 29, 2023 14:12
Show Gist options
  • Save saliksyed/6831fe421f3e30e5a4bc7de51cf0f7dd to your computer and use it in GitHub Desktop.
Save saliksyed/6831fe421f3e30e5a4bc7de51cf0f7dd to your computer and use it in GitHub Desktop.
ViT encoder/decoder
import random
import torchvision
import torch
import torch.nn as nn
import skimage
from torch.utils.data import Dataset
from typing import Callable
from collections import OrderedDict
from functools import partial
import torchvision.transforms as transforms
from torch.nn import functional as F
class MLPBlock(torchvision.ops.MLP):
def __init__(self, in_dim: int, mlp_dim: int, dropout: float):
super().__init__(
in_dim,
[mlp_dim, in_dim],
activation_layer=nn.GELU,
inplace=None,
dropout=dropout,
)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.normal_(m.bias, std=1e-6)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
version = local_metadata.get("version", None)
if version is None or version < 2:
# Replacing legacy MLPBlock with MLP. See https://github.com/pytorch/vision/pull/6053
for i in range(2):
for type in ["weight", "bias"]:
old_key = f"{prefix}linear_{i+1}.{type}"
new_key = f"{prefix}{3*i}.{type}"
if old_key in state_dict:
state_dict[new_key] = state_dict.pop(old_key)
super()._load_from_state_dict(
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
class EncoderBlock(nn.Module):
"""Transformer encoder block."""
def __init__(
self,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
self.num_heads = num_heads
# Attention block
self.ln_1 = norm_layer(hidden_dim)
self.self_attention = nn.MultiheadAttention(
hidden_dim, num_heads, dropout=attention_dropout, batch_first=True
)
self.dropout = nn.Dropout(dropout)
# MLP block
self.ln_2 = norm_layer(hidden_dim)
self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)
def forward(self, input: torch.Tensor):
torch._assert(
input.dim() == 3,
f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}",
)
x = self.ln_1(input)
x, _ = self.self_attention(x, x, x, need_weights=False)
x = self.dropout(x)
x = x + input
y = self.ln_2(x)
y = self.mlp(y)
return x + y
class Encoder(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation."""
def __init__(
self,
seq_length: int,
num_layers: int,
num_heads: int,
hidden_dim: int,
mlp_dim: int,
dropout: float,
attention_dropout: float,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
# Note that batch_size is on the first dim because
# we have batch_first=True in nn.MultiAttention() by default
self.pos_embedding = nn.Parameter(
torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)
) # from BERT
self.dropout = nn.Dropout(dropout)
layers: OrderedDict[str, nn.Module] = OrderedDict()
for i in range(num_layers):
layers[f"encoder_layer_{i}"] = EncoderBlock(
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
self.layers = nn.Sequential(layers)
self.ln = norm_layer(hidden_dim)
def forward(self, input: torch.Tensor):
torch._assert(
input.dim() == 3,
f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}",
)
input = input + self.pos_embedding
return self.ln(self.layers(self.dropout(input)))
class ViTEncoder(nn.Module):
def __init__(
self,
image_size: int,
patch_size: int,
hidden_dim: int,
mlp_dim: int,
num_layers: int,
num_heads: int,
dropout: float = 0.0,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
attention_dropout: float = 0.0,
):
super().__init__()
self.image_size = image_size
self.hidden_dim = hidden_dim
self.patch_size = patch_size
self.conv_proj = nn.Conv2d(
in_channels=3,
out_channels=hidden_dim,
kernel_size=patch_size,
stride=patch_size,
)
seq_length = (image_size // patch_size) ** 2
self.mlp_dim = mlp_dim
self.encoder = Encoder(
seq_length,
num_layers,
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
def forward(self, input: torch.Tensor):
x = input
n, c, h, w = x.shape
p = self.patch_size
torch._assert(
h == self.image_size,
f"Wrong image height! Expected {self.image_size} but got {h}!",
)
torch._assert(
w == self.image_size,
f"Wrong image width! Expected {self.image_size} but got {w}!",
)
n_h = h // p
n_w = w // p
x = self.conv_proj(x)
x = x.reshape(n, self.hidden_dim, n_h * n_w)
x = x.permute(0, 2, 1)
x = self.encoder(x)
return x
class ViTDecoder(nn.Module):
def __init__(
self,
image_size: int,
patch_size: int,
hidden_dim: int,
mlp_dim: int,
num_layers: int,
num_heads: int,
dropout: float = 0.0,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
attention_dropout: float = 0.0,
):
super().__init__()
self.image_size = image_size
self.hidden_dim = hidden_dim
seq_length = (image_size // patch_size) ** 2
self.mlp_dim = mlp_dim
self.encoder = Encoder(
seq_length,
num_layers,
num_heads,
hidden_dim,
mlp_dim,
dropout,
attention_dropout,
norm_layer,
)
def forward(self, input: torch.Tensor):
x = input
n, t, w = x.shape
x = self.encoder(x)
return x
class DepthPredictor(nn.Module):
def __init__(
self,
image_size: int,
patch_size: int,
hidden_dim: int,
mlp_dim: int,
decoder_dim: int,
num_layers: int,
num_heads: int,
dropout: float = 0.0,
norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6),
):
super().__init__()
self.image_size = image_size
self.encoder = ViTEncoder(
image_size,
patch_size,
hidden_dim,
mlp_dim,
num_layers,
num_heads,
dropout,
norm_layer,
)
self.transform = nn.Linear(hidden_dim, decoder_dim)
self.decoder = ViTDecoder(
image_size,
patch_size,
decoder_dim,
mlp_dim,
8,
16,
dropout,
norm_layer,
)
self.reconstruct = nn.Linear(decoder_dim, 768)
def forward(self, input: torch.Tensor, targets: torch.Tensor = None):
n, c, h, w = input.shape
x = input
torch._assert(
h == self.image_size,
f"Wrong image height! Expected {self.image_size} but got {h}!",
)
torch._assert(
w == self.image_size,
f"Wrong image width! Expected {self.image_size} but got {w}!",
)
x = self.encoder(x)
x = self.transform(x)
x = self.decoder(x)
x = self.reconstruct(x)
patches = F.unfold(targets, kernel_size=16, stride=16)
patches = patches.permute(0, 2, 1)
# split the targets into patches
error = F.mse_loss(x, patches)
return x, error
class CustomDataset(Dataset):
def __init__(self, data, transforms=None):
self.data = data
self.transforms = transforms
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
image = self.data[idx]
if self.transforms != None:
image = self.transforms(image)
return image
albedo_images = []
normal_images = []
for i in range(0, 1000):
albedo = f"data/example_{i}_albedo0001.png"
target = f"data/example_{i}_normal0001.png"
albedo_images.append(skimage.io.imread(albedo))
normal_images.append(skimage.io.imread(target))
curr_transforms = [
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
inputs = CustomDataset(
albedo_images,
transforms=transforms.Compose(curr_transforms),
)
outputs = CustomDataset(
normal_images,
transforms=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
),
)
epochs = 10
batch_size = 32
device = "mps"
network = DepthPredictor(128, 16, 768, 3072, 768, 12, 12, 0.0)
network.train()
network.to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(network.parameters(), lr=1.5e-4)
checkpoint_path = "weights/chkpt.pt"
# # # ### Train
for epoch in range(epochs):
print(f"Epoch {epoch+1}/{epochs}")
print(f"Training on {len(inputs)} samples")
for i in range(0, 1000):
train = []
target = []
for j in range(0, batch_size):
r = random.randint(0, len(inputs) - 1)
train.append(inputs[r])
target.append(inputs[r])
images = torch.stack(train).to(device)
target = torch.stack(target).to(device)
# # zeroing gradients
optimizer.zero_grad()
output, loss = network(images, target)
if i % 10 == 0:
print(loss)
loss.backward()
optimizer.step()
checkpoint = {
"model": network.state_dict(),
"optimizer": optimizer.state_dict(),
}
print(f"saving checkpoint")
torch.save(checkpoint, checkpoint_path)
# ### Test
checkpoint = torch.load(checkpoint_path, map_location=device)
state_dict = checkpoint["model"]
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
network.load_state_dict(state_dict)
albedo_images = []
normal_images = []
for i in range(1001, 1100):
albedo = f"data/example_{i}_albedo0001.png"
target = f"data/example_{i}_normal0001.png"
albedo_images.append(skimage.io.imread(albedo))
normal_images.append(skimage.io.imread(target))
curr_transforms = [
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
test_inputs = CustomDataset(
albedo_images,
transforms=transforms.Compose(curr_transforms),
)
test_outputs = CustomDataset(
normal_images,
transforms=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
),
)
idx = 28
test = torch.stack([inputs[idx]]).to(device)
result = torch.stack([inputs[idx]]).to(device)
output, loss = network(test, result)
estimate = F.fold(
output.permute(0, 2, 1), output_size=(128, 128), kernel_size=16, stride=16
)
import matplotlib.pyplot as plt
viz = estimate.detach().to("cpu").permute(0, 2, 3, 1).numpy().reshape(128, 128, 3)
f, ax = plt.subplots(1, 2)
ax[0].imshow(viz)
ax[1].imshow(inputs[1].permute(1, 2, 0))
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment