Created
May 30, 2023 02:53
-
-
Save TeaPoly/28c554b064a23c4a4ab7927f4a268ecc to your computer and use it in GitHub Desktop.
Deep model with built-in self-attention alignment for acoustic echo cancellation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python | |
# -*- coding: utf-8 -*- | |
# Copyright 2023 Lucky Wong | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License | |
import torch | |
import math | |
def make_pad_mask(lengths: torch.Tensor, max_len: int = None) -> torch.Tensor: | |
"""Make mask tensor containing indices of padded part. | |
See description of make_non_pad_mask. | |
Args: | |
lengths (torch.Tensor): Batch of lengths (B,). | |
Returns: | |
torch.Tensor: Mask tensor containing indices of padded part. | |
Examples: | |
>>> lengths = [5, 3, 2] | |
>>> make_pad_mask(lengths) | |
masks = [[0, 0, 0, 0 ,0], | |
[0, 0, 0, 1, 1], | |
[0, 0, 1, 1, 1]] | |
""" | |
batch_size = int(lengths.size(0)) | |
if max_len is None: | |
max_len = int(lengths.max().item()) | |
seq_range = torch.arange( | |
0, max_len, dtype=torch.int64, device=lengths.device) | |
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) | |
seq_length_expand = lengths.unsqueeze(-1) | |
mask = seq_range_expand >= seq_length_expand | |
return mask | |
def subsequent_chunk_mask( | |
size: int, | |
chunk_size: int, | |
left_chunk_size: int = -1, | |
device: torch.device = torch.device("cpu"), | |
) -> torch.Tensor: | |
"""Create mask for subsequent steps (size, size) with chunk size, | |
this is for streaming encoder | |
Args: | |
size (int): size of mask | |
chunk_size (int): size of chunk | |
device (torch.device): "cpu" or "cuda" or torch.Tensor.device | |
left_chunk_size (int): size of history chunk size | |
Returns: | |
torch.Tensor: mask | |
Examples: | |
>>> subsequent_mask(4, 2, left_chunk_size=1) | |
[[1, 1, 0, 0], | |
[1, 1, 0, 0], | |
[0, 1, 1, 1], | |
[0, 1, 1, 1]] | |
""" | |
ret = torch.zeros(size, size, device=device, dtype=torch.bool) | |
for i in range(size): | |
ending = min((i // chunk_size + 1) * chunk_size, size) | |
ret[i, 0:ending] = True | |
if left_chunk_size != -1: | |
left_start = max(0, (i // chunk_size) * | |
chunk_size - left_chunk_size) | |
ret[i, 0:left_start] = False | |
return ret | |
class AttAlignBlock(torch.nn.Module): | |
"""Attention Align Block. | |
Reference: Deep model with built-in self-attention alignment for acoustic echo cancellation | |
Link: https://arxiv.org/pdf/2208.11308.pdf | |
Args: | |
mdim (int): The number of features. | |
fdim (int): The number of farend features. | |
pdim (int): The projection size. | |
chunk_size (int): The left chunk size. | |
max_delay_blocks (int): The max delay chunk size. | |
""" | |
def __init__( | |
self, | |
mdim: int = 2048, | |
fdim: int = 256, | |
pdim: int = 64, | |
chunk_size: int = 1, | |
max_delay_blocks: int = 80, | |
): | |
"""Construct an AttAlignBlock object.""" | |
super().__init__() | |
self.chunk_size = chunk_size | |
self.max_delay_blocks = max_delay_blocks | |
self.pdim = pdim | |
self.linear_q = torch.nn.Linear(mdim, pdim) | |
self.linear_k = torch.nn.Linear(fdim, pdim) | |
def forward( | |
self, | |
query: torch.Tensor, | |
key: torch.Tensor, | |
lengths: torch.Tensor, | |
) -> torch.Tensor: | |
"""Compute scaled dot product attention. | |
Args: | |
query (torch.Tensor): Microphone features tensor (#batch, time, size). | |
key (torch.Tensor): Farend features tensor (#batch, time, size). | |
lengths (torch.Tensor): Lengths tensor (#batch) | |
Returns: | |
torch.Tensor: Output tensor (#batch, time, d_model). | |
""" | |
q = self.linear_q(query) | |
k = self.linear_k(key) | |
# (#batch, time, size)*(#batch, time, size).T -> (#batch, time, time) | |
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.pdim) | |
masks = ~make_pad_mask(lengths).unsqueeze( | |
1).to(query.device) # (B, 1, L) | |
chunk_masks = subsequent_chunk_mask( | |
query.size(1), self.chunk_size, self.max_delay_blocks, query.device | |
) # (L, L) | |
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) | |
chunk_masks = masks & chunk_masks # (B, L, L) | |
chunk_masks = chunk_masks.eq(0) # (batch, *, time) | |
scores = scores.masked_fill(chunk_masks, -float("inf")) | |
# (#batch, time, time) | |
attn = torch.softmax(scores, dim=-1).masked_fill( | |
chunk_masks, 0.0 | |
) # (batch, head, time, time) | |
# (#batch, time, time)*(#batch, time, size) -> (#batch, time, size) | |
return torch.matmul(attn, key), attn |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment