Last active
December 8, 2023 17:37
-
-
Save nousr/afb467d28d629e9809f7ab9c183a408c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""" | |
Proof of concept "DiM" - nousr | |
general structure was "transpiled" from DiT by meta | |
bi-direction idea comes from DifuSSM (https://arxiv.org/abs/2311.18257) | |
""" | |
import torch | |
import math | |
from timm.models.vision_transformer import PatchEmbed | |
from einops import rearrange | |
from torch import nn | |
from mamba_ssm.modules.mamba_simple import Mamba, Block as MambaBlock | |
def exists(x): | |
return x is not None | |
def default(x, default): | |
return x if exists(x) else default | |
def modulate(x, shift, scale): | |
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) | |
class AdaLNModulation(nn.Module): | |
""" | |
implements AdaLN-Zero from DiT | |
""" | |
def __init__(self, hidden_dim, expansion_factor=2) -> None: | |
super().__init__() | |
self.layers = nn.Sequential( | |
nn.SiLU(), nn.Linear(hidden_dim, expansion_factor * hidden_dim, bias=True) | |
) | |
self.init_weights() | |
def init_weights(self): | |
nn.init.constant_(self.layers[-1].weight, 0) | |
nn.init.constant_(self.layers[-1].bias, 0) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.layers(x) | |
class FinalLayer(nn.Module): | |
def __init__(self, hidden_dim, patch_size, out_channels) -> None: | |
super().__init__() | |
self.norm_final = nn.LayerNorm(hidden_dim, elementwise_affine=False, eps=1e-6) | |
self.linear = nn.Linear( | |
hidden_dim, patch_size * patch_size * out_channels, bias=True | |
) | |
self.adaLN_modulation = AdaLNModulation(hidden_dim=hidden_dim) | |
self.init_weights() | |
def init_weights(self): | |
nn.init.constant_(self.linear.weight, 0) | |
nn.init.constant_(self.linear.bias, 0) | |
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: | |
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) | |
x = modulate(self.norm_final(x), shift, scale) | |
x = self.linear(x) | |
return x | |
class MLP(nn.Module): | |
def __init__(self, dim, output_dim=None, mult=1): | |
super().__init__() | |
self.layers = nn.Sequential( | |
nn.Linear(dim, dim * mult), | |
nn.SiLU(), | |
nn.Linear(dim * mult, output_dim or dim), | |
) | |
def forward(self, x): | |
return self.layers(x) | |
class ModulatedMambaBlock(nn.Module): | |
def __init__(self, hidden_dim): | |
super().__init__() | |
self.mixer = Mamba(hidden_dim) | |
self.norm = nn.LayerNorm(hidden_dim) | |
self.adaLN_mod = AdaLNModulation(hidden_dim=hidden_dim, expansion_factor=3) | |
def forward(self, hidden_states, residual, c): | |
# apply the residual unless we're in the first state | |
# cast to fp32 (was the default for example mamba models) | |
# TODO: try adding shift, scale, gate for the residual? | |
residual = ( | |
(hidden_states + residual) if residual is not None else hidden_states | |
).to(torch.float32) | |
# find the shift & scale for the conditioning | |
shift, scale, gate = self.adaLN_mod(c).chunk(3, dim=1) | |
# norm the residual to create the new hidden state | |
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) | |
# apply shift and scale to the normed hidden state | |
hidden_states = modulate(hidden_states, shift, scale) | |
# actually send through mamba layer after its been modulated | |
hidden_states = self.mixer(hidden_states, inference_params=None) | |
# apply the gate | |
hidden_states = hidden_states + gate.unsqueeze(1) * hidden_states | |
return hidden_states, residual | |
class BidirectionalModulatedMambaBlock(nn.Module): | |
def __init__(self, hidden_dim): | |
super().__init__() | |
self.mamba_fwd = ModulatedMambaBlock(hidden_dim) | |
self.mamba_bwd = ModulatedMambaBlock(hidden_dim) | |
self.proj_x = nn.Linear(hidden_dim * 2, hidden_dim) | |
def forward(self, x, residual, bwd_residual, c): | |
# send through both blocks | |
x_fwd, x_fwd_residual = self.mamba_fwd(x, residual, c) | |
x_bwd, x_bwd_residual = self.mamba_bwd(x.flip(dims=[1]), bwd_residual, c) | |
# flip the bwd | |
x_bwd = x_bwd.flip(dims=[1]) | |
# combine along embedding dimension | |
x = torch.cat([x_fwd, x_bwd], dim=-1) | |
# project | |
x = self.proj_x(x) | |
return x, x_fwd_residual, x_bwd_residual | |
class TimestepEmbedder(nn.Module): | |
""" | |
Embeds scalar timesteps into vector representations. | |
""" | |
def __init__(self, hidden_size, frequency_embedding_size=256): | |
super().__init__() | |
self.mlp = nn.Sequential( | |
nn.Linear(frequency_embedding_size, hidden_size, bias=True), | |
nn.SiLU(), | |
nn.Linear(hidden_size, hidden_size, bias=True), | |
) | |
self.frequency_embedding_size = frequency_embedding_size | |
@staticmethod | |
def timestep_embedding(t, dim, max_period=10000): | |
""" | |
Create sinusoidal timestep embeddings. | |
:param t: a 1-D Tensor of N indices, one per batch element. | |
These may be fractional. | |
:param dim: the dimension of the output. | |
:param max_period: controls the minimum frequency of the embeddings. | |
:return: an (N, D) Tensor of positional embeddings. | |
""" | |
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py | |
half = dim // 2 | |
freqs = torch.exp( | |
-math.log(max_period) | |
* torch.arange(start=0, end=half, dtype=torch.float32) | |
/ half | |
).to(device=t.device) | |
args = t[:, None].float() * freqs[None] | |
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
if dim % 2: | |
embedding = torch.cat( | |
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1 | |
) | |
return embedding | |
def forward(self, t): | |
t_freq = self.timestep_embedding(t, self.frequency_embedding_size) | |
t_emb = self.mlp(t_freq) | |
return t_emb | |
class DiM(nn.Module): | |
def __init__( | |
self, | |
hidden_dim=512, | |
image_size=64, | |
patch_size=4, | |
in_channels=3, | |
out_channels=None, | |
depth=4, | |
): | |
super().__init__() | |
# NOTE: stuff for lucidrains wrapper | |
self.random_or_learned_sinusoidal_cond = True | |
self.self_condition = False | |
self.in_channels = in_channels | |
self.out_channels = default(out_channels, in_channels) | |
self.patch_size = patch_size | |
self.x_embedder = PatchEmbed( | |
image_size, patch_size, in_channels, hidden_dim, bias=True | |
) | |
self.blocks = nn.ModuleList( | |
[BidirectionalModulatedMambaBlock(hidden_dim) for _ in range(depth)] | |
) | |
self.t_embedder = TimestepEmbedder(hidden_dim) | |
self.final_layer = FinalLayer( | |
hidden_dim=hidden_dim, patch_size=patch_size, out_channels=self.out_channels | |
) | |
def init_weights(self): | |
""" | |
init weights according to some reference implementations | |
""" | |
def _basic_init(module): | |
if isinstance(module, nn.Linear): | |
torch.nn.init.xavier_uniform_(module.weight) | |
if module.bias is not None: | |
nn.init.constant_(module.bias, 0) | |
self.apply(_basic_init) | |
# initialize the patch embed like nn.linear (instead of nn.conv2d) | |
w = self.x_embedder.proj.weight.data | |
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
nn.init.constant_(self.x_embedder.proj.bias, 0) | |
# initialize timestep embedding MLP | |
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) | |
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) | |
def unpatchify(self, x): | |
""" | |
x: (N, T, patch_size**2 * C) | |
imgs: (N, H, W, C) | |
""" | |
c = self.out_channels | |
p = self.x_embedder.patch_size[0] | |
h = w = int(x.shape[1] ** 0.5) | |
assert h * w == x.shape[1] | |
x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) | |
x = torch.einsum("nhwpqc->nchpwq", x) | |
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) | |
return imgs | |
def forward(self, x, t): | |
""" | |
forward pass of DiM | |
""" | |
x = self.x_embedder(x) | |
t = self.t_embedder(t) | |
residual = None | |
bwd_residual = None | |
for block in self.blocks: | |
x, residual, bwd_residual = block(x, residual, bwd_residual, t) | |
x = self.final_layer(x, t) | |
x = self.unpatchify(x) | |
return x |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment