-
-
Save AhmadMoussa/73a412f0da20181d76b84a87cb48a9ad to your computer and use it in GitHub Desktop.
import torch | |
from torch import nn | |
def convBlock(inc, outc, ksz, conv_or_deconv): | |
return nn.Sequential( | |
nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=ksz, | |
stride=2) if conv_or_deconv else nn.ConvTranspose2d(in_channels=inc, out_channels=outc, | |
kernel_size=ksz, stride=2), | |
nn.LeakyReLU(), | |
nn.BatchNorm2d(num_features=outc) | |
) | |
class UNet(nn.Module): | |
def __init__(self, number_of_layers=6, ksz=(3, 3)): | |
super(UNet, self).__init__() | |
self.ksz = ksz | |
self.number_of_layers = number_of_layers | |
self.sizes = [(2 ** i, 2 ** (i + 1)) for i in range(0, self.number_of_layers)] | |
self.encoder_layers = nn.ModuleList( | |
[convBlock(inc, outc, self.ksz, 1) for i, (inc, outc) in enumerate(self.sizes)]) | |
self.residuals = [] | |
self.decoder_layers = nn.ModuleList( | |
[convBlock(2 * inc, outc, self.ksz, 0) for i, (outc, inc) in enumerate(list(reversed(self.sizes)))]) | |
def forward(self, x): | |
for layer in self.encoder_layers: | |
x = layer(x) | |
print(x.shape) | |
self.residuals.append(x) | |
for residual, layer in zip(reversed(self.residuals[:]), self.decoder_layers): | |
x = torch.cat((residual, x), 1) | |
print(x.shape) | |
x = layer(x) | |
return x | |
unet = UNet(number_of_layers=6, ksz=(3, 3)) | |
''' TEST: | |
import numpy as np | |
outputs = unet(torch.tensor(np.zeros((1, 1, 256, 256))).float()) | |
print(outputs.shape) | |
''' |
I'm not an expert, but maybe there are better choices for this task than a UNet. Depending on your requirements, maybe a Wavenet?
Can you point me to a paper that elaborates on what you have in mind?
What makes you think that UNet is a bad choice for segmentation?
Well first I'd have to know what kind of segmentation you are doing? Are you separating different audio events from each other? Or segmenting multiple speakers? And what does your data look like? As said, I'm not an expert, but since a UNet uses regular convolutions you will need your data to have a specific shape or specific length in seconds if you're working in the time domain. If you're using spectrograms then you'll have to use a specific resolution and train on that. This limits the length (in seconds) of audio you can feed to your network. Also regular convolutions (1D or 2D) generally struggle with audio.
If you're not trying to write something that competes with SOTA papers then I think the UNet can be fine. Otherwise probably not.
Some papers you could check out:
Wavenet -> mainly for generative audio tasks
Adversarial Audio Synthesis -> they use unets to generate audio
I wrote a paper recently on generating audio with conditional GANs, but it's still in the process of being published. If you would like to talk more about this, shoot me an email [email protected]
I was looking for flexible/scalable unet implementations as I will have to search for appropriate proportions for my task at hand. I am applying it to audio segmentation.
I will push it to my git in a week or two