Skip to content

Instantly share code, notes, and snippets.

@taylanbil
Created October 28, 2020 17:57
Show Gist options
  • Save taylanbil/bd04d55751d4119e5268bb35d924e483 to your computer and use it in GitHub Desktop.
Save taylanbil/bd04d55751d4119e5268bb35d924e483 to your computer and use it in GitHub Desktop.
changes to sample negatives correctly.
diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py
index baafe0f9..2114c93a 100644
--- a/fairseq/data/audio/raw_audio_dataset.py
+++ b/fairseq/data/audio/raw_audio_dataset.py
@@ -185,6 +185,7 @@ class RawAudioDataset(FairseqDataset):
(B, T, self._C), padding_mask_reshaped,
)
input["mask_indices"] = mask_indices
+ input['padding_counts'] = input['mask_indices'].sum(-1).tolist()
input["mask_channel_indices"] = mask_channel_indices
out['sample_size'] = mask_indices.sum().item()
diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py
index 21239d8b..be3f8df0 100644
--- a/fairseq/models/wav2vec/wav2vec2.py
+++ b/fairseq/models/wav2vec/wav2vec2.py
@@ -450,7 +450,21 @@ class Wav2Vec2Model(BaseFairseqModel):
return x, mask_indices
- def sample_negatives(self, y, num):
+ def _get_neg_idxs(self, high, size, padding_counts=None):
+ if padding_counts is None:
+ neg_idxs = torch.randint(low=0, high=high-1, size=size)
+ else:
+ bsz, l = size
+ num = l // self.n_negatives
+ assert len(padding_counts) == bsz
+ neg_idxs = [
+ torch.randint(low=0, high=num-pc-1, size=(1, l))
+ for pc in padding_counts
+ ]
+ neg_idxs = torch.stack(neg_idxs)
+ return neg_idxs
+
+ def sample_negatives(self, y, num, padding_counts=None):
if self.n_negatives == 0 and self.cross_sample_negatives == 0:
return y.new(0)
@@ -471,8 +485,9 @@ class Wav2Vec2Model(BaseFairseqModel):
.flatten()
)
- neg_idxs = torch.randint(
- low=0, high=high - 1, size=(bsz, self.n_negatives * num)
+ neg_idxs = self._get_neg_idxs(
+ high, (bsz, self.n_negatives * num),
+ padding_counts=padding_counts,
)
neg_idxs[neg_idxs >= tszs] += 1
@@ -529,7 +544,7 @@ class Wav2Vec2Model(BaseFairseqModel):
def forward(
self, source, padding_mask=None, mask=True, features_only=False,
- mask_indices=None, mask_channel_indices=None,
+ mask_indices=None, mask_channel_indices=None, padding_counts=None,
):
if self.feature_grad_mult > 0:
@@ -608,12 +623,18 @@ class Wav2Vec2Model(BaseFairseqModel):
y = self.project_q(y)
if self.negatives_from_everywhere:
- neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False)
- negs, _ = self.sample_negatives(neg_cands, y.size(1))
+ neg_cands, *_ = self.quantizer(
+ unmasked_features, produce_targets=False,
+ )
+ negs, _ = self.sample_negatives(
+ neg_cands, y.size(1), padding_counts=padding_counts,
+ )
negs = self.project_q(negs)
else:
- negs, _ = self.sample_negatives(y, y.size(1))
+ negs, _ = self.sample_negatives(
+ y, y.size(1), padding_counts=padding_counts,
+ )
if self.codebook_negatives > 0:
cb_negs = self.quantizer.sample_from_codebook(
@@ -628,10 +649,14 @@ class Wav2Vec2Model(BaseFairseqModel):
y = self.project_q(y)
if self.negatives_from_everywhere:
- negs, _ = self.sample_negatives(unmasked_features, y.size(1))
+ negs, _ = self.sample_negatives(
+ unmasked_features, y.size(1), padding_counts=padding_counts,
+ )
negs = self.project_q(negs)
else:
- negs, _ = self.sample_negatives(y, y.size(1))
+ negs, _ = self.sample_negatives(
+ y, y.size(1), padding_counts=padding_counts,
+ )
if x.device.type != 'xla':
# tpu-comment: reducing the size in a dynamic way causes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment