Skip to content

Commit d438eee

Browse files
Adding TFWav2Vec2Model (#11617)
* [WIP] Add TFWav2Vec2Model Work in progress for adding a tensorflow version of Wav2Vec2 * feedback changes * small fix * Test Feedback Round 1 * Add SpecAugment and CTC Loss * correct spec augment mask creation * docstring and correct copyright * correct bugs * remove bogus file * finish tests correction * del unnecessary layers * Update src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * make style * correct final bug * Feedback Changes Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent 1ed2ebf commit d438eee

12 files changed

+2250
-13
lines changed

docs/source/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ Flax), PyTorch, and/or TensorFlow.
399399
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
400400
| VisualBert ||||||
401401
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
402-
| Wav2Vec2 |||| ||
402+
| Wav2Vec2 |||| ||
403403
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
404404
| XLM ||||||
405405
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+

docs/source/model_doc/wav2vec2.rst

+14-1
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,22 @@ Wav2Vec2ForCTC
8080
.. autoclass:: transformers.Wav2Vec2ForCTC
8181
:members: forward
8282

83-
8483
Wav2Vec2ForPreTraining
8584
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
8685

8786
.. autoclass:: transformers.Wav2Vec2ForPreTraining
8887
:members: forward
88+
89+
90+
TFWav2Vec2Model
91+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
92+
93+
.. autoclass:: transformers.TFWav2Vec2Model
94+
:members: call
95+
96+
97+
TFWav2Vec2ForCTC
98+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
99+
100+
.. autoclass:: transformers.TFWav2Vec2ForCTC
101+
:members: call

src/transformers/__init__.py

+14
Original file line numberDiff line numberDiff line change
@@ -1430,6 +1430,14 @@
14301430
"TFTransfoXLPreTrainedModel",
14311431
]
14321432
)
1433+
_import_structure["models.wav2vec2"].extend(
1434+
[
1435+
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
1436+
"TFWav2Vec2ForCTC",
1437+
"TFWav2Vec2Model",
1438+
"TFWav2Vec2PreTrainedModel",
1439+
]
1440+
)
14331441
_import_structure["models.xlm"].extend(
14341442
[
14351443
"TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -2743,6 +2751,12 @@
27432751
TFTransfoXLModel,
27442752
TFTransfoXLPreTrainedModel,
27452753
)
2754+
from .models.wav2vec2 import (
2755+
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
2756+
TFWav2Vec2ForCTC,
2757+
TFWav2Vec2Model,
2758+
TFWav2Vec2PreTrainedModel,
2759+
)
27462760
from .models.xlm import (
27472761
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
27482762
TFXLMForMultipleChoice,

src/transformers/convert_pytorch_checkpoint_to_tf2.py

+10
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
3838
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
3939
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
40+
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
4041
WEIGHTS_NAME,
4142
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
4243
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
@@ -79,10 +80,13 @@
7980
TFRobertaForSequenceClassification,
8081
TFT5ForConditionalGeneration,
8182
TFTransfoXLLMHeadModel,
83+
TFWav2Vec2Model,
8284
TFXLMRobertaForMaskedLM,
8385
TFXLMWithLMHeadModel,
8486
TFXLNetLMHeadModel,
8587
TransfoXLConfig,
88+
Wav2Vec2Config,
89+
Wav2Vec2Model,
8690
XLMConfig,
8791
XLMRobertaConfig,
8892
XLNetConfig,
@@ -287,6 +291,12 @@
287291
ElectraForPreTraining,
288292
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
289293
),
294+
"wav2vec2": (
295+
Wav2Vec2Config,
296+
TFWav2Vec2Model,
297+
Wav2Vec2Model,
298+
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
299+
),
290300
}
291301

292302

src/transformers/models/auto/modeling_tf_auto.py

+3
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@
163163
TFTransfoXLLMHeadModel,
164164
TFTransfoXLModel,
165165
)
166+
from ..wav2vec2.modeling_tf_wav2vec2 import TFWav2Vec2Model
166167
from ..xlm.modeling_tf_xlm import (
167168
TFXLMForMultipleChoice,
168169
TFXLMForQuestionAnsweringSimple,
@@ -218,6 +219,7 @@
218219
RoFormerConfig,
219220
T5Config,
220221
TransfoXLConfig,
222+
Wav2Vec2Config,
221223
XLMConfig,
222224
XLMRobertaConfig,
223225
XLNetConfig,
@@ -263,6 +265,7 @@
263265
(PegasusConfig, TFPegasusModel),
264266
(BlenderbotConfig, TFBlenderbotModel),
265267
(BlenderbotSmallConfig, TFBlenderbotSmallModel),
268+
(Wav2Vec2Config, TFWav2Vec2Model),
266269
]
267270
)
268271

src/transformers/models/wav2vec2/__init__.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# limitations under the License.
1818
from typing import TYPE_CHECKING
1919

20-
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available
20+
from ...file_utils import _BaseLazyModule, is_tf_available, is_torch_available
2121

2222

2323
_import_structure = {
@@ -38,6 +38,15 @@
3838
]
3939

4040

41+
if is_tf_available():
42+
_import_structure["modeling_tf_wav2vec2"] = [
43+
"TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST",
44+
"TFWav2Vec2ForCTC",
45+
"TFWav2Vec2Model",
46+
"TFWav2Vec2PreTrainedModel",
47+
]
48+
49+
4150
if TYPE_CHECKING:
4251
from .configuration_wav2vec2 import WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP, Wav2Vec2Config
4352
from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
@@ -54,6 +63,14 @@
5463
Wav2Vec2PreTrainedModel,
5564
)
5665

66+
if is_tf_available():
67+
from .modeling_tf_wav2vec2 import (
68+
TF_WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST,
69+
TFWav2Vec2ForCTC,
70+
TFWav2Vec2Model,
71+
TFWav2Vec2PreTrainedModel,
72+
)
73+
5774

5875
else:
5976
import importlib

0 commit comments

Comments
 (0)