Created
December 6, 2023 04:47
-
-
Save madebyollin/69440ecb9805ebd60aeafaf533008a9e to your computer and use it in GitHub Desktop.
Mamba Diffusion (IADB)
Code with the fixes
class Denoiser(nn.Module):
def __init__(self, n_io=Config.channels, n_f=128, n_b=8):
super().__init__()
assert n_b % 8 == 0, f"Silly flipping logic breaks if n_b is not divisible by 8"
self.enc = nn.Sequential(nn.Conv2d(n_io + 1, n_f, 1), nn.ReLU(), nn.Conv2d(n_f, n_f, 1, bias=False), nn.PixelUnshuffle(2))
self.mid = nn.ModuleList(Block(n_f * 4, Mamba) for _ in range(n_b))
self.dec = nn.Sequential(nn.Conv2d(n_f * 12, n_f * 4, 1), nn.ReLU(), nn.PixelShuffle(2), nn.Conv2d(n_f, n_io, 1))
def transpose_xy(self, *args):
# swap x/y axes of an N[XY]C tensor
return [a.view(a.shape[0], int(a.shape[1]**0.5), int(a.shape[1]**0.5), a.shape[2]).transpose(1, 2).reshape(a.shape) for a in args]
def flip_s(self, *args):
# reverse sequence axis of an NSE tensor
return [a.flip(1) for a in args]
def forward(self, x_noisy, noise_level):
x = self.enc(th.cat([x_noisy, noise_level.expand(x_noisy[:, :1].shape)], 1))
y = x.flatten(2).transpose(-2, -1)
z = None
for i, mid in enumerate(self.mid):
y, z = mid(y, z)
# make mamba's 1d conv alternate axes (possible alternative: make mamba use a 2d conv...somehow...)
y, z = self.transpose_xy(y, z)
if (i + 1) % 4 == 0:
# let the network also process the sequence of both directions, by reversing every 4 layers
y, z = self.flip_s(y, z)
y, z = y.transpose(-2, -1).view(x.shape), z.transpose(-2, -1).view(x.shape)
out = self.dec(th.cat([x, y, z], 1))
return Prediction(IADB.target_to_denoised(out, x_noisy, noise_level).detach(), out)
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Still needs more training I guess