Skip to content

Commit 2e9fb13

Browse files
patrickvonplatenPatrick von Platen
and
Patrick von Platen
authored
[Wav2Vec2] Correctly pad mask indices for PreTraining (#12748)
* fix_torch_device_generate_test * remove @ * start adding tests * correct wav2vec2 pretraining * up * up Co-authored-by: Patrick von Platen <patrick@huggingface.co>
1 parent 5f2791c commit 2e9fb13

File tree

7 files changed

+98
-5
lines changed

7 files changed

+98
-5
lines changed

examples/research_projects/jax-projects/wav2vec2/run_wav2vec2_pretrain_flax.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,23 @@ def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> D
174174
)
175175
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
176176

177+
batch_size = batch["input_values"].shape[0]
178+
179+
if batch["attention_mask"] is not None:
180+
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
181+
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
182+
183+
# these two operations makes sure that all values
184+
# before the output lengths indices are attended to
185+
attention_mask[(np.arange(attention_mask.shape[0]), output_lengths - 1)] = 1
186+
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
187+
177188
# sample randomly masked indices
178189
batch["mask_time_indices"] = _compute_mask_indices(
179-
(batch["input_values"].shape[0], mask_indices_seq_length),
190+
(batch_size, mask_indices_seq_length),
180191
self.model.config.mask_time_prob,
181192
self.model.config.mask_time_length,
193+
attention_mask=attention_mask,
182194
min_masks=2,
183195
)
184196

examples/research_projects/wav2vec2/run_pretrain.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,33 @@ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) ->
172172
)
173173
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
174174

175+
batch_size = batch["input_values"].shape[0]
176+
177+
# make sure that no loss is computed on padded inputs
178+
if batch["attention_mask"] is not None:
179+
# compute real output lengths according to convolution formula
180+
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1)).to(
181+
torch.long
182+
)
183+
184+
attention_mask = torch.zeros(
185+
(batch_size, mask_indices_seq_length), dtype=torch.long, device=batch["input_values"].device
186+
)
187+
188+
# these two operations makes sure that all values
189+
# before the output lengths indices are attended to
190+
attention_mask[
191+
(torch.arange(attention_mask.shape[0], device=batch["input_values"].device), output_lengths - 1)
192+
] = 1
193+
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
194+
175195
# sample randomly masked indices
176196
batch["mask_time_indices"] = _compute_mask_indices(
177-
(batch["input_values"].shape[0], mask_indices_seq_length),
197+
(batch_size, mask_indices_seq_length),
178198
self.model.config.mask_time_prob,
179199
self.model.config.mask_time_length,
180200
device=batch["input_values"].device,
201+
attention_mask=attention_mask,
181202
min_masks=2,
182203
)
183204

src/transformers/models/hubert/modeling_hubert.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def _compute_mask_indices(
4747
mask_prob: float,
4848
mask_length: int,
4949
device: torch.device,
50+
attention_mask: Optional[torch.tensor] = None,
5051
min_masks: int = 0,
5152
) -> torch.tensor:
5253
"""
@@ -813,7 +814,10 @@ def __init__(self, config: HubertConfig):
813814

814815
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
815816
def _mask_hidden_states(
816-
self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None
817+
self,
818+
hidden_states: torch.FloatTensor,
819+
mask_time_indices: Optional[torch.FloatTensor] = None,
820+
attention_mask: Optional[torch.LongTensor] = None,
817821
):
818822
"""
819823
Masks extracted features along time axis and/or along feature axis according to `SpecAugment
@@ -836,6 +840,7 @@ def _mask_hidden_states(
836840
mask_prob=self.config.mask_time_prob,
837841
mask_length=self.config.mask_time_length,
838842
device=hidden_states.device,
843+
attention_mask=attention_mask,
839844
min_masks=2,
840845
)
841846
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
@@ -847,6 +852,7 @@ def _mask_hidden_states(
847852
mask_prob=self.config.mask_feature_prob,
848853
mask_length=self.config.mask_feature_length,
849854
device=hidden_states.device,
855+
attention_mask=attention_mask,
850856
)
851857
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
852858

src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py

+6
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ def _compute_mask_indices(
107107
shape: Tuple[int, int],
108108
mask_prob: float,
109109
mask_length: int,
110+
attention_mask: Optional[np.ndarray] = None,
110111
min_masks: int = 0,
111112
) -> np.ndarray:
112113
"""
@@ -166,6 +167,10 @@ def _compute_mask_indices(
166167
# scatter indices to mask
167168
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
168169

170+
if attention_mask is not None:
171+
# make sure padded input ids cannot be masked
172+
spec_aug_mask = np.where(attention_mask, spec_aug_mask, False)
173+
169174
return spec_aug_mask
170175

171176

@@ -873,6 +878,7 @@ def __call__(
873878
"""
874879
extract_features = self.feature_extractor(input_values)
875880

881+
# make sure that no loss is computed on padded inputs
876882
if attention_mask is not None:
877883
# compute real output lengths according to convolution formula
878884
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1).astype("i4"))

src/transformers/models/wav2vec2/modeling_wav2vec2.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def _compute_mask_indices(
120120
mask_prob: float,
121121
mask_length: int,
122122
device: torch.device,
123+
attention_mask: Optional[torch.tensor] = None,
123124
min_masks: int = 0,
124125
) -> torch.tensor:
125126
"""
@@ -179,6 +180,10 @@ def _compute_mask_indices(
179180
# scatter indices to mask
180181
spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True)
181182

183+
if attention_mask is not None:
184+
# make sure padded input ids cannot be masked
185+
spec_aug_mask = torch.where(attention_mask.bool(), spec_aug_mask, False)
186+
182187
return spec_aug_mask
183188

184189

@@ -950,7 +955,10 @@ def __init__(self, config: Wav2Vec2Config):
950955
self.init_weights()
951956

952957
def _mask_hidden_states(
953-
self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None
958+
self,
959+
hidden_states: torch.FloatTensor,
960+
mask_time_indices: Optional[torch.FloatTensor] = None,
961+
attention_mask: Optional[torch.LongTensor] = None,
954962
):
955963
"""
956964
Masks extracted features along time axis and/or along feature axis according to `SpecAugment
@@ -973,6 +981,7 @@ def _mask_hidden_states(
973981
mask_prob=self.config.mask_time_prob,
974982
mask_length=self.config.mask_time_length,
975983
device=hidden_states.device,
984+
attention_mask=attention_mask,
976985
min_masks=2,
977986
)
978987
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
@@ -984,6 +993,7 @@ def _mask_hidden_states(
984993
mask_prob=self.config.mask_feature_prob,
985994
mask_length=self.config.mask_feature_length,
986995
device=hidden_states.device,
996+
attention_mask=attention_mask,
987997
)
988998
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
989999

@@ -1049,7 +1059,9 @@ def forward(
10491059
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
10501060

10511061
hidden_states, extract_features = self.feature_projection(extract_features)
1052-
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
1062+
hidden_states = self._mask_hidden_states(
1063+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
1064+
)
10531065

10541066
encoder_outputs = self.encoder(
10551067
hidden_states,

tests/test_modeling_flax_wav2vec2.py

+18
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,24 @@ def test_compute_mask_indices_overlap(self):
245245
for batch_sum in mask.sum(axis=-1):
246246
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
247247

248+
def test_compute_mask_indices_attn_mask_overlap(self):
249+
batch_size = 4
250+
sequence_length = 80
251+
mask_prob = 0.5
252+
mask_length = 4
253+
254+
attention_mask = np.ones((batch_size, sequence_length), dtype=np.int32)
255+
attention_mask[:2, sequence_length // 2 :] = 0
256+
257+
mask = _compute_mask_indices(
258+
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
259+
)
260+
261+
for batch_sum in mask.sum(axis=-1):
262+
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
263+
264+
self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
265+
248266
def test_compute_perplexity(self):
249267
probs = np.arange(100).reshape(2, 5, 10) / 100
250268

tests/test_modeling_wav2vec2.py

+18
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,24 @@ def test_compute_mask_indices_overlap(self):
580580
for batch_sum in mask.sum(axis=-1):
581581
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
582582

583+
def test_compute_mask_indices_attn_mask_overlap(self):
584+
batch_size = 4
585+
sequence_length = 80
586+
mask_prob = 0.5
587+
mask_length = 4
588+
589+
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
590+
attention_mask[:2, sequence_length // 2 :] = 0
591+
592+
mask = _compute_mask_indices(
593+
(batch_size, sequence_length), mask_prob, mask_length, device=torch_device, attention_mask=attention_mask
594+
)
595+
596+
for batch_sum in mask.sum(axis=-1):
597+
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
598+
599+
self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
600+
583601
def test_compute_perplexity(self):
584602
probs = torch.arange(100, device=torch_device).reshape(2, 5, 10) / 100
585603

0 commit comments

Comments
 (0)