Created
October 28, 2020 17:57
-
-
Save taylanbil/bd04d55751d4119e5268bb35d924e483 to your computer and use it in GitHub Desktop.
changes to sample negatives correctly.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py | |
index baafe0f9..2114c93a 100644 | |
--- a/fairseq/data/audio/raw_audio_dataset.py | |
+++ b/fairseq/data/audio/raw_audio_dataset.py | |
@@ -185,6 +185,7 @@ class RawAudioDataset(FairseqDataset): | |
(B, T, self._C), padding_mask_reshaped, | |
) | |
input["mask_indices"] = mask_indices | |
+ input['padding_counts'] = input['mask_indices'].sum(-1).tolist() | |
input["mask_channel_indices"] = mask_channel_indices | |
out['sample_size'] = mask_indices.sum().item() | |
diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py | |
index 21239d8b..be3f8df0 100644 | |
--- a/fairseq/models/wav2vec/wav2vec2.py | |
+++ b/fairseq/models/wav2vec/wav2vec2.py | |
@@ -450,7 +450,21 @@ class Wav2Vec2Model(BaseFairseqModel): | |
return x, mask_indices | |
- def sample_negatives(self, y, num): | |
+ def _get_neg_idxs(self, high, size, padding_counts=None): | |
+ if padding_counts is None: | |
+ neg_idxs = torch.randint(low=0, high=high-1, size=size) | |
+ else: | |
+ bsz, l = size | |
+ num = l // self.n_negatives | |
+ assert len(padding_counts) == bsz | |
+ neg_idxs = [ | |
+ torch.randint(low=0, high=num-pc-1, size=(1, l)) | |
+ for pc in padding_counts | |
+ ] | |
+ neg_idxs = torch.stack(neg_idxs) | |
+ return neg_idxs | |
+ | |
+ def sample_negatives(self, y, num, padding_counts=None): | |
if self.n_negatives == 0 and self.cross_sample_negatives == 0: | |
return y.new(0) | |
@@ -471,8 +485,9 @@ class Wav2Vec2Model(BaseFairseqModel): | |
.flatten() | |
) | |
- neg_idxs = torch.randint( | |
- low=0, high=high - 1, size=(bsz, self.n_negatives * num) | |
+ neg_idxs = self._get_neg_idxs( | |
+ high, (bsz, self.n_negatives * num), | |
+ padding_counts=padding_counts, | |
) | |
neg_idxs[neg_idxs >= tszs] += 1 | |
@@ -529,7 +544,7 @@ class Wav2Vec2Model(BaseFairseqModel): | |
def forward( | |
self, source, padding_mask=None, mask=True, features_only=False, | |
- mask_indices=None, mask_channel_indices=None, | |
+ mask_indices=None, mask_channel_indices=None, padding_counts=None, | |
): | |
if self.feature_grad_mult > 0: | |
@@ -608,12 +623,18 @@ class Wav2Vec2Model(BaseFairseqModel): | |
y = self.project_q(y) | |
if self.negatives_from_everywhere: | |
- neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False) | |
- negs, _ = self.sample_negatives(neg_cands, y.size(1)) | |
+ neg_cands, *_ = self.quantizer( | |
+ unmasked_features, produce_targets=False, | |
+ ) | |
+ negs, _ = self.sample_negatives( | |
+ neg_cands, y.size(1), padding_counts=padding_counts, | |
+ ) | |
negs = self.project_q(negs) | |
else: | |
- negs, _ = self.sample_negatives(y, y.size(1)) | |
+ negs, _ = self.sample_negatives( | |
+ y, y.size(1), padding_counts=padding_counts, | |
+ ) | |
if self.codebook_negatives > 0: | |
cb_negs = self.quantizer.sample_from_codebook( | |
@@ -628,10 +649,14 @@ class Wav2Vec2Model(BaseFairseqModel): | |
y = self.project_q(y) | |
if self.negatives_from_everywhere: | |
- negs, _ = self.sample_negatives(unmasked_features, y.size(1)) | |
+ negs, _ = self.sample_negatives( | |
+ unmasked_features, y.size(1), padding_counts=padding_counts, | |
+ ) | |
negs = self.project_q(negs) | |
else: | |
- negs, _ = self.sample_negatives(y, y.size(1)) | |
+ negs, _ = self.sample_negatives( | |
+ y, y.size(1), padding_counts=padding_counts, | |
+ ) | |
if x.device.type != 'xla': | |
# tpu-comment: reducing the size in a dynamic way causes |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment