Skip to content

Instantly share code, notes, and snippets.

@TeaPoly
Last active June 10, 2025 06:06
Show Gist options
  • Save TeaPoly/bd643cad1f53a01be637ca0f5856dbb6 to your computer and use it in GitHub Desktop.
Save TeaPoly/bd643cad1f53a01be637ca0f5856dbb6 to your computer and use it in GitHub Desktop.
RESOURCE-CONSTRAINED STEREO SINGING VOICE CANCELLATION
#!/usr/bin/python
# -*- coding: utf-8 -*-
# Copyright 2025 Lucky Wong
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import torch.nn as nn
class S_Conv(nn.Module):
"""S-Conv Block"""
def __init__(self, channels, kernel_size, dilation, causal: bool = False):
super(S_Conv, self).__init__()
# 1x1 Conv
self.conv1x1 = nn.Conv1d(channels, channels, 1)
self.prelu_1 = nn.PReLU()
self.norm_1 = nn.LayerNorm(channels)
# D Conv
self.causal = causal
self.pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
dilation * (kernel_size - 1))
self.dconv = nn.Conv1d(
channels,
channels,
kernel_size,
groups=channels,
dilation=dilation,
padding=self.pad)
self.prelu_2 = nn.PReLU()
self.norm_2 = nn.LayerNorm(channels)
def forward(self, x):
residual = x
# 1x1 Conv
# [1, 448, 5511]
x = self.conv1x1(x)
x = self.prelu_1(x)
x = self.norm_1(x.transpose(1, 2)).transpose(1, 2)
# D Conv
# [1, 448, 5511]
x = self.dconv(x)
if self.causal:
x = x[:, :, :-self.pad]
x = self.prelu_2(x)
x = self.norm_2(x.transpose(1, 2)).transpose(1, 2)
return residual+x
class Separator(nn.Module):
def __init__(
self,
N: int = 384, # Number of channels in input
B: int = 448, # Number of channels in bottleneck and the residual paths’ 1 × 1-conv blocks
R: int = 4, # Number of repeats
P: int = 3, # Kernel size in convolutional blocks
X: int = 9, # Number of convolutional blocks in each repeat
):
super(Separator, self).__init__()
self.norm = nn.LayerNorm(N)
self.conv1x1_1 = nn.Conv1d(N, B, 1)
self.s_conv_layers = nn.ModuleList([
nn.Sequential(
*[S_Conv(B, P, dilation=2**j, causal=False if i == 0 else True)
for j in range(X)]
)
for i in range(R)]
)
self.conv1x1_2 = nn.Conv1d(B, N, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# [1, 384, 5511]
x = self.norm(x.transpose(1, 2)).transpose(1, 2)
# [1, 384, 5511] -> [1, 448, 5511]
x = self.conv1x1_1(x)
# [1, 448, 5511] -> [1, 448, 5511]
for layer in self.s_conv_layers:
x = layer(x)
# [1, 448, 5511] -> [1, 384, 5511]
x = self.conv1x1_2(x)
return self.sigmoid(x)
class MonoVoxTasNet(nn.Module):
""" RESOURCE-CONSTRAINED STEREO SINGING VOICE CANCELLATION """
def __init__(
self,
L: int = 64, # Length of the filters (in samples)
N: int = 384, # Number of channels in input
B: int = 448, # Number of channels in bottleneck and the residual paths’ 1 × 1-conv blocks
R: int = 4, # Number of repeats
P: int = 3, # Kernel size in convolutional blocks
X: int = 9, # Number of convolutional blocks in each repeat
quantize: bool = False,
) -> None:
super(MonoVoxTasNet, self).__init__()
self.encoder = nn.Conv1d(
1, N, kernel_size=L, stride=L//2)
self.separator = Separator(N=N, B=B, R=R, P=P, X=X)
self.decoder = nn.ConvTranspose1d(
N, 1, kernel_size=L, stride=L//2)
def forward(self, x):
# [1, 1, 176400] -> [1, 384, 5511]
encoded = self.encoder(x)
# [1, 384, 5511] -> [1, 384, 5511]
mask = self.separator(encoded)
# [1, 384, 5511] -> [1, 1, 176384]
return self.transposed_conv(encoded * mask)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment