Skip to content

Files

Latest commit

 

History

History

Fine-tuning Wav2Vec2

The run_asr.py script allows one to fine-tune pretrained Wav2Vec2 models that can be found here.

This finetuning script can also be run as a google colab TODO: here.

The script is actively maintained by Patrick von Platen. Feel free to ask a question on the Forum or post an issue on GitHub and adding @patrickvonplaten as a tag.

Fine-Tuning with TIMIT

Let's take a look at the script used to fine-tune wav2vec2-base with the TIMIT dataset:

#!/usr/bin/env bash
python run_asr.py \
--output_dir="./wav2vec2-base-timit-asr" \
--num_train_epochs="30" \
--per_device_train_batch_size="20" \
--per_device_eval_batch_size="20" \
--evaluation_strategy="steps" \
--save_steps="500" \
--eval_steps="100" \
--logging_steps="50" \
--learning_rate="5e-4" \
--warmup_steps="3000" \
--model_name_or_path="facebook/wav2vec2-base" \
--fp16 \
--dataset_name="timit_asr" \
--train_split_name="train" \
--validation_split_name="test" \
--orthography="timit" \
--preprocessing_num_workers="$(nproc)" \
--group_by_length \
--freeze_feature_extractor \
--verbose_logging \

The resulting model and inference examples can be found here. Some of the arguments above may look unfamiliar, let's break down what's going on:

--orthography="timit" applies certain text preprocessing rules, for tokenization and normalization, to clean up the dataset. In this case, we use the following instance of Orthography:

Orthography(
    do_lower_case=True,
    # break compounds like "quarter-century-old" and replace pauses "--"
    translation_table=str.maketrans({"-": " "}),
)

The instance above is used as follows:

  • creates a tokenizer with do_lower_case=True (ignores casing for input and lowercases output when decoding)
  • replaces "-" with " " to break compounds like "quarter-century-old" and to clean up suspended hyphens
  • cleans up consecutive whitespaces (replaces them with a single space: " ")
  • removes characters not in vocabulary (lacking respective sound units)

--verbose_logging logs text preprocessing updates and when evaluating, using the validation split every eval_steps, logs references and predictions.

Fine-Tuning with Arabic Speech Corpus

Other datasets, like the Arabic Speech Corpus dataset, require more work! Let's take a look at the script used to fine-tune wav2vec2-large-xlsr-53:

#!/usr/bin/env bash
python run_asr.py \
--output_dir="./wav2vec2-large-xlsr-53-arabic-speech-corpus" \
--num_train_epochs="50" \
--per_device_train_batch_size="1" \
--per_device_eval_batch_size="1" \
--gradient_accumulation_steps="8" \
--evaluation_strategy="steps" \
--save_steps="500" \
--eval_steps="100" \
--logging_steps="50" \
--learning_rate="5e-4" \
--warmup_steps="3000" \
--model_name_or_path="elgeish/wav2vec2-large-xlsr-53-arabic" \
--fp16 \
--dataset_name="arabic_speech_corpus" \
--train_split_name="train" \
--validation_split_name="test" \
--max_duration_in_seconds="15" \
--orthography="buckwalter" \
--preprocessing_num_workers="$(nproc)" \
--group_by_length \
--freeze_feature_extractor \
--target_feature_extractor_sampling_rate \
--verbose_logging \

First, let's understand how this dataset represents Arabic text; it uses a format called Buckwalter transliteration. We use the lang-trans package to convert back to Arabic when logging. The Buckwalter format only includes ASCII characters, some of which are non-alpha (e.g., ">" maps to "أ").

--orthography="buckwalter" applies certain text preprocessing rules, for tokenization and normalization, to clean up the dataset. In this case, we use the following instance of Orthography:

Orthography(
    vocab_file=pathlib.Path(__file__).parent.joinpath("vocab/buckwalter.json"),
    word_delimiter_token="/",  # "|" is Arabic letter alef with madda above
    words_to_remove={"sil"},  # fixing "sil" in arabic_speech_corpus dataset
    untransliterator=arabic.buckwalter.untransliterate,
    translation_table=str.maketrans(translation_table = {
        "-": " ",  # sometimes used to represent pauses
        "^": "v",  # fixing "tha" in arabic_speech_corpus dataset
    }),
)

The instance above is used as follows:

  • creates a tokenizer with Buckwalter vocabulary and word_delimiter_token="/"
  • replaces "-" with " " to clean up hyphens and fixes the orthography for "ث"
  • removes words used as indicators (in this case, "sil" is used for silence)
  • cleans up consecutive whitespaces (replaces them with a single space: " ")
  • removes characters not in vocabulary (lacking respective sound units)

--verbose_logging logs text preprocessing updates and when evaluating, using the validation split every eval_steps, logs references and predictions. Using the Buckwalter format, text is also logged in Arabic abjad.

--target_feature_extractor_sampling_rate resamples audio to target feature extractor's sampling rate (16kHz).

--max_duration_in_seconds="15" filters out examples whose audio is longer than the specified limit, which helps with capping GPU memory usage.