@@ -120,6 +120,7 @@ def _compute_mask_indices(
120
120
mask_prob : float ,
121
121
mask_length : int ,
122
122
device : torch .device ,
123
+ attention_mask : Optional [torch .tensor ] = None ,
123
124
min_masks : int = 0 ,
124
125
) -> torch .tensor :
125
126
"""
@@ -179,6 +180,10 @@ def _compute_mask_indices(
179
180
# scatter indices to mask
180
181
spec_aug_mask = spec_aug_mask .scatter (1 , spec_aug_mask_idxs , True )
181
182
183
+ if attention_mask is not None :
184
+ # make sure padded input ids cannot be masked
185
+ spec_aug_mask = torch .where (attention_mask .bool (), spec_aug_mask , False )
186
+
182
187
return spec_aug_mask
183
188
184
189
@@ -950,7 +955,10 @@ def __init__(self, config: Wav2Vec2Config):
950
955
self .init_weights ()
951
956
952
957
def _mask_hidden_states (
953
- self , hidden_states : torch .FloatTensor , mask_time_indices : Optional [torch .FloatTensor ] = None
958
+ self ,
959
+ hidden_states : torch .FloatTensor ,
960
+ mask_time_indices : Optional [torch .FloatTensor ] = None ,
961
+ attention_mask : Optional [torch .LongTensor ] = None ,
954
962
):
955
963
"""
956
964
Masks extracted features along time axis and/or along feature axis according to `SpecAugment
@@ -973,6 +981,7 @@ def _mask_hidden_states(
973
981
mask_prob = self .config .mask_time_prob ,
974
982
mask_length = self .config .mask_time_length ,
975
983
device = hidden_states .device ,
984
+ attention_mask = attention_mask ,
976
985
min_masks = 2 ,
977
986
)
978
987
hidden_states [mask_time_indices ] = self .masked_spec_embed .to (hidden_states .dtype )
@@ -984,6 +993,7 @@ def _mask_hidden_states(
984
993
mask_prob = self .config .mask_feature_prob ,
985
994
mask_length = self .config .mask_feature_length ,
986
995
device = hidden_states .device ,
996
+ attention_mask = attention_mask ,
987
997
)
988
998
hidden_states [mask_feature_indices [:, None ].expand (- 1 , sequence_length , - 1 )] = 0
989
999
@@ -1049,7 +1059,9 @@ def forward(
1049
1059
attention_mask = attention_mask .flip ([- 1 ]).cumsum (- 1 ).flip ([- 1 ]).bool ()
1050
1060
1051
1061
hidden_states , extract_features = self .feature_projection (extract_features )
1052
- hidden_states = self ._mask_hidden_states (hidden_states , mask_time_indices = mask_time_indices )
1062
+ hidden_states = self ._mask_hidden_states (
1063
+ hidden_states , mask_time_indices = mask_time_indices , attention_mask = attention_mask
1064
+ )
1053
1065
1054
1066
encoder_outputs = self .encoder (
1055
1067
hidden_states ,
0 commit comments