diff --git a/.dockerignore b/.dockerignore index 45d6203913..0759ddb659 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,2 +1,4 @@ LibriSpeech Models +.venv* +venv* diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 4449a3b74e..eb7a5434d5 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -17,7 +17,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: '3.8.x' + python-version: '3.10.x' - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/wiki-publish.yml b/.github/workflows/wiki-publish.yml new file mode 100644 index 0000000000..ca7f6f4ab5 --- /dev/null +++ b/.github/workflows/wiki-publish.yml @@ -0,0 +1,19 @@ +name: Publish Wiki Pages +on: + push: + branches: [main] +concurrency: + group: publish-wiki + cancel-in-progress: true +permissions: + contents: write +jobs: + publish-wiki: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4.1.4 + - uses: nglehuy/github-wiki-action@master + with: + token: ${{ secrets.TOKEN }} + path: docs + preprocess: true diff --git a/.gitignore b/.gitignore index 9e789adac1..9e2319a21c 100755 --- a/.gitignore +++ b/.gitignore @@ -9,7 +9,8 @@ Session.vim .idea __pycache__ .pytest* -venv +venv* +.venv* my_train .DS_Store models/* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000..8139bcdd53 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,28 @@ +repos: + - repo: local + hooks: + - id: black-formatter-fix + name: black-formatter-fix + entry: bash -c "for f in $@; do black --verbose $f; done" + language: system + types: [python] + stages: [pre-commit] + fail_fast: true + verbose: true + - id: isort-fix + name: isort-fix + entry: bash -c "for f in $@; do echo -e \"Organize import for file $f\" && isort $f; done" + language: system + types: [python] + stages: [pre-commit] + fail_fast: true + verbose: true + - id: pylint-check + name: pylint-check + entry: bash -c "for f in $@; do pylint --rcfile=.pylintrc -rn -sn $f; done" + language: system + types: [python] + stages: [pre-commit] + fail_fast: true + require_serial: true + verbose: true diff --git a/.pylintrc b/.pylintrc index 410fdeee45..08db16fe3d 100644 --- a/.pylintrc +++ b/.pylintrc @@ -107,8 +107,26 @@ disable=too-few-public-methods, dangerous-default-value, too-many-branches, ungrouped-imports, - attribute-defined-outside-init - + attribute-defined-outside-init, + too-many-public-methods, + use-dict-literal, + protected-access, + consider-using-enumerate, + too-many-statements, + assignment-from-none, + eval-used, + duplicate-code, + redefined-outer-name, + consider-using-f-string, + fixme, + unused-variable, + pointless-string-statement, + too-many-lines, + abstract-method, + too-many-ancestors, + import-outside-toplevel, + too-many-positional-arguments, + # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option # multiple time (only on the command line, not in the configuration file where @@ -209,7 +227,7 @@ contextmanager-decorators=contextlib.contextmanager # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. -generated-members= +generated-members=tensorflow.python,tensorflow.keras # Tells whether missing members accessed in mixin class should be ignored. A # mixin class is detected if its name ends with "mixin" (case insensitive). @@ -573,6 +591,4 @@ min-public-methods=2 [EXCEPTIONS] # Exceptions that will emit a warning when being caught. Defaults to -# "BaseException, Exception". -overgeneral-exceptions=BaseException, - Exception +overgeneral-exceptions= diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000000..ad8cb9108e --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,9 @@ +{ + "recommendations": [ + "ms-python.isort", + "ms-python.black-formatter", + "ms-python.pylint", + "ms-python.vscode-pylance", + "ms-python.python" + ] +} \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json index 201a69e3ee..1cf951d8b0 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -14,6 +14,21 @@ "-s", "./tests/test_rnnt_loss.py" ] + }, + { + "name": "Test Prediction", + "type": "python", + "request": "launch", + "justMyCode": true, + "program": "./examples/inferences/main.py", + "args": [ + "--file-path", + "/Users/nglehuy/Data/Persona/MachineLearning/Datasets/LibriSpeech/test-clean/61/70970/61-70970-0030.flac", + "--config-path", + "~/Data/Persona/Projects/TensorFlowASR/examples/models/transducer/contextnet/small.yml.j2", + "--h5", + "~/Data/Persona/MachineLearning/Models/transducer/sp1k-contextnet/small/28.h5" + ] } ] } \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index 303282f88f..dca2d16a6e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,17 +1,31 @@ { - "python.linting.pylintEnabled": true, - "python.linting.flake8Enabled": false, - "python.linting.enabled": true, - "editor.formatOnSave": true, - "python.linting.lintOnSave": true, - "editor.codeActionsOnSave": { - "source.organizeImports": true, - }, - "isort.args": [ - "--profile", - "black", - "--line-length", - "130" - ], - "python.formatting.provider": "black" -} \ No newline at end of file + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.tabSize": 4 + }, + "[markdown]": { + "editor.tabSize": 2, + "editor.indentSize": 2, + "editor.detectIndentation": false + }, + "[json]": { + "editor.tabSize": 2 + }, + "[yaml]": { + "editor.tabSize": 2 + }, + "autoDocstring.docstringFormat": "numpy", + "black-formatter.args": ["--config", "${workspaceFolder}/pyproject.toml"], + "black-formatter.path": ["${interpreter}", "-m", "black"], + "editor.codeActionsOnSave": { + "source.fixAll": "explicit", + "source.organizeImports": "explicit" + }, + "editor.formatOnSave": true, + "isort.args": ["--settings-file", "${workspaceFolder}/pyproject.toml"], + "pylint.args": ["--rcfile=${workspaceFolder}/.pylintrc"], + "pylint.path": ["${interpreter}", "-m", "pylint"], + "python.analysis.fixAll": ["source.unusedImports", "source.convertImportFormat"], + "python.analysis.importFormat": "absolute", + "markdown.extension.list.indentationSize": "inherit" +} diff --git a/Dockerfile b/Dockerfile index 5b9ac59272..b7d221ce21 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM tensorflow/tensorflow:2.3.2-gpu +FROM tensorflow/tensorflow:2.18.0-gpu RUN apt-get update \ && apt-get upgrade -y \ @@ -9,8 +9,8 @@ RUN apt-get update \ RUN apt clean && apt-get clean # Install dependencies -COPY requirements.txt / -RUN pip --no-cache-dir install -r /requirements.txt +COPY requirements*.txt / +RUN pip --no-cache-dir install -r /requirements.txt -r /requirements.cuda.txt # Install rnnt_loss COPY scripts /scripts @@ -21,4 +21,4 @@ RUN if [ "$install_rnnt_loss" = "true" ] ; \ && ./scripts/install_rnnt_loss.sh \ else echo 'Using pure TensorFlow'; fi -RUN echo "export LD_LIBRARY_PATH=/usr/local/cuda-10.2/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> /root/.bashrc \ No newline at end of file +RUN echo "export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}" >> /root/.bashrc \ No newline at end of file diff --git a/README.md b/README.md index 456e9e80fb..fce9f723c7 100755 --- a/README.md +++ b/README.md @@ -1,18 +1,18 @@

-

TensorFlowASR :zap:

+TensorFlowASR :zap: +

GitHub -python -tensorflow +python +tensorflow PyPI

-

-

Almost State-of-the-art Automatic Speech Recognition in Tensorflow 2

+Almost State-of-the-art Automatic Speech Recognition in Tensorflow 2

@@ -21,8 +21,6 @@ TensorFlowASR implements some automatic speech recognition architectures such as ## What's New? -- (9/4/2022) Breaking changes release v1.1.x - ## Table of Contents @@ -33,10 +31,6 @@ TensorFlowASR implements some automatic speech recognition architectures such as - [Baselines](#baselines) - [Publications](#publications) - [Installation](#installation) - - [Installing from source (recommended)](#installing-from-source-recommended) - - [Installing via PyPi](#installing-via-pypi) - - [Installing for development](#installing-for-development) - - [Running in a container](#running-in-a-container) - [Training \& Testing Tutorial](#training--testing-tutorial) - [Features Extraction](#features-extraction) - [Augmentations](#augmentations) @@ -61,69 +55,43 @@ TensorFlowASR implements some automatic speech recognition architectures such as ### Publications - **Conformer Transducer** (Reference: [https://arxiv.org/abs/2005.08100](https://arxiv.org/abs/2005.08100)) - See [examples/conformer](./examples/conformer) + See [examples/models/transducer/conformer](./examples/models/transducer/conformer) +- **Streaming Conformer** (Reference: [http://arxiv.org/abs/2010.11395](http://arxiv.org/abs/2010.11395)) + See [examples/models/transducer/conformer](./examples/models/transducer/conformer) - **ContextNet** (Reference: [http://arxiv.org/abs/2005.03191](http://arxiv.org/abs/2005.03191)) - See [examples/contextnet](./examples/contextnet) + See [examples/models/transducer/contextnet](./examples/models/transducer/contextnet) - **RNN Transducer** (Reference: [https://arxiv.org/abs/1811.06621](https://arxiv.org/abs/1811.06621)) - See [examples/rnn_transducer](./examples/rnn_transducer) + See [examples/models/transducer/rnnt](./examples/models/transducer/rnnt) - **Deep Speech 2** (Reference: [https://arxiv.org/abs/1512.02595](https://arxiv.org/abs/1512.02595)) - See [examples/deepspeech2](./examples/deepspeech2) + See [examples/models/ctc/deepspeech2](./examples/models/ctc/deepspeech2) - **Jasper** (Reference: [https://arxiv.org/abs/1904.03288](https://arxiv.org/abs/1904.03288)) - See [examples/jasper](./examples/jasper) + See [examples/models/ctc/jasper](./examples/models/ctc/jasper) ## Installation For training and testing, you should use `git clone` for installing necessary packages from other authors (`ctc_decoders`, `rnnt_loss`, etc.) -### Installing from source (recommended) - -```bash -git clone https://github.com/TensorSpeech/TensorFlowASR.git -cd TensorFlowASR -# Tensorflow 2.x (with 2.x.x >= 2.5.1) -pip3 install ".[tf2.x]" # or ".[tf2.x-gpu]" -``` +**NOTE ONLY FOR APPLE SILICON**: TensorFlowASR requires python >= 3.12 -For anaconda3: +See the `requirements.[extra].txt` files for extra dependencies ```bash -conda create -y -n tfasr tensorflow-gpu python=3.8 # tensorflow if using CPU, this makes sure conda install all dependencies for tensorflow -conda activate tfasr -pip install -U tensorflow-gpu # upgrade to latest version of tensorflow git clone https://github.com/TensorSpeech/TensorFlowASR.git cd TensorFlowASR -# Tensorflow 2.x (with 2.x.x >= 2.5.1) -pip3 install ".[tf2.x]" # or ".[tf2.x-gpu]" -``` - -### Installing via PyPi - -```bash -# Tensorflow 2.x (with 2.x >= 2.3) -pip3 install "TensorFlowASR[tf2.x]" # or pip3 install "TensorFlowASR[tf2.x-gpu]" +./setup.sh [apple|tpu|gpu] [dev] ``` -### Installing for development - -```bash -git clone https://github.com/TensorSpeech/TensorFlowASR.git -cd TensorFlowASR -pip3 install -e ".[dev]" -pip3 install -e ".[tf2.x]" # or ".[tf2.x-gpu]" -``` - -### Running in a container +**Running in a container** ```bash docker-compose up -d ``` - ## Training & Testing Tutorial -- For training, please read [tutorial_training](./docs/1_tutorial_training.md) -- For testing, please read [tutorial_testing](./docs/2_tutorial_testing.md) +- For training, please read [tutorial_training](./docs/tutorials/training.md) +- For testing, please read [tutorial_testing](./docs/tutorials/testing.md) **FYI**: Keras builtin training uses **infinite dataset**, which avoids the potential last partial batch. @@ -131,7 +99,7 @@ See [examples](./examples/) for some predefined ASR models and results ## Features Extraction -See [features_extraction](./tensorflow_asr/featurizers/README.md) +See [features_extraction](./tensorflow_asr/features/README.md) ## Augmentations @@ -139,38 +107,13 @@ See [augmentations](./tensorflow_asr/augmentations/README.md) ## TFLite Convertion -After converting to tflite, the tflite model is like a function that transforms directly from an **audio signal** to **unicode code points**, then we can convert unicode points to string. - -1. Install `tf-nightly` using `pip install tf-nightly` -2. Build a model with the same architecture as the trained model _(if model has tflite argument, you must set it to True)_, then load the weights from trained model to the built model -3. Load `TFSpeechFeaturizer` and `TextFeaturizer` to model using function `add_featurizers` -4. Convert model's function to tflite as follows: - -```python -func = model.make_tflite_function(**options) # options are the arguments of the function -concrete_func = func.get_concrete_function() -converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) -converter.experimental_new_converter = True -converter.optimizations = [tf.lite.Optimize.DEFAULT] -converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, - tf.lite.OpsSet.SELECT_TF_OPS] -tflite_model = converter.convert() -``` - -5. Save the converted tflite model as follows: - -```python -if not os.path.exists(os.path.dirname(tflite_path)): - os.makedirs(os.path.dirname(tflite_path)) -with open(tflite_path, "wb") as tflite_out: - tflite_out.write(tflite_model) -``` +After converting to tflite, the tflite model is like a function that transforms directly from an **audio signal** to **text and tokens** -5. Then the `.tflite` model is ready to be deployed +See [tflite_convertion](./docs/tutorials/tflite.md) ## Pretrained Models -Go to [drive](https://drive.google.com/drive/folders/1BD0AK30n8hc-yR28C5FW3LqzZxtLOQfl?usp=sharing) +See the results on each example folder, e.g. [./examples/models//transducer/conformer/results/sentencepiece/README.md](./examples/models//transducer/conformer/results/sentencepiece/README.md) ## Corpus Sources @@ -183,11 +126,12 @@ Go to [drive](https://drive.google.com/drive/folders/1BD0AK30n8hc-yR28C5FW3LqzZx ### Vietnamese -| **Name** | **Source** | **Hours** | -| :------------------------------------- | :------------------------------------------------------------------------------------- | :-------- | -| Vivos | [https://ailab.hcmus.edu.vn/vivos](https://ailab.hcmus.edu.vn/vivos) | 15h | -| InfoRe Technology 1 | [InfoRe1 (passwd: BroughtToYouByInfoRe)](https://files.huylenguyen.com/25hours.zip) | 25h | -| InfoRe Technology 2 (used in VLSP2019) | [InfoRe2 (passwd: BroughtToYouByInfoRe)](https://files.huylenguyen.com/audiobooks.zip) | 415h | +| **Name** | **Source** | **Hours** | +| :------------------------------------- | :------------------------------------------------------------------------------------------------------------------- | :-------- | +| Vivos | [https://ailab.hcmus.edu.vn/vivos](https://www.kaggle.com/datasets/kynthesis/vivos-vietnamese-speech-corpus-for-asr) | 15h | +| InfoRe Technology 1 | [InfoRe1 (passwd: BroughtToYouByInfoRe)](https://files.huylenguyen.com/datasets/infore/25hours.zip) | 25h | +| InfoRe Technology 2 (used in VLSP2019) | [InfoRe2 (passwd: BroughtToYouByInfoRe)](https://files.huylenguyen.com/datasets/infore/audiobooks.zip) | 415h | +| VietBud500 | [https://huggingface.co/datasets/linhtran92/viet_bud500](https://huggingface.co/datasets/linhtran92/viet_bud500) | 500h | ## How to contribute diff --git a/docs/1_tutorial_training.md b/docs/1_tutorial_training.md deleted file mode 100644 index eddc4a3302..0000000000 --- a/docs/1_tutorial_training.md +++ /dev/null @@ -1,113 +0,0 @@ -# Training Tutorial - -These commands are example for librispeech dataset, but we can apply similar to other datasets - -## 1. Install packages (tf>=2.8) - -If you use google colab, it's recommended to use the tensorflow version pre-installed on the colab itself - -```bash -pip uninstall -y TensorFlowASR # uninstall for clean install if needed -pip install ".[tf2.x]" -``` - -## 2. Prepare transcripts files - -This is the example for preparing transcript files for librispeech data corpus - -```bash -python scripts/create_librispeech_trans.py \ ---directory=/path/to/dataset/train-clean-100 \ ---output=/path/to/dataset/train-clean-100/transcripts.tsv -``` - -Do the same thing with `train-clean-360`, `train-other-500`, `dev-clean`, `dev-other`, `test-clean`, `test-other` - -For other datasets, you must prepare your own python script like the `scripts/create_librispeech_trans.py` - -## 3. Prepare config file - -The config file is under format `config.j2` which is jinja2 format - -Please take a look in some examples for config files in `examples/*/config*.j2` - -## 4. [Optional][Required if using TPUs] Create tfrecords - -```bash -python scripts/create_tfrecords.py \ ---mode=train \ ---config-path=/path/to/config.j2 \ ---tfrecords-dir=/path/to/dataset/tfrecords \ ---tfrecords-shards=16 \ # available options are from 1 -> inf ---shuffle \ -/path/to/dataset/train-clean-100/transcripts.tsv \ -/path/to/dataset/train-clean-360/transcripts.tsv \ -/path/to/dataset/train-other-500/transcripts.tsv -``` - -Reduce the `--tfrecords-shards` if the size of the dataset is small - -Do the same thing with `--mode=eval` and `--mode=test` if needed, corresponds to `dev` and `test` datasets - -## 5. Generate vocabulary - -This step requires defining path to vocabulary file and other options for generating vocabulary in config file. - -Characters: - -```bash -Prepare like the files in vocabularies/*.characters -``` - -Wordpiece: - -```bash -python scripts/generate_vocab_wordpiece.py --config-path=/path/to/config.j2 -``` - -Sentencepiece: - -```bash -python scripts/generate_vocab_sentencepiece.py --config-path=/path/to/config.j2 -``` - -The inputs, outputs and other options of vocabulary are defined in the config file - -## 5. [Optional][Required if using TPUs] Generate metadata.json - -The metadata json file contains all the metadata of dataset derived with the current config of `speech_config` and `decoder_config` in the config file - -These metadata is for **static-shape** training, which is required for TPUs - -Static shape means that it will pad each record to the longest record size of the whole data, therefore if you use with `train` mode and `eval` mode, you have to generate metadata for both stages (aka modes) so that when loading the dataset, it will get the longest record size from both train and eval modes - -```bash -python scripts/generate_metadata.py \ ---stage=train \ ---config-path=/path/to/config.j2 \ ---metadata=/path/to/metadata.json \ -/path/to/dataset/train-clean-100/transcripts.tsv \ -/path/to/dataset/train-clean-360/transcripts.tsv \ -/path/to/dataset/train-other-500/transcripts.tsv -# same thing with eval mode -python scripts/generate_metadata.py \ ---stage=eval \ ---config-path=/path/to/config.j2 \ ---metadata=/path/to/metadata.json \ -/path/to/dataset/dev-clean/transcripts.tsv \ -/path/to/dataset/dev-other/transcripts.tsv -``` - -## 6. Update config file - -Update config file with: -- The paths to transcript files (and tfrecords if used) -- The path to metadata json file (if use static shape training) - -## 7. Run training - -```bash -python examples/conformer/train.py --mxp=auto --jit-compile --config-path=/path/to/config.j2 --tfrecords -``` - -See other options for each example \ No newline at end of file diff --git a/docs/2_tutorial_testing.md b/docs/2_tutorial_testing.md deleted file mode 100644 index 35c96c84cc..0000000000 --- a/docs/2_tutorial_testing.md +++ /dev/null @@ -1,71 +0,0 @@ -# Testing Tutorial - -These commands are example for librispeech dataset, but we can apply similar to other datasets - -## 1. Install packages (tf>=2.8) - -If you use google colab, it's recommended to use the tensorflow version pre-installed on the colab itself - -```bash -pip uninstall -y TensorFlowASR # uninstall for clean install if needed -pip install ".[tf2.x]" -``` - -## 2. Prepare transcripts files - -This is the example for preparing transcript files for librispeech data corpus - -```bash -python scripts/create_librispeech_trans.py \ ---directory=/path/to/dataset/test-clean \ ---output=/path/to/dataset/test-clean/transcripts.tsv -``` - -Do the same thing with `test-clean`, `test-other` - -For other datasets, you must prepare your own python script like the `scripts/create_librispeech_trans.py` - -## 3. Prepare config file - -The config file is under format `config.j2` which is jinja2 format - -Please take a look in some examples for config files in `examples/*/config*.j2` - -The config file is the same as the config used for training - -## 4. [Optional][Required if not exists] Generate vocabulary - -Use the same vocabulary file used in training - -Characters: - -```bash -Prepare like the files in vocabularies/*.characters -``` - -Wordpiece: - -```bash -python scripts/generate_vocab_wordpiece.py --config-path=/path/to/config.j2 -``` - -Sentencepiece: - -```bash -python scripts/generate_vocab_sentencepiece.py --config-path=/path/to/config.j2 -``` - -The inputs, outputs and other options of vocabulary are defined in the config file - -## 5. Update config file - -Update config file with: -- The paths to transcript files for test stage - -## 6. Run testing - -```bash -python examples/conformer/test.py --config-path=/path/to/config.j2 --saved=/path/to/saved_weights.h5 --bs=1 --output=/path/to/test.tsv -``` - -See other options for each example \ No newline at end of file diff --git a/docs/features.md b/docs/features.md new file mode 100644 index 0000000000..40bec0ca97 --- /dev/null +++ b/docs/features.md @@ -0,0 +1,22 @@ +# Speech Features Extraction + +See [feature_extraction.py](../tensorflow_asr/models/layers/feature_extraction.py) for more detail + +**Speech features** are extracted from the **Signal** with `sample_rate`, `frame_ms`, `stride_ms` and `num_feature_bins`. + +Speech features has the shape `(B, T, num_feature_bins, num_channels)` and it contains from 1-4 channels: + +1. Spectrogram, Log Mel Spectrogram, Log Gammatone Spectrogram or MFCCs +2. TODO: Delta features: like `librosa.feature.delta` from the features extracted on channel 1. +3. TODO: Delta deltas features: like `librosa.feature.delta` with `order=2` from the features extracted on channel 1. +4. TODO: Pitch features: like `librosa.core.piptrack` from the signal + +Implementation in tensorflow keras [layer](../tensorflow_asr/models/layers/feature_extraction.py) + +![Spectrogram](./figs/spectrogram.png) + +![Log Mel Spectrogram](./figs/log_mel_spectrogram.png) + +![MFCCs](./figs/mfcc.png) + +![Log Gammatone Spectrogram](./figs/log_gammatone_spectrogram.png) diff --git a/docs/figs/log_gammatone_spectrogram.png b/docs/figs/log_gammatone_spectrogram.png new file mode 100644 index 0000000000..612401bd49 Binary files /dev/null and b/docs/figs/log_gammatone_spectrogram.png differ diff --git a/docs/figs/log_mel_spectrogram.png b/docs/figs/log_mel_spectrogram.png new file mode 100644 index 0000000000..b367ecc243 Binary files /dev/null and b/docs/figs/log_mel_spectrogram.png differ diff --git a/docs/figs/mfcc.png b/docs/figs/mfcc.png new file mode 100644 index 0000000000..ea3fe0301f Binary files /dev/null and b/docs/figs/mfcc.png differ diff --git a/docs/figs/spectrogram.png b/docs/figs/spectrogram.png new file mode 100644 index 0000000000..2882fb597f Binary files /dev/null and b/docs/figs/spectrogram.png differ diff --git a/docs/tokenizers.md b/docs/tokenizers.md new file mode 100644 index 0000000000..352a462a84 --- /dev/null +++ b/docs/tokenizers.md @@ -0,0 +1,26 @@ +- [Tokenizers](#tokenizers) + - [1. Character Tokenizer](#1-character-tokenizer) + - [2. Wordpiece Tokenizer](#2-wordpiece-tokenizer) + - [3. Sentencepiece Tokenizer](#3-sentencepiece-tokenizer) + +# Tokenizers + +## 1. Character Tokenizer + +See [librespeech config](../examples/datasets/librispeech/characters/char.yml.j2) + +This splits the text into characters and then maps each character to an index. The index starts from 1 and 0 is reserved for blank token. This tokenizer only used for languages that have a small number of characters and each character is not a combination of other characters. For example, English, Vietnamese, etc. + +## 2. Wordpiece Tokenizer + +See [librespeech config](../examples/datasets/librispeech/wordpiece/wp.yml.j2) for wordpiece splitted by whitespace + +See [librespeech config](../examples/datasets/librispeech/wordpiece/wp_whitespace.yml.j2) for wordpiece that whitespace is a separate token + +This splits the text into words and then splits each word into subwords. The subwords are then mapped to indices. Blank token can be set to as index 0. This tokenizer is used for languages that have a large number of words and each word can be a combination of other words, therefore it can be applied to any language. + +## 3. Sentencepiece Tokenizer + +See [librespeech config](../examples/datasets/librispeech/sentencepiece/sp.yml.j2) + +This splits the whole sentence into subwords and then maps each subword to an index. Blank token can be set to as index 0. This tokenizer is used for languages that have a large number of words and each word can be a combination of other words, therefore it can be applied to any language. \ No newline at end of file diff --git a/docs/tutorials/testing.md b/docs/tutorials/testing.md new file mode 100644 index 0000000000..68aaf11620 --- /dev/null +++ b/docs/tutorials/testing.md @@ -0,0 +1,63 @@ +- [Testing Tutorial](#testing-tutorial) + - [1. Installation](#1-installation) + - [2. Prepare transcripts files](#2-prepare-transcripts-files) + - [3. Prepare config file](#3-prepare-config-file) + - [4. Run testing](#4-run-testing) + + +# Testing Tutorial + +These commands are example for librispeech dataset, but we can apply similar to other datasets + +## 1. Installation + +```bash +./setup.sh [tpu|gpu|cpu] install +``` + +## 2. Prepare transcripts files + +This is the example for preparing transcript files for librispeech data corpus + +```bash +python examples/datasets/librispeech/prepare_transcript.py \ + --directory=/path/to/dataset/test-clean \ + --output=/path/to/dataset/test-clean/transcripts.tsv +``` + +Do the same thing with `test-clean`, `test-other` + +For other datasets, please make your own script to prepare the transcript files, take a look at the [`prepare_transcript.py`](../../examples/datasets/librispeech/prepare_transcript.py) file for more reference + +## 3. Prepare config file + +The config file is under format `config.yml.j2` which is jinja2 format with yaml content + +Please take a look in some examples for config files in `examples/*/*.yml.j2` + +The config file is the same as the config used for training + +The inputs, outputs and other options of vocabulary are defined in the config file + +For example: + +```jinja2 +{% import "examples/datasets/librispeech/sentencepiece/sp.yml.j2" as decoder_config with context %} +{{decoder_config}} + +{% import "examples/models/transducer/conformer/small.yml.j2" as config with context %} +{{config}} +``` + +## 4. Run testing + +```bash +tensorflow_asr test \ +--config-path /path/to/config.yml.j2 \ +--dataset_type slice \ +--datadir /path/to/datadir \ +--outputdir /path/to/modeldir/tests \ +--h5 /path/to/modeldir/weights.h5 +## See others params +tensorflow_asr test --help +``` \ No newline at end of file diff --git a/docs/tutorials/tflite.md b/docs/tutorials/tflite.md new file mode 100644 index 0000000000..5caf76d2fc --- /dev/null +++ b/docs/tutorials/tflite.md @@ -0,0 +1,66 @@ +- [TFLite Tutorial](#tflite-tutorial) + - [Conversion](#conversion) + - [Inference](#inference) + - [1. Input](#1-input) + - [2. Output](#2-output) + - [3. Example script](#3-example-script) + + +# TFLite Tutorial + +## Conversion + +```bash +tensorflow_asr tflite \ + --config-path=/path/to/config.yml.j2 \ + --h5=/path/to/weight.h5 \ + --bs=1 \ # Batch size + --beam-width=0 \ # Beam width, set >0 to enable beam search + --output=/path/to/output.tflite +## See others params +tensorflow_asr tflite --help +``` + +## Inference + +### 1. Input + +Input of each tflite depends on the models' parameters and configs. + +The `inputs`, `inputs_length` and `previous_tokens` are still the same as bellow for all models. + +```python +schemas.PredictInput( + inputs=tf.TensorSpec([batch_size, None], dtype=tf.float32), + inputs_length=tf.TensorSpec([batch_size], dtype=tf.int32), + previous_tokens=tf.TensorSpec.from_tensor(self.get_initial_tokens(batch_size)), + previous_encoder_states=tf.TensorSpec.from_tensor(self.get_initial_encoder_states(batch_size)), + previous_decoder_states=tf.TensorSpec.from_tensor(self.get_initial_decoder_states(batch_size)), +) +``` + +For models that don't have encoder states or decoder states, the default values are `tf.zeros([], dtype=self.dtype)` tensors for `previous_encoder_states` and `previous_decoder_states`. This is just for tflite conversion because tflite does not allow `None` value in `input_signature`. However, the output `next_encoder_states` and `next_decoder_states` are still `None`, so we can simply ignore those outputs. + +### 2. Output + +```python +schemas.PredictOutputWithTranscript( + transcript=self.tokenizer.detokenize(outputs.tokens), + tokens=outputs.tokens, + next_tokens=outputs.next_tokens, + next_encoder_states=outputs.next_encoder_states, + next_decoder_states=outputs.next_decoder_states, +) +``` + +This is for supporting streaming inference. + +Each output corresponds to the input = each chunk of audio signal. + +Then we can overwrite `previous_tokens`, `previous_encoder_states` and `previous_decoder_states` with `next_tokens`, `next_encoder_states` and `next_decoder_states` for the next chunk of audio signal. + +And continue until the end of the audio signal. + +### 3. Example script + +See [examples/inferences/tflite.py](../../examples/inferences/tflite.py) for more details. \ No newline at end of file diff --git a/docs/tutorials/training.md b/docs/tutorials/training.md new file mode 100644 index 0000000000..4d5e1e8069 --- /dev/null +++ b/docs/tutorials/training.md @@ -0,0 +1,94 @@ +- [Training Tutorial](#training-tutorial) + - [1. Install packages](#1-install-packages) + - [2. Prepare transcripts files](#2-prepare-transcripts-files) + - [3. Prepare config file](#3-prepare-config-file) + - [4. \[Optional\]\[Required if using TPUs\] Create tfrecords](#4-optionalrequired-if-using-tpus-create-tfrecords) + - [5. Generate vocabulary and metadata](#5-generate-vocabulary-and-metadata) + - [6. Run training](#6-run-training) + + +# Training Tutorial + +These commands are example for librispeech dataset, but we can apply similar to other datasets + +## 1. Installation + +```bash +./setup.sh [tpu|gpu|cpu] install +``` + +## 2. Prepare transcripts files + +This is the example for preparing transcript files for librispeech data corpus + +```bash +python examples/datasets/librispeech/prepare_transcript.py \ + --directory=/path/to/dataset/train-clean-100 \ + --output=/path/to/dataset/train-clean-100/transcripts.tsv +``` + +Do the same thing with `train-clean-360`, `train-other-500`, `dev-clean`, `dev-other`, `test-clean`, `test-other` + +For other datasets, please make your own script to prepare the transcript files, take a look at the [`prepare_transcript.py`](../../examples/datasets/librispeech/prepare_transcript.py) file for more reference + +## 3. Prepare config file + +The config file is under format `config.yml.j2` which is jinja2 format with yaml content + +Please take a look in some examples for config files in `examples/*/*.yml.j2` + +For example: + +```jinja2 +{% import "examples/datasets/librispeech/sentencepiece/sp.yml.j2" as decoder_config with context %} +{{decoder_config}} + +{% import "examples/models/transducer/conformer/small.yml.j2" as config with context %} +{{config}} +``` + +## 4. [Optional] Create tfrecords + +If you want to train with tfrecords + +```bash +tensorflow_asr utils create_tfrecords \ + --config-path=/path/to/config.yml.j2 \ + --mode=\["train","eval","test"\] \ + --datadir=/path/to/datadir +``` + +You can reduce the flag `--modes` to `--modes=\["train","eval"\]` to only create train and eval datasets + +## 5. Generate vocabulary and metadata + +This step requires defining path to vocabulary file and other options for generating vocabulary in config file. + +```bash +tensorflow_asr utils create_datasets_metadata \ + --config-path=/path/to/config.yml.j2 \ + --datadir=/path/to/datadir \ + --dataset-type="slice" +``` + +The inputs, outputs and other options of vocabulary are defined in the config file + +## 6. Run training + +```bash +tensorflow_asr train \ + --config-path=/path/to/config.yml.j2 \ + --modeldir=/path/to/modeldir \ + --datadir=/path/to/datadir \ + --dataset-type=tfrecord \ # or "generator" or "slice" \ + --dataset-cache \ + --mxp=strict \ + --bs=4 \ + --ga-steps=8 \ + --verbose=1 \ + --jit-compile \ + --device-type=tpu \ + --tpu-address=local +## See others params +tensorflow_asr train --help +``` \ No newline at end of file diff --git a/examples/ctc/conformer/confs/config_wp.j2 b/examples/ctc/conformer/confs/config_wp.j2 deleted file mode 100644 index 2231dd36ae..0000000000 --- a/examples/ctc/conformer/confs/config_wp.j2 +++ /dev/null @@ -1,169 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -{% set repodir = "." %} -{% set modeldir = "/Users/nlhuy/Paraphernalia/models/local/conformer-ctc" %} -{% set datadir = "/Users/nlhuy/Paraphernalia/data/LibriSpeech" %} - -model_config: - class_name: tensorflow_asr.models.ctc>Conformer - config: - name: conformer - encoder_subsampling: - type: conv2d - nlayers: 2 - filters: 144 - kernel_size: 3 - strides: 2 - padding: causal - norm: none - activation: relu - encoder_dmodel: 144 - encoder_num_blocks: 16 - encoder_head_size: 36 - encoder_num_heads: 4 - encoder_mha_type: relmha - encoder_interleave_relpe: True - encoder_use_attention_causal_mask: False - encoder_use_attention_auto_mask: True - encoder_kernel_size: 32 - encoder_dropout: 0.1 - encoder_padding: causal - encoder_depthwise_as_groupwise: False - encoder_ffm_residual_factor: 0.5 - encoder_mhsam_residual_factor: 1.0 - encoder_convm_residual_factor: 1.0 - encoder_module_norm_position: pre - encoder_block_norm_position: post - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - -decoder_config: - type: wordpiece - blank_index: 0 - unknown_token: "" - unknown_index: 1 - beam_width: 0 - norm_score: True - lm_config: null - vocabulary: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000_50.tokens - vocab_size: 1000 - max_token_length: 50 - max_unique_chars: 1000 - reserved_tokens: - - "" - - "" - normalization_form: NFKC - num_iterations: 4 - corpus_files: - - {{datadir}}/train-clean-100/transcripts.tsv - - {{datadir}}/train-clean-360/transcripts.tsv - - {{datadir}}/train-other-500/transcripts.tsv - -learning_config: - train_dataset_config: - enabled: True - use_tf: True - augmentation_config: - feature_augment: - time_masking: - prob: 1.0 - num_masks: 5 - mask_factor: 100 - p_upperbound: 0.05 - freq_masking: - prob: 1.0 - num_masks: 1 - mask_factor: 27 - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: null - shuffle: True - cache: True - buffer_size: 100 - drop_remainder: True - stage: train - metadata: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000_50.metadata.json - - eval_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: eval - metadata: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000_50.metadata.json - - test_dataset_config: - use_tf: True - data_paths: - - {{datadir}}/test-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: False - stage: test - - optimizer_config: - class_name: adam - config: - learning_rate: - class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule - config: - dmodel: 144 - warmup_steps: 10000 - max_lr: 0.00035 - min_lr: 1e-6 - beta_1: 0.9 - beta_2: 0.98 - epsilon: 1e-9 - - running_config: - batch_size: 2 - num_epochs: 100 - checkpoint: - filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5 - save_best_only: False - save_weights_only: True - save_freq: epoch - backup_and_restore: - backup_dir: {{modeldir}}/states - save_freq: epoch - delete_checkpoint: False - tensorboard: - log_dir: {{modeldir}}/tensorboard - histogram_freq: 1 - write_graph: True - write_images: True - update_freq: epoch - profile_batch: 100 - early_stopping: - monitor: val_loss - min_delta: 0 - patience: 1 - verbose: 0 - mode: min - baseline: null - restore_best_weights: True - start_from_epoch: 10 diff --git a/examples/ctc/conformer/tests/create_model.py b/examples/ctc/conformer/tests/create_model.py deleted file mode 100644 index 562d4b1904..0000000000 --- a/examples/ctc/conformer/tests/create_model.py +++ /dev/null @@ -1,92 +0,0 @@ -# %% -from tensorflow_asr.configs.config import Config -from tensorflow_asr.helpers import featurizer_helpers -from tensorflow_asr.models.ctc.conformer import Conformer -from tensorflow_asr.utils import env_util - -logger = env_util.setup_environment() - -env_util.setup_seed() - -config_dict = { - "speech_config": { - "sample_rate": 16000, - "frame_ms": 25, - "stride_ms": 10, - "num_feature_bins": 80, - "feature_type": "log_mel_spectrogram", - }, - "decoder_config": { - "type": "wordpiece", - "blank_index": 0, - "unknown_token": "[PAD]", - "unknown_index": 0, - "beam_width": 0, - "norm_score": True, - "lm_config": None, - "vocabulary": "../../../vocabularies/librispeech/wordpiece/train_1000_50.tokens", - "vocab_size": 1000, - "max_token_length": 50, - "max_unique_chars": 1000, - "reserved_tokens": ["[PAD]"], - "normalization_form": "NFKC", - "num_iterations": 4, - }, - "model_config": { - "name": "conformer", - "encoder_subsampling": { - "type": "conv2d_blurpool", - "filters": 144, - "kernel_size": 3, - "strides": 2, - "conv_padding": "same", - "pool_padding": "reflect", - "activation": "relu", - }, - "encoder_dmodel": 144, - "encoder_num_blocks": 16, - "encoder_head_size": 36, - "encoder_num_heads": 4, - "encoder_mha_type": "relmha", - "encoder_use_attention_mask": True, - "encoder_kernel_size": 32, - "encoder_fc_factor": 0.5, - "encoder_dropout": 0.1, - "encoder_padding": "same", - }, -} - -config = Config(config_dict) - -speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(config=config) - -global_batch_size = 2 -speech_featurizer.update_length(1200) -text_featurizer.update_length(700) - -conformer = Conformer( - **config.model_config, - vocab_size=text_featurizer.num_classes, -) -conformer.make(speech_featurizer.shape, batch_size=global_batch_size) -conformer.add_featurizers(speech_featurizer, text_featurizer) -conformer.summary() -# %% -import tensorflow as tf - -from tensorflow_asr.models.layers.multihead_attention import compute_self_attention_mask - -compute_self_attention_mask(tf.zeros([4, 10, 3]), [8, 7, 9, 10]) -# %% -conformer.save_weights("./conformer.h5") -conformer.load_weights("./conformer.h5") -# %% -conformer.save("./saved_model") - -# %% - -import tf2onnx - -tf2onnx.convert.from_keras(conformer, output_path="./conformer.onnx") - -# %% diff --git a/examples/ctc/deepspeech2/README.md b/examples/ctc/deepspeech2/README.md deleted file mode 100755 index 8d285316e5..0000000000 --- a/examples/ctc/deepspeech2/README.md +++ /dev/null @@ -1,13 +0,0 @@ -# Deep Speech 2 - -References: [https://arxiv.org/abs/1512.02595](https://arxiv.org/abs/1512.02595) - -## Example YAML Config - -Go to [config.yml](./config.yml) - -## Training and Testing - -See `python examples/deepspeech2/train_*.py --help` - -See `python examples/deepspeech2/test_*.py --help` diff --git a/examples/ctc/deepspeech2/confs/config_wp.j2 b/examples/ctc/deepspeech2/confs/config_wp.j2 deleted file mode 100644 index 249408f65c..0000000000 --- a/examples/ctc/deepspeech2/confs/config_wp.j2 +++ /dev/null @@ -1,147 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -{% set repodir = "." %} -{% set modeldir = "/mnt/Miscellanea/Models/local/deepspeech2" %} -{% set datadir = "/mnt/Data/MLDL/Datasets/ASR/LibriSpeech" %} - -model_config: - name: deepspeech2 - conv_type: conv2d - conv_kernels: [[11, 41], [11, 21], [11, 11]] - conv_strides: [[2, 2], [1, 2], [1, 2]] - conv_filters: [32, 32, 96] - conv_dropout: 0.1 - rnn_nlayers: 5 - rnn_type: lstm - rnn_units: 512 - rnn_bidirectional: True - rnn_unroll: False - rnn_rowconv: 0 - rnn_dropout: 0.1 - fc_nlayers: 0 - fc_units: 1024 - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - -decoder_config: - type: wordpiece - - blank_index: 0 - unknown_token: "" - unknown_index: 1 - - beam_width: 0 - norm_score: True - lm_config: null - - vocabulary: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000_50.tokens - vocab_size: 1000 - max_token_length: 50 - max_unique_chars: 1000 - reserved_tokens: - - "" - - "" - normalization_form: NFKC - num_iterations: 4 - - corpus_files: - - {{datadir}}/train-clean-100/transcripts.tsv - - {{datadir}}/train-clean-360/transcripts.tsv - - {{datadir}}/train-other-500/transcripts.tsv - -learning_config: - train_dataset_config: - enabled: True - use_tf: True - augmentation_config: - feature_augment: - time_masking: - prob: 0.5 - num_masks: 10 - mask_factor: 100 - p_upperbound: 0.05 - freq_masking: - prob: 0.5 - num_masks: 1 - mask_factor: 27 - data_paths: - - {{datadir}}/train-clean-100/transcripts.tsv - tfrecords_dir: {{datadir}}/tfrecords/100h - shuffle: True - cache: True - buffer_size: 100 - drop_remainder: True - stage: train - metadata: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000.metadata.json - - eval_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: eval - metadata: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000.metadata.json - - test_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/test-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: test - - optimizer_config: - beta_1: 0.9 - beta_2: 0.98 - epsilon: 1e-9 - - learning_rate_config: - warmup_steps: 10000 - max_lr_numerator: 0.05 - - running_config: - batch_size: 2 - num_epochs: 100 - checkpoint: - filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5 - save_best_only: False - save_weights_only: True - save_freq: epoch - backup_and_restore: - backup_dir: {{modeldir}}/states - save_freq: epoch - delete_checkpoint: False - tensorboard: - log_dir: {{modeldir}}/tensorboard - histogram_freq: 1 - write_graph: True - write_images: True - update_freq: epoch - profile_batch: 2 - diff --git a/examples/ctc/deepspeech2/results/sentencepiece.md b/examples/ctc/deepspeech2/results/sentencepiece.md deleted file mode 100644 index cac98a64d3..0000000000 --- a/examples/ctc/deepspeech2/results/sentencepiece.md +++ /dev/null @@ -1,169 +0,0 @@ -# Sentencepiece DeepSpeech2 - - -- [Sentencepiece DeepSpeech2](#sentencepiece-deepspeech2) - - [2023-02-12](#2023-02-12) - - - -## 2023-02-12 - -Config: - -```python -config = """ -{% set repodir = "/path/to/TensorFlowASR" %} -{% set modeldir = "/path/to/models/sp1k-deepspeech2/20230212" %} -{% set datadir = "/path/to/librispeech/tfrecords" %} - -model_config: - name: deepspeech2 - conv_type: conv2d - conv_kernels: [[11, 41], [11, 21], [11, 11]] - conv_strides: [[3, 2], [2, 2], [1, 2]] - conv_filters: [32, 32, 96] - conv_padding: same - conv_dropout: 0.1 - rnn_nlayers: 7 - rnn_type: lstm - rnn_bn_type: bn - rnn_units: 512 - rnn_bidirectional: True - rnn_unroll: False - rnn_rowconv: 0 - rnn_dropout: 0.1 - fc_nlayers: 0 - fc_units: 1024 - fc_dropout: 0.1 - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 128 - feature_type: spectrogram - -decoder_config: - type: sentencepiece - - blank_index: 0 - pad_token: "" - pad_index: 0 - unknown_token: "" - unknown_index: 1 - bos_token: "" - bos_index: 2 - eos_token: "" - eos_index: 3 - - beam_width: 0 - norm_score: True - lm_config: null - - model_type: bpe - vocabulary: {{repodir}}/vocabularies/librispeech/sentencepiece/train_bpe_1000.model - vocab_size: 1000 - max_token_length: 50 - max_unique_chars: 1000 - reserved_tokens: null - normalization_form: NFKC - num_iterations: 4 - -learning_config: - train_dataset_config: - enabled: True - use_tf: True - data_paths: null - tfrecords_dir: {{datadir}} - shuffle: True - cache: False - buffer_size: 1000 - drop_remainder: True - stage: train - metadata: {{repodir}}/vocabularies/librispeech/sentencepiece/train_bpe_1000.metadata.json - - eval_dataset_config: - enabled: False - use_tf: True - data_paths: null - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: eval - metadata: null - - test_dataset_config: - enabled: False - use_tf: True - data_paths: null - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: test - - optimizer_config: - class_name: adam - config: - beta_1: 0.9 - beta_2: 0.98 - epsilon: 1e-9 - - running_config: - batch_size: 8 - num_epochs: 300 - checkpoint: - filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5 - save_best_only: False - save_weights_only: True - save_freq: epoch - options: - experimental_enable_async_checkpoint: True - backup_and_restore: - backup_dir: {{modeldir}}/states - tensorboard: - log_dir: {{modeldir}}/tensorboard - write_graph: False - write_images: False - update_freq: epoch - profile_batch: 100 -""" -with open("/path/to/config.j2", "w") as f: - f.write(config) -``` - -Training: - -```bash -python /path/to/TensorFlowASR/examples/ctc/deepspeech2/train.py \ - --config-path=/path/to/config.j2 \ - --mxp=strict \ - --jit-compile \ - --tfrecords -``` - -Testing: - -```bash -python /path/to/TensorFlowASR/examples/ctc/deepspeech2/test.py \ - --config-path=/path/to/config.j2 \ - --saved=/path/to/models/sp1k-deepspeech2/20230212/checkpoints/25.h5 \ - --output=/path/to/models/sp1k-deepspeech2/20230212/tests/25.tsv \ - --bs=1 -``` - -RNNT Loss Curves: - - - -Error Rates: - -| Dataset | Mode | Batch size | Epoch | WER (%) | CER (%) | -| :--------------------- | :----------------------- | :--------: | :---: | :-----: | :-----: | -| librispeech-test-clean | greedy | 1 | 25 | | | -| librispeech-test-clean | beamsearch with size 500 | 1 | 25 | | | -| librispeech-test-other | greedy | 1 | 25 | | | -| librispeech-test-other | beamsearch with size 500 | 1 | 25 | | | \ No newline at end of file diff --git a/examples/ctc/jasper/README.md b/examples/ctc/jasper/README.md deleted file mode 100755 index 4c9195579f..0000000000 --- a/examples/ctc/jasper/README.md +++ /dev/null @@ -1,13 +0,0 @@ -# Jasper - -References: [https://arxiv.org/abs/1904.03288](https://arxiv.org/abs/1904.03288) - -## Example YAML Config - -Go to [config.yml](./config.yml) - -## Training and Testing - -See `python examples/jasper/train_*.py --help` - -See `python examples/jasper/test_*.py --help` diff --git a/examples/ctc/jasper/confs/config_wp.j2 b/examples/ctc/jasper/confs/config_wp.j2 deleted file mode 100644 index 7f936ffdc5..0000000000 --- a/examples/ctc/jasper/confs/config_wp.j2 +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -{% set repodir = "." %} -{% set modeldir = "/mnt/Miscellanea/Models/local/jasper" %} -{% set datadir = "/mnt/Data/MLDL/Datasets/ASR/LibriSpeech" %} - -model_config: - name: jasper - dense: True - first_additional_block_channels: 256 - first_additional_block_kernels: 11 - first_additional_block_strides: 2 - first_additional_block_dilation: 1 - first_additional_block_dropout: 0.2 - nsubblocks: 3 - block_channels: [256, 384, 512, 640, 768] - block_kernels: [11, 13, 17, 21, 25] - block_dropout: [0.2, 0.2, 0.2, 0.3, 0.3] - second_additional_block_channels: 896 - second_additional_block_kernels: 1 - second_additional_block_strides: 1 - second_additional_block_dilation: 2 - second_additional_block_dropout: 0.4 - third_additional_block_channels: 1024 - third_additional_block_kernels: 1 - third_additional_block_strides: 1 - third_additional_block_dilation: 1 - third_additional_block_dropout: 0.4 - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - -decoder_config: - type: wordpiece - - blank_index: 0 - unknown_token: "" - unknown_index: 1 - - beam_width: 0 - norm_score: True - lm_config: null - - vocabulary: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000_50.tokens - vocab_size: 1000 - max_token_length: 50 - max_unique_chars: 1000 - reserved_tokens: - - "" - - "" - normalization_form: NFKC - num_iterations: 4 - - corpus_files: - - {{datadir}}/train-clean-100/transcripts.tsv - - {{datadir}}/train-clean-360/transcripts.tsv - - {{datadir}}/train-other-500/transcripts.tsv - -learning_config: - train_dataset_config: - enabled: True - use_tf: True - augmentation_config: - feature_augment: - time_masking: - prob: 0.5 - num_masks: 10 - mask_factor: 100 - p_upperbound: 0.05 - freq_masking: - prob: 0.5 - num_masks: 1 - mask_factor: 27 - data_paths: - - {{datadir}}/train-clean-100/transcripts.tsv - tfrecords_dir: {{datadir}}/tfrecords/100h - shuffle: True - cache: True - buffer_size: 100 - drop_remainder: True - stage: train - metadata: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000.metadata.json - - eval_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: eval - metadata: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000.metadata.json - - test_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/test-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: test - - optimizer_config: - beta_1: 0.9 - beta_2: 0.98 - epsilon: 1e-9 - - learning_rate_config: - warmup_steps: 10000 - max_lr_numerator: 0.05 - - running_config: - batch_size: 2 - num_epochs: 100 - checkpoint: - filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5 - save_best_only: False - save_weights_only: True - save_freq: epoch - backup_and_restore: - backup_dir: {{modeldir}}/states - save_freq: epoch - delete_checkpoint: False - tensorboard: - log_dir: {{modeldir}}/tensorboard - histogram_freq: 1 - write_graph: True - write_images: True - update_freq: epoch - profile_batch: 2 - diff --git a/examples/ctc/transformer/confs/config_char.j2 b/examples/ctc/transformer/confs/config_char.j2 deleted file mode 100644 index 4a6bfdb9db..0000000000 --- a/examples/ctc/transformer/confs/config_char.j2 +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -{% set repodir = "." %} -{% set modeldir = "/Volumes/Data/Miscellanea/Models/local/transformer-ctc" %} -{% set datadir = "/Volumes/Data/MLDL/Datasets/ASR/LibriSpeech" %} - -model_config: - class_name: tensorflow_asr.models.ctc>transformer - config: - name: transformer - encoder_subsampling: - type: conv2d - nlayers: 2 - filters: 512 - kernel_size: 3 - strides: 2 - padding: causal - norm: none - activation: relu - encoder_dropout: 0.1 - encoder_residual_factor: 1.0 - encoder_norm_position: post - encoder_dmodel: 512 - encoder_dff: 1024 - encoder_num_blocks: 6 - encoder_head_size: 128 - encoder_num_heads: 4 - encoder_mha_type: relmha - encoder_interleave_relpe: True - encoder_use_attention_causal_mask: False - encoder_use_attention_auto_mask: True - encoder_pwffn_activation: relu - encoder_memory_length: 512 - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - -decoder_config: - type: characters - blank_index: 0 - beam_width: 0 - norm_score: True - lm_config: null - vocabulary: {{repodir}}/vocabularies/english.characters - -learning_config: - train_dataset_config: - enabled: True - use_tf: True - augmentation_config: - feature_augment: - time_masking: - prob: 1.0 - num_masks: 5 - mask_factor: 100 - p_upperbound: 0.05 - freq_masking: - prob: 1.0 - num_masks: 1 - mask_factor: 27 - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: null - shuffle: True - cache: True - buffer_size: 100 - drop_remainder: True - stage: train - metadata: {{repodir}}/vocabularies/librispeech/characters/train.metadata.json - - eval_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: eval - metadata: {{repodir}}/vocabularies/librispeech/characters/train.metadata.json - - test_dataset_config: - use_tf: True - data_paths: - - {{datadir}}/test-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: False - stage: test - - optimizer_config: - class_name: adam - config: - learning_rate: - class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule - config: - dmodel: 512 - warmup_steps: 10000 - max_lr: null - min_lr: 1e-6 - beta_1: 0.9 - beta_2: 0.98 - epsilon: 1e-9 - - running_config: - batch_size: 2 - num_epochs: 100 - checkpoint: - filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5 - save_best_only: False - save_weights_only: True - save_freq: epoch - backup_and_restore: - backup_dir: {{modeldir}}/states - save_freq: epoch - delete_checkpoint: False - tensorboard: - log_dir: {{modeldir}}/tensorboard - histogram_freq: 1 - write_graph: True - write_images: True - update_freq: epoch - profile_batch: 2 diff --git a/examples/datasets/librispeech/characters/char.yml.j2 b/examples/datasets/librispeech/characters/char.yml.j2 new file mode 100644 index 0000000000..6f5141d3a1 --- /dev/null +++ b/examples/datasets/librispeech/characters/char.yml.j2 @@ -0,0 +1,15 @@ +{% set vocabsize = 29 %} +{% set vocabprefix = repodir ~ "/examples/datasets/librispeech/characters/english" %} +{% set metadata = vocabprefix ~ ".metadata.json" %} + +decoder_config: + type: characters + blank_index: 0 + beam_width: 0 + norm_score: True + lm_config: null + vocabulary: {{vocabprefix}}.vocab + vocab_size: {{vocabsize}} + +{% import "examples/datasets/librispeech/config.yml.j2" as data_config with context %} +{{data_config}} \ No newline at end of file diff --git a/vocabularies/librispeech/characters/train.metadata.json b/examples/datasets/librispeech/characters/english.metadata.json similarity index 69% rename from vocabularies/librispeech/characters/train.metadata.json rename to examples/datasets/librispeech/characters/english.metadata.json index 8f79f0fe58..60afb73c56 100644 --- a/vocabularies/librispeech/characters/train.metadata.json +++ b/examples/datasets/librispeech/characters/english.metadata.json @@ -1,11 +1,11 @@ { "train": { - "max_input_length": 2972, + "max_input_length": 475760, "max_label_length": 524, "num_entries": 281241 }, "eval": { - "max_input_length": 3514, + "max_input_length": 562480, "max_label_length": 516, "num_entries": 5567 } diff --git a/vocabularies/english.characters b/examples/datasets/librispeech/characters/english.vocab similarity index 100% rename from vocabularies/english.characters rename to examples/datasets/librispeech/characters/english.vocab diff --git a/examples/datasets/librispeech/config.yml.j2 b/examples/datasets/librispeech/config.yml.j2 new file mode 100644 index 0000000000..bc59df7e00 --- /dev/null +++ b/examples/datasets/librispeech/config.yml.j2 @@ -0,0 +1,59 @@ +data_config: + train_dataset_config: + enabled: True + sample_rate: 16000 + data_paths: + - {{datadir}}/train-clean-100/transcripts.tsv + - {{datadir}}/train-clean-360/transcripts.tsv + - {{datadir}}/train-other-500/transcripts.tsv + tfrecords_dir: {{datadir}}/tfrecords + tfrecords_shards: 32 + shuffle: True + cache: False + buffer_size: 1024 + drop_remainder: True + stage: train + metadata: {{metadata}} + indefinite: True + + eval_dataset_config: + enabled: True + sample_rate: 16000 + data_paths: + - {{datadir}}/dev-clean/transcripts.tsv + - {{datadir}}/dev-other/transcripts.tsv + tfrecords_dir: {{datadir}}/tfrecords + buffer_size: 1024 + tfrecords_shards: 2 + shuffle: True + cache: False + drop_remainder: True + stage: eval + metadata: {{metadata}} + indefinite: True + + test_dataset_configs: + - name: test-clean + enabled: True + sample_rate: 16000 + data_paths: + - {{datadir}}/test-clean/transcripts.tsv + tfrecords_dir: {{datadir}}/tfrecords + shuffle: False + cache: False + buffer_size: null + drop_remainder: False + stage: test + indefinite: False + - name: test-other + enabled: True + sample_rate: 16000 + data_paths: + - {{datadir}}/test-other/transcripts.tsv + tfrecords_dir: {{datadir}}/tfrecords + shuffle: False + cache: False + buffer_size: null + drop_remainder: False + stage: test + indefinite: False \ No newline at end of file diff --git a/scripts/create_librispeech_trans.py b/examples/datasets/librispeech/prepare_transcript.py similarity index 94% rename from scripts/create_librispeech_trans.py rename to examples/datasets/librispeech/prepare_transcript.py index f7093ed39b..a23ad2e881 100644 --- a/scripts/create_librispeech_trans.py +++ b/examples/datasets/librispeech/prepare_transcript.py @@ -17,7 +17,6 @@ import unicodedata import librosa -from tqdm.auto import tqdm from tensorflow_asr.utils import cli_util, file_util @@ -33,7 +32,9 @@ def main( text_files = glob.glob(os.path.join(directory, "**", "*.txt"), recursive=True) - for text_file in tqdm(text_files, desc="[Loading]"): + from tqdm.auto import tqdm + + for text_file in tqdm(text_files, desc="[Loading]", disable=False): current_dir = os.path.dirname(text_file) with open(text_file, "r", encoding="utf-8") as txt: lines = txt.read().splitlines() diff --git a/examples/datasets/librispeech/sentencepiece/sp.256.yml.j2 b/examples/datasets/librispeech/sentencepiece/sp.256.yml.j2 new file mode 100644 index 0000000000..5ec3ddfad7 --- /dev/null +++ b/examples/datasets/librispeech/sentencepiece/sp.256.yml.j2 @@ -0,0 +1,30 @@ +{% set vocabsize = 256 %} +{% set vocabprefix = repodir ~ "/examples/datasets/librispeech/sentencepiece/train_bpe_" ~ vocabsize %} +{% set metadata = vocabprefix ~ ".metadata.json" %} + +decoder_config: + type: sentencepiece + blank_index: 0 + unknown_token: "" + unknown_index: 0 + pad_token: "" + pad_index: -1 + bos_token: "" + bos_index: -1 + eos_token: "" + eos_index: -1 + beam_width: 0 + norm_score: True + lm_config: null + model_type: bpe + vocabulary: {{vocabprefix}}.model + vocab_size: {{vocabsize}} + reserved_tokens: null + normalization_form: NFKC + max_sentencepiece_length: 16 + max_sentence_length: 1048576 + character_coverage: 1.0 + keep_whitespace: False + +{% import "examples/datasets/librispeech/config.yml.j2" as data_config with context %} +{{data_config}} \ No newline at end of file diff --git a/examples/datasets/librispeech/sentencepiece/sp.yml.j2 b/examples/datasets/librispeech/sentencepiece/sp.yml.j2 new file mode 100644 index 0000000000..16c0f5ae53 --- /dev/null +++ b/examples/datasets/librispeech/sentencepiece/sp.yml.j2 @@ -0,0 +1,30 @@ +{% set vocabsize = 1000 %} +{% set vocabprefix = repodir ~ "/examples/datasets/librispeech/sentencepiece/train_bpe_" ~ vocabsize %} +{% set metadata = vocabprefix ~ ".metadata.json" %} + +decoder_config: + type: sentencepiece + blank_index: 0 + unknown_token: "" + unknown_index: 0 + pad_token: "" + pad_index: -1 + bos_token: "" + bos_index: -1 + eos_token: "" + eos_index: -1 + beam_width: 0 + norm_score: True + lm_config: null + model_type: bpe + vocabulary: {{vocabprefix}}.model + vocab_size: {{vocabsize}} + reserved_tokens: null + normalization_form: NFKC + max_sentencepiece_length: 16 + max_sentence_length: 1048576 + character_coverage: 1.0 + keep_whitespace: False + +{% import "examples/datasets/librispeech/config.yml.j2" as data_config with context %} +{{data_config}} \ No newline at end of file diff --git a/vocabularies/librispeech/sentencepiece/train_8000&960.model b/examples/datasets/librispeech/sentencepiece/train_8000&960.model similarity index 100% rename from vocabularies/librispeech/sentencepiece/train_8000&960.model rename to examples/datasets/librispeech/sentencepiece/train_8000&960.model diff --git a/vocabularies/librispeech/sentencepiece/train_bpe_1000.metadata.json b/examples/datasets/librispeech/sentencepiece/train_bpe_1000.metadata.json similarity index 55% rename from vocabularies/librispeech/sentencepiece/train_bpe_1000.metadata.json rename to examples/datasets/librispeech/sentencepiece/train_bpe_1000.metadata.json index f6cd8daceb..6bd5be7c42 100644 --- a/vocabularies/librispeech/sentencepiece/train_bpe_1000.metadata.json +++ b/examples/datasets/librispeech/sentencepiece/train_bpe_1000.metadata.json @@ -1,11 +1,11 @@ { "train": { - "max_input_length": 2972, - "max_label_length": 231, + "max_input_length": 475760, + "max_label_length": 230, "num_entries": 281241 }, "eval": { - "max_input_length": 3514, + "max_input_length": 562480, "max_label_length": 225, "num_entries": 5567 } diff --git a/vocabularies/librispeech/sentencepiece/train_bpe_1000.model b/examples/datasets/librispeech/sentencepiece/train_bpe_1000.model similarity index 97% rename from vocabularies/librispeech/sentencepiece/train_bpe_1000.model rename to examples/datasets/librispeech/sentencepiece/train_bpe_1000.model index f0b50e999f..b754edf7a5 100644 Binary files a/vocabularies/librispeech/sentencepiece/train_bpe_1000.model and b/examples/datasets/librispeech/sentencepiece/train_bpe_1000.model differ diff --git a/examples/datasets/librispeech/sentencepiece/train_bpe_1000.vocab b/examples/datasets/librispeech/sentencepiece/train_bpe_1000.vocab new file mode 100644 index 0000000000..b5c4a3c777 --- /dev/null +++ b/examples/datasets/librispeech/sentencepiece/train_bpe_1000.vocab @@ -0,0 +1,1000 @@ + 0 +▁t -0 +he -1 +▁a -2 +▁the -3 +in -4 +▁s -5 +▁w -6 +▁o -7 +re -8 +nd -9 +▁b -10 +▁h -11 +er -12 +▁m -13 +▁i -14 +ou -15 +▁c -16 +▁f -17 +at -18 +ed -19 +▁and -20 +en -21 +▁to -22 +▁of -23 +on -24 +is -25 +▁d -26 +ing -27 +▁th -28 +▁p -29 +▁he -30 +or -31 +▁l -32 +es -33 +▁in -34 +ll -35 +it -36 +ar -37 +as -38 +an -39 +▁n -40 +▁g -41 +om -42 +▁be -43 +▁ha -44 +▁e -45 +le -46 +ot -47 +▁y -48 +ut -49 +ow -50 +ic -51 +▁wh -52 +▁it -53 +ld -54 +ve -55 +▁that -56 +ly -57 +▁was -58 +id -59 +se -60 +st -61 +▁on -62 +gh -63 +ent -64 +▁re -65 +▁you -66 +im -67 +ce -68 +▁u -69 +ver -70 +ion -71 +▁as -72 +et -73 +▁for -74 +ay -75 +▁we -76 +▁his -77 +ith -78 +al -79 +ir -80 +▁r -81 +▁with -82 +▁st -83 +ad -84 +ur -85 +ght -86 +▁an -87 +▁her -88 +▁not -89 +▁had -90 +▁is -91 +ter -92 +her -93 +ac -94 +am -95 +▁at -96 +oo -97 +▁but -98 +ould -99 +▁she -100 +▁k -101 +▁se -102 +▁sa -103 +▁sh -104 +▁fr -105 +▁him -106 +▁so -107 +ill -108 +▁me -109 +ain -110 +▁su -111 +ight -112 +ch -113 +red -114 +ct -115 +all -116 +ro -117 +ke -118 +ess -119 +il -120 +ore -121 +▁de -122 +▁they -123 +▁my -124 +▁whe -125 +▁all -126 +ich -127 +▁ne -128 +ri -129 +▁by -130 +▁have -131 +ome -132 +pp -133 +▁this -134 +▁li -135 +▁do -136 +▁con -137 +us -138 +▁which -139 +▁ch -140 +ul -141 +qu -142 +▁j -143 +▁up -144 +▁said -145 +▁from -146 +ard -147 +ge -148 +▁or -149 +▁v -150 +▁one -151 +th -152 +▁no -153 +▁ex -154 +▁were -155 +▁there -156 +pe -157 +and -158 +est -159 +▁man -160 +▁who -161 +ble -162 +ant -163 +ie -164 +▁al -165 +res -166 +ous -167 +ust -168 +very -169 +ation -170 +▁fe -171 +▁them -172 +lf -173 +▁when -174 +ind -175 +nt -176 +ame -177 +ra -178 +▁go -179 +ers -180 +ast -181 +fe -182 +ood -183 +▁kn -184 +▁int -185 +ist -186 +art -187 +▁are -188 +out -189 +▁would -190 +▁le -191 +os -192 +▁their -193 +ong -194 +▁what -195 +our -196 +▁if -197 +ound -198 +▁com -199 +▁ab -200 +▁out -201 +▁wor -202 +em -203 +▁will -204 +ak -205 +▁mis -206 +ate -207 +ol -208 +um -209 +un -210 +itt -211 +ough -212 +ked -213 +ap -214 +ig -215 +one -216 +▁been -217 +own -218 +ive -219 +▁then -220 +▁br -221 +ven -222 +if -223 +▁ar -224 +▁tr -225 +self -226 +▁pl -227 +▁ro -228 +ther -229 +▁pr -230 +reat -231 +▁un -232 +▁af -233 +▁sp -234 +▁qu -235 +▁pro -236 +ity -237 +hed -238 +▁tw -239 +▁ag -240 +▁could -241 +ost -242 +ace -243 +ort -244 +ure -245 +ake -246 +ack -247 +▁am -248 +▁any -249 +▁some -250 +▁your -251 +▁more -252 +▁can -253 +au -254 +▁tim -255 +ep -256 +▁en -257 +ag -258 +ck -259 +▁cl -260 +▁into -261 +ry -262 +hing -263 +▁now -264 +nder -265 +are -266 +▁very -267 +▁gr -268 +el -269 +ose -270 +▁loo -271 +▁bo -272 +ved -273 +op -274 +▁other -275 +▁did -276 +ance -277 +▁than -278 +ittle -279 +▁little -280 +ine -281 +ies -282 +way -283 +ite -284 +▁like -285 +ide -286 +ass -287 +▁bl -288 +able -289 +▁lo -290 +urn -291 +ought -292 +▁know -293 +other -294 +▁time -295 +▁im -296 +▁dis -297 +▁us -298 +▁co -299 +fore -300 +▁te -301 +▁how -302 +ence -303 +▁day -304 +▁ad -305 +ade -306 +▁about -307 +ice -308 +▁see -309 +▁over -310 +pt -311 +cc -312 +▁too -313 +ink -314 +▁fl -315 +wn -316 +▁great -317 +▁after -318 +pl -319 +de -320 +▁per -321 +▁again -322 +ment -323 +▁upon -324 +▁hand -325 +ab -326 +ree -327 +▁has -328 +ish -329 +ci -330 +▁only -331 +ally -332 +▁well -333 +▁should -334 +▁po -335 +▁mar -336 +ress -337 +▁say -338 +▁good -339 +ather -340 +▁two -341 +ings -342 +▁pe -343 +ount -344 +▁our -345 +ire -346 +ving -347 +▁down -348 +ars -349 +ert -350 +we -351 +▁before -352 +ile -353 +▁app -354 +ves -355 +▁every -356 +▁its -357 +▁old -358 +▁thr -359 +▁mu -360 +▁made -361 +ick -362 +ied -363 +▁long -364 +te -365 +age -366 +ft -367 +▁where -368 +▁never -369 +ang -370 +▁pre -371 +▁must -372 +▁sm -373 +▁such -374 +ull -375 +ful -376 +▁str -377 +ions -378 +▁sc -379 +▁off -380 +▁came -381 +ious -382 +ue -383 +▁miss -384 +ward -385 +▁fir -386 +ild -387 +▁even -388 +▁under -389 +▁these -390 +act -391 +▁come -392 +▁part -393 +▁fo -394 +ated -395 +ness -396 +▁rem -397 +▁bec -398 +ord -399 +▁may -400 +ty -401 +▁think -402 +▁much -403 +per -404 +▁mister -405 +▁way -406 +led -407 +orn -408 +▁ey -409 +▁let -410 +▁cont -411 +▁gl -412 +▁thought -413 +▁look -414 +ect -415 +▁spe -416 +▁back -417 +ise -418 +▁bet -419 +▁ye -420 +ady -421 +ach -422 +ans -423 +▁just -424 +▁first -425 +▁here -426 +ren -427 +▁ho -428 +▁des -429 +▁ob -430 +▁own -431 +ried -432 +ud -433 +ary -434 +▁went -435 +▁himself -436 +▁mo -437 +cl -438 +▁men -439 +air -440 +ave -441 +ath -442 +▁sl -443 +ff -444 +co -445 +▁cr -446 +llow -447 +▁res -448 +▁might -449 +ily -450 +▁seem -451 +int -452 +ip -453 +▁beg -454 +ouse -455 +anc -456 +▁wat -457 +▁through -458 +▁comp -459 +ber -460 +▁car -461 +▁away -462 +▁em -463 +▁get -464 +▁imp -465 +▁head -466 +oss -467 +▁don -468 +▁bel -469 +▁life -470 +▁without -471 +▁pass -472 +▁most -473 +▁make -474 +ened -475 +▁cons -476 +▁som -477 +▁turn -478 +av -479 +ng -480 +▁shall -481 +▁those -482 +▁eyes -483 +▁pres -484 +▁acc -485 +▁house -486 +iz -487 +▁somet -488 +▁jo -489 +▁still -490 +▁call -491 +hes -492 +▁op -493 +▁night -494 +ause -495 +▁wom -496 +less -497 +▁last -498 +ks -499 +ared -500 +▁comm -501 +▁nothing -502 +▁ent -503 +▁tell -504 +▁new -505 +▁take -506 +ign -507 +▁being -508 +▁many -509 +▁word -510 +▁found -511 +ons -512 +▁ret -513 +ase -514 +▁while -515 +▁ear -516 +▁att -517 +ory -518 +▁saw -519 +ix -520 +▁put -521 +oth -522 +ne -523 +▁ser -524 +▁peop -525 +iend -526 +▁wr -527 +ark -528 +▁young -529 +dy -530 +aking -531 +les -532 +▁la -533 +▁once -534 +ens -535 +▁count -536 +pect -537 +▁friend -538 +▁people -539 +ible -540 +ors -541 +▁mat -542 +fect -543 +ince -544 +▁room -545 +ered -546 +▁three -547 +▁yet -548 +ail -549 +▁same -550 +▁father -551 +▁right -552 +▁child -553 +igh -554 +▁cour -555 +▁another -556 +▁place -557 +ult -558 +iv -559 +▁though -560 +ition -561 +▁ind -562 +▁want -563 +▁nor -564 +▁far -565 +▁king -566 +▁end -567 +▁happ -568 +▁heart -569 +▁face -570 +▁ever -571 +▁nat -572 +get -573 +thing -574 +▁took -575 +▁hu -576 +▁love -577 +▁dist -578 +ew -579 +ever -580 +▁arm -581 +ian -582 +▁inst -583 +man -584 +▁work -585 +▁light -586 +▁set -587 +▁ple -588 +ict -589 +▁looked -590 +▁char -591 +▁missus -592 +▁ac -593 +▁mind -594 +▁inte -595 +▁rep -596 +▁asked -597 +▁supp -598 +cess -599 +▁yes -600 +ently -601 +▁left -602 +ertain -603 +gg -604 +▁ke -605 +ished -606 +▁pers -607 +▁things -608 +ub -609 +ways -610 +▁mom -611 +irl -612 +alk -613 +▁sir -614 +▁moment -615 +▁wa -616 +ations -617 +▁sat -618 +sel -619 +▁find -620 +ia -621 +ower -622 +rew -623 +▁world -624 +ject -625 +vent -626 +▁give -627 +▁gen -628 +▁cap -629 +so -630 +▁gu -631 +▁sw -632 +▁why -633 +lt -634 +ling -635 +▁always -636 +▁mother -637 +dd -638 +pped -639 +▁soon -640 +▁ans -641 +▁act -642 +▁form -643 +▁el -644 +▁heard -645 +der -646 +ret -647 +▁thing -648 +▁seemed -649 +▁something -650 +ange -651 +▁door -652 +▁sub -653 +▁girl -654 +ced -655 +ither -656 +▁appe -657 +▁wind -658 +▁mon -659 +▁dif -660 +▁because -661 +ss -662 +▁told -663 +▁going -664 +orm -665 +▁home -666 +▁war -667 +ained -668 +▁got -669 +aught -670 +▁gi -671 +▁god -672 +▁eng -673 +▁sur -674 +land -675 +ning -676 +▁hands -677 +▁woman -678 +aut -679 +▁vo -680 +▁poss -681 +▁follow -682 +▁feel -683 +ched -684 +▁rel -685 +ph -686 +ple -687 +ical -688 +▁return -689 +ook -690 +▁boy -691 +▁knew -692 +▁reg -693 +▁each -694 +ner -695 +▁rest -696 +▁kind -697 +▁ma -698 +▁exp -699 +▁cle -700 +iver -701 +▁oh -702 +▁hel -703 +▁sil -704 +ual -705 +▁water -706 +ting -707 +▁del -708 +▁ass -709 +▁inf -710 +▁wo -711 +▁bre -712 +▁certain -713 +▁against -714 +▁conf -715 +cept -716 +▁belie -717 +▁hard -718 +row -719 +▁unt -720 +▁years -721 +▁quite -722 +iness -723 +▁near -724 +▁ph -725 +ined -726 +▁side -727 +▁hor -728 +▁four -729 +ired -730 +ters -731 +ool -732 +▁few -733 +ier -734 +rest -735 +▁done -736 +most -737 +▁half -738 +▁che -739 +▁better -740 +ited -741 +▁tre -742 +▁min -743 +ock -744 +ps -745 +▁also -746 +uck -747 +▁care -748 +oub -749 +▁began -750 +ully -751 +ised -752 +▁having -753 +ru -754 +▁enough -755 +▁gener -756 +▁dra -757 +▁seen -758 +▁lady -759 +▁pur -760 +aps -761 +ott -762 +▁hum -763 +ross -764 +aken -765 +ying -766 +▁ter -767 +ank -768 +▁inde -769 +▁called -770 +▁hour -771 +ial -772 +ason -773 +▁beh -774 +▁does -775 +▁whole -776 +▁morn -777 +▁ste -778 +▁pleas -779 +▁turned -780 +ib -781 +▁ref -782 +ense -783 +▁ins -784 +ream -785 +▁occ -786 +▁course -787 +gether -788 +▁both -789 +▁gave -790 +uth -791 +▁cur -792 +▁sou -793 +een -794 +▁read -795 +▁add -796 +ween -797 +▁col -798 +selves -799 +▁between -800 +▁among -801 +ular -802 +▁beaut -803 +▁keep -804 +▁inc -805 +▁poor -806 +▁sure -807 +▁morning -808 +▁white -809 +ged -810 +▁dear -811 +▁name -812 +▁toward -813 +▁whom -814 +▁small -815 +▁sk -816 +▁repl -817 +▁lar -818 +ute -819 +▁felt -820 +osed -821 +bo -822 +ating -823 +▁open -824 +▁six -825 +▁myself -826 +ond -827 +▁however -828 +xt -829 +▁bu -830 +▁herself -831 +▁inter -832 +▁high -833 +aint -834 +▁fore -835 +ction -836 +▁stood -837 +▁hund -838 +▁tra -839 +▁hundred -840 +▁ev -841 +▁sent -842 +aster -843 +▁sim -844 +ife -845 +▁show -846 +▁round -847 +▁point -848 +▁almost -849 +▁days -850 +▁words -851 +vel -852 +▁gra -853 +ale -854 +▁dr -855 +▁gre -856 +▁eight -857 +ents -858 +dden -859 +ates -860 +▁bus -861 +▁fam -862 +ces -863 +▁land -864 +▁stand -865 +ung -866 +▁ed -867 +▁sun -868 +haps -869 +ird -870 +▁mean -871 +▁perhaps -872 +ned -873 +ures -874 +iet -875 +▁since -876 +▁sudden -877 +▁sle -878 +▁best -879 +▁dark -880 +iss -881 +▁replied -882 +▁voice -883 +▁bar -884 +▁met -885 +▁till -886 +▁anything -887 +▁until -888 +▁underst -889 +its -890 +▁black -891 +oud -892 +aring -893 +▁bro -894 +▁looking -895 +ins -896 +▁cried -897 +amp -898 +▁prin -899 +▁fact -900 +▁next -901 +▁less -902 +▁law -903 +▁lay -904 +up -905 +▁power -906 +▁prop -907 +▁brought -908 +not -909 +enty -910 +ately -911 +rent -912 +▁country -913 +▁help -914 +med -915 +▁vis -916 +▁sn -917 +als -918 +▁air -919 +▁quest -920 +▁together -921 +fully -922 +▁spo -923 +▁adv -924 +▁person -925 +▁need -926 +▁use -927 +▁indeed -928 +▁contin -929 +oney -930 +ows -931 +▁present -932 +▁gent -933 +▁par -934 +▁unc -935 +ured -936 +▁run -937 +▁full -938 +▁aw -939 +▁rather -940 +▁ide -941 +nded -942 +▁feet -943 +tain -944 +▁cond -945 +▁sy -946 +▁lat -947 +be -948 +▁fall -949 +du -950 +▁five -951 +eter -952 +▁har -953 +▁fin -954 +cei -955 +▁bed -956 +▁mil -957 +▁doct -958 +▁interest -959 +oc -960 +▁matter -961 +▁gone -962 +ressed -963 +▁lord -964 +▁wife -965 +▁pat -966 +▁es -967 +fort -968 +ering -969 +▁serv -970 +▁ -971 +e -972 +t -973 +a -974 +o -975 +n -976 +i -977 +h -978 +s -979 +r -980 +d -981 +l -982 +u -983 +m -984 +c -985 +w -986 +f -987 +g -988 +y -989 +p -990 +b -991 +v -992 +k -993 +' -994 +x -995 +j -996 +q -997 +z -998 diff --git a/examples/datasets/librispeech/sentencepiece/train_bpe_256.metadata.json b/examples/datasets/librispeech/sentencepiece/train_bpe_256.metadata.json new file mode 100644 index 0000000000..e102ed6b8b --- /dev/null +++ b/examples/datasets/librispeech/sentencepiece/train_bpe_256.metadata.json @@ -0,0 +1,12 @@ +{ + "train": { + "max_input_length": 475760, + "max_label_length": 270, + "num_entries": 281241 + }, + "eval": { + "max_input_length": 562480, + "max_label_length": 260, + "num_entries": 5567 + } +} \ No newline at end of file diff --git a/examples/datasets/librispeech/sentencepiece/train_bpe_256.model b/examples/datasets/librispeech/sentencepiece/train_bpe_256.model new file mode 100644 index 0000000000..5fb573af3c Binary files /dev/null and b/examples/datasets/librispeech/sentencepiece/train_bpe_256.model differ diff --git a/examples/datasets/librispeech/sentencepiece/train_bpe_256.vocab b/examples/datasets/librispeech/sentencepiece/train_bpe_256.vocab new file mode 100644 index 0000000000..b0dab472e3 --- /dev/null +++ b/examples/datasets/librispeech/sentencepiece/train_bpe_256.vocab @@ -0,0 +1,256 @@ + 0 +▁t -0 +he -1 +▁a -2 +▁the -3 +in -4 +▁s -5 +▁w -6 +▁o -7 +re -8 +nd -9 +▁b -10 +▁h -11 +er -12 +▁m -13 +▁i -14 +ou -15 +▁c -16 +▁f -17 +at -18 +ed -19 +▁and -20 +en -21 +▁to -22 +▁of -23 +on -24 +is -25 +▁d -26 +ing -27 +▁th -28 +▁p -29 +▁he -30 +or -31 +▁l -32 +es -33 +▁in -34 +ll -35 +it -36 +ar -37 +as -38 +an -39 +▁n -40 +▁g -41 +om -42 +▁be -43 +▁ha -44 +▁e -45 +le -46 +ot -47 +▁y -48 +ut -49 +ow -50 +ic -51 +▁wh -52 +▁it -53 +ld -54 +ve -55 +▁that -56 +ly -57 +▁was -58 +id -59 +se -60 +st -61 +▁on -62 +gh -63 +ent -64 +▁re -65 +▁you -66 +im -67 +ce -68 +▁u -69 +ver -70 +ion -71 +▁as -72 +et -73 +▁for -74 +ay -75 +▁we -76 +▁his -77 +ith -78 +al -79 +ir -80 +▁r -81 +▁with -82 +▁st -83 +ad -84 +ur -85 +ght -86 +▁an -87 +▁her -88 +▁not -89 +▁had -90 +▁is -91 +ter -92 +her -93 +ac -94 +am -95 +▁at -96 +oo -97 +▁but -98 +ould -99 +▁she -100 +▁k -101 +▁se -102 +▁sa -103 +▁sh -104 +▁fr -105 +▁him -106 +▁so -107 +ill -108 +▁me -109 +ain -110 +▁su -111 +ight -112 +ch -113 +red -114 +ct -115 +all -116 +ro -117 +ke -118 +ess -119 +il -120 +ore -121 +▁de -122 +▁they -123 +▁my -124 +▁whe -125 +▁all -126 +ich -127 +▁ne -128 +ri -129 +▁by -130 +▁have -131 +ome -132 +pp -133 +▁this -134 +▁li -135 +▁do -136 +▁con -137 +us -138 +▁which -139 +▁ch -140 +ul -141 +qu -142 +▁j -143 +▁up -144 +▁said -145 +▁from -146 +ard -147 +ge -148 +▁or -149 +▁v -150 +▁one -151 +th -152 +▁no -153 +▁ex -154 +▁were -155 +▁there -156 +pe -157 +and -158 +est -159 +▁man -160 +▁who -161 +ble -162 +ant -163 +ie -164 +▁al -165 +res -166 +ous -167 +ust -168 +very -169 +ation -170 +▁fe -171 +▁them -172 +lf -173 +▁when -174 +ind -175 +nt -176 +ame -177 +ra -178 +▁go -179 +ers -180 +ast -181 +fe -182 +ood -183 +▁kn -184 +▁int -185 +ist -186 +art -187 +▁are -188 +out -189 +▁would -190 +▁le -191 +os -192 +▁their -193 +ong -194 +▁what -195 +our -196 +▁if -197 +ound -198 +▁com -199 +▁ab -200 +▁out -201 +▁wor -202 +em -203 +▁will -204 +ak -205 +▁mis -206 +ate -207 +ol -208 +um -209 +un -210 +itt -211 +ough -212 +ked -213 +ap -214 +ig -215 +one -216 +▁been -217 +own -218 +ive -219 +▁then -220 +▁br -221 +ven -222 +if -223 +▁ar -224 +▁tr -225 +self -226 +▁ -227 +e -228 +t -229 +a -230 +o -231 +n -232 +i -233 +h -234 +s -235 +r -236 +d -237 +l -238 +u -239 +m -240 +c -241 +w -242 +f -243 +g -244 +y -245 +p -246 +b -247 +v -248 +k -249 +' -250 +x -251 +j -252 +q -253 +z -254 diff --git a/vocabularies/librispeech/subwords/train_1030_4.metadata.json b/examples/datasets/librispeech/subwords/train_1030_4.metadata.json similarity index 100% rename from vocabularies/librispeech/subwords/train_1030_4.metadata.json rename to examples/datasets/librispeech/subwords/train_1030_4.metadata.json diff --git a/vocabularies/librispeech/subwords/train_1030_4.subwords b/examples/datasets/librispeech/subwords/train_1030_4.subwords similarity index 100% rename from vocabularies/librispeech/subwords/train_1030_4.subwords rename to examples/datasets/librispeech/subwords/train_1030_4.subwords diff --git a/vocabularies/librispeech/wordpiece/train_1000_50.metadata.json b/examples/datasets/librispeech/wordpiece/train_1000.metadata.json similarity index 69% rename from vocabularies/librispeech/wordpiece/train_1000_50.metadata.json rename to examples/datasets/librispeech/wordpiece/train_1000.metadata.json index e8b13c28c9..db00829752 100644 --- a/vocabularies/librispeech/wordpiece/train_1000_50.metadata.json +++ b/examples/datasets/librispeech/wordpiece/train_1000.metadata.json @@ -1,11 +1,11 @@ { "train": { - "max_input_length": 2972, + "max_input_length": 475760, "max_label_length": 202, "num_entries": 281241 }, "eval": { - "max_input_length": 3514, + "max_input_length": 562480, "max_label_length": 190, "num_entries": 5567 } diff --git a/vocabularies/librispeech/wordpiece/train_1000_50.tokens b/examples/datasets/librispeech/wordpiece/train_1000.vocab similarity index 99% rename from vocabularies/librispeech/wordpiece/train_1000_50.tokens rename to examples/datasets/librispeech/wordpiece/train_1000.vocab index 3285502ebe..bdd7925b66 100644 --- a/vocabularies/librispeech/wordpiece/train_1000_50.tokens +++ b/examples/datasets/librispeech/wordpiece/train_1000.vocab @@ -1,4 +1,3 @@ - ' a diff --git a/examples/datasets/librispeech/wordpiece/train_1000_whitespace.metadata.json b/examples/datasets/librispeech/wordpiece/train_1000_whitespace.metadata.json new file mode 100644 index 0000000000..0df2400bde --- /dev/null +++ b/examples/datasets/librispeech/wordpiece/train_1000_whitespace.metadata.json @@ -0,0 +1,12 @@ +{ + "train": { + "max_input_length": 475760, + "max_label_length": 281, + "num_entries": 281241 + }, + "eval": { + "max_input_length": 562480, + "max_label_length": 279, + "num_entries": 5567 + } +} \ No newline at end of file diff --git a/vocabularies/librispeech/wordpiece/train_100h_1000_50.tokens b/examples/datasets/librispeech/wordpiece/train_1000_whitespace.vocab similarity index 91% rename from vocabularies/librispeech/wordpiece/train_100h_1000_50.tokens rename to examples/datasets/librispeech/wordpiece/train_1000_whitespace.vocab index 6bec0af51a..1a6c23d274 100644 --- a/vocabularies/librispeech/wordpiece/train_100h_1000_50.tokens +++ b/examples/datasets/librispeech/wordpiece/train_1000_whitespace.vocab @@ -1,4 +1,4 @@ -[PAD] + ' a @@ -34,956 +34,963 @@ to ##s in he -was that +was it ##ing ##ed his -had -as -with you +with +as for -##ly +had +is her +##ly but -is not she -##y +##d at on be -##d +##y him -##e they -by -##er +##e have +by this -my -were -which all +##er +which +my +said from ##t so -said one +were me we there ##n -their no +or when +are +their an -or them would -##a -##al -who if +##a what -are +who +##al +will been ##r -up -then out -##ation +then +up +##ion +do could -will -##h -into -more man +more +into ##or -##ion -##es -some -very -do -##ate now +very +##ation +##h +your +some little +##ate ##on -about -your -did -than -##ity -##l time -##ies -##le +##ity +##es ##en -like ##in +about +like +than +##le +did upon +can +only +has +any +##k +##m ##ry well -##m -has -only -##o -other +##ies +##l +##able two -##ment -any -##nt ##ine -can -our -made -after -##ure -##ance +other +##ment +see before -##ant -its +##it ##an +##o +its +good down +##ch over -##able -such -old -see -these -##ness -##ful -##k know -day -##ic -came -##ge +##ance +made +##ant +our +after +should +##ar +##ness great -us -##th -mister -good +old +such ##ers -should +##ful +came +must +how +day +##nt +##g +never +these +come +##ure much -##ated +##at +mister +go +us +##ic ##age -how -##it -##ar +##ary ##est -way -never ##ive -##ow -must -##ary -##ter -come -##ch -##ous +##th where -##at -again -back +##ated +may first -##ain -go +##ous +way +again +here himself -##ast +went ##ce -men -long -own -may +##ain +##ise ##nce -##ight -even -went +##ow +long +back say -just -here -might +men +own +##st +am +think ##ious -##ise -##g -through -##ted -eyes -##ood -make +thought too +away +might +just +even +##et +##p +##ge without -think -house -those -thought +##ice +through life -many -every +make +##ight most -away -##et -##st -##ace -being +##ie +every +those +##ter +##ted +eyes don -am -##ard +shall +##id +##ast still nothing +take +being +##ard +many +hand while -##ie -##ice -##ty -##p -people last -though -young -##id -yet -found -three -##led +house +once off -##ish -hand -##ily -get +saw night -take -##ied +let +##ward +people ##less -asked -##rs -saw -##ling +##all +##ish +three +yet +found +##ne same -missus -##ars ##ay -##ible head -##ot -##ne +get +##ace +##ily another -right -##ent -left -once -tell -shall place -ever -took -face -##ned -seemed -always +##us +though +##ot ##ction -room -new -under +##ood +face +tell +took father -##ick +##ible +young ##ct -##ting -##are -why -told -let +room +##ling +##ty +ever looked -heard +##ied +##ve +missus put -##ual -##ring -##ian -##ia -because -##cy -##ward -things +##are +##ll +##ent +##ick +asked +under +##is +##ning +##rs +left +right ##ide -something ##b -##ake -##les -going -##ill -##ts -##ions -give -mother -##ning -##ence +things ll +##ered +sir +##f look -##i -##ations -love -mind -thing -##all -##ass -soon -##red -##ue -##ress +##led +##ia +##man +give king -##ite -knew -heart -##land -##ical -each +why +always +heard world -##ged -against -far -moment -having -woman -##aid -##is -few -##ore -began -miss +thing +seemed +something +##ions +new +##red +##cy +##row +##ence +because yes +mother +##ting +told +going +##ian +mind +##ad +love +##der +##se +##ite +got +##ual +##ake +##air +##ress ##rew -##row door -##ined -##ist -better -##ared -##oke +##aid +woman home +far +##ore +knew +soon +each +##ned +##ts +moment +##les +##ical work -##us -##der -##man +heart +##ened +##ory +against +oh +##am +##ue +##ave years -enough -sir -##f -##ded -##te -##air -##ered +quite +##ist +##ire +##ined +##as +find done -got -##ort -##iving -seen -side ##ath -##ire +better +##ill +few +##land +##te +side +##ase +also ##ors +began +water +##ile +half ##and -##ial -##ome +having +enough +##ised +##our +seen +##ars +##re +##ped +##i called -whole -##old -between -morning ##w -##ll -felt -girl -##ile -##as -##ving -##ad -##re -##ised -herself -find -##ried -##il -##se -however +whole +god +hands +part turned -##ired -white -also -half -perhaps -##ail -##el +##ged +lady +##ale +course +gave ##gs -##ms -replied -water +both +between +##ried +light +##ame ##ced -hundred -##ained -quite -##ase -##x -myself -oh -part -##ating -##our -##ve -course +morning +##ded poor -voice -##ave -##ched -##one -both -name -##ale -##tion -gave -hands +##ief +##ial whom -days -almost -among +##iving set -##ally -##ld -##ief +##ained +##ired +felt +##rt +##op +myself +however +##ared +miss +herself +girl +##hip +##ations +##ort +stood +white ##side -together +dear +##ms +##il +almost +##ect +days +##ks +name words -##ank -##ition -##ut -until -##alk -##im -##ose -##ame +nor ##ft -##per -##op -##un +##ched +##el +perhaps +##om +##ount +##ring +##ail +among +replied +voice ##ton -##ience -##ead +##per anything -feet -##ished -next +end +##pt want +round +until +##one +looking +cried +##ating +##ead +does four -stood -##ped -##ntly +hundred +next +##fully +till +##ble brought -light +##ut +##ings +small +##ished +##ose +together +##ition +boy +indeed +since +##ass +##tion +rather +##ind +##ute +##rown +##ack ##ited -##uck -best +feet +ve ##cted -others five -##oat -##uth -##ressed -##om -##ond -##hip -##rown -##fully -##ool -looking -##aw -##ff -##ap -##am -##pt -nor -small -along -##hy +##ience +word +friend +country ##ually -near -rather -##ble -since -believe -money -does -passed -##ened +##ier +best +##een +##ntly +gone +matter +wife +##ap +##ressed +##ments +##ome +full +##ur +sure +taken +##ding +others +##ach +##low +cannot +##ally +##im ##ient +themselves +sat +##ple +kind +death ##orn +##rty +##ply +##oat +air +thus +##ool lay -##ind +##ock +##x +along +child +near +##ever +##ther +##um +high +##ree open -end -indeed -round -##low +##ust +behind +children +true +answered +whose +##uth ##ular -##ect -kind -full -##by -##ments -##ures +money twenty -##ount -cried -taken -##ls -matter -sure -##oon -##ering -##ings -word +##ult +##ss +##ey +##uck +thou +believe +large ##ased -##ier -##ities -##ory -##ree -country -##ured -dear +keep +passed ##ak -gone -child -god -ve -whose -##ade -##ple -##ply -answered -##ever -##ur -##ute -high -less -uncle -##ably -##illed -themselves -##mer -sat -##ins +##den +black nature +##end +##tle +##ins +##ering +alone +##cle +##de +##ord +##ping ##ants -air -##ized -black -behind -john -true -wife -##ived -##ther -power -##ult -lady -##ding -##um -death -around -boy -##bed ##bject -certain -during -women +##ured +##ities +##ves +doctor +less +##ging +##oke +##arry +power +##oon +leave +##illed +lord +##ably +##ond +##als +##ank +##ost ##body -##osition +fire +given +##ess +hope +##ough +certain +therefore +##mber +body +help +##atch +##ls rest -children -keep -##uty -##ians -often +sea +hear ##opped -till -##ost -already -hope -thus -##ane -present +##oy +speak +##ds +##by +often +##bed ##ound -##overed -doctor -large -became -returned -##ack +present wish -cannot -didn -##atch -##arry +really +##sion ##ates -##ey -##ks -##tle -body -general -master -##ange -land -sent -##ach -alone -given -leave -re ##asure -case -short -##ances -everything -##cing -##med -##den -##ock +##ceived +##gth +fell +fact +##art +##osition +thousand +dead +##ans +son coming -says -##mber -##sion +everything +##overed +##ked ##ying -friend -##fied -held -really -##amp +##uty +##ving +hour +##ven above -fell -ground -##ess -##ided -##ook -fire -help -##alked -##ans -city -thousand -##gth -##iness -##ip -##ust -evening -##oy -speak -##read -fact -order -state -##aring -hear -sometimes -##oved -##tered -##ved -##owed -##ps -##men -##be +##ner +##its ##pect -##ping -kept -##ns -within -##als -##ape -hour -point -friends -##ign -##ves -least -##ister -##ny +master +already +##ized +##ff +itself +used ##ten -family -##lish -captain -care +least +suddenly +re +##ances +sent +##ook +around +order +##ird +##ither +##ided +##dle +evening +ground +##iness use -beautiful -making +##fied +case +known +##dy +##ld +prince +##ign +##men +##ell whether -##day -##ines -##ration -sight -##art +##aring +within +##ains +##read times -##alking +general +##oot +friends +women +##ps +##tered +hard +##cing +returned bed -suddenly -six +reason +held ##oud -dead -call +state +point +during +care +beautiful +says +##ade +earth either -itself -thou +##leep +lost +##owed +##ister +kept +captain +means +red +six +dark +horse +land +became +second +table +mean +##ches +making +##oss ##ray -##ung -son -##iles -able +read towards -used -##ains -##ord -dark +##dom +##ver manner -##gue -##iety -##ul -mean -several -##ution -appeared -lost -town -##ages -story -##ell -##orm -lord -##ner +ask +##day +##ung +short +##go ##aught -##eep -big -known -possible -##ke -sea +##airs +call +##ians +##ream +brother ##ark -##ck -##ds -##ird -##vice -fine -means +##here +##mon +##son +feel +city +##old +sometimes +##arly +##iety +##ved +##rain +question +##dded +ready +##ny +##ape +##ause +##ages +##rm +##lay +##ows +else +##unt +##ration +live +close +##ip +##dge +##ution +didn +sight +##ility +won person -continued -second -strange -##aint -##ild -##ven -##ley -red -##ces -##lf -##nts -human +##amp +##cept +##ines +sun +##lled ##uble -hard +answer +business +family +arms +##bly +ten +turn year -##ility -##oss -street -##oken -feel -reached -##end -##go -##rm -##ver -close -hair -##ither -##pose -question +##ange +fear +##lain +become +##light +soul +##un +possible +##aint +letter +several ##ative -##my -##teen -arms +##king +##iles rose -##anding +##ong +##oved +idea ##iously -##rit -##ugh -##urn -##airs -become -followed -longer -therefore -won -##cious -##ern -##iled -##ows -##ught -##ging -##ham -##its +##itting +##ote +able +big +##lass ##itted -##ream -##rty -##tely -business -understand -##ination -##rt -##vil -brother -happy sort -sound -table +town +across +strange +##ctly +story +##ury +##ces +##hy +##ild +##lish +##ater +strong +##over +form +ran +window +##ane ##ather -##dle -##ming -##ored +##ns +##lic +truth +##be +longer +fine +followed +understand +##hes +##ger +##pose +##ately +##iled +bring +##ink +##cious +thy +##mer +suppose +##gue +##lood ##ture -##umber -across different +doubt +taking +hair +happy +##right +##ears +##ature +daughter +thee +spoke +##my +human +cold +##rying +##race +law +return +##vil +reached +##utes +appeared +##eal +##ield +##ination +##teen +pretty +##ait +opened +##orm +##ize +##tely +##une +##lation +##vice +##rove +##eck +sound +##ear +saying +##self +##istance +met +##rit +##antly +##ets +##reen ##ising -##itting -##pected -##rain -else -fellow -live -##anged -##ense -##ger -##ining -##way -certainly -turn -##arted -##ately -##rable +##tention +river +##ope +john +##rance +##anding +road +##ived +##aw +talk war -window -##bs -##light -blue -earth -met +continued need -tree -##lay -##lled -##mon -##ush -wanted -##ches -##cial -##sing -peter -##ause -##tention -##une -prince -read -toward -##bly -##cept -##ents -##ep -##used -green -##ong -ready -reason -##arriage -##itude -ask +##anged +##lly +##u carried -daughter -ought -princess -##rove -feeling -idea -lived -ten -##lad -answer -##aged -##ang run -##ait -##dom -##rying -##ubt -church -cold -eighteen -fear -later -##ctly -##lic -##ries -##uit -strong -##any -##ek -##ote -##ror -added -although +##ox +##ush +eye +certainly +##ures +wanted husband -party -river -##ism +##lue +##hed show -##ber -##ius -##u -##ury -eye -##out -##ove -##ters -##val -foot -girls -suppose -taking -##ize -##od -##utes -trees -##ailed +tree +##avy ##pression -book -##erved -##ingly -##lain -##ushed -entered -pretty +hold +past +##ries ##lock -##rance -##son -fall -hours -road -sleep -##oad -##tory -##where -clear -received -##aimed -##istance -##ists -horse +##pected +##arriage +##lad +##aving +##cial +street +##sing +##umber low +##ips +##ural +##ingly +##ret +fellow +sense +##erved +##arted +##owing +##ike +deep +##ah +##bs +bad +feeling +##ledge +ought +##itude +book +##laimed +##ney +##ley +##ont +cut +##ining +##cent ## ##' ##c diff --git a/examples/datasets/librispeech/wordpiece/wp.yml.j2 b/examples/datasets/librispeech/wordpiece/wp.yml.j2 new file mode 100644 index 0000000000..5873d5d464 --- /dev/null +++ b/examples/datasets/librispeech/wordpiece/wp.yml.j2 @@ -0,0 +1,24 @@ +{% set vocabsize = 1000 %} +{% set vocabprefix = repodir ~ "/examples/datasets/librispeech/wordpiece/train_" ~ vocabsize %} +{% set metadata = vocabprefix ~ ".metadata.json" %} + +decoder_config: + type: wordpiece + blank_index: 0 + unknown_token: "" + unknown_index: 0 + beam_width: 0 + norm_score: True + lm_config: null + vocabulary: {{vocabprefix}}.vocab + keep_whitespace: False + vocab_size: {{vocabsize}} + max_token_length: 50 + max_unique_chars: 1000 + reserved_tokens: + - "" + normalization_form: NFKC + num_iterations: 4 + +{% import "examples/datasets/librispeech/config.yml.j2" as data_config with context %} +{{data_config}} \ No newline at end of file diff --git a/examples/datasets/librispeech/wordpiece/wp_whitespace.yml.j2 b/examples/datasets/librispeech/wordpiece/wp_whitespace.yml.j2 new file mode 100644 index 0000000000..7cc0f34ef1 --- /dev/null +++ b/examples/datasets/librispeech/wordpiece/wp_whitespace.yml.j2 @@ -0,0 +1,24 @@ +{% set vocabsize = 1000 %} +{% set vocabprefix = repodir ~ "/examples/datasets/librispeech/wordpiece/train_" ~ vocabsize ~ "_whitespace" %} +{% set metadata = vocabprefix ~ ".metadata.json" %} + +decoder_config: + type: wordpiece + blank_index: 0 + unknown_token: "" + unknown_index: 0 + beam_width: 0 + norm_score: True + lm_config: null + vocabulary: {{vocabprefix}}.vocab + keep_whitespace: True + vocab_size: {{vocabsize}} + max_token_length: 50 + max_unique_chars: 1000 + reserved_tokens: + - "" + normalization_form: NFKC + num_iterations: 4 + +{% import "examples/datasets/librispeech/config.yml.j2" as data_config with context %} +{{data_config}} \ No newline at end of file diff --git a/examples/datasets/vietbud500/config.yml.j2 b/examples/datasets/vietbud500/config.yml.j2 new file mode 100644 index 0000000000..8724d76383 --- /dev/null +++ b/examples/datasets/vietbud500/config.yml.j2 @@ -0,0 +1,44 @@ +data_config: + train_dataset_config: + enabled: True + sample_rate: 16000 + data_paths: + - {{datadir}}/train/transcripts.tsv + tfrecords_dir: {{datadir}}/tfrecords + tfrecords_shards: 32 + shuffle: True + cache: False + buffer_size: 1024 + drop_remainder: True + stage: train + metadata: {{metadata}} + indefinite: True + + eval_dataset_config: + enabled: True + sample_rate: 16000 + data_paths: + - {{datadir}}/validation/transcripts.tsv + tfrecords_dir: {{datadir}}/tfrecords + buffer_size: 1024 + tfrecords_shards: 2 + shuffle: True + cache: False + drop_remainder: True + stage: eval + metadata: {{metadata}} + indefinite: True + + test_dataset_configs: + - name: test + enabled: True + sample_rate: 16000 + data_paths: + - {{datadir}}/test/transcripts.tsv + tfrecords_dir: {{datadir}}/tfrecords + shuffle: False + cache: False + buffer_size: null + drop_remainder: False + stage: test + indefinite: False \ No newline at end of file diff --git a/examples/datasets/vietbud500/download.py b/examples/datasets/vietbud500/download.py new file mode 100644 index 0000000000..8e24d1f3d2 --- /dev/null +++ b/examples/datasets/vietbud500/download.py @@ -0,0 +1,52 @@ +import os + +import datasets +import librosa +import soundfile +from tqdm import tqdm + +from tensorflow_asr.utils import cli_util, data_util + +MAPPING = { + "audio.array": "audio", + "audio.sampling_rate": "sample_rate", + "transcription": "transcript", +} + + +def load_item_from_mapping(item): + data = {} + for path, key in MAPPING.items(): + data[key] = data_util.get(item, path) + if not all(x in data for x in ["audio", "transcript"]): + return None + return data["audio"], int(data["sample_rate"]), str(data["transcript"]) + + +def main( + directory: str, + token: str, +): + dataset_list = datasets.load_dataset("linhtran92/viet_bud500", token=token, streaming=True, keep_in_memory=False) + for stage in dataset_list.keys(): + print(f"[Loading {stage}]") + output = os.path.realpath(os.path.join(directory, stage, "audio")) + tsv_output = os.path.realpath(os.path.join(directory, stage, "transcripts.tsv")) + os.makedirs(output, exist_ok=True) + with open(tsv_output, "w", encoding="utf-8") as out: + out.write("PATH\tDURATION\tTRANSCRIPT\n") + index = 1 + for item in tqdm(dataset_list[stage], desc=f"[Loading to {output}]", disable=False): + data = load_item_from_mapping(item) + if data is None: + continue + audio, sample_rate, transcript = data + path = os.path.join(output, f"{index}.wav") + soundfile.write(path, audio, sample_rate) + duration = librosa.get_duration(y=audio, sr=sample_rate) + out.write(f"{path}\t{duration}\t{transcript}\n") + index += 1 + + +if __name__ == "__main__": + cli_util.run(main) diff --git a/examples/datasets/vietbud500/sentencepiece/sp.256.yml.j2 b/examples/datasets/vietbud500/sentencepiece/sp.256.yml.j2 new file mode 100644 index 0000000000..f59abd2318 --- /dev/null +++ b/examples/datasets/vietbud500/sentencepiece/sp.256.yml.j2 @@ -0,0 +1,30 @@ +{% set vocabsize = 256 %} +{% set vocabprefix = repodir ~ "/examples/datasets/vietbud500/sentencepiece/train_bpe_" ~ vocabsize %} +{% set metadata = vocabprefix ~ ".metadata.json" %} + +decoder_config: + type: sentencepiece + blank_index: 0 + unknown_token: "" + unknown_index: 0 + pad_token: "" + pad_index: -1 + bos_token: "" + bos_index: -1 + eos_token: "" + eos_index: -1 + beam_width: 0 + norm_score: True + lm_config: null + model_type: bpe + vocabulary: {{vocabprefix}}.model + vocab_size: {{vocabsize}} + reserved_tokens: null + normalization_form: NFKC + max_sentencepiece_length: 16 + max_sentence_length: 1048576 + character_coverage: 1.0 + keep_whitespace: False + +{% import "examples/datasets/vietbud500/config.yml.j2" as data_config with context %} +{{data_config}} \ No newline at end of file diff --git a/examples/datasets/vietbud500/sentencepiece/sp.yml.j2 b/examples/datasets/vietbud500/sentencepiece/sp.yml.j2 new file mode 100644 index 0000000000..c57225f786 --- /dev/null +++ b/examples/datasets/vietbud500/sentencepiece/sp.yml.j2 @@ -0,0 +1,30 @@ +{% set vocabsize = 1000 %} +{% set vocabprefix = repodir ~ "/examples/datasets/vietbud500/sentencepiece/train_bpe_" ~ vocabsize %} +{% set metadata = vocabprefix ~ ".metadata.json" %} + +decoder_config: + type: sentencepiece + blank_index: 0 + unknown_token: "" + unknown_index: 0 + pad_token: "" + pad_index: -1 + bos_token: "" + bos_index: -1 + eos_token: "" + eos_index: -1 + beam_width: 0 + norm_score: True + lm_config: null + model_type: bpe + vocabulary: {{vocabprefix}}.model + vocab_size: {{vocabsize}} + reserved_tokens: null + normalization_form: NFKC + max_sentencepiece_length: 16 + max_sentence_length: 1048576 + character_coverage: 1.0 + keep_whitespace: False + +{% import "examples/datasets/vietbud500/config.yml.j2" as data_config with context %} +{{data_config}} \ No newline at end of file diff --git a/examples/datasets/vietbud500/sentencepiece/train_bpe_1000.metadata.json b/examples/datasets/vietbud500/sentencepiece/train_bpe_1000.metadata.json new file mode 100644 index 0000000000..d790faa2f5 --- /dev/null +++ b/examples/datasets/vietbud500/sentencepiece/train_bpe_1000.metadata.json @@ -0,0 +1,12 @@ +{ + "train": { + "max_input_length": 256498, + "max_label_length": 75, + "num_entries": 634158 + }, + "eval": { + "max_input_length": 117571, + "max_label_length": 42, + "num_entries": 7500 + } +} \ No newline at end of file diff --git a/vocabularies/librispeech/sentencepiece/train_uni_1000.model b/examples/datasets/vietbud500/sentencepiece/train_bpe_1000.model similarity index 94% rename from vocabularies/librispeech/sentencepiece/train_uni_1000.model rename to examples/datasets/vietbud500/sentencepiece/train_bpe_1000.model index 18f11beaa3..14ed518b83 100644 Binary files a/vocabularies/librispeech/sentencepiece/train_uni_1000.model and b/examples/datasets/vietbud500/sentencepiece/train_bpe_1000.model differ diff --git a/examples/datasets/vietbud500/sentencepiece/train_bpe_1000.vocab b/examples/datasets/vietbud500/sentencepiece/train_bpe_1000.vocab new file mode 100644 index 0000000000..872e7378b7 --- /dev/null +++ b/examples/datasets/vietbud500/sentencepiece/train_bpe_1000.vocab @@ -0,0 +1,1000 @@ + 0 +▁c -0 +ng -1 +▁t -2 +nh -3 +▁đ -4 +▁m -5 +▁l -6 +▁th -7 +▁v -8 +▁ch -9 +▁b -10 +▁nh -11 +▁k -12 +▁n -13 +▁h -14 +▁kh -15 +▁ng -16 +▁s -17 +▁g -18 +▁là -19 +ông -20 +▁tr -21 +▁r -22 +▁không -23 +ời -24 +▁p -25 +▁ph -26 +▁cá -27 +▁có -28 +ên -29 +▁d -30 +ôi -31 +ình -32 +▁gi -33 +anh -34 +qu -35 +▁qu -36 +▁và -37 +ột -38 +ới -39 +▁củ -40 +▁của -41 +iế -42 +ười -43 +▁như -44 +▁một -45 +▁tôi -46 +▁nó -47 +▁mà -48 +▁người -49 +iệ -50 +▁x -51 +▁anh -52 +▁đư -53 +ại -54 +ất -55 +ấy -56 +▁nà -57 +▁mình -58 +▁đi -59 +▁thì -60 +▁cái -61 +ợc -62 +em -63 +▁được -64 +ay -65 +▁cũ -66 +uy -67 +▁co -68 +▁cũng -69 +ững -70 +ong -71 +▁những -72 +▁cho -73 +▁con -74 +ai -75 +ải -76 +▁em -77 +▁ngh -78 +▁cả -79 +ều -80 +▁đó -81 +▁cô -82 +ồi -83 +▁lại -84 +▁với -85 +ch -86 +ao -87 +ân -88 +▁này -89 +▁đã -90 +▁trong -91 +ần -92 +uố -93 +▁để -94 +▁làm -95 +▁nói -96 +▁ta -97 +ạn -98 +▁phải -99 +▁ra -100 +ây -101 +▁chú -102 +▁nhưng -103 +ướ -104 +ang -105 +au -106 +▁rồi -107 +▁sẽ -108 +âu -109 +ến -110 +▁về -111 +▁nhi -112 +iết -113 +an -114 +ác -115 +▁khi -116 +òn -117 +▁ti -118 +▁gì -119 +▁thế -120 +▁bạn -121 +ước -122 +▁ở -123 +▁họ -124 +▁đến -125 +▁còn -126 +▁thể -127 +▁các -128 +ết -129 +▁mẹ -130 +▁việ -131 +ươ -132 +ật -133 +▁ông -134 +▁biết -135 +úc -136 +▁nhà -137 +▁chúng -138 +ương -139 +ận -140 +oà -141 +▁chị -142 +ành -143 +▁bà -144 +ơn -145 +▁thấy -146 +▁từ -147 +ầu -148 +ậy -149 +▁chuy -150 +▁nào -151 +ăn -152 +▁chỉ -153 +ờng -154 +ữa -155 +▁rất -156 +▁sự -157 +ồng -158 +▁nhiều -159 +ùng -160 +inh -161 +ền -162 +▁số -163 +▁vậy -164 +▁vào -165 +ện -166 +ức -167 +ằng -168 +▁sao -169 +êu -170 +▁việc -171 +ính -172 +ẫn -173 +ọi -174 +iể -175 +ái -176 +▁nên -177 +iện -178 +▁y -179 +ực -180 +▁vì -181 +uốn -182 +ối -183 +ản -184 +àng -185 +▁hai -186 +▁tiế -187 +▁chuyện -188 +▁giờ -189 +▁đây -190 +▁cảm -191 +ắt -192 +ộc -193 +ôn -194 +▁muốn -195 +òng -196 +▁đầu -197 +▁ngà -198 +▁nghĩ -199 +ơi -200 +iên -201 +▁sau -202 +▁công -203 +ưa -204 +▁ấy -205 +ặt -206 +▁đang -207 +▁ngày -208 +▁nữa -209 +ường -210 +▁mới -211 +iến -212 +▁lúc -213 +âm -214 +▁vẫn -215 +▁lên -216 +▁bị -217 +ìn -218 +ội -219 +ếu -220 +ỏi -221 +▁nhìn -222 +uộc -223 +▁hơn -224 +ắc -225 +ung -226 +ổi -227 +▁cứ -228 +át -229 +áng -230 +▁thành -231 +▁cùng -232 +▁hay -233 +▁sống -234 +iểu -235 +▁quá -236 +▁đâu -237 +▁tình -238 +eo -239 +▁ho -240 +▁trước -241 +▁nếu -242 +ậu -243 +▁điều -244 +▁ăn -245 +▁thôi -246 +▁nhất -247 +▁độ -248 +▁ai -249 +ăm -250 +ạnh -251 +▁mọi -252 +▁lu -253 +▁yêu -254 +ích -255 +▁trên -256 +▁khác -257 +ôm -258 +ắm -259 +▁qua -260 +▁nhau -261 +ánh -262 +▁nhận -263 +▁nghe -264 +▁quan -265 +ẳng -266 +ục -267 +▁đị -268 +▁cuộc -269 +ảo -270 +▁theo -271 +▁tới -272 +▁chưa -273 +▁chính -274 +ăng -275 +▁học -276 +▁thật -277 +▁cách -278 +ừng -279 +ập -280 +▁thời -281 +ứng -282 +êm -283 +ỗi -284 +oàn -285 +▁mặt -286 +áo -287 +▁tư -288 +▁tâm -289 +án -290 +ấn -291 +▁thân -292 +▁cậu -293 +▁đấy -294 +ặc -295 +▁cần -296 +▁tự -297 +ắn -298 +ách -299 +ừa -300 +uống -301 +▁bên -302 +▁tại -303 +úng -304 +ọng -305 +▁khó -306 +▁tay -307 +▁quy -308 +▁ý -309 +ài -310 +▁lời -311 +▁nước -312 +▁chồng -313 +▁câu -314 +▁thứ -315 +▁tố -316 +▁hôm -317 +▁nay -318 +▁lần -319 +▁hiện -320 +▁lo -321 +am -322 +ửa -323 +▁rằng -324 +ìm -325 +▁hết -326 +▁luôn -327 +▁định -328 +▁lớ -329 +▁cơ -330 +ắng -331 +▁tiền -332 +oài -333 +▁bao -334 +▁gia -335 +▁hỏi -336 +▁bản -337 +▁chi -338 +▁động -339 +▁hiểu -340 +▁bố -341 +▁đồng -342 +▁nhân -343 +ởng -344 +▁bác -345 +▁chẳng -346 +▁vợ -347 +▁tin -348 +▁thương -349 +▁tìm -350 +▁thực -351 +ượ -352 +▁chứ -353 +ám -354 +ọn -355 +▁đúng -356 +▁bây -357 +▁tu -358 +ặp -359 +▁mắt -360 +▁gái -361 +▁vừa -362 +ồn -363 +áu -364 +▁thu -365 +ạo -366 +ầm -367 +▁tiếng -368 +ốc -369 +▁ba -370 +▁đời -371 +▁ngoài -372 +▁lấy -373 +▁do -374 +▁hàng -375 +▁bảo -376 +▁lớn -377 +▁tiếp -378 +ấu -379 +ống -380 +▁lòng -381 +ấp -382 +iệt -383 +▁trở -384 +ởi -385 +▁lý -386 +▁gặp -387 +▁gọi -388 +▁gian -389 +▁tốt -390 +▁đối -391 +▁nhiên -392 +▁bu -393 +àn -394 +▁lắm -395 +▁năm -396 +ứa -397 +▁thường -398 +in -399 +▁phòng -400 +▁sinh -401 +áp -402 +▁tất -403 +iệu -404 +▁đều -405 +▁mấy -406 +▁bất -407 +▁ki -408 +▁bằng -409 +▁kho -410 +▁bắt -411 +▁bình -412 +▁hình -413 +▁thích -414 +en -415 +▁đưa -416 +▁đường -417 +uyên -418 +▁vu -419 +▁xin -420 +ạt -421 +▁xuống -422 +ọc -423 +▁ngồi -424 +▁bé -425 +▁bởi -426 +▁bộ -427 +ãy -428 +▁cháu -429 +ào -430 +▁tho -431 +▁vị -432 +▁xem -433 +▁tính -434 +ẹp -435 +▁chủ -436 +▁xe -437 +▁bỏ -438 +▁chắc -439 +▁khiến -440 +iệm -441 +úp -442 +▁ngay -443 +▁dụ -444 +▁vô -445 +ầy -446 +▁hành -447 +▁mu -448 +▁vui -449 +▁trọng -450 +▁nam -451 +ệnh -452 +▁xu -453 +▁dù -454 +▁đình -455 +ảnh -456 +▁đứa -457 +▁từng -458 +▁tuy -459 +ãi -460 +▁thêm -461 +uối -462 +▁giúp -463 +▁mỗi -464 +▁sợ -465 +▁cố -466 +▁mất -467 +▁nhớ -468 +ảng -469 +▁trường -470 +▁đứng -471 +ạy -472 +▁toàn -473 +▁sáng -474 +▁điể -475 +ịch -476 +áy -477 +iền -478 +▁trai -479 +▁đẹp -480 +▁cười -481 +▁nhỏ -482 +▁cửa -483 +▁mang -484 +ẩn -485 +▁phát -486 +▁càng -487 +▁trả -488 +▁thay -489 +▁quyết -490 +▁quay -491 +ợp -492 +ợi -493 +óng -494 +▁đừng -495 +▁chịu -496 +▁hãy -497 +ạc -498 +ặng -499 +▁ngo -500 +▁hội -501 +▁quân -502 +▁điểm -503 +▁hợp -504 +▁đàn -505 +▁xong -506 +▁kia -507 +▁đề -508 +▁giới -509 +▁liên -510 +▁minh -511 +ấm -512 +▁trung -513 +▁đại -514 +▁cao -515 +▁hoà -516 +iệp -517 +▁kỳ -518 +út -519 +▁bước -520 +▁thức -521 +▁cha -522 +▁chiến -523 +▁viên -524 +▁kết -525 +▁đủ -526 +▁bệnh -527 +▁chút -528 +ạng -529 +▁sĩ -530 +▁đổi -531 +ượng -532 +óc -533 +▁hoàng -534 +▁chiế -535 +▁trẻ -536 +▁thông -537 +▁phần -538 +ưới -539 +ướng -540 +▁dự -541 +▁phí -542 +▁năng -543 +▁chu -544 +▁trời -545 +▁ngủ -546 +▁hoàn -547 +▁nơi -548 +▁gần -549 +▁an -550 +▁mở -551 +▁sức -552 +ễn -553 +ưởng -554 +▁lâu -555 +▁dân -556 +▁quốc -557 +▁đau -558 +ệt -559 +▁giải -560 +ét -561 +▁cuối -562 +▁lực -563 +▁mua -564 +ảy -565 +▁rõ -566 +▁tập -567 +▁tiên -568 +ằm -569 +▁giả -570 +▁tưởng -571 +▁chân -572 +▁chí -573 +▁hoa -574 +▁mạnh -575 +▁kể -576 +▁lẽ -577 +▁à -578 +▁vấn -579 +▁giá -580 +▁đất -581 +▁gh -582 +▁sang -583 +▁việt -584 +▁vẻ -585 +▁hoặc -586 +▁đánh -587 +▁tác -588 +ắp -589 +▁tổ -590 +▁tên -591 +▁hạnh -592 +▁xa -593 +▁phụ -594 +▁kinh -595 +▁phương -596 +ít -597 +▁ảnh -598 +▁điện -599 +ộng -600 +▁trình -601 +▁nghĩa -602 +▁đôi -603 +▁phúc -604 +ẩm -605 +▁biệt -606 +ạm -607 +▁tạo -608 +▁khách -609 +▁hệ -610 +▁tài -611 +▁sản -612 +▁thủ -613 +▁chọn -614 +▁nhanh -615 +▁cảnh -616 +iển -617 +▁chết -618 +▁quả -619 +▁suy -620 +▁chạy -621 +▁di -622 +▁tức -623 +▁thằng -624 +▁khỏi -625 +▁dụng -626 +ển -627 +▁mai -628 +▁đơn -629 +▁đồ -630 +▁giống -631 +▁đêm -632 +▁máy -633 +▁lã -634 +▁chơi -635 +▁ơi -636 +▁nằm -637 +▁chiếc -638 +▁phía -639 +▁quen -640 +ưng -641 +▁khu -642 +▁xuất -643 +ạch -644 +▁nữ -645 +▁tối -646 +êng -647 +▁ơn -648 +ép -649 +▁nội -650 +▁quý -651 +▁thanh -652 +▁chỗ -653 +▁nghiệp -654 +▁thị -655 +ỉnh -656 +▁vụ -657 +iếm -658 +iêng -659 +▁văn -660 +óa -661 +▁chia -662 +iếu -663 +▁dễ -664 +▁kiểu -665 +▁mày -666 +▁báo -667 +▁cầu -668 +ậm -669 +▁mặc -670 +▁bán -671 +▁tế -672 +▁nổi -673 +▁đo -674 +▁bàn -675 +▁dài -676 +▁buồn -677 +▁ít -678 +▁nguy -679 +▁yên -680 +▁hồi -681 +▁đáng -682 +▁lỗi -683 +▁giữa -684 +▁mong -685 +▁giữ -686 +▁tiêu -687 +▁khổ -688 +▁ké -689 +▁hà -690 +▁ạ -691 +▁trị -692 +âng -693 +▁uống -694 +▁dùng -695 +▁to -696 +▁đặt -697 +▁tục -698 +▁vũ -699 +▁hóa -700 +▁đặc -701 +▁thiên -702 +▁thư -703 +▁vật -704 +▁ty -705 +▁hôn -706 +▁thái -707 +▁hơi -708 +▁tháng -709 +▁sử -710 +▁huy -711 +▁giác -712 +ềm -713 +▁ngờ -714 +▁trí -715 +▁cứu -716 +▁dám -717 +iểm -718 +▁cổ -719 +▁nguyên -720 +▁pháp -721 +▁giáo -722 +im -723 +▁già -724 +òa -725 +▁xúc -726 +▁thầy -727 +▁riêng -728 +▁cạnh -729 +▁mãi -730 +▁áo -731 +▁thần -732 +▁lượng -733 +▁kiếm -734 +ĩnh -735 +▁tranh -736 +▁lập -737 +▁quyền -738 +ốn -739 +▁chuyển -740 +ón -741 +▁chào -742 +▁hướng -743 +▁chất -744 +▁vài -745 +▁nhật -746 +▁cơm -747 +▁coi -748 +▁sách -749 +▁tuổi -750 +▁chung -751 +▁bài -752 +▁hương -753 +▁ánh -754 +▁phong -755 +úi -756 +ưu -757 +▁mỹ -758 +▁trái -759 +▁cây -760 +▁linh -761 +▁chăm -762 +▁tuấn -763 +▁duy -764 +▁so -765 +▁dành -766 +▁sai -767 +▁đọc -768 +ớm -769 +▁biến -770 +ùa -771 +▁thiết -772 +▁trang -773 +▁kiến -774 +▁phố -775 +▁truy -776 +▁loại -777 +▁khỏ -778 +ếp -779 +ịp -780 +ợng -781 +▁chuẩn -782 +▁dưới -783 +▁gắng -784 +▁chứng -785 +▁khỏe -786 +▁buổi -787 +▁đầy -788 +▁chấp -789 +▁khăn -790 +▁khá -791 +▁lắng -792 +▁diễn -793 +▁hắn -794 +▁tượng -795 +▁hiệu -796 +▁may -797 +▁đông -798 +▁tiến -799 +▁tương -800 +▁cực -801 +▁thoại -802 +▁thẳng -803 +▁liệu -804 +▁phẩm -805 +▁xã -806 +▁ôm -807 +▁đội -808 +▁tỉnh -809 +▁xảy -810 +▁chờ -811 +▁sớm -812 +▁thuộc -813 +▁hồ -814 +▁liền -815 +▁nỗi -816 +▁dẫn -817 +▁món -818 +▁khóc -819 +▁quên -820 +▁sát -821 +▁nghi -822 +uốt -823 +▁đợi -824 +▁chiều -825 +▁dậy -826 +▁viện -827 +uôi -828 +▁cấp -829 +▁hải -830 +óm -831 +▁phục -832 +▁quanh -833 +▁kéo -834 +▁nghỉ -835 +▁tao -836 +▁dạ -837 +▁yếu -838 +▁nghiệm -839 +▁chắn -840 +▁trải -841 +▁mối -842 +▁tinh -843 +▁lặng -844 +▁căn -845 +ộn -846 +▁sâu -847 +▁đạo -848 +ửi -849 +▁ngọc -850 +▁bay -851 +▁ví -852 +▁nhẹ -853 +ượt -854 +▁khí -855 +▁tích -856 +▁vọng -857 +▁vâng -858 +▁xưa -859 +▁lịch -860 +▁sẻ -861 +▁giọng -862 +▁cầm -863 +▁lợi -864 +▁vệ -865 +▁hưởng -866 +èo -867 +▁địa -868 +▁ban -869 +▁thiếu -870 +▁diện -871 +▁mức -872 +ưỡ -873 +▁hòa -874 +uốc -875 +ựa -876 +ẹn -877 +▁bọn -878 +▁xử -879 +▁ngoại -880 +▁tử -881 +▁cụ -882 +▁thậm -883 +▁hộ -884 +▁triển -885 +▁nuôi -886 +▁hả -887 +▁dương -888 +▁sắp -889 +▁trưởng -890 +▁mời -891 +▁thử -892 +▁giao -893 +▁nhiêu -894 +▁thuốc -895 +▁kẻ -896 +▁nhờ -897 +▁thú -898 +▁trò -899 +▁khả -900 +▁kỹ -901 +▁cưới -902 +▁dần -903 +▁quê -904 +▁ -905 +n -906 +h -907 +c -908 +i -909 +t -910 +g -911 +m -912 +a -913 +đ -914 +à -915 +u -916 +l -917 +o -918 +ư -919 +y -920 +ô -921 +v -922 +r -923 +b -924 +k -925 +á -926 +ó -927 +ì -928 +s -929 +ế -930 +p -931 +ờ -932 +ấ -933 +ạ -934 +ả -935 +ê -936 +ộ -937 +ớ -938 +â -939 +ố -940 +ệ -941 +ề -942 +ủ -943 +d -944 +ậ -945 +ể -946 +e -947 +ợ -948 +ú -949 +q -950 +ữ -951 +ơ -952 +ồ -953 +ọ -954 +ầ -955 +ị -956 +ứ -957 +x -958 +ắ -959 +ã -960 +ở -961 +ũ -962 +ự -963 +í -964 +ò -965 +ă -966 +ừ -967 +ặ -968 +ẽ -969 +ẹ -970 +ù -971 +ỏ -972 +ụ -973 +ổ -974 +ỉ -975 +ĩ -976 +ằ -977 +ẫ -978 +ý -979 +é -980 +ử -981 +ỗ -982 +ẻ -983 +ẳ -984 +ẩ -985 +ễ -986 +è -987 +ỡ -988 +õ -989 +ỳ -990 +ỹ -991 +ỷ -992 +ẵ -993 +ỵ -994 +w -995 +f -996 +j -997 +z -998 diff --git a/examples/datasets/vietbud500/sentencepiece/train_bpe_256.metadata.json b/examples/datasets/vietbud500/sentencepiece/train_bpe_256.metadata.json new file mode 100644 index 0000000000..4b867852ee --- /dev/null +++ b/examples/datasets/vietbud500/sentencepiece/train_bpe_256.metadata.json @@ -0,0 +1,12 @@ +{ + "train": { + "max_input_length": 256498, + "max_label_length": 100, + "num_entries": 634158 + }, + "eval": { + "max_input_length": 117571, + "max_label_length": 57, + "num_entries": 7500 + } +} \ No newline at end of file diff --git a/examples/datasets/vietbud500/sentencepiece/train_bpe_256.model b/examples/datasets/vietbud500/sentencepiece/train_bpe_256.model new file mode 100644 index 0000000000..8c43f4bef1 Binary files /dev/null and b/examples/datasets/vietbud500/sentencepiece/train_bpe_256.model differ diff --git a/examples/datasets/vietbud500/sentencepiece/train_bpe_256.vocab b/examples/datasets/vietbud500/sentencepiece/train_bpe_256.vocab new file mode 100644 index 0000000000..a67a33d269 --- /dev/null +++ b/examples/datasets/vietbud500/sentencepiece/train_bpe_256.vocab @@ -0,0 +1,256 @@ + 0 +▁c -0 +ng -1 +▁t -2 +nh -3 +▁đ -4 +▁m -5 +▁l -6 +▁th -7 +▁v -8 +▁ch -9 +▁b -10 +▁nh -11 +▁k -12 +▁n -13 +▁h -14 +▁kh -15 +▁ng -16 +▁s -17 +▁g -18 +▁là -19 +ông -20 +▁tr -21 +▁r -22 +▁không -23 +ời -24 +▁p -25 +▁ph -26 +▁cá -27 +▁có -28 +ên -29 +▁d -30 +ôi -31 +ình -32 +▁gi -33 +anh -34 +qu -35 +▁qu -36 +▁và -37 +ột -38 +ới -39 +▁củ -40 +▁của -41 +iế -42 +ười -43 +▁như -44 +▁một -45 +▁tôi -46 +▁nó -47 +▁mà -48 +▁người -49 +iệ -50 +▁x -51 +▁anh -52 +▁đư -53 +ại -54 +ất -55 +ấy -56 +▁nà -57 +▁mình -58 +▁đi -59 +▁thì -60 +▁cái -61 +ợc -62 +em -63 +▁được -64 +ay -65 +▁cũ -66 +uy -67 +▁co -68 +▁cũng -69 +ững -70 +ong -71 +▁những -72 +▁cho -73 +▁con -74 +ai -75 +ải -76 +▁em -77 +▁ngh -78 +▁cả -79 +ều -80 +▁đó -81 +▁cô -82 +ồi -83 +▁lại -84 +▁với -85 +ch -86 +ao -87 +ân -88 +▁này -89 +▁đã -90 +▁trong -91 +ần -92 +uố -93 +▁để -94 +▁làm -95 +▁nói -96 +▁ta -97 +ạn -98 +▁phải -99 +▁ra -100 +ây -101 +▁chú -102 +▁nhưng -103 +ướ -104 +ang -105 +au -106 +▁rồi -107 +▁sẽ -108 +âu -109 +ến -110 +▁về -111 +▁nhi -112 +iết -113 +an -114 +ác -115 +▁khi -116 +òn -117 +▁ti -118 +▁gì -119 +▁thế -120 +▁bạn -121 +ước -122 +▁ở -123 +▁họ -124 +▁đến -125 +▁còn -126 +▁thể -127 +▁các -128 +ết -129 +▁mẹ -130 +▁việ -131 +ươ -132 +ật -133 +▁ông -134 +▁biết -135 +úc -136 +▁nhà -137 +▁chúng -138 +ương -139 +ận -140 +oà -141 +▁chị -142 +ành -143 +▁bà -144 +ơn -145 +▁thấy -146 +▁từ -147 +ầu -148 +ậy -149 +▁chuy -150 +▁nào -151 +ăn -152 +▁chỉ -153 +ờng -154 +ữa -155 +▁rất -156 +▁sự -157 +ồng -158 +▁nhiều -159 +ùng -160 +▁ -161 +n -162 +h -163 +c -164 +i -165 +t -166 +g -167 +m -168 +a -169 +đ -170 +à -171 +u -172 +l -173 +o -174 +ư -175 +y -176 +ô -177 +v -178 +r -179 +b -180 +k -181 +á -182 +ó -183 +ì -184 +s -185 +ế -186 +p -187 +ờ -188 +ấ -189 +ạ -190 +ả -191 +ê -192 +ộ -193 +ớ -194 +â -195 +ố -196 +ệ -197 +ề -198 +ủ -199 +d -200 +ậ -201 +ể -202 +e -203 +ợ -204 +ú -205 +q -206 +ữ -207 +ơ -208 +ồ -209 +ọ -210 +ầ -211 +ị -212 +ứ -213 +x -214 +ắ -215 +ã -216 +ở -217 +ũ -218 +ự -219 +í -220 +ò -221 +ă -222 +ừ -223 +ặ -224 +ẽ -225 +ẹ -226 +ù -227 +ỏ -228 +ụ -229 +ổ -230 +ỉ -231 +ĩ -232 +ằ -233 +ẫ -234 +ý -235 +é -236 +ử -237 +ỗ -238 +ẻ -239 +ẳ -240 +ẩ -241 +ễ -242 +è -243 +ỡ -244 +õ -245 +ỳ -246 +ỹ -247 +ỷ -248 +ẵ -249 +ỵ -250 +w -251 +f -252 +j -253 +z -254 diff --git a/vocabularies/vietnamese.characters b/examples/datasets/vivos/vietnamese.characters similarity index 100% rename from vocabularies/vietnamese.characters rename to examples/datasets/vivos/vietnamese.characters diff --git a/examples/demonstration/conformer.py b/examples/demonstration/conformer.py deleted file mode 100644 index 255695ed62..0000000000 --- a/examples/demonstration/conformer.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os - -from tensorflow_asr.utils import env_util - -logger = env_util.setup_environment() -import tensorflow as tf - -parser = argparse.ArgumentParser(prog="Conformer non streaming") - -parser.add_argument("filename", metavar="FILENAME", help="audio file to be played back") - -parser.add_argument("--config", type=str, default=None, help="Path to conformer config yaml") - -parser.add_argument("--saved", type=str, default=None, help="Path to conformer saved h5 weights") - -parser.add_argument("--beam_width", type=int, default=0, help="Beam width") - -parser.add_argument("--timestamp", default=False, action="store_true", help="Return with timestamp") - -parser.add_argument("--device", type=int, default=0, help="Device's id to run test on") - -parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu") - -parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords") - -parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") - -args = parser.parse_args() - -env_util.setup_devices([args.device], cpu=args.cpu) - -from tensorflow_asr.configs.config import Config -from tensorflow_asr.featurizers.speech_featurizers import SpeechFeaturizer, read_raw_audio -from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer, SentencePieceFeaturizer, SubwordFeaturizer -from tensorflow_asr.models.transducer.conformer import Conformer -from tensorflow_asr.utils.data_util import create_inputs - -config = Config(args.config) -speech_featurizer = SpeechFeaturizer(config.speech_config) -if args.sentence_piece: - logger.info("Loading SentencePiece model ...") - text_featurizer = SentencePieceFeaturizer.load_from_file(config.decoder_config, args.subwords) -elif args.subwords and os.path.exists(args.subwords): - logger.info("Loading subwords ...") - text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords) -else: - text_featurizer = CharFeaturizer(config.decoder_config) -text_featurizer.decoder_config.beam_width = args.beam_width - -# build model -conformer = Conformer(**config.model_config, vocab_size=text_featurizer.num_classes) -conformer.make(speech_featurizer.shape) -conformer.load_weights(args.saved, by_name=True, skip_mismatch=True) -conformer.summary() -conformer.add_featurizers(speech_featurizer, text_featurizer) - -signal = read_raw_audio(args.filename) -features = speech_featurizer.tf_extract(signal) -input_length = tf.shape(features)[0] - -if args.beam_width: - inputs = create_inputs(features[None, ...], input_length[None, ...]) - transcript = conformer.recognize_beam(inputs) - logger.info(f"Transcript: {transcript[0].numpy().decode('UTF-8')}") -elif args.timestamp: - transcript, stime, etime, _, _ = conformer.recognize_tflite_with_timestamp( - signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state() - ) - logger.info(f"Transcript: {transcript}") - logger.info(f"Start time: {stime}") - logger.info(f"End time: {etime}") -else: - code_points, _, _ = conformer.recognize_tflite( - signal, tf.constant(text_featurizer.blank, dtype=tf.int32), conformer.predict_net.get_initial_state() - ) - transcript = tf.strings.unicode_encode(code_points, "UTF-8").numpy().decode("UTF-8") - logger.info(f"Transcript: {transcript}") diff --git a/examples/demonstration/rnn_transducer.py b/examples/demonstration/rnn_transducer.py deleted file mode 100644 index 63a46ffe89..0000000000 --- a/examples/demonstration/rnn_transducer.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse - -from tensorflow_asr.utils import data_util, env_util, math_util - -logger = env_util.setup_environment() -import tensorflow as tf - -parser = argparse.ArgumentParser(prog="Rnn Transducer non streaming") - -parser.add_argument("filename", metavar="FILENAME", help="audio file to be played back") - -parser.add_argument("--config", type=str, default=None, help="Path to rnnt config yaml") - -parser.add_argument("--saved", type=str, default=None, help="Path to rnnt saved h5 weights") - -parser.add_argument("--beam_width", type=int, default=0, help="Beam width") - -parser.add_argument("--timestamp", default=False, action="store_true", help="Return with timestamp") - -parser.add_argument("--device", type=int, default=0, help="Device's id to run test on") - -parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu") - -parser.add_argument("--subwords", default=False, action="store_true", help="Path to file that stores generated subwords") - -parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") - -args = parser.parse_args() - -env_util.setup_devices([args.device], cpu=args.cpu) - -from tensorflow_asr.configs.config import Config -from tensorflow_asr.featurizers.speech_featurizers import SpeechFeaturizer, read_raw_audio -from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer, SentencePieceFeaturizer, SubwordFeaturizer -from tensorflow_asr.models.transducer.rnn_transducer import RnnTransducer - -config = Config(args.config) -speech_featurizer = SpeechFeaturizer(config.speech_config) -if args.sentence_piece: - logger.info("Loading SentencePiece model ...") - text_featurizer = SentencePieceFeaturizer(config.decoder_config) -elif args.subwords: - logger.info("Loading subwords ...") - text_featurizer = SubwordFeaturizer(config.decoder_config) -else: - text_featurizer = CharFeaturizer(config.decoder_config) -text_featurizer.decoder_config.beam_width = args.beam_width - -# build model -rnnt = RnnTransducer(**config.model_config, vocab_size=text_featurizer.num_classes) -rnnt.make(speech_featurizer.shape) -rnnt.load_weights(args.saved, by_name=True, skip_mismatch=True) -rnnt.summary() -rnnt.add_featurizers(speech_featurizer, text_featurizer) - -signal = read_raw_audio(args.filename) -features = speech_featurizer.tf_extract(signal) -input_length = math_util.get_reduced_length(tf.shape(features)[0], rnnt.time_reduction_factor) - -if args.beam_width: - transcript = rnnt.recognize_beam(data_util.create_inputs(inputs=features[None, ...], inputs_length=input_length[None, ...])) - logger.info("Transcript:", transcript[0].numpy().decode("UTF-8")) -elif args.timestamp: - transcript, stime, etime, _, _, _ = rnnt.recognize_tflite_with_timestamp( - signal=signal, - predicted=tf.constant(text_featurizer.blank, dtype=tf.int32), - encoder_states=rnnt.encoder.get_initial_state(), - prediction_states=rnnt.predict_net.get_initial_state(), - ) - logger.info("Transcript:", transcript) - logger.info("Start time:", stime) - logger.info("End time:", etime) -else: - transcript = rnnt.recognize(data_util.create_inputs(inputs=features[None, ...], inputs_length=input_length[None, ...])) - logger.info("Transcript:", transcript[0].numpy().decode("UTF-8")) diff --git a/examples/demonstration/streaming_tflite_conformer.py b/examples/demonstration/streaming_tflite_conformer.py deleted file mode 100644 index 321f2a9c5f..0000000000 --- a/examples/demonstration/streaming_tflite_conformer.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import queue -import sys -from multiprocessing import Event, Manager, Process - -import numpy as np -import sounddevice as sd -import soundfile as sf -import tensorflow as tf - - -def int_or_str(text): - """Helper function for argument parsing.""" - try: - return int(text) - except ValueError: - return text - - -parser = argparse.ArgumentParser(prog="Conformer audio file streaming") - -parser.add_argument("-l", "--list-devices", action="store_true", help="show list of audio devices and exit") - -args, remaining = parser.parse_known_args() - -if args.list_devices: - print(sd.query_devices()) - parser.exit(0) - -parser.add_argument("filename", metavar="FILENAME", help="audio file to be played back") - -parser.add_argument("-d", "--device", type=int_or_str, help="output device (numeric ID or substring)") - -parser.add_argument("-b", "--blocksize", type=int, default=4096, help="block size (default: %(default)s)") - -parser.add_argument("-q", "--buffersize", type=int, default=20, help="number of blocks used for buffering (default: %(default)s)") - -parser.add_argument("--tflite", type=str, default=None, help="Path to conformer tflite") - -parser.add_argument("--blank", type=int, default=0, help="Path to conformer tflite") - -parser.add_argument("--num_rnns", type=int, default=1, help="Number of RNN layers in prediction network") - -parser.add_argument("--nstates", type=int, default=2, help="Number of RNN states in prediction network (1 for GRU and 2 for LSTM)") - -parser.add_argument("--statesize", type=int, default=320, help="Size of RNN state in prediction network") - -args = parser.parse_args(remaining) - -if args.blocksize == 0: - parser.error("blocksize must not be zero") -if args.buffersize < 1: - parser.error("buffersize must be at least 1") - -q = queue.Queue(maxsize=args.buffersize) -m = Manager() -Q = m.Queue() -E = Event() - - -def recognizer(Q): - tflitemodel = tf.lite.Interpreter(model_path=args.tflite) - - input_details = tflitemodel.get_input_details() - output_details = tflitemodel.get_output_details() - - tflitemodel.resize_tensor_input(input_details[0]["index"], [args.blocksize]) - tflitemodel.allocate_tensors() - - def recognize(signal, lastid, states): - if signal.shape[0] < args.blocksize: - signal = tf.pad(signal, [[0, args.blocksize - signal.shape[0]]]) - tflitemodel.set_tensor(input_details[0]["index"], signal) - tflitemodel.set_tensor(input_details[1]["index"], lastid) - tflitemodel.set_tensor(input_details[2]["index"], states) - tflitemodel.invoke() - upoints = tflitemodel.get_tensor(output_details[0]["index"]) - lastid = tflitemodel.get_tensor(output_details[1]["index"]) - states = tflitemodel.get_tensor(output_details[2]["index"]) - text = "".join([chr(u) for u in upoints]) - return text, lastid, states - - lastid = args.blank * tf.ones(shape=[], dtype=tf.int32) - states = tf.zeros(shape=[args.num_rnns, args.nstates, 1, args.statesize], dtype=tf.float32) - transcript = "" - - while True: - try: - data = Q.get() - text, lastid, states = recognize(data, lastid, states) - transcript += text - print(transcript, flush=True) - except queue.Empty: - pass - - -tflite_process = Process(target=recognizer, args=[Q]) -tflite_process.start() - - -def send(q, Q, E): - def callback(outdata, frames, time, status): - assert frames == args.blocksize - if status.output_underflow: - print("Output underflow: increase blocksize?", file=sys.stderr) - raise sd.CallbackAbort - assert not status - try: - data = q.get_nowait() - Q.put(np.frombuffer(data, dtype=np.float32)) - except queue.Empty as e: - print("Buffer is empty: increase buffersize?", file=sys.stderr) - raise sd.CallbackAbort from e - if len(data) < len(outdata): - outdata[: len(data)] = data - outdata[len(data) :] = b"\x00" * (len(outdata) - len(data)) - raise sd.CallbackStop - else: - outdata[:] = data - - try: - with sf.SoundFile(args.filename) as f: - for _ in range(args.buffersize): - data = f.buffer_read(args.blocksize, dtype="float32") - if not data: - break - q.put_nowait(data) # Pre-fill queue - stream = sd.RawOutputStream( - samplerate=f.samplerate, - blocksize=args.blocksize, - device=args.device, - channels=f.channels, - dtype="float32", - callback=callback, - finished_callback=E.set, - ) - with stream: - timeout = args.blocksize * args.buffersize / f.samplerate - while data: - data = f.buffer_read(args.blocksize, dtype="float32") - q.put(data, timeout=timeout) - E.wait() - - except KeyboardInterrupt: - parser.exit("\nInterrupted by user") - except queue.Full: - # A timeout occurred, i.e. there was an error in the callback - parser.exit(1) - except Exception as e: - parser.exit(type(e).__name__ + ": " + str(e)) - - -send_process = Process(target=send, args=[q, Q, E]) -send_process.start() -send_process.join() -send_process.close() - -tflite_process.terminate() diff --git a/examples/demonstration/tflite_conformer.py b/examples/demonstration/tflite_conformer.py deleted file mode 100644 index 534b40715c..0000000000 --- a/examples/demonstration/tflite_conformer.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse - -import tensorflow as tf - -from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio - -parser = argparse.ArgumentParser(prog="Conformer non streaming") - -parser.add_argument("filename", metavar="FILENAME", help="Audio file to be played back") - -parser.add_argument("--tflite", type=str, default=None, help="Path to conformer tflite") - -parser.add_argument("--blank", type=int, default=0, help="Blank index") - -parser.add_argument("--num_rnns", type=int, default=1, help="Number of RNN layers in prediction network") - -parser.add_argument("--nstates", type=int, default=2, help="Number of RNN states in prediction network (1 for GRU and 2 for LSTM)") - -parser.add_argument("--statesize", type=int, default=320, help="Size of RNN state in prediction network") - -args = parser.parse_args() - -tflitemodel = tf.lite.Interpreter(model_path=args.tflite) - -signal = read_raw_audio(args.filename) - -input_details = tflitemodel.get_input_details() -output_details = tflitemodel.get_output_details() -tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape) -tflitemodel.allocate_tensors() -tflitemodel.set_tensor(input_details[0]["index"], signal) -tflitemodel.set_tensor(input_details[1]["index"], tf.constant(args.blank, dtype=tf.int32)) -tflitemodel.set_tensor(input_details[2]["index"], tf.zeros([args.num_rnns, args.nstates, 1, args.statesize], dtype=tf.float32)) -tflitemodel.invoke() -hyp = tflitemodel.get_tensor(output_details[0]["index"]) - -print("".join([chr(u) for u in hyp])) diff --git a/examples/demonstration/README.md b/examples/inferences/README.md similarity index 100% rename from examples/demonstration/README.md rename to examples/inferences/README.md diff --git a/examples/inferences/main.py b/examples/inferences/main.py new file mode 100644 index 0000000000..efe52761aa --- /dev/null +++ b/examples/inferences/main.py @@ -0,0 +1,60 @@ +# Copyright 2020 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from tensorflow_asr import keras, schemas, tf, tokenizers +from tensorflow_asr.configs import Config +from tensorflow_asr.models import base_model +from tensorflow_asr.utils import cli_util, data_util, env_util, file_util + +logger = tf.get_logger() + + +def main( + file_path: str, + config_path: str, + h5: str, + repodir: str = os.getcwd(), +): + env_util.setup_seed() + file_path = file_util.preprocess_paths(file_path) + + config = Config(config_path, training=False, repodir=repodir) + tokenizer = tokenizers.get(config) + + model: base_model.BaseModel = keras.Model.from_config(config.model_config) + model.make(batch_size=1) + model.load_weights(h5, by_name=file_util.is_hdf5_filepath(h5), skip_mismatch=False) + model.summary() + + signal = data_util.read_raw_audio(data_util.load_and_convert_to_wav(file_path)) + signal = tf.reshape(signal, [1, -1]) + signal_length = tf.reshape(tf.shape(signal)[1], [1]) + + outputs = model.recognize( + schemas.PredictInput( + inputs=signal, + inputs_length=signal_length, + previous_tokens=model.get_initial_tokens(), + previous_encoder_states=model.get_initial_encoder_states(), + previous_decoder_states=model.get_initial_decoder_states(), + ) + ) + transcript = tokenizer.detokenize(outputs.tokens)[0].numpy().decode("utf-8") + logger.info(f"Transcript: {transcript}") + + +if __name__ == "__main__": + cli_util.run(main) diff --git a/examples/inferences/rnn_transducer.py b/examples/inferences/rnn_transducer.py new file mode 100644 index 0000000000..b6d471a6d4 --- /dev/null +++ b/examples/inferences/rnn_transducer.py @@ -0,0 +1,89 @@ +# # Copyright 2020 Huy Le Nguyen (@nglehuy) +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. + +# import argparse + +# from tensorflow_asr.utils import data_util, env_util, math_util + +# logger = env_util.setup_environment() +# import tensorflow as tf + +# parser = argparse.ArgumentParser(prog="Rnn Transducer non streaming") + +# parser.add_argument("filename", metavar="FILENAME", help="audio file to be played back") + +# parser.add_argument("--config", type=str, default=None, help="Path to rnnt config yaml") + +# parser.add_argument("--saved", type=str, default=None, help="Path to rnnt saved h5 weights") + +# parser.add_argument("--beam_width", type=int, default=0, help="Beam width") + +# parser.add_argument("--timestamp", default=False, action="store_true", help="Return with timestamp") + +# parser.add_argument("--device", type=int, default=0, help="Device's id to run test on") + +# parser.add_argument("--cpu", default=False, action="store_true", help="Whether to only use cpu") + +# parser.add_argument("--subwords", default=False, action="store_true", help="Path to file that stores generated subwords") + +# parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") + +# args = parser.parse_args() + +# env_util.setup_devices([args.device], cpu=args.cpu) + +# from tensorflow_asr.configs import Config +# from tensorflow_asr.features.speech_featurizers import SpeechFeaturizer, read_raw_audio +# from tensorflow_asr.models.transducer.rnnt import RnnTransducer +# from tensorflow_asr.tokenizers import CharTokenizer, SentencePieceTokenizer, SubwordFeaturizer + +# config = Config(args.config) +# speech_featurizer = SpeechFeaturizer(config.speech_config) +# if args.sentence_piece: +# logger.info("Loading SentencePiece model ...") +# text_featurizer = SentencePieceTokenizer(config.decoder_config) +# elif args.subwords: +# logger.info("Loading subwords ...") +# text_featurizer = SubwordFeaturizer(config.decoder_config) +# else: +# text_featurizer = CharTokenizer(config.decoder_config) +# text_featurizer.decoder_config.beam_width = args.beam_width + +# # build model +# rnnt = RnnTransducer(**config.model_config, vocab_size=text_featurizer.num_classes) +# rnnt.make(speech_featurizer.shape) +# rnnt.load_weights(args.saved, by_name=True, skip_mismatch=True) +# rnnt.summary() +# rnnt.add_featurizers(speech_featurizer, text_featurizer) + +# signal = read_raw_audio(args.filename) +# features = speech_featurizer.tf_extract(signal) +# input_length = math_util.get_reduced_length(tf.shape(features)[0], rnnt.time_reduction_factor) + +# if args.beam_width: +# transcript = rnnt.recognize_beam(data_util.create_inputs(inputs=features[None, ...], inputs_length=input_length[None, ...])) +# logger.info("Transcript:", transcript[0].numpy().decode("UTF-8")) +# elif args.timestamp: +# transcript, stime, etime, _, _, _ = rnnt.recognize_tflite_with_timestamp( +# signal=signal, +# predicted=tf.constant(text_featurizer.blank, dtype=tf.int32), +# encoder_states=rnnt.encoder.get_initial_state(), +# prediction_states=rnnt.predict_net.get_initial_state(), +# ) +# logger.info("Transcript:", transcript) +# logger.info("Start time:", stime) +# logger.info("End time:", etime) +# else: +# transcript = rnnt.recognize(data_util.create_inputs(inputs=features[None, ...], inputs_length=input_length[None, ...])) +# logger.info("Transcript:", transcript[0].numpy().decode("UTF-8")) diff --git a/examples/inferences/streaming_tflite_conformer.py b/examples/inferences/streaming_tflite_conformer.py new file mode 100644 index 0000000000..46c0523a58 --- /dev/null +++ b/examples/inferences/streaming_tflite_conformer.py @@ -0,0 +1,172 @@ +# # Copyright 2020 Huy Le Nguyen (@nglehuy) +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. + +# import argparse +# import queue +# import sys +# from multiprocessing import Event, Manager, Process + +# import numpy as np +# import sounddevice as sd +# import soundfile as sf +# import tensorflow as tf + + +# def int_or_str(text): +# """Helper function for argument parsing.""" +# try: +# return int(text) +# except ValueError: +# return text + + +# parser = argparse.ArgumentParser(prog="Conformer audio file streaming") + +# parser.add_argument("-l", "--list-devices", action="store_true", help="show list of audio devices and exit") + +# args, remaining = parser.parse_known_args() + +# if args.list_devices: +# print(sd.query_devices()) +# parser.exit(0) + +# parser.add_argument("filename", metavar="FILENAME", help="audio file to be played back") + +# parser.add_argument("-d", "--device", type=int_or_str, help="output device (numeric ID or substring)") + +# parser.add_argument("-b", "--blocksize", type=int, default=4096, help="block size (default: %(default)s)") + +# parser.add_argument("-q", "--buffersize", type=int, default=20, help="number of blocks used for buffering (default: %(default)s)") + +# parser.add_argument("--tflite", type=str, default=None, help="Path to conformer tflite") + +# parser.add_argument("--blank", type=int, default=0, help="Path to conformer tflite") + +# parser.add_argument("--num_rnns", type=int, default=1, help="Number of RNN layers in prediction network") + +# parser.add_argument("--nstates", type=int, default=2, help="Number of RNN states in prediction network (1 for GRU and 2 for LSTM)") + +# parser.add_argument("--statesize", type=int, default=320, help="Size of RNN state in prediction network") + +# args = parser.parse_args(remaining) + +# if args.blocksize == 0: +# parser.error("blocksize must not be zero") +# if args.buffersize < 1: +# parser.error("buffersize must be at least 1") + +# q = queue.Queue(maxsize=args.buffersize) +# m = Manager() +# Q = m.Queue() +# E = Event() + + +# def recognizer(Q): +# tflitemodel = tf.lite.Interpreter(model_path=args.tflite) + +# input_details = tflitemodel.get_input_details() +# output_details = tflitemodel.get_output_details() + +# tflitemodel.resize_tensor_input(input_details[0]["index"], [args.blocksize]) +# tflitemodel.allocate_tensors() + +# def recognize(signal, lastid, states): +# if signal.shape[0] < args.blocksize: +# signal = tf.pad(signal, [[0, args.blocksize - signal.shape[0]]]) +# tflitemodel.set_tensor(input_details[0]["index"], signal) +# tflitemodel.set_tensor(input_details[1]["index"], lastid) +# tflitemodel.set_tensor(input_details[2]["index"], states) +# tflitemodel.invoke() +# upoints = tflitemodel.get_tensor(output_details[0]["index"]) +# lastid = tflitemodel.get_tensor(output_details[1]["index"]) +# states = tflitemodel.get_tensor(output_details[2]["index"]) +# text = "".join([chr(u) for u in upoints]) +# return text, lastid, states + +# lastid = args.blank * tf.ones(shape=[], dtype=tf.int32) +# states = tf.zeros(shape=[args.num_rnns, args.nstates, 1, args.statesize], dtype=tf.float32) +# transcript = "" + +# while True: +# try: +# data = Q.get() +# text, lastid, states = recognize(data, lastid, states) +# transcript += text +# print(transcript, flush=True) +# except queue.Empty: +# pass + + +# tflite_process = Process(target=recognizer, args=[Q]) +# tflite_process.start() + + +# def send(q, Q, E): +# def callback(outdata, frames, time, status): +# assert frames == args.blocksize +# if status.output_underflow: +# print("Output underflow: increase blocksize?", file=sys.stderr) +# raise sd.CallbackAbort +# assert not status +# try: +# data = q.get_nowait() +# Q.put(np.frombuffer(data, dtype=np.float32)) +# except queue.Empty as e: +# print("Buffer is empty: increase buffersize?", file=sys.stderr) +# raise sd.CallbackAbort from e +# if len(data) < len(outdata): +# outdata[: len(data)] = data +# outdata[len(data) :] = b"\x00" * (len(outdata) - len(data)) +# raise sd.CallbackStop +# else: +# outdata[:] = data + +# try: +# with sf.SoundFile(args.filename) as f: +# for _ in range(args.buffersize): +# data = f.buffer_read(args.blocksize, dtype="float32") +# if not data: +# break +# q.put_nowait(data) # Pre-fill queue +# stream = sd.RawOutputStream( +# samplerate=f.samplerate, +# blocksize=args.blocksize, +# device=args.device, +# channels=f.channels, +# dtype="float32", +# callback=callback, +# finished_callback=E.set, +# ) +# with stream: +# timeout = args.blocksize * args.buffersize / f.samplerate +# while data: +# data = f.buffer_read(args.blocksize, dtype="float32") +# q.put(data, timeout=timeout) +# E.wait() + +# except KeyboardInterrupt: +# parser.exit("\nInterrupted by user") +# except queue.Full: +# # A timeout occurred, i.e. there was an error in the callback +# parser.exit(1) +# except Exception as e: +# parser.exit(type(e).__name__ + ": " + str(e)) + + +# send_process = Process(target=send, args=[q, Q, E]) +# send_process.start() +# send_process.join() +# send_process.close() + +# tflite_process.terminate() diff --git a/examples/inferences/tflite.py b/examples/inferences/tflite.py new file mode 100644 index 0000000000..8ab5e8d70e --- /dev/null +++ b/examples/inferences/tflite.py @@ -0,0 +1,72 @@ +# Copyright 2020 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import tensorflow_text as tft +from tensorflow.lite.python import interpreter + +from tensorflow_asr import tf +from tensorflow_asr.utils import cli_util, data_util + +logger = logging.getLogger(__name__) + + +def main( + audio_file_path: str, + tflite: str, + sample_rate: int = 16000, + blank: int = 0, +): + wav = data_util.load_and_convert_to_wav(audio_file_path, sample_rate=sample_rate) + signal = data_util.read_raw_audio(wav) + signal = tf.reshape(signal, [1, -1]) + signal_length = tf.reshape(tf.shape(signal)[1], [1]) + + tflitemodel = interpreter.InterpreterWithCustomOps(model_path=tflite, custom_op_registerers=tft.tflite_registrar.SELECT_TFTEXT_OPS) + input_details = tflitemodel.get_input_details() + output_details = tflitemodel.get_output_details() + + tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape, strict=True) + tflitemodel.allocate_tensors() + tflitemodel.set_tensor(input_details[0]["index"], signal) + tflitemodel.set_tensor(input_details[1]["index"], signal_length) + tflitemodel.set_tensor(input_details[2]["index"], tf.ones(input_details[2]["shape"], dtype=input_details[2]["dtype"]) * blank) + tflitemodel.set_tensor(input_details[3]["index"], tf.zeros(input_details[3]["shape"], dtype=input_details[3]["dtype"])) + tflitemodel.set_tensor(input_details[4]["index"], tf.zeros(input_details[4]["shape"], dtype=input_details[4]["dtype"])) + + tflitemodel.invoke() + + transcript = tflitemodel.get_tensor(output_details[0]["index"]) + tokens = tflitemodel.get_tensor(output_details[1]["index"]) + next_tokens = tflitemodel.get_tensor(output_details[2]["index"]) + if len(output_details) > 4: + next_encoder_states = tflitemodel.get_tensor(output_details[3]["index"]) + next_decoder_states = tflitemodel.get_tensor(output_details[4]["index"]) + elif len(output_details) > 3: + next_encoder_states = None + next_decoder_states = tflitemodel.get_tensor(output_details[3]["index"]) + else: + next_encoder_states = None + next_decoder_states = None + + logger.info(f"Transcript: {transcript}") + logger.info(f"Tokens: {tokens}") + logger.info(f"Next tokens: {next_tokens}") + logger.info(f"Next encoder states: {None if next_encoder_states is None else next_encoder_states.shape}") + logger.info(f"Next decoder states: {None if next_decoder_states is None else next_decoder_states.shape}") + + +if __name__ == "__main__": + cli_util.run(main) diff --git a/examples/demonstration/wavs/1089-134691-0000.flac b/examples/inferences/wavs/1089-134691-0000.flac similarity index 100% rename from examples/demonstration/wavs/1089-134691-0000.flac rename to examples/inferences/wavs/1089-134691-0000.flac diff --git a/examples/demonstration/wavs/2033-164915-0001.flac b/examples/inferences/wavs/2033-164915-0001.flac similarity index 100% rename from examples/demonstration/wavs/2033-164915-0001.flac rename to examples/inferences/wavs/2033-164915-0001.flac diff --git a/examples/models/ctc/conformer/results/sentencepiece/README.md b/examples/models/ctc/conformer/results/sentencepiece/README.md new file mode 100644 index 0000000000..c2c0d6e2f3 --- /dev/null +++ b/examples/models/ctc/conformer/results/sentencepiece/README.md @@ -0,0 +1,81 @@ +- [\[English\] LibriSpeech](#english-librispeech) + - [I. Small + SentencePiece 256](#i-small--sentencepiece-256) + - [II. Small + Streaming + SentencePiece 256](#ii-small--streaming--sentencepiece-256) + +# [English] LibriSpeech + +## I. Small + SentencePiece 256 + +| Category | Description | +| :---------------- | :--------------------------------------------------------------------------------------- | +| Config | [small.yml.j2](../../small.yml.j2) | +| Tensorflow | **2.18.0** | +| Device | Google Cloud TPUs v4-8 | +| Mixed Precision | strict | +| Global Batch Size | 8 * 4 * 8 = 256 (as 4 TPUs, 8 Gradient Accumulation Steps) | +| Max Epochs | 450 | +| Pretrained | [Link](https://www.kaggle.com/models/lordh9072/tfasr-conformer-ctc/tensorFlow2/v3-small) | + +**Config:** + +```jinja2 +{% import "examples/datasets/librispeech/sentencepiece/sp.256.yml.j2" as decoder_config with context %} +{{decoder_config}} +{% import "examples/models/ctc/conformer/small.yml.j2" as config with context %} +{{config}} +``` + +**Results:** + +| Epoch | Dataset | decoding | wer | cer | mer | wil | wip | +| :---- | :--------- | :------- | :-------- | :-------- | :-------- | :------- | :------- | +| 170 | test-clean | greedy | 0.0967171 | 0.031954 | 0.0958403 | 0.168307 | 0.831693 | +| 170 | test-other | greedy | 0.201612 | 0.0812955 | 0.197415 | 0.330207 | 0.669793 | + + +## II. Small + Streaming + SentencePiece 256 + +| Category | Description | +| :---------------- | :------------------------------------------------------------------------------------------------- | +| Config | [small-streaming.yml.j2](../../small-streaming.yml.j2) | +| Tensorflow | **2.18.0** | +| Device | Google Cloud TPUs v4-8 | +| Mixed Precision | strict | +| Global Batch Size | 8 * 4 * 8 = 256 (as 4 TPUs, 8 Gradient Accumulation Steps) | +| Max Epochs | 450 | +| Pretrained | [Link](https://www.kaggle.com/models/lordh9072/tfasr-conformer-ctc/tensorFlow2/v3-small-streaming) | + +**Config:** + +```jinja2 +{% import "examples/datasets/librispeech/sentencepiece/sp.256.yml.j2" as decoder_config with context %} +{{decoder_config}} +{% import "examples/models/ctc/conformer/small-streaming.yml.j2" as config with context %} +{{config}} +``` + +**Tensorboard:** + + + + + + + +
+
+ Epoch Loss +
+
+ Batch Loss +
+
+ Learning Rate +
+ +**Results:** + +| Epoch | Dataset | decoding | wer | cer | mer | wil | wip | +| :---- | :--------- | :------- | :-------- | :-------- | :-------- | :------ | :------ | +| 60 | test-clean | greedy | 0.0848106 | 0.0286257 | 0.0841686 | 0.14896 | 0.85104 | +| 60 | test-other | greedy | 0.217221 | 0.0913044 | 0.213409 | 0.3555 | 0.6445 | \ No newline at end of file diff --git a/examples/models/ctc/conformer/results/sentencepiece/figs/librispeech-small-streaming-batch-loss.jpg b/examples/models/ctc/conformer/results/sentencepiece/figs/librispeech-small-streaming-batch-loss.jpg new file mode 100644 index 0000000000..df3e7b7991 Binary files /dev/null and b/examples/models/ctc/conformer/results/sentencepiece/figs/librispeech-small-streaming-batch-loss.jpg differ diff --git a/examples/models/ctc/conformer/results/sentencepiece/figs/librispeech-small-streaming-epoch-loss.jpg b/examples/models/ctc/conformer/results/sentencepiece/figs/librispeech-small-streaming-epoch-loss.jpg new file mode 100644 index 0000000000..7eefbd1680 Binary files /dev/null and b/examples/models/ctc/conformer/results/sentencepiece/figs/librispeech-small-streaming-epoch-loss.jpg differ diff --git a/examples/models/ctc/conformer/results/sentencepiece/figs/librispeech-small-streaming-lr.jpg b/examples/models/ctc/conformer/results/sentencepiece/figs/librispeech-small-streaming-lr.jpg new file mode 100644 index 0000000000..0f91d35e28 Binary files /dev/null and b/examples/models/ctc/conformer/results/sentencepiece/figs/librispeech-small-streaming-lr.jpg differ diff --git a/examples/models/ctc/conformer/small-streaming.yml.j2 b/examples/models/ctc/conformer/small-streaming.yml.j2 new file mode 100644 index 0000000000..56f285ff70 --- /dev/null +++ b/examples/models/ctc/conformer/small-streaming.yml.j2 @@ -0,0 +1,95 @@ +model_config: + class_name: tensorflow_asr.models.ctc.conformer>Conformer + config: + name: conformer + speech_config: + sample_rate: 16000 + frame_ms: 25 + stride_ms: 10 + nfft: 512 + num_feature_bins: 80 + feature_type: log_mel_spectrogram + augmentation_config: null + encoder_subsampling: + class_name: tensorflow_asr.models.layers.subsampling>Conv2dSubsampling + config: + filters: [176, 176] + kernels: [3, 3] + strides: [2, 2] + paddings: ["causal", "causal"] + norms: ["layer", "layer"] + activations: ["swish", "swish"] + encoder_ffm_residual_factor: 0.5 + encoder_mhsam_residual_factor: 1.0 + encoder_convm_residual_factor: 1.0 + encoder_dmodel: 176 + encoder_num_blocks: 16 + encoder_head_size: 44 # == dmodel // num_heads + encoder_num_heads: 4 + encoder_mha_type: relmha + encoder_interleave_relpe: True + encoder_use_attention_causal_mask: False + encoder_use_attention_auto_mask: True + encoder_mhsam_use_attention_bias: True + encoder_convm_dw_norm_type: layer + encoder_kernel_size: 31 + encoder_dropout: 0.1 + encoder_padding: causal + encoder_memory_length: null + encoder_history_size: 64 # frames = 4 * chunk_size + encoder_chunk_size: 16 # frames + blank: 0 + vocab_size: {{decoder_config.vocabsize}} + kernel_regularizer: + class_name: l2 + config: + l2: 1e-6 + +learning_config: + optimizer_config: + class_name: Adam + config: + learning_rate: + class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule + config: + dmodel: 176 + warmup_steps: 10000 + max_lr: 0.05/(176**0.5) + min_lr: null + scale: 2.0 + beta_1: 0.9 + beta_2: 0.98 + epsilon: 1e-9 + weight_decay: 1e-6 + + gwn_config: null + + gradn_config: null + + batch_size: 8 + ga_steps: 4 + num_epochs: 450 + + callbacks: + - class_name: tensorflow_asr.callbacks>TerminateOnNaN + config: {} + - class_name: tensorflow_asr.callbacks>ModelCheckpoint + config: + filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 + save_best_only: False + save_weights_only: True + save_freq: epoch + - class_name: tensorflow_asr.callbacks>TensorBoard + config: + log_dir: {{modeldir}}/tensorboard + histogram_freq: 0 + write_graph: False + write_images: False + write_steps_per_second: False + update_freq: batch + profile_batch: 0 + - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore + config: + model_handle: {{kaggle_model_handle}} + model_dir: {{modeldir}} + save_freq: epoch \ No newline at end of file diff --git a/examples/models/ctc/conformer/small.yml.j2 b/examples/models/ctc/conformer/small.yml.j2 new file mode 100644 index 0000000000..b51bbeb7e5 --- /dev/null +++ b/examples/models/ctc/conformer/small.yml.j2 @@ -0,0 +1,104 @@ +model_config: + class_name: tensorflow_asr.models.ctc.conformer>Conformer + config: + name: conformer + speech_config: + sample_rate: 16000 + frame_ms: 25 + stride_ms: 10 + nfft: 512 + num_feature_bins: 80 + feature_type: log_mel_spectrogram + augmentation_config: + feature_augment: + time_masking: + prob: 0.5 + num_masks: 5 + mask_factor: -1 # whole utterance + p_upperbound: 0.05 + mask_value: 0 + freq_masking: + prob: 0.5 + num_masks: 2 + mask_factor: 27 + mask_value: 0 + encoder_subsampling: + class_name: tensorflow_asr.models.layers.subsampling>Conv2dSubsampling + config: + filters: [176, 176] + kernels: [3, 3] + strides: [2, 2] + paddings: ["causal", "causal"] + norms: ["batch", "batch"] + activations: ["swish", "swish"] + encoder_ffm_residual_factor: 0.5 + encoder_mhsam_residual_factor: 1.0 + encoder_convm_residual_factor: 1.0 + encoder_dmodel: 176 + encoder_num_blocks: 16 + encoder_head_size: 44 # == dmodel // num_heads + encoder_num_heads: 4 + encoder_mha_type: relmha + encoder_interleave_relpe: True + encoder_use_attention_causal_mask: False + encoder_use_attention_auto_mask: True + encoder_mhsam_use_attention_bias: True + encoder_kernel_size: 31 + encoder_dropout: 0.1 + encoder_padding: causal + encoder_memory_length: null + blank: 0 + vocab_size: {{decoder_config.vocabsize}} + kernel_regularizer: + class_name: l2 + config: + l2: 1e-6 + +learning_config: + optimizer_config: + class_name: Adam + config: + learning_rate: + class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule + config: + dmodel: 176 + warmup_steps: 10000 + max_lr: 0.05/(176**0.5) + min_lr: null + scale: 2.0 + beta_1: 0.9 + beta_2: 0.98 + epsilon: 1e-9 + weight_decay: 1e-6 + + gwn_config: null + + gradn_config: null + + batch_size: 8 + ga_steps: 4 + num_epochs: 450 + + callbacks: + - class_name: tensorflow_asr.callbacks>TerminateOnNaN + config: {} + - class_name: tensorflow_asr.callbacks>ModelCheckpoint + config: + filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 + save_best_only: False + save_weights_only: True + save_freq: epoch + - class_name: tensorflow_asr.callbacks>TensorBoard + config: + log_dir: {{modeldir}}/tensorboard + histogram_freq: 0 + write_graph: False + write_images: False + write_steps_per_second: False + update_freq: batch + profile_batch: 0 + - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore + config: + model_handle: {{kaggle_model_handle}} + model_dir: {{modeldir}} + save_freq: epoch \ No newline at end of file diff --git a/examples/models/ctc/deepspeech2/base.yml.j2 b/examples/models/ctc/deepspeech2/base.yml.j2 new file mode 100644 index 0000000000..f9762f69ef --- /dev/null +++ b/examples/models/ctc/deepspeech2/base.yml.j2 @@ -0,0 +1,100 @@ +model_config: + class_name: tensorflow_asr.models.ctc.deepspeech2>DeepSpeech2 + config: + name: deepspeech2 + speech_config: + sample_rate: 16000 + frame_ms: 25 + stride_ms: 10 + nfft: 512 + num_feature_bins: 160 + feature_type: spectrogram + augmentation_config: + feature_augment: + time_masking: + prob: 1.0 + num_masks: 5 + mask_factor: -1 # whole utterance + p_upperbound: 0.05 + mask_value: 0 + freq_masking: + prob: 1.0 + num_masks: 1 + mask_factor: 27 + mask_value: 0 + conv_type: conv2d + conv_kernels: [ [ 11, 41 ], [ 11, 21 ] ] + conv_strides: [ [ 2, 2 ], [ 1, 2 ] ] + conv_filters: [ 32, 32 ] + conv_activation: relu + conv_padding: same + conv_initializer: he_uniform + rnn_nlayers: 5 + rnn_type: lstm + rnn_units: 512 + rnn_bidirectional: True + rnn_unroll: False + rnn_rowconv: 0 + rnn_rowconv_activation: relu + rnn_dropout: 0.5 + fc_nlayers: 1 + fc_units: 1024 + fc_activation: relu + fc_dropout: 0.5 + fc_initializer: he_uniform + blank: 0 + vocab_size: {{decoder_config.vocabsize}} + kernel_regularizer: + class_name: l2 + config: + l2: 0.0005 + bias_regularizer: + class_name: l2 + config: + l2: 0.0005 + +learning_config: + optimizer_config: + class_name: Adam + config: + learning_rate: + class_name: ExponentialDecay + module: keras.src.optimizers.schedules.learning_rate_schedule + config: + initial_learning_rate: 0.0001 + decay_steps: 5000 + decay_rate: 0.9 + staircase: True + + gwn_config: null + + gradn_config: null + + batch_size: 16 + ga_steps: 4 + num_epochs: 450 + + callbacks: + - class_name: tensorflow_asr.callbacks>TerminateOnNaN + config: {} + - class_name: tensorflow_asr.callbacks>ModelCheckpoint + config: + filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 + save_best_only: False + save_weights_only: True + save_freq: epoch + - class_name: tensorflow_asr.callbacks>TensorBoard + config: + log_dir: {{modeldir}}/tensorboard + histogram_freq: 0 + write_graph: False + write_images: False + write_steps_per_second: False + update_freq: batch + profile_batch: 0 + - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore + config: + model_handle: {{kaggle_model_handle}} + model_dir: {{modeldir}} + save_freq: epoch + diff --git a/examples/models/ctc/deepspeech2/uni.yml.j2 b/examples/models/ctc/deepspeech2/uni.yml.j2 new file mode 100644 index 0000000000..5d79f9abc9 --- /dev/null +++ b/examples/models/ctc/deepspeech2/uni.yml.j2 @@ -0,0 +1,103 @@ +model_config: + class_name: tensorflow_asr.models.ctc.deepspeech2>DeepSpeech2 + config: + name: deepspeech2 + speech_config: + sample_rate: 16000 + frame_ms: 25 + stride_ms: 10 + nfft: 512 + num_feature_bins: 160 + feature_type: spectrogram + augmentation_config: + feature_augment: + time_masking: + prob: 1.0 + num_masks: 5 + mask_factor: -1 # whole utterance + p_upperbound: 0.05 + mask_value: 0 + freq_masking: + prob: 1.0 + num_masks: 1 + mask_factor: 27 + mask_value: 0 + conv_type: conv2d + conv_kernels: [ [ 11, 41 ], [ 11, 21 ] ] + conv_strides: [ [ 2, 2 ], [ 1, 2 ] ] + conv_filters: [ 32, 32 ] + conv_activation: relu + conv_padding: causal + conv_initializer: he_uniform + rnn_nlayers: 5 + rnn_type: lstm + rnn_units: 512 + rnn_bidirectional: False + rnn_unroll: False + rnn_rowconv: 3 + rnn_rowconv_activation: relu + rnn_dropout: 0.1 + fc_nlayers: 1 + fc_units: 1024 + fc_activation: relu + fc_dropout: 0.1 + fc_initializer: he_uniform + blank: 0 + vocab_size: {{decoder_config.vocabsize}} + kernel_regularizer: + class_name: l2 + config: + l2: 0.0005 + bias_regularizer: + class_name: l2 + config: + l2: 0.0005 + +learning_config: + optimizer_config: + class_name: Adam + config: + learning_rate: + class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule + config: + dmodel: 512 + warmup_steps: 10000 + min_lr: 1e-6 + scale: 2.0 + beta_1: 0.9 + beta_2: 0.98 + epsilon: 1e-9 + weight_decay: 1e-6 + + gwn_config: null + + gradn_config: null + + batch_size: 16 + ga_steps: 4 + num_epochs: 450 + + callbacks: + - class_name: tensorflow_asr.callbacks>TerminateOnNaN + config: {} + - class_name: tensorflow_asr.callbacks>ModelCheckpoint + config: + filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 + save_best_only: False + save_weights_only: True + save_freq: epoch + - class_name: tensorflow_asr.callbacks>TensorBoard + config: + log_dir: {{modeldir}}/tensorboard + histogram_freq: 0 + write_graph: False + write_images: False + write_steps_per_second: False + update_freq: batch + profile_batch: 0 + - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore + config: + model_handle: {{kaggle_model_handle}} + model_dir: {{modeldir}} + save_freq: epoch + diff --git a/examples/models/ctc/jasper/base.yml.j2 b/examples/models/ctc/jasper/base.yml.j2 new file mode 100644 index 0000000000..2fd04b762e --- /dev/null +++ b/examples/models/ctc/jasper/base.yml.j2 @@ -0,0 +1,80 @@ +model_config: + class_name: tensorflow_asr.models.ctc.jasper>Jasper + config: + name: jasper + speech_config: + sample_rate: 16000 + frame_ms: 25 + stride_ms: 10 + nfft: 512 + num_feature_bins: 80 + feature_type: log_mel_spectrogram + log_base: "10" + dense: True + first_additional_block_channels: 256 + first_additional_block_kernels: 11 + first_additional_block_strides: 2 + first_additional_block_dilation: 1 + first_additional_block_dropout: 0.2 + nsubblocks: 3 + block_channels: [ 256, 384, 512, 640, 768 ] + block_kernels: [ 11, 13, 17, 21, 25 ] + block_dropout: [ 0.2, 0.2, 0.2, 0.3, 0.3 ] + second_additional_block_channels: 896 + second_additional_block_kernels: 1 + second_additional_block_strides: 1 + second_additional_block_dilation: 2 + second_additional_block_dropout: 0.4 + third_additional_block_channels: 1024 + third_additional_block_kernels: 1 + third_additional_block_strides: 1 + third_additional_block_dilation: 1 + third_additional_block_dropout: 0.4 + blank: 0 + vocab_size: {{decoder_config.vocabsize}} + kernel_regularizer: + class_name: l2 + config: + l2: 1e-6 + +learning_config: + optimizer_config: + class_name: Adam + config: + learning_rate: 0.001 + beta_1: 0.9 + beta_2: 0.98 + epsilon: 1e-9 + + gwn_config: null + + gradn_config: null + + batch_size: 16 + ga_steps: 4 + num_epochs: 450 + + callbacks: + - class_name: tensorflow_asr.callbacks>TerminateOnNaN + config: {} + - class_name: tensorflow_asr.callbacks>ModelCheckpoint + config: + filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 + save_best_only: False + save_weights_only: True + save_freq: epoch + - class_name: tensorflow_asr.callbacks>TensorBoard + config: + log_dir: {{modeldir}}/tensorboard + histogram_freq: 0 + write_graph: False + write_images: False + write_steps_per_second: False + update_freq: batch + profile_batch: 0 + - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore + config: + model_handle: {{kaggle_model_handle}} + model_dir: {{modeldir}} + save_freq: epoch + diff --git a/tensorflow_asr/configs/__init__.py b/examples/models/ctc/transformer/README.md similarity index 100% rename from tensorflow_asr/configs/__init__.py rename to examples/models/ctc/transformer/README.md diff --git a/examples/models/ctc/transformer/base-streaming.yml.j2 b/examples/models/ctc/transformer/base-streaming.yml.j2 new file mode 100644 index 0000000000..169251dd1b --- /dev/null +++ b/examples/models/ctc/transformer/base-streaming.yml.j2 @@ -0,0 +1,101 @@ +model_config: + class_name: tensorflow_asr.models.ctc.transformer>Transformer + config: + name: transformer + speech_config: + sample_rate: 16000 + frame_ms: 25 + stride_ms: 10 + nfft: 512 + num_feature_bins: 80 + feature_type: log_mel_spectrogram + augmentation_config: + feature_augment: + time_masking: + prob: 1.0 + num_masks: 5 + mask_factor: -1 + p_upperbound: 0.05 + freq_masking: + prob: 1.0 + num_masks: 2 + mask_factor: 27 + encoder_subsampling: + type: conv2d + filters: [512, 512] + kernels: [3, 3] + strides: [2, 2] + paddings: ["causal", "causal"] + norms: ["batch", "batch"] + activations: ["relu", "relu"] + encoder_dropout: 0.1 + encoder_residual_factor: 1.0 + encoder_norm_position: post + encoder_dmodel: 512 + encoder_dff: 1024 + encoder_num_blocks: 6 + encoder_head_size: 128 + encoder_num_heads: 4 + encoder_mha_type: mha + encoder_interleave_relpe: True + encoder_use_attention_causal_mask: False + encoder_use_attention_auto_mask: True + encoder_pwffn_activation: relu + encoder_memory_length: null + encoder_history_size: 64 # frames = 4 * chunk_size + encoder_chunk_size: 16 # frames + blank: 0 + vocab_size: {{decoder_config.vocabsize}} + kernel_regularizer: + class_name: l2 + config: + l2: 1e-6 + +learning_config: + optimizer_config: + class_name: Adam + config: + learning_rate: + class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule + config: + dmodel: 512 + warmup_steps: 10000 + max_lr: null + min_lr: null + beta_1: 0.9 + beta_2: 0.98 + epsilon: 1e-9 + + gwn_config: + predict_net_step: 0 + predict_net_stddev: 0.075 + + gradn_config: null + + batch_size: 8 + ga_steps: 4 + num_epochs: 450 + + callbacks: + - class_name: tensorflow_asr.callbacks>TerminateOnNaN + config: {} + - class_name: tensorflow_asr.callbacks>ModelCheckpoint + config: + filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 + save_best_only: False + save_weights_only: True + save_freq: epoch + - class_name: tensorflow_asr.callbacks>TensorBoard + config: + log_dir: {{modeldir}}/tensorboard + histogram_freq: 0 + write_graph: False + write_images: False + write_steps_per_second: False + update_freq: batch + profile_batch: 0 + - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore + config: + model_handle: {{kaggle_model_handle}} + model_dir: {{modeldir}} + save_freq: epoch diff --git a/examples/models/ctc/transformer/base.yml.j2 b/examples/models/ctc/transformer/base.yml.j2 new file mode 100644 index 0000000000..c4e347a171 --- /dev/null +++ b/examples/models/ctc/transformer/base.yml.j2 @@ -0,0 +1,97 @@ +model_config: + class_name: tensorflow_asr.models.ctc.transformer>Transformer + config: + name: transformer + speech_config: + sample_rate: 16000 + frame_ms: 25 + stride_ms: 10 + nfft: 512 + num_feature_bins: 80 + feature_type: log_mel_spectrogram + augmentation_config: + feature_augment: + time_masking: + prob: 1.0 + num_masks: 5 + mask_factor: -1 + p_upperbound: 0.05 + freq_masking: + prob: 1.0 + num_masks: 2 + mask_factor: 27 + encoder_subsampling: + type: conv2d + filters: [512, 512] + kernels: [3, 3] + strides: [2, 2] + paddings: ["causal", "causal"] + norms: ["batch", "batch"] + activations: ["relu", "relu"] + encoder_dropout: 0.1 + encoder_residual_factor: 1.0 + encoder_norm_position: post + encoder_dmodel: 512 + encoder_dff: 1024 + encoder_num_blocks: 6 + encoder_head_size: 128 + encoder_num_heads: 4 + encoder_mha_type: mha + encoder_interleave_relpe: True + encoder_use_attention_causal_mask: False + encoder_use_attention_auto_mask: True + encoder_pwffn_activation: relu + encoder_memory_length: null + blank: 0 + vocab_size: {{decoder_config.vocabsize}} + kernel_regularizer: + class_name: l2 + config: + l2: 1e-6 + +learning_config: + optimizer_config: + class_name: Adam + config: + learning_rate: + class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule + config: + dmodel: 512 + warmup_steps: 10000 + max_lr: null + min_lr: null + beta_1: 0.9 + beta_2: 0.98 + epsilon: 1e-9 + + gwn_config: null + + gradn_config: null + + batch_size: 8 + ga_steps: 4 + num_epochs: 450 + + callbacks: + - class_name: tensorflow_asr.callbacks>TerminateOnNaN + config: {} + - class_name: tensorflow_asr.callbacks>ModelCheckpoint + config: + filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 + save_best_only: False + save_weights_only: True + save_freq: epoch + - class_name: tensorflow_asr.callbacks>TensorBoard + config: + log_dir: {{modeldir}}/tensorboard + histogram_freq: 0 + write_graph: False + write_images: False + write_steps_per_second: False + update_freq: batch + profile_batch: 0 + - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore + config: + model_handle: {{kaggle_model_handle}} + model_dir: {{modeldir}} + save_freq: epoch diff --git a/examples/models/transducer/conformer/README.md b/examples/models/transducer/conformer/README.md new file mode 100644 index 0000000000..1fbad1ae48 --- /dev/null +++ b/examples/models/transducer/conformer/README.md @@ -0,0 +1,5 @@ +# Conformer Transducer + +## Results + +See [results](./results) for more details. \ No newline at end of file diff --git a/examples/models/transducer/conformer/inference/gen_saved_model.py b/examples/models/transducer/conformer/inference/gen_saved_model.py new file mode 100644 index 0000000000..c9cc875950 --- /dev/null +++ b/examples/models/transducer/conformer/inference/gen_saved_model.py @@ -0,0 +1,56 @@ +# # pylint: disable=no-member +# # Copyright 2020 Huy Le Nguyen (@nglehuy) +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. + +# import os + +# import fire +# from tensorflow_asr import tf, keras + +# from tensorflow_asr.configs import Config +# from tensorflow_asr.helpers import featurizer_helpers +# from tensorflow_asr.models.transducer.conformer import Conformer +# from tensorflow_asr.utils import env_util + +# logger = env_util.setup_environment() + +# DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config_wp.j2") + + +# def main( +# config_path: str = DEFAULT_YAML, +# saved: str = None, +# output_dir: str = None, +# ): +# assert saved and output_dir +# tf.random.set_seed(0) +# keras.backend.clear_session() + +# logger.info("Load config and featurizers ...") +# config = Config(config_path) +# speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(config=config) + +# logger.info("Build and load model ...") +# conformer = Conformer(**config.model_config, vocab_size=text_featurizer.num_classes) +# conformer.make(speech_featurizer.shape) +# conformer.add_featurizers(speech_featurizer, text_featurizer) +# conformer.load_weights(saved, by_name=True) +# conformer.summary() + +# logger.info("Save model ...") +# tf.saved_model.save(conformer, export_dir=output_dir, signatures=conformer.recognize_from_signal.get_concrete_function()) + + +# if __name__ == "__main__": +# fire.Fire(main) diff --git a/examples/models/transducer/conformer/inference/run_saved_model.py b/examples/models/transducer/conformer/inference/run_saved_model.py new file mode 100644 index 0000000000..56da5da980 --- /dev/null +++ b/examples/models/transducer/conformer/inference/run_saved_model.py @@ -0,0 +1,43 @@ +# # Copyright 2020 Huy Le Nguyen (@nglehuy) +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. + +# import os + +# import fire +# from tensorflow_asr import tf, keras + +# from tensorflow_asr.features.speech_featurizers import read_raw_audio +# from tensorflow_asr.utils import env_util + +# logger = env_util.setup_environment() + +# DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config_wp.j2") + + +# def main( +# saved_model: str = None, +# filename: str = None, +# ): +# keras.backend.clear_session() + +# module = tf.saved_model.load(export_dir=saved_model) + +# signal = read_raw_audio(filename) +# transcript = module.pred(signal) + +# print("Transcript: ", "".join([chr(u) for u in transcript])) + + +# if __name__ == "__main__": +# fire.Fire(main) diff --git a/examples/models/transducer/conformer/results/sentencepiece/README.md b/examples/models/transducer/conformer/results/sentencepiece/README.md new file mode 100644 index 0000000000..142429ca72 --- /dev/null +++ b/examples/models/transducer/conformer/results/sentencepiece/README.md @@ -0,0 +1,115 @@ +- [\[English\] LibriSpeech](#english-librispeech) + - [I. Small + SentencePiece 1k](#i-small--sentencepiece-1k) + - [II. Small + Streaming + SentencePiece 1k](#ii-small--streaming--sentencepiece-1k) +- [\[Vietnamese\] VietBud500](#vietnamese-vietbud500) + - [I. Small + Streaming + SentencePiece 1k](#i-small--streaming--sentencepiece-1k) + + + +# [English] LibriSpeech + +## I. Small + SentencePiece 1k + +| Category | Description | +| :---------------- | :---------------------------------------------------------------------------------------------- | +| Config | [small.yml.j2](../../small.yml.j2) | +| Tensorflow | **2.18.0** | +| Device | Google Cloud TPUs v4-8 | +| Mixed Precision | strict | +| Global Batch Size | 4 * 4 * 8 = 128 (as 4 TPUs, 8 Gradient Accumulation Steps) | +| Max Epochs | 300 | +| Pretrained | [Link](https://www.kaggle.com/models/lordh9072/tfasr-conformer-transducer/tensorFlow2/v3-small) | + +**Config:** + +```jinja2 +{% import "examples/datasets/librispeech/sentencepiece/sp.yml.j2" as decoder_config with context %} +{{decoder_config}} +{% import "examples/models/transducer/conformer/small.yml.j2" as config with context %} +{{config}} +``` + +**Results:** + +| Epoch | Dataset | decoding | wer | cer | mer | wil | wip | +| :---- | :--------- | :------- | :------- | :------- | :------- | :------- | :------- | +| 157 | test-clean | greedy | 0.062918 | 0.025361 | 0.062527 | 0.109992 | 0.890007 | +| 157 | test-other | greedy | 0.142616 | 0.066839 | 0.140610 | 0.239201 | 0.760798 | + +## II. Small + Streaming + SentencePiece 1k + +| Category | Description | +| :---------------- | :-------------------------------------------------------------------------------------------------------- | +| Config | [small-streaming.yml.j2](../../small-streaming.yml.j2) | +| Tensorflow | **2.18.0** | +| Device | Google Cloud TPUs v4-8 | +| Mixed Precision | strict | +| Global Batch Size | 4 * 4 * 8 = 128 (as 4 TPUs, 8 Gradient Accumulation Steps) | +| Max Epochs | 300 | +| Pretrained | [Link](https://www.kaggle.com/models/lordh9072/tfasr-conformer-transducer/tensorFlow2/v3-small-streaming) | + +**Config:** + +```jinja2 +{% import "examples/datasets/librispeech/sentencepiece/sp.yml.j2" as decoder_config with context %} +{{decoder_config}} +{% import "examples/models/transducer/conformer/small-streaming.yml.j2" as config with context %} +{{config}} +``` + +**Results:** + +| Epoch | Dataset | decoding | wer | cer | mer | wil | wip | +| :---- | :--------- | :------- | :-------- | :-------- | :-------- | :------- | :------- | +| 45 | test-clean | greedy | 0.0797322 | 0.0312862 | 0.0790049 | 0.137228 | 0.862772 | +| 45 | test-other | greedy | 0.211872 | 0.104173 | 0.207305 | 0.341269 | 0.658731 | + + + +# [Vietnamese] VietBud500 + +## I. Small + Streaming + SentencePiece 1k + +| Category | Description | +| :---------------- | :---------------------------------------------------------------------------------------------------------------- | +| Config | [small-streaming.yml.j2](../../small-streaming.yml.j2) | +| Tensorflow | **2.18.0** | +| Device | Google Cloud TPUs v4-8 | +| Mixed Precision | strict | +| Global Batch Size | 8 * 4 * 8 = 256 (as 4 TPUs, 8 Gradient Accumulation Steps) | +| Max Epochs | 300 | +| Pretrained | [Link](https://www.kaggle.com/models/lordh9072/tfasr-vietbud500-conformer-transducer/tensorFlow2/small-streaming) | + +**Config:** + +```jinja2 +{% import "examples/datasets/vietbud500/sentencepiece/sp.yml.j2" as decoder_config with context %} +{{decoder_config}} +{% import "examples/models/transducer/conformer/small-streaming.yml.j2" as config with context %} +{{config}} +``` + +**Tensorboard:** + + + + + + + +
+
+ Epoch Loss +
+
+ Batch Loss +
+
+ Learning Rate +
+ +**Results:** + +| Epoch | decoding | wer | cer | mer | wil | wip | +| :---- | :------- | :------- | :------- | :------ | :------- | :------- | +| 52 | greedy | 0.053723 | 0.034548 | 0.05362 | 0.086421 | 0.913579 | \ No newline at end of file diff --git a/examples/models/transducer/conformer/results/sentencepiece/figs/vietbud500-small-streaming-batch-loss.jpg b/examples/models/transducer/conformer/results/sentencepiece/figs/vietbud500-small-streaming-batch-loss.jpg new file mode 100644 index 0000000000..6d2802c432 Binary files /dev/null and b/examples/models/transducer/conformer/results/sentencepiece/figs/vietbud500-small-streaming-batch-loss.jpg differ diff --git a/examples/models/transducer/conformer/results/sentencepiece/figs/vietbud500-small-streaming-epoch-loss.jpg b/examples/models/transducer/conformer/results/sentencepiece/figs/vietbud500-small-streaming-epoch-loss.jpg new file mode 100644 index 0000000000..52dedc719b Binary files /dev/null and b/examples/models/transducer/conformer/results/sentencepiece/figs/vietbud500-small-streaming-epoch-loss.jpg differ diff --git a/examples/models/transducer/conformer/results/sentencepiece/figs/vietbud500-small-streaming-lr.jpg b/examples/models/transducer/conformer/results/sentencepiece/figs/vietbud500-small-streaming-lr.jpg new file mode 100644 index 0000000000..c21f9d733d Binary files /dev/null and b/examples/models/transducer/conformer/results/sentencepiece/figs/vietbud500-small-streaming-lr.jpg differ diff --git a/examples/transducer/conformer/results/figs/conformer.svg b/examples/models/transducer/conformer/results/subword - deprecated/figs/conformer.svg similarity index 100% rename from examples/transducer/conformer/results/figs/conformer.svg rename to examples/models/transducer/conformer/results/subword - deprecated/figs/conformer.svg diff --git a/examples/transducer/conformer/results/figs/subword_conformer_loss.svg b/examples/models/transducer/conformer/results/subword - deprecated/figs/subword_conformer_loss.svg similarity index 100% rename from examples/transducer/conformer/results/figs/subword_conformer_loss.svg rename to examples/models/transducer/conformer/results/subword - deprecated/figs/subword_conformer_loss.svg diff --git a/examples/models/transducer/conformer/small-streaming.yml.j2 b/examples/models/transducer/conformer/small-streaming.yml.j2 new file mode 100644 index 0000000000..3b05883464 --- /dev/null +++ b/examples/models/transducer/conformer/small-streaming.yml.j2 @@ -0,0 +1,109 @@ +model_config: + class_name: tensorflow_asr.models.transducer.conformer>Conformer + config: + name: conformer + speech_config: + sample_rate: 16000 + frame_ms: 25 + stride_ms: 10 + nfft: 512 + num_feature_bins: 80 + feature_type: log_mel_spectrogram + encoder_subsampling: + class_name: tensorflow_asr.models.layers.subsampling>Conv2dSubsampling + config: + filters: [144, 144] + kernels: [3, 3] + strides: [2, 2] + paddings: ["causal", "causal"] + norms: ["layer", "layer"] + activations: ["swish", "swish"] + encoder_ffm_residual_factor: 0.5 + encoder_mhsam_residual_factor: 1.0 + encoder_convm_residual_factor: 1.0 + encoder_dmodel: 144 + encoder_num_blocks: 16 + encoder_head_size: 36 # == dmodel // num_heads + encoder_num_heads: 4 + encoder_mha_type: relmha + encoder_interleave_relpe: True + encoder_use_attention_causal_mask: False + encoder_use_attention_auto_mask: True + encoder_mhsam_use_attention_bias: False + encoder_convm_dw_norm_type: layer + encoder_kernel_size: 31 + encoder_dropout: 0.1 + encoder_padding: causal + encoder_memory_length: null + encoder_history_size: 64 # frames = 4 * chunk_size + encoder_chunk_size: 16 # frames + prediction_label_encode_mode: embedding + prediction_embed_dim: 320 + prediction_num_rnns: 1 + prediction_rnn_units: 320 + prediction_rnn_type: lstm + prediction_rnn_implementation: 2 + prediction_rnn_unroll: False + prediction_layer_norm: True + prediction_projection_units: 0 + joint_dim: 320 + prejoint_encoder_linear: True + prejoint_prediction_linear: True + postjoint_linear: False + joint_activation: tanh + joint_mode: add + blank: 0 + vocab_size: {{decoder_config.vocabsize}} + kernel_regularizer: + class_name: l2 + config: + l2: 1e-6 + +learning_config: + optimizer_config: + class_name: Adam + config: + learning_rate: + class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule + config: + dmodel: 144 + warmup_steps: 10000 + max_lr: 0.05/(144**0.5) + min_lr: null + scale: 2.0 + beta_1: 0.9 + beta_2: 0.98 + epsilon: 1e-9 + weight_decay: 1e-6 + + gwn_config: null + + gradn_config: null + + batch_size: 8 + ga_steps: 4 + num_epochs: 300 + + callbacks: + - class_name: tensorflow_asr.callbacks>TerminateOnNaN + config: {} + - class_name: tensorflow_asr.callbacks>ModelCheckpoint + config: + filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 + save_best_only: False + save_weights_only: True + save_freq: epoch + - class_name: tensorflow_asr.callbacks>TensorBoard + config: + log_dir: {{modeldir}}/tensorboard + histogram_freq: 0 + write_graph: False + write_images: False + write_steps_per_second: False + update_freq: batch + profile_batch: 0 + - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore + config: + model_handle: {{kaggle_model_handle}} + model_dir: {{modeldir}} + save_freq: epoch \ No newline at end of file diff --git a/examples/models/transducer/conformer/small.yml.j2 b/examples/models/transducer/conformer/small.yml.j2 new file mode 100644 index 0000000000..6ca7269833 --- /dev/null +++ b/examples/models/transducer/conformer/small.yml.j2 @@ -0,0 +1,119 @@ +model_config: + class_name: tensorflow_asr.models.transducer.conformer>Conformer + config: + name: conformer + speech_config: + sample_rate: 16000 + frame_ms: 25 + stride_ms: 10 + nfft: 512 + num_feature_bins: 80 + feature_type: log_mel_spectrogram + augmentation_config: + feature_augment: + time_masking: + prob: 1.0 + num_masks: 10 + mask_factor: -1 + p_upperbound: 0.05 + mask_value: 0 + freq_masking: + prob: 1.0 + num_masks: 1 + mask_factor: 27 + mask_value: 0 + encoder_subsampling: + class_name: tensorflow_asr.models.layers.subsampling>Conv2dSubsampling + config: + filters: [144, 144] + kernels: [3, 3] + strides: [2, 2] + paddings: ["causal", "causal"] + norms: ["batch", "batch"] + activations: ["swish", "swish"] + encoder_ffm_residual_factor: 0.5 + encoder_mhsam_residual_factor: 1.0 + encoder_convm_residual_factor: 1.0 + encoder_dmodel: 144 + encoder_num_blocks: 16 + encoder_head_size: 36 # == dmodel // num_heads + encoder_num_heads: 4 + encoder_mha_type: relmha + encoder_interleave_relpe: True + encoder_use_attention_causal_mask: False + encoder_use_attention_auto_mask: True + encoder_mhsam_use_attention_bias: False + encoder_kernel_size: 31 + encoder_dropout: 0.1 + encoder_padding: causal + encoder_memory_length: null + prediction_label_encode_mode: embedding + prediction_embed_dim: 320 + prediction_num_rnns: 1 + prediction_rnn_units: 320 + prediction_rnn_type: lstm + prediction_rnn_implementation: 2 + prediction_rnn_unroll: False + prediction_layer_norm: True + prediction_projection_units: 0 + joint_dim: 320 + prejoint_encoder_linear: True + prejoint_prediction_linear: True + postjoint_linear: False + joint_activation: tanh + joint_mode: add + blank: 0 + vocab_size: {{decoder_config.vocabsize}} + kernel_regularizer: + class_name: l2 + config: + l2: 1e-6 + +learning_config: + optimizer_config: + class_name: Adam + config: + learning_rate: + class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule + config: + dmodel: 144 + warmup_steps: 10000 + max_lr: 0.05/(144**0.5) + min_lr: null + scale: 2.0 + beta_1: 0.9 + beta_2: 0.98 + epsilon: 1e-9 + weight_decay: 1e-6 + + gwn_config: null + + gradn_config: null + + batch_size: 2 + ga_steps: 16 + num_epochs: 300 + + callbacks: + - class_name: tensorflow_asr.callbacks>TerminateOnNaN + config: {} + - class_name: tensorflow_asr.callbacks>ModelCheckpoint + config: + filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 + save_best_only: False + save_weights_only: True + save_freq: epoch + - class_name: tensorflow_asr.callbacks>TensorBoard + config: + log_dir: {{modeldir}}/tensorboard + histogram_freq: 0 + write_graph: False + write_images: False + write_steps_per_second: False + update_freq: batch + profile_batch: 0 + - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore + config: + model_handle: {{kaggle_model_handle}} + model_dir: {{modeldir}} + save_freq: epoch \ No newline at end of file diff --git a/examples/models/transducer/contextnet/README.md b/examples/models/transducer/contextnet/README.md new file mode 100644 index 0000000000..b8b398b4bc --- /dev/null +++ b/examples/models/transducer/contextnet/README.md @@ -0,0 +1,5 @@ +# ContextNet Transducer + +## Results + +See [results](./results) for more details. \ No newline at end of file diff --git a/examples/models/transducer/contextnet/results/wordpiece/README.md b/examples/models/transducer/contextnet/results/wordpiece/README.md new file mode 100644 index 0000000000..ee486b2b22 --- /dev/null +++ b/examples/models/transducer/contextnet/results/wordpiece/README.md @@ -0,0 +1,59 @@ +**Table of Contents** +- [WordPiece 1k With Whitespace + Small + LibriSpeech](#wordpiece-1k-with-whitespace--small--librispeech) + - [Epoch Loss](#epoch-loss) + - [Batch Loss](#batch-loss) + - [Training Learning Rate](#training-learning-rate) + - [Results](#results) + +# WordPiece 1k With Whitespace + Small + LibriSpeech + + +| Category | Description | +| :---------------- | :--------------------------------- | +| Config | [small.yml.j2](../../small.yml.j2) | +| Tensorflow | **2.13.x** | +| Device | Google Colab TPUs | +| Global Batch Size | 2 * 16 * 8 = 256 (as 8 TPUs) | + + +### Epoch Loss + +![Epoch Loss](./figs/contextnet-small-wp1k-whitespace-epoch-loss.svg) + +### Batch Loss + +![Batch Loss](./figs/contextnet-small-wp1k-whitespace-batch-loss.svg) + +### Training Learning Rate + +![Learning Rate](./figs/contextnet-small-wp1k-whitespace-lr.svg) + +### Results + +Pretrain Model here: [link](https://drive.google.com/drive/folders/1xT3j_L5q4oSBeUiLArnBPliZ0g9k-N7O?usp=drive_link) + +```json +[ + { + "epoch": 273, + "test-clean": { + "greedy": { + "wer": 0.07923767498478393, + "cer": 0.0336269669307001, + "mer": 0.07840111410128536, + "wil": 0.13531145375649656, + "wip": 0.8646885462435034 + } + }, + "test-other": { + "greedy": { + "wer": 0.19121945627877654, + "cer": 0.09776798480704507, + "mer": 0.1870526453493805, + "wil": 0.3107931720744128, + "wip": 0.6892068279255872 + } + } + } +] +``` \ No newline at end of file diff --git a/examples/models/transducer/contextnet/results/wordpiece/figs/contextnet-small-wp1k-whitespace-batch-loss.svg b/examples/models/transducer/contextnet/results/wordpiece/figs/contextnet-small-wp1k-whitespace-batch-loss.svg new file mode 100644 index 0000000000..2842c0b879 --- /dev/null +++ b/examples/models/transducer/contextnet/results/wordpiece/figs/contextnet-small-wp1k-whitespace-batch-loss.svg @@ -0,0 +1 @@ +30354045505560657075-50k050k100k150k200k250k300k350k \ No newline at end of file diff --git a/examples/models/transducer/contextnet/results/wordpiece/figs/contextnet-small-wp1k-whitespace-epoch-loss.svg b/examples/models/transducer/contextnet/results/wordpiece/figs/contextnet-small-wp1k-whitespace-epoch-loss.svg new file mode 100644 index 0000000000..8c818c6a09 --- /dev/null +++ b/examples/models/transducer/contextnet/results/wordpiece/figs/contextnet-small-wp1k-whitespace-epoch-loss.svg @@ -0,0 +1 @@ +10152025303540455055606570050100150200250300 \ No newline at end of file diff --git a/examples/models/transducer/contextnet/results/wordpiece/figs/contextnet-small-wp1k-whitespace-lr.svg b/examples/models/transducer/contextnet/results/wordpiece/figs/contextnet-small-wp1k-whitespace-lr.svg new file mode 100644 index 0000000000..ecb8fa67b8 --- /dev/null +++ b/examples/models/transducer/contextnet/results/wordpiece/figs/contextnet-small-wp1k-whitespace-lr.svg @@ -0,0 +1 @@ +01e-42e-43e-44e-45e-46e-47e-48e-4-100k-50k050k100k150k200k250k300k350k400k450k \ No newline at end of file diff --git a/examples/models/transducer/contextnet/small.yml.j2 b/examples/models/transducer/contextnet/small.yml.j2 new file mode 100644 index 0000000000..914f9bee64 --- /dev/null +++ b/examples/models/transducer/contextnet/small.yml.j2 @@ -0,0 +1,268 @@ +model_config: + class_name: tensorflow_asr.models.transducer.contextnet>ContextNet + config: + name: contextnet + speech_config: + sample_rate: 16000 + frame_ms: 25 + stride_ms: 10 + num_feature_bins: 80 + feature_type: log_mel_spectrogram + augmentation_config: + feature_augment: + time_masking: + prob: 1.0 + num_masks: 10 + mask_factor: -1 # whole utterance + p_upperbound: 0.05 + mask_value: 0 + freq_masking: + prob: 1.0 + num_masks: 1 + mask_factor: 27 + mask_value: 0 + encoder_alpha: 0.5 + encoder_blocks: + # C0 + - nlayers: 1 + kernel_size: 5 + filters: 256 + strides: 1 + residual: False + activation: silu + padding: causal + # C1-C2 + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 1 + residual: True + activation: silu + padding: causal + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 1 + residual: True + activation: silu + padding: causal + # C3 + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 2 + residual: True + activation: silu + padding: causal + # C4-C6 + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 1 + residual: True + activation: silu + padding: causal + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 1 + residual: True + activation: silu + padding: causal + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 1 + residual: True + activation: silu + padding: causal + # C7 + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 2 + residual: True + activation: silu + padding: causal + # C8 - C10 + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 1 + residual: True + activation: silu + padding: causal + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 1 + residual: True + activation: silu + padding: causal + - nlayers: 5 + kernel_size: 5 + filters: 256 + strides: 1 + residual: True + activation: silu + padding: causal + # C11 - C13 + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + padding: causal + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + padding: causal + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + padding: causal + # C14 + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 2 + residual: True + activation: silu + padding: causal + # C15 - C21 + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + padding: causal + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + padding: causal + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + padding: causal + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + padding: causal + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + padding: causal + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + padding: causal + - nlayers: 5 + kernel_size: 5 + filters: 512 + strides: 1 + residual: True + activation: silu + padding: causal + # C22 + - nlayers: 1 + kernel_size: 5 + filters: 640 + strides: 1 + residual: False + activation: silu + padding: causal + prediction_label_encode_mode: embedding + prediction_embed_dim: 640 + prediction_num_rnns: 1 + prediction_rnn_units: 512 + prediction_rnn_type: lstm + prediction_rnn_implementation: 2 + prediction_rnn_unroll: False + prediction_layer_norm: False + prediction_projection_units: 0 + joint_dim: 512 + prejoint_encoder_linear: True + prejoint_prediction_linear: True + postjoint_linear: False + joint_activation: tanh + joint_mode: add + blank: 0 + vocab_size: {{decoder_config.vocabsize}} + kernel_regularizer: + class_name: l2 + config: + l2: 1e-6 + +learning_config: + optimizer_config: + class_name: Adam + config: + learning_rate: + class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule + config: + dmodel: 320 + warmup_steps: 15000 + max_lr: 0.0025 + min_lr: 1e-6 + scale: 2.0 + beta_1: 0.9 + beta_2: 0.98 + epsilon: 1e-9 + weight_decay: 1e-6 + + gwn_config: + predict_net_step: 20000 + predict_net_stddev: 0.075 + + gradn_config: null + + batch_size: 4 + ga_steps: 8 + num_epochs: 400 + + callbacks: + - class_name: tensorflow_asr.callbacks>TerminateOnNaN + config: {} + - class_name: tensorflow_asr.callbacks>ModelCheckpoint + config: + filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 + save_best_only: False + save_weights_only: True + save_freq: epoch + - class_name: tensorflow_asr.callbacks>TensorBoard + config: + log_dir: {{modeldir}}/tensorboard + histogram_freq: 0 + write_graph: False + write_images: False + write_steps_per_second: False + update_freq: batch + profile_batch: 0 + - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore + config: + model_handle: {{kaggle_model_handle}} + model_dir: {{modeldir}} + save_freq: epoch diff --git a/tensorflow_asr/datasets/__init__.py b/examples/models/transducer/rnnt/README.md similarity index 100% rename from tensorflow_asr/datasets/__init__.py rename to examples/models/transducer/rnnt/README.md diff --git a/examples/models/transducer/rnnt/results/sentencepiece/README.md b/examples/models/transducer/rnnt/results/sentencepiece/README.md new file mode 100644 index 0000000000..03b137f092 --- /dev/null +++ b/examples/models/transducer/rnnt/results/sentencepiece/README.md @@ -0,0 +1,57 @@ +- [SentencePiece 256 + Tiny + LibriSpeech](#sentencepiece-256--tiny--librispeech) + - [Training Loss](#training-loss) + - [1. Epoch Loss](#1-epoch-loss) + - [2. Batch Loss](#2-batch-loss) + - [Results](#results) + + +# SentencePiece 256 + Tiny + LibriSpeech + +| Category | Description | +| :---------------- | :------------------------------- | +| Config | [tiny.yml.j2](../../tiny.yml.j2) | +| Tensorflow | **2.15.x** | +| Device | NVIDIA GeForce GTX 1650 | +| Global Batch Size | 3 | +| Max Epochs | 300 | + + +### Training Loss + +#### 1. Epoch Loss + +![Epoch Loss](./figs/rnnt-tiny-sp256-epoch-loss.svg) + +#### 2. Batch Loss + +![Batch Loss](./figs/rnnt-tiny-sp256-batch-loss.svg) + + +### Results + +Pretrain Model here: [link](https://drive.google.com/drive/folders/1h0BrCzZo8JTz_MUU5bJPJ3UBqroBnsuv?usp=sharing) + +```json +[ + { + "epoch": 136, + "test-clean": { + "greedy": { + "wer": 0.15853241022519782, + "cer": 0.07179696657549817, + "mer": 0.15537908021549876, + "wil": 0.2587056704145151, + "wip": 0.7412943295854849 + } + }, + "test-other": { + "greedy": { + "wer": 0.3457577899623636, + "cer": 0.18733822655980759, + "mer": 0.33391759995571874, + "wil": 0.5185365485613327, + "wip": 0.48146345143866726 + } + } + }, +] \ No newline at end of file diff --git a/examples/models/transducer/rnnt/results/sentencepiece/figs/rnnt-tiny-sp256-batch-loss.svg b/examples/models/transducer/rnnt/results/sentencepiece/figs/rnnt-tiny-sp256-batch-loss.svg new file mode 100644 index 0000000000..c90c689ff3 --- /dev/null +++ b/examples/models/transducer/rnnt/results/sentencepiece/figs/rnnt-tiny-sp256-batch-loss.svg @@ -0,0 +1 @@ +303540455055606570-100k0100k200k300k400k500k600k700k800k \ No newline at end of file diff --git a/examples/models/transducer/rnnt/results/sentencepiece/figs/rnnt-tiny-sp256-epoch-loss.svg b/examples/models/transducer/rnnt/results/sentencepiece/figs/rnnt-tiny-sp256-epoch-loss.svg new file mode 100644 index 0000000000..21b438c108 --- /dev/null +++ b/examples/models/transducer/rnnt/results/sentencepiece/figs/rnnt-tiny-sp256-epoch-loss.svg @@ -0,0 +1 @@ +3436384042444648505254565860-20020406080100120140 \ No newline at end of file diff --git a/examples/transducer/rnnt/results/subword.md b/examples/models/transducer/rnnt/results/subword - deprecated/README.md similarity index 100% rename from examples/transducer/rnnt/results/subword.md rename to examples/models/transducer/rnnt/results/subword - deprecated/README.md diff --git a/examples/transducer/rnnt/results/figs/epoch_learning_rate.svg b/examples/models/transducer/rnnt/results/subword - deprecated/figs/epoch_learning_rate.svg similarity index 100% rename from examples/transducer/rnnt/results/figs/epoch_learning_rate.svg rename to examples/models/transducer/rnnt/results/subword - deprecated/figs/epoch_learning_rate.svg diff --git a/examples/transducer/rnnt/results/figs/subword_rnnt_loss.svg b/examples/models/transducer/rnnt/results/subword - deprecated/figs/subword_rnnt_loss.svg similarity index 100% rename from examples/transducer/rnnt/results/figs/subword_rnnt_loss.svg rename to examples/models/transducer/rnnt/results/subword - deprecated/figs/subword_rnnt_loss.svg diff --git a/examples/models/transducer/rnnt/small.yml.j2 b/examples/models/transducer/rnnt/small.yml.j2 new file mode 100644 index 0000000000..72fdaf969b --- /dev/null +++ b/examples/models/transducer/rnnt/small.yml.j2 @@ -0,0 +1,100 @@ +model_config: + class_name: tensorflow_asr.models.transducer.rnnt>RnnTransducer + config: + name: rnn_transducer + speech_config: + sample_rate: 16000 + frame_ms: 25 + stride_ms: 10 + num_feature_bins: 80 + nfft: 512 + feature_type: log_mel_spectrogram + augmentation_config: + feature_augment: + time_masking: + prob: 1.0 + num_masks: 5 + mask_factor: -1 + p_upperbound: 0.05 + freq_masking: + prob: 1.0 + num_masks: 1 + mask_factor: 27 + encoder_reduction_positions: [ post, post, post, post ] + encoder_reduction_factors: [ 3, 0, 2, 0 ] # downsampled to 30ms and add 2 reduction after second layer + encoder_dmodel: 320 + encoder_rnn_type: lstm + encoder_rnn_units: 1024 + encoder_nlayers: 4 + encoder_layer_norm: True + prediction_label_encode_mode: embedding + prediction_embed_dim: 512 + prediction_num_rnns: 1 + prediction_rnn_units: 1024 + prediction_rnn_type: lstm + prediction_rnn_unroll: False + prediction_layer_norm: True + prediction_projection_units: 0 + joint_dim: 320 + prejoint_encoder_linear: True + prejoint_prediction_linear: True + postjoint_linear: False + joint_activation: tanh + joint_mode: add + blank: 0 + vocab_size: {{decoder_config.vocabsize}} + kernel_regularizer: + class_name: l2 + config: + l2: 1e-6 + +learning_config: + optimizer_config: + class_name: Adam + config: + learning_rate: + class_name: tensorflow_asr.optimizers.schedules>TransformerSchedule + config: + dmodel: 320 + warmup_steps: 10000 + max_lr: null + min_lr: 1e-6 + scale: 2.0 + beta_1: 0.9 + beta_2: 0.98 + epsilon: 1e-9 + weight_decay: 1e-6 + + gwn_config: + predict_net_step: 20000 + predict_net_stddev: 0.075 + + gradn_config: null + + batch_size: 4 + ga_steps: 8 + num_epochs: 300 + + callbacks: + - class_name: tensorflow_asr.callbacks>TerminateOnNaN + config: {} + - class_name: tensorflow_asr.callbacks>ModelCheckpoint + config: + filepath: {{modeldir}}/checkpoints/{epoch:02d}.weights.h5 + save_best_only: False + save_weights_only: True + save_freq: epoch + - class_name: tensorflow_asr.callbacks>TensorBoard + config: + log_dir: {{modeldir}}/tensorboard + histogram_freq: 0 + write_graph: False + write_images: False + write_steps_per_second: False + update_freq: batch + profile_batch: 0 + - class_name: tensorflow_asr.callbacks>KaggleModelBackupAndRestore + config: + model_handle: {{kaggle_model_handle}} + model_dir: {{modeldir}} + save_freq: epoch diff --git a/examples/test.py b/examples/test.py deleted file mode 100644 index 4735c04c8e..0000000000 --- a/examples/test.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2023 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from tensorflow_asr import tf # import to aid logging messages -from tensorflow_asr.configs.config import Config -from tensorflow_asr.helpers import dataset_helpers, exec_helpers, featurizer_helpers -from tensorflow_asr.utils import cli_util, env_util, file_util - - -def main( - config_path: str, - h5: str = None, - mxp: str = "none", - bs: int = None, - device: int = 0, - cpu: bool = False, - output: str = "test.tsv", -): - assert h5 and output - tf.keras.backend.clear_session() - env_util.setup_seed() - env_util.setup_devices([device], cpu=cpu) - env_util.setup_mxp(mxp=mxp) - - config = Config(config_path) - - batch_size = bs or config.learning_config.running_config.batch_size - speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(config=config) - - model = tf.keras.models.model_from_config(config.model_config) - model.make(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=batch_size) - model.load_weights(h5, by_name=file_util.is_hdf5_filepath(h5)) - model.summary() - model.add_featurizers(speech_featurizer, text_featurizer) - - test_dataset = dataset_helpers.prepare_testing_datasets(config=config, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer) - test_data_loader = test_dataset.create(batch_size) - - exec_helpers.run_testing(model=model, test_dataset=test_dataset, test_data_loader=test_data_loader, output=output) - - -if __name__ == "__main__": - cli_util.run(main) diff --git a/examples/tflite.py b/examples/tflite.py deleted file mode 100644 index 73d3cc7068..0000000000 --- a/examples/tflite.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2023 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from tensorflow_asr import tf # import to aid logging messages -from tensorflow_asr.configs.config import Config -from tensorflow_asr.helpers import exec_helpers, featurizer_helpers -from tensorflow_asr.utils import cli_util, env_util, file_util - - -def main( - config_path: str, - h5: str = None, - output: str = None, -): - assert h5 and output - tf.keras.backend.clear_session() - env_util.setup_seed() - tf.compat.v1.enable_control_flow_v2() - - config = Config(config_path) - speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(config=config) - - model = tf.keras.models.model_from_config(config.model_config) - model.make(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape) - model.load_weights(h5, by_name=file_util.is_hdf5_filepath(h5)) - model.summary() - model.add_featurizers(speech_featurizer, text_featurizer) - - exec_helpers.convert_tflite(model=model, output=output) - - -if __name__ == "__main__": - cli_util.run(main) diff --git a/examples/train.py b/examples/train.py deleted file mode 100644 index 462026ce8e..0000000000 --- a/examples/train.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright 2023 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from tensorflow_asr import tf # import to aid logging messages -from tensorflow_asr.configs.config import Config -from tensorflow_asr.helpers import dataset_helpers, featurizer_helpers -from tensorflow_asr.utils import cli_util, env_util, file_util - - -def main( - config_path: str, - tfrecords: bool = False, - bs: int = None, - spx: int = 1, - devices: list = None, - mxp: str = "none", - jit_compile: bool = False, - ga_steps: int = None, -): - tf.keras.backend.clear_session() - env_util.setup_seed() - strategy = env_util.setup_strategy(devices) - env_util.setup_mxp(mxp=mxp) - - config = Config(config_path) - - speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(config=config) - - train_dataset, eval_dataset = dataset_helpers.prepare_training_datasets( - config=config, - speech_featurizer=speech_featurizer, - text_featurizer=text_featurizer, - tfrecords=tfrecords, - ) - - train_data_loader, eval_data_loader, global_batch_size = dataset_helpers.prepare_training_data_loaders( - config=config, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - strategy=strategy, - batch_size=bs, - ) - - with strategy.scope(): - model = tf.keras.models.model_from_config(config.model_config) - model.make(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=global_batch_size) - if config.learning_config.pretrained: - model.load_weights( - config.learning_config.pretrained, - by_name=file_util.is_hdf5_filepath(config.learning_config.pretrained), - skip_mismatch=True, - ) - model.compile( - optimizer=tf.keras.optimizers.get(config.learning_config.optimizer_config), - steps_per_execution=spx, - blank=text_featurizer.blank, - jit_compile=jit_compile, - mxp=mxp, - ga_steps=ga_steps or config.learning_config.running_config.ga_steps, - apply_gwn_config=config.learning_config.apply_gwn_config, - ) - model.summary() - - callbacks = [ - tf.keras.callbacks.TerminateOnNaN(), - tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), - tf.keras.callbacks.BackupAndRestore(**config.learning_config.running_config.backup_and_restore), - tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard), - ] - if config.learning_config.running_config.early_stopping: - callbacks.append(tf.keras.callbacks.EarlyStopping(**config.learning_config.running_config.early_stopping)) - - model.fit( - train_data_loader, - epochs=config.learning_config.running_config.num_epochs, - validation_data=eval_data_loader, - callbacks=callbacks, - steps_per_epoch=train_dataset.total_steps, - validation_steps=eval_dataset.total_steps if eval_data_loader else None, - ) - - -if __name__ == "__main__": - cli_util.run(main) diff --git a/examples/transducer/conformer/README.md b/examples/transducer/conformer/README.md deleted file mode 100755 index 082c09d86f..0000000000 --- a/examples/transducer/conformer/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# Conformer: Convolution-augmented Transformer for Speech Recognition - -Reference: [https://arxiv.org/abs/2005.08100](https://arxiv.org/abs/2005.08100) - -## Example Model YAML Config - -Go to [config.yml](./config.yml) - -## Usage - -Training, see `python examples/transducer/conformer/train.py --help` - -Testing, see `python examples/transducer/conformer/test.py --help` - -TFLite Conversion, see `python examples/transducer/conformer/tflite.py --help` - diff --git a/examples/transducer/conformer/confs/config_char.j2 b/examples/transducer/conformer/confs/config_char.j2 deleted file mode 100644 index 769b4b6fdd..0000000000 --- a/examples/transducer/conformer/confs/config_char.j2 +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -{% set repodir = "." %} -{% set modeldir = "/Users/nlhuy/Paraphernalia/models/local/conformer" %} -{% set datadir = "/Users/nlhuy/Paraphernalia/data/LibriSpeech" %} - -decoder_config: - type: characters - - blank_index: 0 - - beam_width: 0 - norm_score: True - lm_config: null - - vocabulary: {{repodir}}/vocabularies/english.characters - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - normalize_feature: True - -model_config: - class_name: tensorflow_asr.models.transducer>Conformer - config: - encoder_subsampling: - type: conv2d - filters: 144 - nlayers: 2 - kernel_size: 3 - strides: 2 - padding: same - norm: none - activation: relu - encoder_ffm_residual_factor: 0.5 - encoder_mhsam_residual_factor: 1.0 - encoder_convm_residual_factor: 1.0 - encoder_dmodel: 144 - encoder_num_blocks: 2 - encoder_head_size: 36 # == dmodel // num_heads - encoder_num_heads: 4 - encoder_mha_type: relmha - encoder_interleave_relpe: True - encoder_use_attention_causal_mask: False - encoder_use_attention_auto_mask: True - encoder_kernel_size: 32 - encoder_dropout: 0.1 - encoder_padding: causal - prediction_label_encode_mode: embedding - prediction_embed_dim: 320 - prediction_num_rnns: 1 - prediction_rnn_units: 320 - prediction_rnn_type: lstm - prediction_rnn_implementation: 2 - prediction_rnn_unroll: False # False to use with CUDA or dynamic length, True to use with TPU and static length - prediction_layer_norm: False - prediction_projection_units: 144 - joint_dim: 320 - prejoint_encoder_linear: False - prejoint_prediction_linear: False - postjoint_linear: True - joint_activation: tanh - joint_mode: add - blank: 0 - vocab_size: 29 - -learning_config: - train_dataset_config: - enabled: True - use_tf: True - augmentation_config: - feature_augment: - time_masking: - prob: 1.0 - num_masks: 10 - mask_factor: 100 - p_upperbound: 0.05 - mask_value: zero - freq_masking: - prob: 1.0 - num_masks: 1 - mask_factor: 27 - mask_value: zero - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: {{datadir}}/tfrecords - shuffle: True - cache: True - buffer_size: 100 - drop_remainder: True - stage: train - - eval_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: eval - - test_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/test-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: test - - optimizer_config: - class_name: adam - config: - learning_rate: - class_name: tensorflow_asr.optimizers.schedule>TransformerSchedule - config: - dmodel: 144 - initial_lr: 1.0 - warmup_steps: 10000 - max_lr: 0.00035 - min_lr: 1e-6 - beta_1: 0.9 - beta_2: 0.98 - epsilon: 1e-9 - - apply_gwn_config: - predict_net_step: 20000 - predict_net_stddev: 0.075 - - running_config: - batch_size: 2 - num_epochs: 100 - checkpoint: - filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5 - save_best_only: False - save_weights_only: True - save_freq: epoch - options: - experimental_enable_async_checkpoint: True - backup_and_restore: - backup_dir: {{modeldir}}/states - save_freq: epoch - delete_checkpoint: False - tensorboard: - log_dir: {{modeldir}}/tensorboard - histogram_freq: 1 - write_graph: True - write_images: True - update_freq: epoch - profile_batch: 2 diff --git a/examples/transducer/conformer/confs/config_sp.j2 b/examples/transducer/conformer/confs/config_sp.j2 deleted file mode 100644 index 808b55c18b..0000000000 --- a/examples/transducer/conformer/confs/config_sp.j2 +++ /dev/null @@ -1,178 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -{% set repodir = "." %} -{% set modeldir = "/Users/nlhuy/Paraphernalia/models/local/conformer" %} -{% set datadir = "/Users/nlhuy/Paraphernalia/data/LibriSpeech" %} - -decoder_config: - type: sentencepiece - - blank_index: 0 - pad_token: "" - pad_index: 0 - unknown_token: "" - unknown_index: 1 - bos_token: "" - bos_index: 2 - eos_token: "" - eos_index: 3 - - beam_width: 0 - norm_score: True - lm_config: null - - model_type: bpe - vocabulary: {{repodir}}/vocabularies/librispeech/sentencepiece/train_bpe_1000.model - vocab_size: 1000 - max_token_length: 50 - max_unique_chars: 1000 - reserved_tokens: null - normalization_form: NFKC - num_iterations: 4 - - corpus_files: - - {{datadir}}/train-clean-100/transcripts.tsv - - {{datadir}}/train-clean-360/transcripts.tsv - - {{datadir}}/train-other-500/transcripts.tsv - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - normalize_feature: True - -model_config: - name: conformer - encoder_subsampling: - type: conv2d - filters: 144 - nlayers: 2 - kernel_size: 3 - strides: 2 - padding: valid - norm: batch - activation: swish - encoder_dmodel: 144 - encoder_num_blocks: 2 - encoder_head_size: 36 # == dmodel // num_heads - encoder_num_heads: 4 - encoder_mha_type: relmha - encoder_use_attention_causal_mask: False - encoder_kernel_size: 32 - encoder_fc_factor: 0.5 - encoder_dropout: 0.1 - encoder_padding: causal - prediction_label_encode_mode: embedding - prediction_embed_dim: 320 - prediction_num_rnns: 1 - prediction_rnn_units: 320 - prediction_rnn_type: lstm - prediction_rnn_implementation: 2 - prediction_rnn_unroll: False # False to use with CUDA or dynamic length, True to use with TPU and static length - prediction_layer_norm: False - prediction_projection_units: 144 - joint_dim: 320 - prejoint_encoder_linear: False - prejoint_prediction_linear: False - postjoint_linear: True - joint_activation: tanh - joint_mode: add - -learning_config: - train_dataset_config: - enabled: True - use_tf: True - augmentation_config: - feature_augment: - time_masking: - prob: 1.0 - num_masks: 10 - mask_factor: 100 - p_upperbound: 0.05 - mask_value: zero - freq_masking: - prob: 1.0 - num_masks: 1 - mask_factor: 27 - mask_value: zero - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: {{datadir}}/tfrecords - shuffle: True - cache: True - buffer_size: 100 - drop_remainder: False - stage: train - - eval_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: eval - - test_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/test-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: test - - optimizer_config: - beta_1: 0.9 - beta_2: 0.98 - epsilon: 1e-9 - - learning_rate_config: - warmup_steps: 10000 - max_lr_numerator: 0.05 - - apply_gwn_config: - predict_net_step: 20000 - predict_net_stddev: 0.075 - - running_config: - batch_size: 2 - num_epochs: 100 - checkpoint: - filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5 - save_best_only: False - save_weights_only: True - save_freq: epoch - options: - experimental_enable_async_checkpoint: True - backup_and_restore: - backup_dir: {{modeldir}}/states - save_freq: epoch - delete_checkpoint: False - tensorboard: - log_dir: {{modeldir}}/tensorboard - histogram_freq: 1 - write_graph: True - write_images: True - update_freq: epoch - profile_batch: 2 diff --git a/examples/transducer/conformer/confs/config_wp_4.j2 b/examples/transducer/conformer/confs/config_wp_4.j2 deleted file mode 100644 index 606ab911b0..0000000000 --- a/examples/transducer/conformer/confs/config_wp_4.j2 +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -{% set repodir = "." %} -{% set modeldir = "/Users/nlhuy/Paraphernalia/models/local/conformer" %} -{% set datadir = "/Users/nlhuy/Paraphernalia/data/LibriSpeech" %} - -decoder_config: - type: wordpiece - - blank_index: 0 - unknown_token: "" - unknown_index: 0 - - beam_width: 0 - norm_score: True - lm_config: null - - vocabulary: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000_4.tokens - vocab_size: 1000 - max_token_length: 4 - max_unique_chars: 1000 - reserved_tokens: - - "" - normalization_form: NFKC - num_iterations: 4 - - corpus_files: - - {{datadir}}/train-clean-100/transcripts.tsv - - {{datadir}}/train-clean-360/transcripts.tsv - - {{datadir}}/train-other-500/transcripts.tsv - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - normalize_feature: True - -model_config: - name: conformer - encoder_subsampling: - type: conv2d - filters: 144 - nlayers: 2 - kernel_size: 3 - strides: 2 - padding: same - norm: batch - activation: swish - encoder_dmodel: 144 - encoder_num_blocks: 2 - encoder_head_size: 36 # == dmodel // num_heads - encoder_num_heads: 4 - encoder_mha_type: relmha - encoder_use_attention_causal_mask: False - encoder_kernel_size: 32 - encoder_fc_factor: 0.5 - encoder_dropout: 0.1 - encoder_padding: causal - prediction_label_encode_mode: embedding - prediction_embed_dim: 320 - prediction_num_rnns: 1 - prediction_rnn_units: 320 - prediction_rnn_type: lstm - prediction_rnn_implementation: 2 - prediction_rnn_unroll: False # False to use with CUDA or dynamic length, True to use with TPU and static length - prediction_layer_norm: False - prediction_projection_units: 144 - joint_dim: 320 - prejoint_encoder_linear: False - prejoint_prediction_linear: False - postjoint_linear: True - joint_activation: tanh - joint_mode: add - -learning_config: - train_dataset_config: - enabled: True - use_tf: True - augmentation_config: - feature_augment: - time_masking: - prob: 1.0 - num_masks: 10 - mask_factor: 100 - p_upperbound: 0.05 - mask_value: zero - freq_masking: - prob: 1.0 - num_masks: 1 - mask_factor: 27 - mask_value: zero - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: {{datadir}}/tfrecords - shuffle: True - cache: True - buffer_size: 100 - drop_remainder: True - stage: train - - eval_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: eval - - test_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/test-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: test - - optimizer_config: - beta_1: 0.9 - beta_2: 0.98 - epsilon: 1e-9 - - learning_rate_config: - warmup_steps: 10000 - max_lr_numerator: 0.05 - - apply_gwn_config: - predict_net_step: 20000 - predict_net_stddev: 0.075 - - running_config: - batch_size: 2 - num_epochs: 100 - checkpoint: - filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5 - save_best_only: False - save_weights_only: True - save_freq: epoch - options: - experimental_enable_async_checkpoint: True - backup_and_restore: - backup_dir: {{modeldir}}/states - save_freq: epoch - delete_checkpoint: False - tensorboard: - log_dir: {{modeldir}}/tensorboard - histogram_freq: 1 - write_graph: True - write_images: True - update_freq: epoch - profile_batch: 2 diff --git a/examples/transducer/conformer/confs/config_wp_6.j2 b/examples/transducer/conformer/confs/config_wp_6.j2 deleted file mode 100644 index 9c72673ea0..0000000000 --- a/examples/transducer/conformer/confs/config_wp_6.j2 +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -{% set repodir = "." %} -{% set modeldir = "/Volumes/Data/Miscellanea/Models/local/conformer" %} -{% set datadir = "/Volumes/Data/MLDL/Datasets/ASR/LibriSpeech" %} - -decoder_config: - type: wordpiece - - blank_index: 0 - unknown_token: "" - unknown_index: 0 - - beam_width: 0 - norm_score: True - lm_config: null - - vocabulary: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000_6.tokens - vocab_size: 1000 - max_token_length: 6 - max_unique_chars: 1000 - reserved_tokens: - - "" - normalization_form: NFKC - num_iterations: 4 - - corpus_files: - - {{datadir}}/train-clean-100/transcripts.tsv - - {{datadir}}/train-clean-360/transcripts.tsv - - {{datadir}}/train-other-500/transcripts.tsv - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - normalize_feature: True - -model_config: - name: conformer - encoder_subsampling: - type: conv2d - nlayers: 2 - filters: 144 - kernel_size: 3 - strides: 2 - padding: causal - norm: batch - activation: swish - encoder_dmodel: 144 - encoder_num_blocks: 16 - encoder_head_size: 36 - encoder_num_heads: 4 - encoder_mha_type: relmha - encoder_interleave_relpe: True - encoder_use_attention_causal_mask: False - encoder_use_attention_auto_mask: True - encoder_kernel_size: 32 - encoder_dropout: 0.1 - encoder_padding: causal - encoder_dense_as_pointwise: False - encoder_depthwise_as_groupwise: False - encoder_ffm_residual_factor: 0.5 - encoder_mhsam_residual_factor: 1.0 - encoder_convm_residual_factor: 1.0 - encoder_module_norm_position: pre - encoder_block_norm_position: post - encoder_memory_length: 512 - prediction_label_encode_mode: embedding - prediction_embed_dim: 320 - prediction_num_rnns: 1 - prediction_rnn_units: 320 - prediction_rnn_type: lstm - prediction_rnn_implementation: 2 - prediction_rnn_unroll: False - prediction_layer_norm: True - prediction_projection_units: 0 - joint_dim: 320 - prejoint_encoder_linear: True - prejoint_prediction_linear: True - postjoint_linear: False - joint_activation: tanh - joint_mode: add - -learning_config: - train_dataset_config: - enabled: True - use_tf: True - augmentation_config: - feature_augment: - time_masking: - prob: 1.0 - num_masks: 10 - mask_factor: 100 - p_upperbound: 0.05 - mask_value: zero - freq_masking: - prob: 1.0 - num_masks: 1 - mask_factor: 27 - mask_value: zero - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: {{datadir}}/tfrecords - shuffle: True - cache: True - buffer_size: 100 - drop_remainder: True - stage: train - - eval_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: eval - - test_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/test-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: test - - optimizer_config: - beta_1: 0.9 - beta_2: 0.98 - epsilon: 1e-9 - - learning_rate_config: - warmup_steps: 10000 - max_lr_numerator: 0.05 - - apply_gwn_config: - predict_net_step: 20000 - predict_net_stddev: 0.075 - - running_config: - batch_size: 2 - num_epochs: 100 - checkpoint: - filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5 - save_best_only: False - save_weights_only: True - save_freq: epoch - options: - experimental_enable_async_checkpoint: True - backup_and_restore: - backup_dir: {{modeldir}}/states - save_freq: epoch - delete_checkpoint: False - tensorboard: - log_dir: {{modeldir}}/tensorboard - histogram_freq: 1 - write_graph: True - write_images: True - update_freq: epoch - profile_batch: 2 diff --git a/examples/transducer/conformer/confs/rezero_config_wp.j2 b/examples/transducer/conformer/confs/rezero_config_wp.j2 deleted file mode 100644 index 500403f38f..0000000000 --- a/examples/transducer/conformer/confs/rezero_config_wp.j2 +++ /dev/null @@ -1,177 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -{% set repodir = "." %} -{% set modeldir = "/Users/nlhuy/Paraphernalia/models/local/conformer" %} -{% set datadir = "/Users/nlhuy/Paraphernalia/data/LibriSpeech" %} - -decoder_config: - type: wordpiece - blank_index: 0 - unknown_token: "" - unknown_index: 1 - beam_width: 0 - norm_score: True - lm_config: null - vocabulary: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000_50.tokens - vocab_size: 1000 - max_token_length: 50 - max_unique_chars: 1000 - reserved_tokens: - - "" - - "" - normalization_form: NFKC - num_iterations: 4 - corpus_files: - - {{datadir}}/train-clean-100/transcripts.tsv - - {{datadir}}/train-clean-360/transcripts.tsv - - {{datadir}}/train-other-500/transcripts.tsv - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - normalize_feature: True - -model_config: - name: conformer - encoder_subsampling: - type: conv2d - nlayers: 2 - filters: 144 - kernel_size: 3 - strides: 2 - padding: causal - norm: none - activation: relu - encoder_ffm_residual_factor: rezero - encoder_mhsam_residual_factor: rezero - encoder_convm_residual_factor: rezero - encoder_module_norm_position: none - encoder_block_norm_position: none - encoder_dmodel: 144 - encoder_num_blocks: 16 - encoder_head_size: 36 - encoder_num_heads: 4 - encoder_mha_type: relmha - encoder_interleave_relpe: True - encoder_use_attention_causal_mask: False - encoder_use_attention_auto_mask: True - encoder_kernel_size: 32 - encoder_dropout: 0.1 - encoder_padding: same - encoder_dense_as_pointwise: False - encoder_depthwise_as_groupwise: True - prediction_label_encode_mode: embedding - prediction_embed_dim: 320 - prediction_num_rnns: 1 - prediction_rnn_units: 320 - prediction_rnn_type: lstm - prediction_rnn_implementation: 2 - prediction_rnn_unroll: False - prediction_layer_norm: True - prediction_projection_units: 0 - joint_dim: 320 - prejoint_encoder_linear: True - prejoint_prediction_linear: True - postjoint_linear: False - joint_activation: tanh - joint_mode: add - -learning_config: - train_dataset_config: - enabled: True - use_tf: True - augmentation_config: - feature_augment: - time_masking: - prob: 1.0 - num_masks: 10 - mask_factor: 100 - p_upperbound: 0.05 - mask_value: zero - freq_masking: - prob: 1.0 - num_masks: 1 - mask_factor: 27 - mask_value: zero - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: {{datadir}}/tfrecords - shuffle: True - cache: True - buffer_size: 100 - drop_remainder: True - stage: train - - eval_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: eval - - test_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/test-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: test - - optimizer_config: - beta_1: 0.9 - beta_2: 0.98 - epsilon: 1e-9 - - learning_rate_config: - warmup_steps: 10000 - max_lr_numerator: 0.05 - - apply_gwn_config: - predict_net_step: 20000 - predict_net_stddev: 0.075 - - running_config: - batch_size: 2 - num_epochs: 100 - checkpoint: - filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5 - save_best_only: False - save_weights_only: True - save_freq: epoch - options: - experimental_enable_async_checkpoint: True - backup_and_restore: - backup_dir: {{modeldir}}/states - save_freq: epoch - delete_checkpoint: False - tensorboard: - log_dir: {{modeldir}}/tensorboard - histogram_freq: 1 - write_graph: True - write_images: True - update_freq: epoch - profile_batch: 2 diff --git a/examples/transducer/conformer/inference/gen_saved_model.py b/examples/transducer/conformer/inference/gen_saved_model.py deleted file mode 100644 index 8c8846dfb8..0000000000 --- a/examples/transducer/conformer/inference/gen_saved_model.py +++ /dev/null @@ -1,56 +0,0 @@ -# pylint: disable=no-member -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import fire -import tensorflow as tf - -from tensorflow_asr.configs.config import Config -from tensorflow_asr.helpers import featurizer_helpers -from tensorflow_asr.models.transducer.conformer import Conformer -from tensorflow_asr.utils import env_util - -logger = env_util.setup_environment() - -DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config_wp.j2") - - -def main( - config_path: str = DEFAULT_YAML, - saved: str = None, - output_dir: str = None, -): - assert saved and output_dir - tf.random.set_seed(0) - tf.keras.backend.clear_session() - - logger.info("Load config and featurizers ...") - config = Config(config_path) - speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(config=config) - - logger.info("Build and load model ...") - conformer = Conformer(**config.model_config, vocab_size=text_featurizer.num_classes) - conformer.make(speech_featurizer.shape) - conformer.add_featurizers(speech_featurizer, text_featurizer) - conformer.load_weights(saved, by_name=True) - conformer.summary() - - logger.info("Save model ...") - tf.saved_model.save(conformer, export_dir=output_dir, signatures=conformer.recognize_from_signal.get_concrete_function()) - - -if __name__ == "__main__": - fire.Fire(main) diff --git a/examples/transducer/conformer/inference/run_saved_model.py b/examples/transducer/conformer/inference/run_saved_model.py deleted file mode 100644 index 35908af279..0000000000 --- a/examples/transducer/conformer/inference/run_saved_model.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -import fire -import tensorflow as tf - -from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio -from tensorflow_asr.utils import env_util - -logger = env_util.setup_environment() - -DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config_wp.j2") - - -def main( - saved_model: str = None, - filename: str = None, -): - tf.keras.backend.clear_session() - - module = tf.saved_model.load(export_dir=saved_model) - - signal = read_raw_audio(filename) - transcript = module.pred(signal) - - print("Transcript: ", "".join([chr(u) for u in transcript])) - - -if __name__ == "__main__": - fire.Fire(main) diff --git a/examples/transducer/conformer/inference/run_tflite_model.py b/examples/transducer/conformer/inference/run_tflite_model.py deleted file mode 100644 index 3bea332e4c..0000000000 --- a/examples/transducer/conformer/inference/run_tflite_model.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import fire -import tensorflow as tf - -from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio - - -def main( - filename: str, - tflite: str = None, - blank: int = 0, - num_rnns: int = 1, - nstates: int = 2, - statesize: int = 320, -): - tflitemodel = tf.lite.Interpreter(model_path=tflite) - - signal = read_raw_audio(filename) - - input_details = tflitemodel.get_input_details() - output_details = tflitemodel.get_output_details() - tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape) - tflitemodel.allocate_tensors() - tflitemodel.set_tensor(input_details[0]["index"], signal) - tflitemodel.set_tensor(input_details[1]["index"], tf.constant(blank, dtype=tf.int32)) - tflitemodel.set_tensor(input_details[2]["index"], tf.zeros([num_rnns, nstates, 1, statesize], dtype=tf.float32)) - tflitemodel.invoke() - hyp = tflitemodel.get_tensor(output_details[0]["index"]) - - print("".join([chr(u) for u in hyp])) - - -if __name__ == "__main__": - fire.Fire(main) diff --git a/examples/transducer/conformer/results/characters.md b/examples/transducer/conformer/results/characters.md deleted file mode 100644 index 4c8379abe7..0000000000 --- a/examples/transducer/conformer/results/characters.md +++ /dev/null @@ -1,185 +0,0 @@ -# Characters Conformer Transducer - -- [Characters Conformer Transducer](#characters-conformer-transducer) - - [2023-02-12](#2023-02-12) - - -## 2023-02-12 - -Config: - -```python -config = """ -{% set repodir = "/path/to/TensorFlowASR" %} -{% set modeldir = "/path/to/models/char-conformer/20230212" %} -{% set datadir = "/path/to/librispeech/tfrecords" %} - -model_config: - name: conformer - encoder_subsampling: - type: conv2d - nlayers: 2 - filters: 144 - kernel_size: 3 - strides: 2 - padding: same - norm: batch - activation: swish - encoder_dmodel: 144 - encoder_num_blocks: 16 - encoder_head_size: 36 # == dmodel // num_heads - encoder_num_heads: 4 - encoder_mha_type: relmha - encoder_use_attention_causal_mask: False - encoder_kernel_size: 32 - encoder_fc_factor: 0.5 - encoder_dropout: 0.1 - encoder_padding: causal - prediction_label_encode_mode: embedding - prediction_embed_dim: 320 - prediction_num_rnns: 1 - prediction_rnn_units: 320 - prediction_rnn_type: lstm - prediction_rnn_implementation: 2 - prediction_rnn_unroll: False - prediction_layer_norm: False - prediction_projection_units: 0 - joint_dim: 320 - prejoint_encoder_linear: True - prejoint_prediction_linear: True - postjoint_linear: False - joint_activation: tanh - joint_mode: add - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - normalize_feature: False - -decoder_config: - type: characters - blank_index: 0 - beam_width: 0 - norm_score: True - lm_config: null - vocabulary: {{repodir}}/vocabularies/librispeech/characters/english.characters - corpus_files: null - -learning_config: - train_dataset_config: - enabled: True - use_tf: True - augmentation_config: - feature_augment: - time_masking: - prob: 1.0 - num_masks: 10 - mask_factor: 100 - p_upperbound: 0.05 - mask_value: mean - freq_masking: - prob: 1.0 - num_masks: 1 - mask_factor: 27 - mask_value: mean - data_paths: null - tfrecords_dir: {{datadir}} - shuffle: True - cache: False - buffer_size: 1000 - drop_remainder: True - stage: train - metadata: {{repodir}}/vocabularies/librispeech/characters/metadata.json - - eval_dataset_config: - enabled: False - use_tf: True - data_paths: null - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: eval - metadata: null - - test_dataset_config: - enabled: False - use_tf: True - data_paths: null - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: test - - optimizer_config: - beta_1: 0.9 - beta_2: 0.98 - epsilon: 1e-9 - - learning_rate_config: - warmup_steps: 10000 - max_lr_numerator: 0.05 - - apply_gwn_config: - predict_net_step: 20000 - predict_net_stddev: 0.075 - - running_config: - batch_size: 4 - num_epochs: 300 - checkpoint: - filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5 - save_best_only: False - save_weights_only: True - save_freq: epoch - options: - experimental_enable_async_checkpoint: True - backup_and_restore: - backup_dir: {{modeldir}}/states - tensorboard: - log_dir: {{modeldir}}/tensorboard - write_graph: False - write_images: False - update_freq: epoch - profile_batch: 100 -""" -with open("/path/to/config.j2", "w") as f: - f.write(config) -``` - -Training: - -```bash -python /path/to/TensorFlowASR/examples/transducer/conformer/train.py \ - --config-path=/path/to/config.j2 \ - --mxp=strict \ - --jit-compile \ - --tfrecords -``` - -Testing: - -```bash -python /path/to/TensorFlowASR/examples/transducer/conformer/test.py \ - --config-path=/path/to/config.j2 \ - --saved=/path/to/models/char-conformer/20230212/checkpoints/25.h5 \ - --output=/path/to/models/char-conformer/20230212/tests/25.tsv \ - --bs=1 -``` - -RNNT Loss Curves: - - - -Error Rates: - -| Dataset | Mode | Batch size | Epoch | WER (%) | CER (%) | -| :--------------------- | :----: | :--------: | :---: | :-----: | :-----: | -| librispeech-test-clean | greedy | 1 | 25 | | | -| librispeech-test-other | greedy | 1 | 25 | | | \ No newline at end of file diff --git a/examples/transducer/conformer/results/figs/arch.png b/examples/transducer/conformer/results/figs/arch.png deleted file mode 100644 index 03874afbcf..0000000000 Binary files a/examples/transducer/conformer/results/figs/arch.png and /dev/null differ diff --git a/examples/transducer/conformer/results/figs/conformer.png b/examples/transducer/conformer/results/figs/conformer.png deleted file mode 100644 index 1e190e3ca3..0000000000 Binary files a/examples/transducer/conformer/results/figs/conformer.png and /dev/null differ diff --git a/examples/transducer/conformer/tests/gen_model_only_bp.py b/examples/transducer/conformer/tests/gen_model_only_bp.py deleted file mode 100644 index df03f64a1c..0000000000 --- a/examples/transducer/conformer/tests/gen_model_only_bp.py +++ /dev/null @@ -1,42 +0,0 @@ -# %% Imports -import os - -import tensorflow as tf - -from tensorflow_asr.configs.config import Config -from tensorflow_asr.helpers import featurizer_helpers -from tensorflow_asr.models.transducer.conformer import Conformer -from tensorflow_asr.utils import env_util - -logger = env_util.setup_environment() - - -# %% Load model - -config_path = f"{os.path.dirname(__file__)}/../../../models/wordpiece-conformer-v2/config.yml" - -config = Config(config_path) -tf.random.set_seed(0) -tf.keras.backend.clear_session() - -speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(config=config) - -h5 = f"{os.path.dirname(__file__)}/../../../models/wordpiece-conformer-v2/21.h5" - -# build model -conformer = Conformer(**config.model_config, vocab_size=text_featurizer.num_classes) -conformer.make(speech_featurizer.shape) -conformer.add_featurizers(speech_featurizer, text_featurizer) -conformer.summary() -# conformer.load_weights(h5, by_name=True) - -# %% Gen bp - -output_bp = f"{os.path.dirname(__file__)}/../../../models/wordpiece-conformer-v2/model_only_bp" - -tf.saved_model.save(conformer, output_bp) - -# %% Load bp -loaded_conformer = tf.saved_model.load(output_bp) - -# %% diff --git a/examples/transducer/contextnet/README.md b/examples/transducer/contextnet/README.md deleted file mode 100644 index 0b7e51ae15..0000000000 --- a/examples/transducer/contextnet/README.md +++ /dev/null @@ -1,15 +0,0 @@ -# ContextNet: Improving Convolutional Neural Networks for Automatic Speech Recognition with Global Context - -Reference: [http://arxiv.org/abs/2005.03191](http://arxiv.org/abs/2005.03191) - -## Example Model YAML Config - -Go to [config.yml](./config.yml) - -## Usage - -Training, see `python examples/transducer/contextnet/train.py --help` - -Testing, see `python examples/transducer/contextnet/test.py --help` - -TFLite Conversion, see `python examples/transducer/contextnet/inference/gen_tflite_model.py --help` \ No newline at end of file diff --git a/examples/transducer/contextnet/confs/config_wp.j2 b/examples/transducer/contextnet/confs/config_wp.j2 deleted file mode 100644 index 683fdd5710..0000000000 --- a/examples/transducer/contextnet/confs/config_wp.j2 +++ /dev/null @@ -1,330 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -{% set repodir = "." %} -{% set modeldir = "/mnt/Miscellanea/Models/local/contextnet" %} -{% set datadir = "/mnt/Data/MLDL/Datasets/ASR/LibriSpeech" %} - -model_config: - name: contextnet - encoder_alpha: 0.5 - encoder_blocks: - # C0 - - nlayers: 1 - kernel_size: 5 - filters: 256 - strides: 1 - residual: False - activation: silu - padding: causal - # C1-C2 - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - padding: causal - # C3 - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 2 - residual: True - activation: silu - padding: causal - # C4-C6 - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - padding: causal - # C7 - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 2 - residual: True - activation: silu - padding: causal - # C8 - C10 - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - padding: causal - # C11 - C13 - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - # C14 - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 2 - residual: True - activation: silu - padding: causal - # C15 - C21 - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - # C22 - - nlayers: 1 - kernel_size: 5 - filters: 640 - strides: 1 - residual: False - activation: silu - padding: causal - prediction_label_encode_mode: embedding - prediction_embed_dim: 640 - prediction_num_rnns: 1 - prediction_rnn_units: 512 - prediction_rnn_type: lstm - prediction_rnn_implementation: 2 - prediction_rnn_unroll: False # False to use with GPU Cuda - prediction_layer_norm: False - prediction_projection_units: 0 - joint_dim: 512 - prejoint_encoder_linear: True - prejoint_prediction_linear: True - postjoint_linear: False - joint_activation: tanh - joint_mode: add - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - -decoder_config: - type: wordpiece - - blank_index: 0 - unknown_token: "" - unknown_index: 1 - - beam_width: 0 - norm_score: True - lm_config: null - - vocabulary: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000_50.tokens - vocab_size: 1000 - max_token_length: 50 - max_unique_chars: 1000 - reserved_tokens: - - "" - - "" - normalization_form: NFKC - num_iterations: 4 - - corpus_files: - - {{datadir}}/train-clean-100/transcripts.tsv - - {{datadir}}/train-clean-360/transcripts.tsv - - {{datadir}}/train-other-500/transcripts.tsv - -learning_config: - train_dataset_config: - enabled: True - use_tf: True - augmentation_config: - feature_augment: - gauss_noise: - prob: 0.1 - stddev: 0.075 - time_masking: - prob: 0.5 - num_masks: 10 - mask_factor: 100 - p_upperbound: 0.05 - freq_masking: - prob: 0.5 - num_masks: 1 - mask_factor: 27 - data_paths: - - {{datadir}}/train-clean-100/transcripts.tsv - tfrecords_dir: {{datadir}}/tfrecords/100h - shuffle: True - cache: True - buffer_size: 100 - drop_remainder: True - stage: train - metadata: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000.metadata.json - - eval_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: eval - metadata: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000.metadata.json - - test_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/test-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: test - - optimizer_config: - beta_1: 0.9 - beta_2: 0.98 - epsilon: 1e-9 - - learning_rate_config: - warmup_steps: 10000 - max_lr_numerator: 0.05 - - apply_gwn_config: - predict_net_step: 20000 - predict_net_stddev: 0.075 - joint_net_step: 20000 - joint_net_stddev: 0.075 - - running_config: - batch_size: 2 - num_epochs: 100 - checkpoint: - filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5 - save_best_only: False - save_weights_only: True - save_freq: epoch - backup_and_restore: - backup_dir: {{modeldir}}/states - save_freq: epoch - delete_checkpoint: False - tensorboard: - log_dir: {{modeldir}}/tensorboard - histogram_freq: 1 - write_graph: True - write_images: True - update_freq: epoch - profile_batch: 2 - diff --git a/examples/transducer/contextnet/results/figs/1008_epoch_learning_rate.svg b/examples/transducer/contextnet/results/figs/1008_epoch_learning_rate.svg deleted file mode 100644 index f1c16b7273..0000000000 --- a/examples/transducer/contextnet/results/figs/1008_epoch_learning_rate.svg +++ /dev/null @@ -1 +0,0 @@ -3e-54e-55e-56e-57e-58e-59e-51e-42e-4-20020406080100120 \ No newline at end of file diff --git a/examples/transducer/contextnet/results/figs/1008_subword_contextnet_loss.svg b/examples/transducer/contextnet/results/figs/1008_subword_contextnet_loss.svg deleted file mode 100644 index f1c5ca2799..0000000000 --- a/examples/transducer/contextnet/results/figs/1008_subword_contextnet_loss.svg +++ /dev/null @@ -1 +0,0 @@ --202468101214161820-100102030405060708090 \ No newline at end of file diff --git a/examples/transducer/contextnet/results/wordpiece.md b/examples/transducer/contextnet/results/wordpiece.md deleted file mode 100644 index 9d947b107c..0000000000 --- a/examples/transducer/contextnet/results/wordpiece.md +++ /dev/null @@ -1,523 +0,0 @@ -# Wordpiece Contextnet Transducer - -- [Wordpiece Contextnet Transducer](#wordpiece-contextnet-transducer) - - [LibriSpeech Only Data](#librispeech-only-data) - - [Config](#config) - - [Training](#training) - - [Testing](#testing) - - -## LibriSpeech Only Data - -#### Config - -```python -config = """ -{% set repodir = "/path/to/TensorFlowASR" %} -{% set modeldir = "/path/to/models/wp1k-contextnet/only-data" %} -{% set datadir = "/path/to/librispeech/tfrecords" %} - -model_config: - name: contextnet - encoder_alpha: 0.5 - encoder_blocks: - # C0 - - nlayers: 1 - kernel_size: 5 - filters: 256 - strides: 1 - residual: False - activation: silu - padding: causal - # C1-C2 - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - padding: causal - # C3 - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 2 - residual: True - activation: silu - padding: causal - # C4-C6 - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - padding: causal - # C7 - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 2 - residual: True - activation: silu - padding: causal - # C8 - C10 - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - padding: causal - # C11 - C13 - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - # C14 - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 2 - residual: True - activation: silu - padding: causal - # C15 - C21 - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - padding: causal - # C22 - - nlayers: 1 - kernel_size: 5 - filters: 640 - strides: 1 - residual: False - activation: silu - padding: causal - prediction_label_encode_mode: embedding - prediction_embed_dim: 640 - prediction_num_rnns: 1 - prediction_rnn_units: 512 - prediction_rnn_type: lstm - prediction_rnn_implementation: 2 - prediction_rnn_unroll: False - prediction_layer_norm: True - prediction_projection_units: 0 - joint_dim: 512 - prejoint_encoder_linear: True - prejoint_prediction_linear: True - postjoint_linear: False - joint_activation: tanh - joint_mode: add - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - -decoder_config: - type: wordpiece - - blank_index: 0 - unknown_token: "" - unknown_index: 1 - - beam_width: 0 - norm_score: True - lm_config: null - - vocabulary: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000_50.tokens - vocab_size: 1000 - max_token_length: 50 - max_unique_chars: 1000 - reserved_tokens: - - "" - - "" - normalization_form: NFKC - num_iterations: 4 - - corpus_files: null - -learning_config: - train_dataset_config: - enabled: True - use_tf: True - data_paths: null - tfrecords_dir: {{datadir}} - shuffle: True - cache: False - buffer_size: 1000 - drop_remainder: True - stage: train - metadata: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000_50.metadata.json - - eval_dataset_config: - enabled: False - use_tf: True - data_paths: null - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: eval - metadata: null - - test_dataset_config: - enabled: False - use_tf: True - data_paths: null - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: test - - optimizer_config: - beta_1: 0.9 - beta_2: 0.98 - epsilon: 1e-9 - - learning_rate_config: - warmup_steps: 15000 - max_lr: 0.0025 - - running_config: - batch_size: 8 - num_epochs: 300 - checkpoint: - filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5 - save_best_only: False - save_weights_only: True - save_freq: epoch - options: - experimental_enable_async_checkpoint: True - backup_and_restore: - backup_dir: {{modeldir}}/states - tensorboard: - log_dir: {{modeldir}}/tensorboard - write_graph: False - write_images: False - update_freq: epoch - profile_batch: 100 -""" -with open("/path/to/config.j2", "w") as f: - f.write(config) -``` - -#### Training - -```bash -python /path/to/TensorFlowASR/examples/transducer/contextnet/train.py \ - --config-path=/path/to/config.j2 \ - --mxp=strict \ - --jit-compile \ - --tfrecords -``` - -Outputs: - -``` -INFO:tensorflow:Use RNNT loss in TensorFlow -INFO:tensorflow:Deallocate tpu buffers before initializing tpu system. -INFO:tensorflow:All TPUs: [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')] -INFO:tensorflow:Found TPU system: -INFO:tensorflow:*** Num TPU Cores: 8 -INFO:tensorflow:*** Num TPU Workers: 1 -INFO:tensorflow:*** Num TPU Cores Per Worker: 8 -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) -INFO:tensorflow:USING mixed precision policy mixed_bfloat16 -INFO:tensorflow:Loading wordpiece ... -INFO:tensorflow:Loading metadata from /content/TensorFlowASR/vocabularies/librispeech/wordpiece/train_1000_50.metadata.json ... -INFO:tensorflow:TFRecords're already existed: train -INFO:tensorflow:Use GPU/TPU implementation for RNNT loss -Model: "contextnet" -__________________________________________________________________________________________________________________________________________ - Layer (type) Output Shape Param # Trainable -========================================================================================================================================== - encoder (ContextNetEncoder) ((8, 372, 640), 6888392 Y - (8,)) - - prediction (TransducerPrediction) (8, 203, 512) 3002368 Y - - joint (TransducerJoint) (8, 372, 203, 1000) 939496 Y - -========================================================================================================================================== -Total params: 10,830,258 -Trainable params: 10,771,120 -Non-trainable params: 59,138 -__________________________________________________________________________________________________________________________________________ -Epoch 1/300 -WARNING:tensorflow:From /usr/local/lib/python3.8/dist-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. -Instructions for updating: -Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 -4394/4394 [==============================] - 3321s 659ms/step - loss: 354.1810 - per_batch_avg_loss: 354.1461 -Epoch 2/300 -4394/4394 [==============================] - 2871s 653ms/step - loss: 208.4284 - per_batch_avg_loss: 208.4139 -Epoch 3/300 -4394/4394 [==============================] - 2889s 657ms/step - loss: 107.2801 - per_batch_avg_loss: 107.2722 -Epoch 4/300 -4394/4394 [==============================] - 2892s 658ms/step - loss: 53.1701 - per_batch_avg_loss: 53.1681 -Epoch 5/300 -4394/4394 [==============================] - 2897s 659ms/step - loss: 36.3960 - per_batch_avg_loss: 36.3937 -Epoch 6/300 -4394/4394 [==============================] - 2895s 659ms/step - loss: 29.5058 - per_batch_avg_loss: 29.5055 -Epoch 7/300 -4394/4394 [==============================] - 2899s 660ms/step - loss: 25.5613 - per_batch_avg_loss: 25.5621 -Epoch 8/300 -4394/4394 [==============================] - 2898s 660ms/step - loss: 22.9065 - per_batch_avg_loss: 22.9076 -Epoch 9/300 -4394/4394 [==============================] - 2900s 660ms/step - loss: 20.9543 - per_batch_avg_loss: 20.9534 -Epoch 10/300 -4394/4394 [==============================] - 2894s 659ms/step - loss: 19.4345 - per_batch_avg_loss: 19.4343 -Epoch 11/300 -4394/4394 [==============================] - 2898s 660ms/step - loss: 18.1739 - per_batch_avg_loss: 18.1751 -Epoch 12/300 -4394/4394 [==============================] - 2903s 661ms/step - loss: 17.1167 - per_batch_avg_loss: 17.1171 -Epoch 13/300 -4394/4394 [==============================] - 2899s 660ms/step - loss: 16.2142 - per_batch_avg_loss: 16.2144 -Epoch 14/300 -4394/4394 [==============================] - 2899s 660ms/step - loss: 15.4079 - per_batch_avg_loss: 15.4081 -Epoch 15/300 -4394/4394 [==============================] - 2894s 659ms/step - loss: 14.6829 - per_batch_avg_loss: 14.6824 -Epoch 16/300 -4394/4394 [==============================] - 2897s 659ms/step - loss: 14.0212 - per_batch_avg_loss: 14.0217 -Epoch 17/300 -4394/4394 [==============================] - 2900s 660ms/step - loss: 13.4270 - per_batch_avg_loss: 13.4276 -Epoch 18/300 -4394/4394 [==============================] - 2893s 658ms/step - loss: 12.8534 - per_batch_avg_loss: 12.8523 -Epoch 19/300 -4394/4394 [==============================] - 2904s 661ms/step - loss: 12.3035 - per_batch_avg_loss: 12.3036 -Epoch 20/300 -4394/4394 [==============================] - 2894s 659ms/step - loss: 11.8388 - per_batch_avg_loss: 11.8391 -Epoch 21/300 -4394/4394 [==============================] - 2892s 658ms/step - loss: 11.3558 - per_batch_avg_loss: 11.3556 -Epoch 22/300 -4394/4394 [==============================] - 2890s 658ms/step - loss: 10.8828 - per_batch_avg_loss: 10.8833 -Epoch 23/300 -4394/4394 [==============================] - 2890s 658ms/step - loss: 10.4597 - per_batch_avg_loss: 10.4601 -Epoch 24/300 -4394/4394 [==============================] - 2888s 657ms/step - loss: 10.0510 - per_batch_avg_loss: 10.0514 -Epoch 25/300 -4394/4394 [==============================] - 2893s 658ms/step - loss: 9.6318 - per_batch_avg_loss: 9.6321 -Epoch 26/300 -4394/4394 [==============================] - 2900s 660ms/step - loss: 9.2559 - per_batch_avg_loss: 9.2558 -Epoch 27/300 -4394/4394 [==============================] - 2897s 659ms/step - loss: 8.8785 - per_batch_avg_loss: 8.8787 -Epoch 28/300 -4394/4394 [==============================] - 2898s 659ms/step - loss: 8.5476 - per_batch_avg_loss: 8.5474 -Epoch 29/300 -4394/4394 [==============================] - 2895s 659ms/step - loss: 8.2002 - per_batch_avg_loss: 8.2000 -Epoch 30/300 -4394/4394 [==============================] - 3303s 661ms/step - loss: 8.0303 - per_batch_avg_loss: 8.0302 -Epoch 31/300 -4394/4394 [==============================] - 2874s 654ms/step - loss: 7.5840 - per_batch_avg_loss: 7.5840 -Epoch 32/300 -4394/4394 [==============================] - 2889s 658ms/step - loss: 7.2728 - per_batch_avg_loss: 7.2727 -Epoch 33/300 -4394/4394 [==============================] - 2874s 654ms/step - loss: 6.9903 - per_batch_avg_loss: 6.9904 -Epoch 34/300 -4394/4394 [==============================] - 2883s 656ms/step - loss: 6.7392 - per_batch_avg_loss: 6.7389 -Epoch 35/300 -4394/4394 [==============================] - 2888s 657ms/step - loss: 6.4757 - per_batch_avg_loss: 6.4754 -Epoch 36/300 -4394/4394 [==============================] - 2895s 659ms/step - loss: 6.2007 - per_batch_avg_loss: 6.2010 -Epoch 37/300 -4394/4394 [==============================] - 2892s 658ms/step - loss: 5.9911 - per_batch_avg_loss: 5.9912 -Epoch 38/300 -4394/4394 [==============================] - 2896s 659ms/step - loss: 5.7801 - per_batch_avg_loss: 5.7799 -Epoch 39/300 -4394/4394 [==============================] - 2886s 657ms/step - loss: 5.5511 - per_batch_avg_loss: 5.5512 -Epoch 40/300 -4394/4394 [==============================] - 2893s 658ms/step - loss: 5.3473 - per_batch_avg_loss: 5.3477 -Epoch 41/300 -4394/4394 [==============================] - 2897s 659ms/step - loss: 5.1490 - per_batch_avg_loss: 5.1488 -Epoch 42/300 -4394/4394 [==============================] - 3323s 662ms/step - loss: 5.1151 - per_batch_avg_loss: 5.1150 -Epoch 43/300 -4394/4394 [==============================] - 2868s 653ms/step - loss: 4.7958 - per_batch_avg_loss: 4.7961 -Epoch 44/300 -4394/4394 [==============================] - 2871s 653ms/step - loss: 4.6019 - per_batch_avg_loss: 4.6019 -Epoch 45/300 -4394/4394 [==============================] - 2872s 654ms/step - loss: 4.4455 - per_batch_avg_loss: 4.4454 -Epoch 46/300 -4394/4394 [==============================] - 2880s 655ms/step - loss: 4.2980 - per_batch_avg_loss: 4.2980 -Epoch 47/300 -4394/4394 [==============================] - 2864s 652ms/step - loss: 4.1674 - per_batch_avg_loss: 4.1674 -Epoch 48/300 -4394/4394 [==============================] - 2858s 650ms/step - loss: 4.0385 - per_batch_avg_loss: 4.0385 -Epoch 49/300 -4394/4394 [==============================] - 2860s 651ms/step - loss: 3.9321 - per_batch_avg_loss: 3.9321 -Epoch 50/300 -4394/4394 [==============================] - 2882s 656ms/step - loss: 3.8276 - per_batch_avg_loss: 3.8274 -Epoch 51/300 -4394/4394 [==============================] - 2885s 657ms/step - loss: 3.7106 - per_batch_avg_loss: 3.7105 -Epoch 52/300 -4394/4394 [==============================] - 2876s 655ms/step - loss: 3.6935 - per_batch_avg_loss: 3.6934 -Epoch 53/300 -4394/4394 [==============================] - 2862s 651ms/step - loss: 3.6176 - per_batch_avg_loss: 3.6176 -Epoch 54/300 -4394/4394 [==============================] - 2907s 661ms/step - loss: 3.5398 - per_batch_avg_loss: 3.5400 -Epoch 55/300 -4394/4394 [==============================] - 2880s 655ms/step - loss: 3.4705 - per_batch_avg_loss: 3.4707 -Epoch 56/300 -4394/4394 [==============================] - 2880s 655ms/step - loss: 3.4127 - per_batch_avg_loss: 3.4126 -Epoch 57/300 -4394/4394 [==============================] - 2910s 662ms/step - loss: 3.4820 - per_batch_avg_loss: 3.4821 -Epoch 58/300 -4394/4394 [==============================] - 2890s 658ms/step - loss: 3.5663 - per_batch_avg_loss: 3.5663 -Epoch 59/300 -4394/4394 [==============================] - 2894s 659ms/step - loss: 3.6664 - per_batch_avg_loss: 3.6661 -Epoch 60/300 -4394/4394 [==============================] - 2884s 656ms/step - loss: 3.8075 - per_batch_avg_loss: 3.8073 -Epoch 61/300 -4394/4394 [==============================] - 2869s 653ms/step - loss: 3.9362 - per_batch_avg_loss: 3.9362 -Epoch 62/300 -4394/4394 [==============================] - 2880s 655ms/step - loss: 4.0172 - per_batch_avg_loss: 4.0174 -Epoch 63/300 -4394/4394 [==============================] - 2894s 659ms/step - loss: 4.0682 - per_batch_avg_loss: 4.0683 -Epoch 64/300 -4394/4394 [==============================] - 2919s 664ms/step - loss: 4.0974 - per_batch_avg_loss: 4.0974 -Epoch 65/300 -4394/4394 [==============================] - 2923s 665ms/step - loss: 4.2700 - per_batch_avg_loss: 4.2701 -Epoch 66/300 -4394/4394 [==============================] - 2924s 665ms/step - loss: 4.4607 - per_batch_avg_loss: 4.4609 -Epoch 67/300 -4394/4394 [==============================] - 2918s 664ms/step - loss: 4.5037 - per_batch_avg_loss: 4.5036 -Epoch 68/300 -4394/4394 [==============================] - 2915s 663ms/step - loss: 4.5902 - per_batch_avg_loss: 4.5899 -``` - -#### Testing - -```bash -python /path/to/TensorFlowASR/examples/transducer/contextnet/test.py \ - --config-path=/path/to/config.j2 \ - --saved=/path/to/models/wp1k-contextnet/only-data/checkpoints/40.h5 \ - --output=/path/to/models/wp1k-contextnet/only-data/tests/40.tsv \ - --bs=1 -``` - -RNNT Loss Curves: - - - -Error Rates: - -| Dataset | Mode | Batch size | Epoch | WER (%) | CER (%) | -| :--------- | :----: | :--------: | :---: | :----------------: | :----------------: | -| test-clean | greedy | 1 | 40 | 18.036746978759766 | 8.55042114853859 | -| test-clean | greedy | 1 | 56 | 18.39812844991684 | 8.690726011991501 | -| test-other | greedy | 1 | 56 | 38.31839859485626 | 21.644461154937744 | \ No newline at end of file diff --git a/examples/transducer/rnnt/README.md b/examples/transducer/rnnt/README.md deleted file mode 100644 index 2ba34ece3e..0000000000 --- a/examples/transducer/rnnt/README.md +++ /dev/null @@ -1,15 +0,0 @@ -# Streaming End-to-end Speech Recognition For Mobile Devices - -Reference: [https://arxiv.org/abs/1811.06621](https://arxiv.org/abs/1811.06621) - -## Example Model YAML Config - -Go to [config.yml](./config.yml) - -## Usage - -Training, see `python examples/transducer/rnnt/train.py --help` - -Testing, see `python examples/transducer/rnnt/test.py --help` - -TFLite Conversion, see `python examples/transducer/rnnt/tflite.py --help` \ No newline at end of file diff --git a/examples/transducer/rnnt/confs/config_char.j2 b/examples/transducer/rnnt/confs/config_char.j2 deleted file mode 100644 index 73f0fc39d5..0000000000 --- a/examples/transducer/rnnt/confs/config_char.j2 +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -{% set repodir = "." %} -{% set modeldir = "/Users/nlhuy/Paraphernalia/models/local/rnn-transducer" %} -{% set datadir = "/Users/nlhuy/Paraphernalia/data/LibriSpeech" %} - -model_config: - name: rnn_transducer - encoder_reductions: - 0: 2 - 1: 2 - encoder_dmodel: 640 - encoder_rnn_type: lstm - encoder_rnn_units: 1024 - encoder_nlayers: 2 - encoder_layer_norm: False - prediction_label_encode_mode: embedding - prediction_embed_dim: 640 - prediction_num_rnns: 2 - prediction_rnn_units: 1024 - prediction_rnn_type: lstm - prediction_rnn_unroll: False - prediction_layer_norm: False - prediction_projection_units: 640 - joint_dim: 640 - prejoint_encoder_linear: False - prejoint_prediction_linear: False - postjoint_linear: False - joint_activation: tanh - joint_mode: add - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - -decoder_config: - type: characters - - blank_index: 0 - unknown_token: "" - unknown_index: 0 - - beam_width: 0 - norm_score: True - lm_config: null - - vocabulary: {{repodir}}/vocabularies/english.characters - vocab_size: 1000 - max_token_length: 4 - max_unique_chars: 1000 - reserved_tokens: - - "" - normalization_form: NFKC - num_iterations: 4 - - corpus_files: - - {{datadir}}/train-clean-100/transcripts.tsv - - {{datadir}}/train-clean-360/transcripts.tsv - - {{datadir}}/train-other-500/transcripts.tsv - -learning_config: - train_dataset_config: - enabled: True - use_tf: True - augmentation_config: - feature_augment: - time_masking: - prob: 1.0 - num_masks: 10 - mask_factor: 100 - p_upperbound: 0.05 - freq_masking: - prob: 1.0 - num_masks: 1 - mask_factor: 27 - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: {{datadir}}/tfrecords - shuffle: True - cache: True - buffer_size: 100 - drop_remainder: True - stage: train - metadata: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000_4.metadata.json - - eval_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: eval - metadata: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000_4.metadata.json - - test_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/test-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: test - - optimizer_config: - class_name: adam - config: - beta_1: 0.9 - beta_2: 0.98 - epsilon: 1e-9 - - running_config: - batch_size: 2 - num_epochs: 100 - checkpoint: - filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5 - save_best_only: False - save_weights_only: True - save_freq: epoch - backup_and_restore: - backup_dir: {{modeldir}}/states - save_freq: epoch - delete_checkpoint: False - tensorboard: - log_dir: {{modeldir}}/tensorboard - histogram_freq: 1 - write_graph: True - write_images: True - update_freq: epoch - profile_batch: 2 - diff --git a/examples/transducer/rnnt/confs/config_wp.j2 b/examples/transducer/rnnt/confs/config_wp.j2 deleted file mode 100644 index 8a8660ab21..0000000000 --- a/examples/transducer/rnnt/confs/config_wp.j2 +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -{% set repodir = "." %} -{% set modeldir = "/Users/nlhuy/Paraphernalia/models/local/rnn-transducer" %} -{% set datadir = "/Users/nlhuy/Paraphernalia/data/LibriSpeech" %} - -model_config: - name: rnn_transducer - encoder_reductions: - 0: 2 - 1: 2 - encoder_dmodel: 640 - encoder_rnn_type: lstm - encoder_rnn_units: 1024 - encoder_nlayers: 2 - encoder_layer_norm: False - prediction_label_encode_mode: embedding - prediction_embed_dim: 640 - prediction_num_rnns: 2 - prediction_rnn_units: 1024 - prediction_rnn_type: lstm - prediction_rnn_unroll: False - prediction_layer_norm: False - prediction_projection_units: 640 - joint_dim: 640 - prejoint_encoder_linear: False - prejoint_prediction_linear: False - postjoint_linear: False - joint_activation: tanh - joint_mode: add - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - -decoder_config: - type: wordpiece - - blank_index: 0 - unknown_token: "" - unknown_index: 0 - - beam_width: 0 - norm_score: True - lm_config: null - - vocabulary: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000_4.tokens - vocab_size: 1000 - max_token_length: 4 - max_unique_chars: 1000 - reserved_tokens: - - "" - normalization_form: NFKC - num_iterations: 4 - - corpus_files: - - {{datadir}}/train-clean-100/transcripts.tsv - - {{datadir}}/train-clean-360/transcripts.tsv - - {{datadir}}/train-other-500/transcripts.tsv - -learning_config: - train_dataset_config: - enabled: True - use_tf: True - augmentation_config: - feature_augment: - time_masking: - prob: 1.0 - num_masks: 10 - mask_factor: 100 - p_upperbound: 0.05 - freq_masking: - prob: 1.0 - num_masks: 1 - mask_factor: 27 - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: {{datadir}}/tfrecords - shuffle: True - cache: True - buffer_size: 100 - drop_remainder: True - stage: train - metadata: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000_4.metadata.json - - eval_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: eval - metadata: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000_4.metadata.json - - test_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/test-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: test - - optimizer_config: - class_name: adam - config: - beta_1: 0.9 - beta_2: 0.98 - epsilon: 1e-9 - - running_config: - batch_size: 2 - num_epochs: 100 - checkpoint: - filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5 - save_best_only: False - save_weights_only: True - save_freq: epoch - backup_and_restore: - backup_dir: {{modeldir}}/states - save_freq: epoch - delete_checkpoint: False - tensorboard: - log_dir: {{modeldir}}/tensorboard - histogram_freq: 1 - write_graph: True - write_images: True - update_freq: epoch - profile_batch: 2 - diff --git a/examples/transducer/rnnt/results/sentencepiece.md b/examples/transducer/rnnt/results/sentencepiece.md deleted file mode 100644 index d95723ef59..0000000000 --- a/examples/transducer/rnnt/results/sentencepiece.md +++ /dev/null @@ -1,286 +0,0 @@ -# Sentencepiece RNN Transducer - -- [Sentencepiece RNN Transducer](#sentencepiece-rnn-transducer) - - [LibriSpeech Only Data](#librispeech-only-data) - - [Config](#config) - - [Training](#training) - - [Testing](#testing) - - -## LibriSpeech Only Data - -#### Config - -```python -config = """ -{% set repodir = "/path/to/TensorFlowASR" %} -{% set modeldir = "/path/to/models/sp1k-rnnt/only-data" %} -{% set datadir = "/path/to/librispeech/tfrecords" %} - -model_config: - name: rnnt - encoder_reductions: - 0: 4 - 1: 2 - encoder_dmodel: 256 - encoder_rnn_type: lstm - encoder_rnn_units: 512 - encoder_rnn_unroll: False - encoder_nlayers: 8 - encoder_layer_norm: True - prediction_label_encode_mode: embedding - prediction_embed_dim: 512 - prediction_num_rnns: 2 - prediction_rnn_units: 512 - prediction_rnn_type: lstm - prediction_rnn_unroll: False - prediction_layer_norm: True - prediction_projection_units: 256 - joint_dim: 256 - prejoint_encoder_linear: False - prejoint_prediction_linear: False - postjoint_linear: False - joint_activation: tanh - joint_mode: add - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - -decoder_config: - type: sentencepiece - - blank_index: 0 - pad_token: "" - pad_index: 0 - unknown_token: "" - unknown_index: 1 - bos_token: "" - bos_index: 2 - eos_token: "" - eos_index: 3 - - beam_width: 0 - norm_score: True - lm_config: null - - model_type: bpe - vocabulary: {{repodir}}/vocabularies/librispeech/sentencepiece/train_bpe_1000.model - vocab_size: 1000 - max_token_length: 50 - max_unique_chars: 1000 - reserved_tokens: null - normalization_form: NFKC - num_iterations: 4 - - corpus_files: null - -learning_config: - train_dataset_config: - enabled: False - use_tf: True - data_paths: null - tfrecords_dir: {{datadir}} - shuffle: True - cache: False - buffer_size: 1000 - drop_remainder: True - stage: train - metadata: {{repodir}}/vocabularies/librispeech/sentencepiece/train_bpe_1000.metadata.json - - eval_dataset_config: - enabled: False - use_tf: True - data_paths: null - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: eval - metadata: null - - test_dataset_config: - enabled: True - use_tf: True - data_paths: - - {{datadir}}/test-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: False - stage: test - - optimizer_config: - beta_1: 0.9 - beta_2: 0.98 - epsilon: 1e-9 - - learning_rate_config: - warmup_steps: 10000 - max_lr_numerator: 0.05 - - running_config: - batch_size: 6 - num_epochs: 300 - checkpoint: - filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5 - save_best_only: False - save_weights_only: True - save_freq: epoch - backup_and_restore: - backup_dir: {{modeldir}}/states - tensorboard: - log_dir: {{modeldir}}/tensorboard - write_graph: False - write_images: False - update_freq: epoch - profile_batch: 100 -""" -with open("/path/to/config.j2", "w") as file: - file.write(config) -``` - -#### Training - -```bash -python /path/to/TensorFlowASR/examples/transducer/rnnt/train.py \ - --config-path=/path/to/config.j2 \ - --mxp=strict \ - --jit-compile \ - --tfrecords -``` - -Outputs: - -``` -2023-02-17 15:05:56.437429: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected -INFO:tensorflow:Use RNNT loss in TensorFlow -INFO:tensorflow:Deallocate tpu buffers before initializing tpu system. -INFO:tensorflow:All TPUs: [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')] -INFO:tensorflow:Found TPU system: -INFO:tensorflow:*** Num TPU Cores: 8 -INFO:tensorflow:*** Num TPU Workers: 1 -INFO:tensorflow:*** Num TPU Cores Per Worker: 8 -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0) -INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) -INFO:tensorflow:USING mixed precision policy mixed_bfloat16 -INFO:tensorflow:Loading SentencePiece model ... -INFO:tensorflow:Loading metadata from /content/TensorFlowASR/vocabularies/librispeech/sentencepiece/train_bpe_1000.metadata.json ... -INFO:tensorflow:TFRecords're already existed: train -INFO:tensorflow:Use GPU/TPU implementation for RNNT loss -Model: "rnnt" -__________________________________________________________________________________________________________________________________________ - Layer (type) Output Shape Param # Trainable -========================================================================================================================================== - encoder (RnnTransducerEncoder) ((6, 372, 256), 13821952 Y - (6,)) - - prediction (TransducerPrediction) (6, 232, 256) 4450816 Y - - joint (TransducerJoint) (6, 372, 232, 1000) 257000 Y - -========================================================================================================================================== -Total params: 18,529,770 -Trainable params: 18,529,768 -Non-trainable params: 2 -__________________________________________________________________________________________________________________________________________ -Epoch 1/300 -WARNING:tensorflow:From /usr/local/lib/python3.8/dist-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. -Instructions for updating: -Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 -5859/5859 [==============================] - 5698s 959ms/step - loss: 202.8960 - avg_loss: 201.8269 - avg_loss_scaled: 25.2734 -Epoch 2/300 -5859/5859 [==============================] - 5587s 953ms/step - loss: 45.2534 - avg_loss: 45.2210 - avg_loss_scaled: 5.6632 -Epoch 3/300 -5859/5859 [==============================] - 5590s 954ms/step - loss: 32.3394 - avg_loss: 32.4127 - avg_loss_scaled: 4.0592 -Epoch 4/300 -5859/5859 [==============================] - 5587s 954ms/step - loss: 25.4292 - avg_loss: 25.5215 - avg_loss_scaled: 3.1960 -Epoch 5/300 -5859/5859 [==============================] - 5585s 953ms/step - loss: 21.2849 - avg_loss: 21.3877 - avg_loss_scaled: 2.6787 -Epoch 6/300 -5859/5859 [==============================] - 5587s 954ms/step - loss: 18.3444 - avg_loss: 18.3039 - avg_loss_scaled: 2.2923 -Epoch 7/300 -5859/5859 [==============================] - 5587s 954ms/step - loss: 16.1143 - avg_loss: 16.1171 - avg_loss_scaled: 2.0184 -Epoch 8/300 -5859/5859 [==============================] - 5588s 954ms/step - loss: 14.2889 - avg_loss: 14.3153 - avg_loss_scaled: 1.7928 -Epoch 9/300 -5859/5859 [==============================] - 5591s 954ms/step - loss: 12.7466 - avg_loss: 12.7925 - avg_loss_scaled: 1.6021 -Epoch 10/300 -5859/5859 [==============================] - 5589s 954ms/step - loss: 11.4285 - avg_loss: 11.4078 - avg_loss_scaled: 1.4286 -Epoch 11/300 -5859/5859 [==============================] - 5588s 954ms/step - loss: 10.2693 - avg_loss: 10.2889 - avg_loss_scaled: 1.2885 -Epoch 12/300 -5859/5859 [==============================] - 5588s 954ms/step - loss: 9.2380 - avg_loss: 9.2959 - avg_loss_scaled: 1.1642 -Epoch 13/300 -5859/5859 [==============================] - 5589s 954ms/step - loss: 8.3355 - avg_loss: 8.3695 - avg_loss_scaled: 1.0481 -Epoch 14/300 -5859/5859 [==============================] - 5591s 954ms/step - loss: 7.5065 - avg_loss: 7.4994 - avg_loss_scaled: 0.9392 -Epoch 15/300 -5859/5859 [==============================] - 5592s 954ms/step - loss: 6.7857 - avg_loss: 6.7983 - avg_loss_scaled: 0.8514 -Epoch 16/300 -5859/5859 [==============================] - 5704s 961ms/step - loss: 6.1724 - avg_loss: 6.2067 - avg_loss_scaled: 0.7773 -Epoch 17/300 -5859/5859 [==============================] - 5610s 957ms/step - loss: 5.5486 - avg_loss: 5.5334 - avg_loss_scaled: 0.6930 -Epoch 18/300 -5859/5859 [==============================] - 5601s 956ms/step - loss: 5.0254 - avg_loss: 4.9964 - avg_loss_scaled: 0.6257 -Epoch 19/300 -5859/5859 [==============================] - 5596s 955ms/step - loss: 4.5515 - avg_loss: 4.5450 - avg_loss_scaled: 0.5692 -Epoch 20/300 -5859/5859 [==============================] - 5600s 956ms/step - loss: 4.1555 - avg_loss: 4.1187 - avg_loss_scaled: 0.5158 -Epoch 21/300 -5859/5859 [==============================] - 5593s 955ms/step - loss: 3.7699 - avg_loss: 3.7569 - avg_loss_scaled: 0.4705 -Epoch 22/300 -5859/5859 [==============================] - 5586s 953ms/step - loss: 3.4460 - avg_loss: 3.4470 - avg_loss_scaled: 0.4317 -Epoch 23/300 -5859/5859 [==============================] - 5585s 953ms/step - loss: 3.1594 - avg_loss: 3.1790 - avg_loss_scaled: 0.3981 -Epoch 24/300 -5859/5859 [==============================] - 5592s 954ms/step - loss: 2.8957 - avg_loss: 2.9069 - avg_loss_scaled: 0.3640 -Epoch 25/300 -5859/5859 [==============================] - 5596s 955ms/step - loss: 2.6875 - avg_loss: 2.6777 - avg_loss_scaled: 0.3353 -Epoch 26/300 -5859/5859 [==============================] - 5587s 954ms/step - loss: 2.4714 - avg_loss: 2.4769 - avg_loss_scaled: 0.3102 -Epoch 27/300 -5859/5859 [==============================] - 5583s 953ms/step - loss: 2.2982 - avg_loss: 2.3308 - avg_loss_scaled: 0.2919 -Epoch 28/300 -5859/5859 [==============================] - 5582s 953ms/step - loss: 2.1336 - avg_loss: 2.1169 - avg_loss_scaled: 0.2651 -Epoch 29/300 -5859/5859 [==============================] - 5583s 953ms/step - loss: 2.0019 - avg_loss: 2.0183 - avg_loss_scaled: 0.2528 -Epoch 30/300 -5859/5859 [==============================] - 5584s 953ms/step - loss: 1.8734 - avg_loss: 1.8510 - avg_loss_scaled: 0.2318 -``` - -#### Testing - -```bash -python /path/to/TensorFlowASR/examples/transducer/rnnt/test.py \ - --config-path=/path/to/config.j2 \ - --saved=/path/to/models/sp1k-rnnt/only-data/checkpoints/30.h5 \ - --output=/path/to/models/sp1k-rnnt/only-data/tests/30.tsv \ - --bs=1 -``` - -RNNT Loss Curves: - - - -Error Rates: - -| Dataset | Mode | Batch size | Epoch | WER (%) | CER (%) | -| :--------- | :----: | :--------: | :---: | :---------------: | :----------------: | -| test-clean | greedy | 1 | 30 | 14.17757123708725 | 6.1642616987228394 | -| test-other | greedy | 1 | 30 | 33.20023715496063 | 17.79550015926361 | \ No newline at end of file diff --git a/examples/transducer/transformer/README.md b/examples/transducer/transformer/README.md deleted file mode 100755 index 532eb95342..0000000000 --- a/examples/transducer/transformer/README.md +++ /dev/null @@ -1,14 +0,0 @@ -# Transformer Transducer - -## Example Model YAML Config - -Go to [confs](./confs/) - -## Usage - -Training, see `python examples/transducer/transformer/train.py --help` - -Testing, see `python examples/transducer/transformer/test.py --help` - -TFLite Conversion, see `python examples/transducer/transformer/tflite.py --help` - diff --git a/examples/transducer/transformer/confs/rezero_config_wp.j2 b/examples/transducer/transformer/confs/rezero_config_wp.j2 deleted file mode 100644 index ccf854b85d..0000000000 --- a/examples/transducer/transformer/confs/rezero_config_wp.j2 +++ /dev/null @@ -1,173 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -{% set repodir = "." %} -{% set modeldir = "/Volumes/Data/Miscellanea/Models/local/tranformer-transducer" %} -{% set datadir = "/Volumes/Data/MLDL/Datasets/ASR/LibriSpeech" %} - -decoder_config: - type: wordpiece - blank_index: 0 - unknown_token: "" - unknown_index: 1 - beam_width: 0 - norm_score: True - lm_config: null - vocabulary: {{repodir}}/vocabularies/librispeech/wordpiece/train_1000_50.tokens - vocab_size: 1000 - max_token_length: 50 - max_unique_chars: 1000 - reserved_tokens: - - "" - - "" - normalization_form: NFKC - num_iterations: 4 - corpus_files: - - {{datadir}}/train-clean-100/transcripts.tsv - - {{datadir}}/train-clean-360/transcripts.tsv - - {{datadir}}/train-other-500/transcripts.tsv - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - normalize_feature: True - -model_config: - name: transformer - encoder_subsampling: - type: conv2d - nlayers: 2 - filters: 512 - kernel_size: 3 - strides: 2 - padding: causal - norm: none - activation: relu - encoder_dropout: 0.1 - encoder_residual_factor: rezero - encoder_norm_position: none - encoder_dmodel: 512 - encoder_dff: 1024 - encoder_num_blocks: 6 - encoder_head_size: 128 - encoder_num_heads: 4 - encoder_mha_type: relmha - encoder_interleave_relpe: True - encoder_use_attention_causal_mask: False - encoder_use_attention_auto_mask: True - encoder_pwffn_activation: relu - encoder_memory_length: 512 - prediction_label_encode_mode: embedding - prediction_embed_dim: 512 - prediction_num_rnns: 1 - prediction_rnn_units: 512 - prediction_rnn_type: lstm - prediction_rnn_implementation: 2 - prediction_rnn_unroll: False - prediction_layer_norm: True - prediction_projection_units: 0 - joint_dim: 512 - prejoint_encoder_linear: True - prejoint_prediction_linear: True - postjoint_linear: False - joint_activation: tanh - joint_mode: add - -learning_config: - train_dataset_config: - enabled: True - use_tf: True - augmentation_config: - feature_augment: - time_masking: - prob: 1.0 - num_masks: 10 - mask_factor: 100 - p_upperbound: 0.05 - mask_value: zero - freq_masking: - prob: 1.0 - num_masks: 1 - mask_factor: 27 - mask_value: zero - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: {{datadir}}/tfrecords - shuffle: True - cache: True - buffer_size: 100 - drop_remainder: True - stage: train - - eval_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/dev-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: eval - - test_dataset_config: - enabled: False - use_tf: True - data_paths: - - {{datadir}}/test-clean/transcripts.tsv - tfrecords_dir: null - shuffle: False - cache: True - buffer_size: 100 - drop_remainder: True - stage: test - - optimizer_config: - beta_1: 0.9 - beta_2: 0.98 - epsilon: 1e-9 - - learning_rate_config: - warmup_steps: 10000 - max_lr_numerator: 0.05 - - apply_gwn_config: - predict_net_step: 20000 - predict_net_stddev: 0.075 - - running_config: - batch_size: 2 - num_epochs: 100 - checkpoint: - filepath: {{modeldir}}/checkpoints/{epoch:02d}.h5 - save_best_only: False - save_weights_only: True - save_freq: epoch - options: - experimental_enable_async_checkpoint: True - backup_and_restore: - backup_dir: {{modeldir}}/states - save_freq: epoch - delete_checkpoint: False - tensorboard: - log_dir: {{modeldir}}/tensorboard - histogram_freq: 1 - write_graph: True - write_images: True - update_freq: epoch - profile_batch: 2 diff --git a/pyproject.toml b/pyproject.toml index b32f7badbc..01a5325f3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,3 +4,16 @@ line-length = 150 [tool.isort] profile = "black" line_length = 150 + +[tool.pytest.ini_options] +minversion = "6.0" +log_cli = true +log_cli_level = "WARNING" +log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +log_cli_date_format = "%Y-%m-%d %H:%M:%S" +testpaths = "tests" +python_files = "test_*.py" +addopts = "-s --durations=0" +filterwarnings = ["error", "ignore::UserWarning", "ignore::DeprecationWarning"] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "session" diff --git a/requirements.apple.txt b/requirements.apple.txt new file mode 100644 index 0000000000..d4dee8c591 --- /dev/null +++ b/requirements.apple.txt @@ -0,0 +1,2 @@ +tensorflow~=2.18.0 +tensorflow-text @ https://github.com/sun1638650145/Libraries-and-Extensions-for-TensorFlow-for-Apple-Silicon/releases/download/v2.18/tensorflow_text-2.18.1-cp312-cp312-macosx_11_0_arm64.whl \ No newline at end of file diff --git a/requirements.cpu.txt b/requirements.cpu.txt new file mode 100644 index 0000000000..f1ce6960c3 --- /dev/null +++ b/requirements.cpu.txt @@ -0,0 +1,2 @@ +tensorflow~=2.18.0 +tensorflow-text~=2.18.0 \ No newline at end of file diff --git a/requirements.dev.txt b/requirements.dev.txt new file mode 100644 index 0000000000..0fbb9b1433 --- /dev/null +++ b/requirements.dev.txt @@ -0,0 +1,9 @@ +pytest>=7.4.1 +black>=24.3.0 +pylint>=3.2.4 +matplotlib>=3.7.2 +pydot-ng>=2.0.0 +graphviz>=0.20.1 +pre-commit>=3.7.0 +tf2onnx>=1.16.1 +netron>=8.0.3 \ No newline at end of file diff --git a/requirements.gpu.txt b/requirements.gpu.txt new file mode 100644 index 0000000000..9ea84ee21e --- /dev/null +++ b/requirements.gpu.txt @@ -0,0 +1,2 @@ +tensorflow[and-cuda]~=2.18.0 +tensorflow-text~=2.18.0 \ No newline at end of file diff --git a/requirements.text.txt b/requirements.text.txt new file mode 100644 index 0000000000..cda1621c37 --- /dev/null +++ b/requirements.text.txt @@ -0,0 +1 @@ +tensorflow-text~=2.18.0 \ No newline at end of file diff --git a/requirements.tpu.txt b/requirements.tpu.txt new file mode 100644 index 0000000000..6284ec1e89 --- /dev/null +++ b/requirements.tpu.txt @@ -0,0 +1 @@ +tensorflow-tpu~=2.18.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 961e07d52b..dea5d19220 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,70 +1,17 @@ -SoundFile==0.10.3.post1 -tensorflow-datasets>=4.9.2 -nltk==3.7 -sentencepiece==0.1.97 -tqdm==4.64.1 -librosa==0.9.2 -PyYAML==6.0 -sounddevice==0.4.5 -jinja2==3.1.2 -fire==0.4.0 - -# extra=dev -pytest==7.3.1 -black==23.3.0 -pylint==2.17.4 -matplotlib==3.7.1 -pydot==1.4.2 -netron==6.9.2 -graphviz==0.20.1 -tf2onnx==1.14.0 - -# extra=tf2.8 -tensorflow>=2.8.0,<2.9.0 -tensorflow-text>=2.8.0,<2.9.0 -tensorflow-io>=0.25.0,<0.26.0 - -# extra=tf2.8-gpu -tensorflow-gpu>=2.8.0,<2.9.0 -tensorflow-text>=2.8.0,<2.9.0 -tensorflow-io>=0.25.0,<0.26.0 - -# extra=tf2.9 -tensorflow>=2.9.0,<2.10.0 -tensorflow-text>=2.9.0,<2.10.0 -tensorflow-io>=0.26.0,<0.27.0 - -# extra=tf2.9-gpu -tensorflow-gpu>=2.9.0,<2.10.0 -tensorflow-text>=2.9.0,<2.10.0 -tensorflow-io>=0.26.0,<0.27.0 - -# extra=tf2.10 -tensorflow>=2.10.0,<2.11.0 -tensorflow-text>=2.10.0,<2.11.0 -tensorflow-io>=0.27.0,<0.28.0 - -# extra=tf2.10-gpu -tensorflow-gpu>=2.10.0,<2.11.0 -tensorflow-text>=2.10.0,<2.11.0 -tensorflow-io>=0.27.0,<0.28.0 - -# extra=tf2.11 -tensorflow>=2.11.0,<2.12.0 -tensorflow-text>=2.11.0,<2.12.0 -tensorflow-io>=0.28.0,<0.31.0 - -# extra=tf2.11-gpu -tensorflow-gpu>=2.11.0,<2.12.0 -tensorflow-text>=2.11.0,<2.12.0 -tensorflow-io>=0.28.0,<0.31.0 - -# extra=tf2.12 -tensorflow>=2.12.0,<2.13.0 -tensorflow-text>=2.12.0,<2.13.0 -tensorflow-io>=0.32.0,<0.33.0 - -# extra=tf2.12-gpu -tensorflow-gpu>=2.12.0,<2.13.0 -tensorflow-text>=2.12.0,<2.13.0 -tensorflow-io>=0.32.0,<0.33.0 \ No newline at end of file +SoundFile~=0.12.1 +nltk>=3.9.0 +sentencepiece~=0.2.0 +tqdm>=4.67.1 +librosa~=0.10.1 +PyYAML~=6.0.1 +sounddevice~=0.4.6 +jinja2~=3.1.3 +fire>=0.7.0 +jiwer~=3.0.3 +keras-nightly~=3.9.0.dev # https://github.com/keras-team/keras/issues/20568#issuecomment-2510432421 +cached_property~=2.0.1 +ipywidgets~=8.1.5 +ipython<9.0.0 +kagglehub~=0.3.6 +datasets~=3.5.1 +tabulate~=0.9.0 \ No newline at end of file diff --git a/scripts/create_tfrecords.py b/scripts/create_tfrecords.py deleted file mode 100644 index fd640d9b43..0000000000 --- a/scripts/create_tfrecords.py +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from tensorflow_asr.utils import env_util - -logger = env_util.setup_environment() - -from tensorflow_asr.configs.config import Config -from tensorflow_asr.datasets.asr_dataset import ASRTFRecordDataset -from tensorflow_asr.helpers import featurizer_helpers -from tensorflow_asr.utils import cli_util, file_util - - -def main( - *transcripts, - mode: str = None, - config_path: str = None, - tfrecords_dir: str = None, - tfrecords_shards: int = 16, - shuffle: bool = True, -): - data_paths = file_util.preprocess_paths(transcripts) - tfrecords_dir = file_util.preprocess_paths(tfrecords_dir, isdir=True) - logger.info(f"Create tfrecords to directory: {tfrecords_dir}") - - config = Config(config_path) - - speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(config=config) - - tfrecord_dataset = ASRTFRecordDataset( - data_paths=data_paths, - tfrecords_dir=tfrecords_dir, - speech_featurizer=speech_featurizer, - text_featurizer=text_featurizer, - stage=mode, - shuffle=shuffle, - tfrecords_shards=tfrecords_shards, - ) - tfrecord_dataset.create_tfrecords() - - -if __name__ == "__main__": - cli_util.run(main) diff --git a/scripts/create_vocab_from_trans.py b/scripts/create_vocab_from_trans.py deleted file mode 100644 index 08ffaf6fcb..0000000000 --- a/scripts/create_vocab_from_trans.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse - -from tqdm.auto import tqdm - -parser = argparse.ArgumentParser(prog="Create vocabulary file from transcripts") - -parser.add_argument("--output", type=str, default=None, help="The output .txt vocabulary file path") - -parser.add_argument("transcripts", nargs="+", type=str, default=None, help="Transcript .tsv files") - -args = parser.parse_args() - -assert args.output and args.transcripts - -lines = [] -for trans in args.transcripts: - with open(trans, "r", encoding="utf-8") as t: - lines.extend(t.read().splitlines()[1:]) - -vocab = {} -for line in tqdm(lines, desc="[Processing]"): - line = line.split("\t")[-1] - for c in line: - vocab[c] = 1 - -with open(args.output, "w", encoding="utf-8") as out: - for key in vocab.keys(): - out.write(f"{key}\n") diff --git a/scripts/generate_metadata.py b/scripts/generate_metadata.py deleted file mode 100644 index b282cbdc3a..0000000000 --- a/scripts/generate_metadata.py +++ /dev/null @@ -1,50 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from tensorflow_asr.utils import env_util - -env_util.setup_environment() - -from tensorflow_asr.configs.config import Config -from tensorflow_asr.datasets.asr_dataset import ASRDataset -from tensorflow_asr.helpers import featurizer_helpers -from tensorflow_asr.utils import cli_util, file_util - - -def main( - *transcripts, - stage: str = "train", - config_path: str = None, - metadata: str = None, -): - transcripts = file_util.preprocess_paths(transcripts) - - config = Config(config_path) - - speech_featurizer, text_featurizer = featurizer_helpers.prepare_featurizers(config=config) - - dataset = ASRDataset( - data_paths=transcripts, - speech_featurizer=speech_featurizer, - text_featurizer=text_featurizer, - stage=stage, - shuffle=False, - ) - - dataset.update_metadata(metadata) - - -if __name__ == "__main__": - cli_util.run(main) diff --git a/scripts/generate_vocab_sentencepiece.py b/scripts/generate_vocab_sentencepiece.py deleted file mode 100644 index 987e4f9082..0000000000 --- a/scripts/generate_vocab_sentencepiece.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2022 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tensorflow as tf - -from tensorflow_asr.configs.config import Config -from tensorflow_asr.featurizers.text_featurizers import SentencePieceFeaturizer -from tensorflow_asr.utils import cli_util, env_util - -logger = env_util.setup_environment() - - -def main( - config_path: str, -): - tf.keras.backend.clear_session() - config = Config(config_path) - SentencePieceFeaturizer.build_from_corpus(config.decoder_config) - - -if __name__ == "__main__": - cli_util.run(main) diff --git a/scripts/generate_vocab_subwords.py b/scripts/generate_vocab_subwords.py deleted file mode 100644 index 310e9d15d8..0000000000 --- a/scripts/generate_vocab_subwords.py +++ /dev/null @@ -1,23 +0,0 @@ -import argparse -import os - -from tensorflow_asr.configs.config import Config -from tensorflow_asr.featurizers.text_featurizers import SubwordFeaturizer - -DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config_wp.j2") - -parser = argparse.ArgumentParser(prog="Vocab Training with Subwords") - -parser.add_argument("corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords") - -parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file") - -parser.add_argument("--output_file", type=str, default=None, help="Path to file that stores generated subwords") - -args = parser.parse_args() - -config = Config(args.config) - -print("Generating subwords ...") - -SubwordFeaturizer.build_from_corpus(config.decoder_config) diff --git a/scripts/generate_vocab_wordpiece.py b/scripts/generate_vocab_wordpiece.py deleted file mode 100644 index 8e11a3cb67..0000000000 --- a/scripts/generate_vocab_wordpiece.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2022 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from tensorflow_asr.configs.config import Config -from tensorflow_asr.featurizers.text_featurizers import WordPieceFeaturizer -from tensorflow_asr.utils import cli_util, env_util - -logger = env_util.setup_environment() - - -def main( - config_path: str, -): - config = Config(config_path) - WordPieceFeaturizer.build_from_corpus(decoder_config=config.decoder_config) - - -if __name__ == "__main__": - cli_util.run(main) diff --git a/scripts/install_ctc_decoders.sh b/scripts/install_ctc_decoders.sh index 8937e5e12c..2af40cc17b 100755 --- a/scripts/install_ctc_decoders.sh +++ b/scripts/install_ctc_decoders.sh @@ -1,18 +1,16 @@ #!/usr/bin/env bash -mkdir externals -cd ./externals || exit +PROJECT_DIR=$(realpath "$(dirname $0)/..") + +mkdir -p $PROJECT_DIR/externals +cd $PROJECT_DIR/externals || exit # Install baidu's beamsearch_with_lm if [ ! -d ctc_decoders ]; then - git clone https://github.com/nglehuy/ctc_decoders.git - + git clone --depth 1 https://github.com/nglehuy/ctc_decoders.git cd ./ctc_decoders || exit chmod a+x setup.sh - chown "$USER":"$USER" setup.sh ./setup.sh - - cd .. fi -cd .. +cd $PROJECT_DIR || exit diff --git a/scripts/install_ctc_loss.sh b/scripts/install_ctc_loss.sh new file mode 100755 index 0000000000..24627429bc --- /dev/null +++ b/scripts/install_ctc_loss.sh @@ -0,0 +1,47 @@ +#!/usr/bin/env bash + +PROJECT_DIR=$(realpath "$(dirname $0)/..") +cd "$PROJECT_DIR" || exit + +mkdir -p $PROJECT_DIR/externals +cd $PROJECT_DIR/externals || exit + +TF_VERSION=$(python3 -c "import tensorflow as tf; print(tf.__version__)") + +# Install rnnt_loss +if [ ! -d warp-ctc ]; then + git clone --depth 1 https://github.com/nglehuy/warp-ctc.git + cd $PROJECT_DIR/externals/warp-ctc/tensorflow_binding + if [ ! -d tensorflow ]; then + git clone --depth 1 --branch v$TF_VERSION https://github.com/tensorflow/tensorflow.git + fi + cd ../../ +fi + +export TENSORFLOW_SRC_PATH="$PROJECT_DIR/externals/warp-ctc/tensorflow_binding/tensorflow" + +rm -rf $PROJECT_DIR/externals/warp-ctc/build +mkdir -p $PROJECT_DIR/externals/warp-ctc/build +cd $PROJECT_DIR/externals/warp-ctc/build || exit + +if [ "$CUDA_HOME" ]; then + cmake \ + -DWITH_GPU=ON \ + -DCUDA_TOOLKIT_ROOT_DIR="$CUDA_HOME" .. +else + cmake \ + -DWITH_GPU=OFF \ + .. +fi + +make -j $(nproc) + +cd $PROJECT_DIR/externals/warp-ctc/tensorflow_binding || exit + +if [ "$CUDA_HOME" ]; then + CUDA="$CUDA_HOME" python3 setup.py install +else + python3 setup.py install +fi + +cd $PROJECT_DIR || exit \ No newline at end of file diff --git a/scripts/install_rnnt_loss.sh b/scripts/install_rnnt_loss.sh index a956774b32..e71afbff03 100755 --- a/scripts/install_rnnt_loss.sh +++ b/scripts/install_rnnt_loss.sh @@ -1,16 +1,28 @@ -#!/usr/bin/env sh +#!/usr/bin/env bash -mkdir -p externals -cd ./externals || exit +PROJECT_DIR=$(realpath "$(dirname $0)/..") +cd "$PROJECT_DIR" || exit + +mkdir -p $PROJECT_DIR/externals +cd $PROJECT_DIR/externals || exit + +TF_VERSION=$(python3 -c "import tensorflow as tf; print(tf.__version__)") # Install rnnt_loss if [ ! -d warp-transducer ]; then - git clone https://github.com/nglehuy/warp-transducer.git + git clone --depth 1 https://github.com/nglehuy/warp-transducer.git + cd $PROJECT_DIR/externals/warp-transducer/tensorflow_binding + if [ ! -d tensorflow ]; then + git clone --depth 1 --branch v$TF_VERSION https://github.com/tensorflow/tensorflow.git + fi + cd ../../ fi -cd ./warp-transducer || exit -rm -rf build -mkdir -p build && cd build || exit +export TENSORFLOW_SRC_PATH="$PROJECT_DIR/externals/warp-transducer/tensorflow_binding/tensorflow" + +rm -rf $PROJECT_DIR/externals/warp-transducer/build +mkdir -p $PROJECT_DIR/externals/warp-transducer/build +cd $PROJECT_DIR/externals/warp-transducer/build || exit if [ "$CUDA_HOME" ]; then cmake \ @@ -27,16 +39,14 @@ else -DCMAKE_CXX_COMPILER_LAUNCHER="$(which g++)" .. fi -make +make -j $(nproc) -cd ../tensorflow_binding || exit +cd $PROJECT_DIR/externals/warp-transducer/tensorflow_binding || exit if [ "$CUDA_HOME" ]; then - CUDA="$CUDA_HOME" python setup.py install + CUDA="$CUDA_HOME" python3 setup.py install else - python setup.py install + python3 setup.py install fi -cd ../.. - -cd .. +cd $PROJECT_DIR || exit \ No newline at end of file diff --git a/scripts/saved_model_to_weights.py b/scripts/saved_model_to_weights.py deleted file mode 100644 index 451c2deace..0000000000 --- a/scripts/saved_model_to_weights.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse - -import tensorflow as tf - -parser = argparse.ArgumentParser(prog="Convert saved model to weights") - -parser.add_argument("--saved", type=str, default=None, help="Saved model path") - -parser.add_argument("output", type=str, default=None, help="output file to store weights") - -args = parser.parse_args() - -assert args.saved and args.output - -model = tf.keras.models.load_model(args.saved) - -model.save_weights(args.output) diff --git a/setup.cfg b/setup.cfg index d5af6a6e7a..f9a30092bd 100755 --- a/setup.cfg +++ b/setup.cfg @@ -6,3 +6,7 @@ max-line-length = 150 ignore = E402,E701,E702,E704,E251,E203,W503,W504,C901,E501 max-line-length = 150 indent-size = 4 + +[options.entry_points] +console_scripts = + tensorflow_asr = tensorflow_asr.scripts:main \ No newline at end of file diff --git a/setup.py b/setup.py index 3eef87d682..48c6b59e9a 100644 --- a/setup.py +++ b/setup.py @@ -12,37 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections import defaultdict -from typing import List +import glob +import os from setuptools import find_packages, setup +install_requires = [] +extras_requires = {} -def parse_requirements(lines: List[str]): - _extras_requires = defaultdict(list) - extra = "requires" - for line in lines: - line = line.strip() - if line.startswith("# extra="): - extra = line.split("=")[1].strip() - continue - if line and line[0] != "#": - lib_package = line.split("#")[0].strip() # split comments - _extras_requires[extra].append(lib_package) - _install_requires = _extras_requires.pop("requires") - return _install_requires, _extras_requires - - -with open("requirements.txt", "r", encoding="utf-8") as fr: - install_requires, extras_requires = parse_requirements(fr.readlines()) +for req_file in glob.glob("requirements*.txt", recursive=False): + name = os.path.basename(req_file).split(".") + extra = name[1] if len(name) > 2 else None + with open(req_file, "r", encoding="utf-8") as fr: + if not extra: + install_requires = fr.readlines() + else: + extras_requires[extra] = fr.readlines() with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() - setup( name="TensorFlowASR", - version="2.0.0", + version="3.0.0", author="Huy Le Nguyen", author_email="nlhuy.cs.16@gmail.com", description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2", @@ -62,5 +54,5 @@ def parse_requirements(lines: List[str]): "License :: OSI Approved :: Apache Software License", "Topic :: Software Development :: Libraries :: Python Modules", ], - python_requires=">=3.6, <4", + python_requires=">=3.8, <4", ) diff --git a/setup.sh b/setup.sh new file mode 100755 index 0000000000..2b255b2031 --- /dev/null +++ b/setup.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash + +python3 -m pip install -r requirements.text.txt + +case "$1" in +tpu) + python3 -m pip uninstall -y tensorflow + python3 -m pip install -r requirements.tpu.txt -f https://storage.googleapis.com/libtpu-tf-releases/index.html --force +;; +gpu) + python3 -m pip install -r requirements.gpu.txt +;; +cpu) + python3 -m pip install -r requirements.cpu.txt +;; +apple) + python3 -m pip install -r requirements.apple.txt +;; +*) echo -e "Usage: $0 " +esac + +python3 -m pip uninstall -y keras # use keras-nightly +python3 -m pip install -r requirements.txt --force + +case "$2" in +dev) + python3 -m pip install -r requirements.dev.txt + python3 -m pip install -e . +;; +install) + python3 -m pip install -e . +;; +esac \ No newline at end of file diff --git a/tensorflow_asr/__init__.py b/tensorflow_asr/__init__.py index d3e37cab76..67a85f074f 100644 --- a/tensorflow_asr/__init__.py +++ b/tensorflow_asr/__init__.py @@ -1,15 +1,28 @@ +# pylint: disable=protected-access import os -import warnings -os.environ["TF_CPP_MIN_LOG_LEVEL"] = os.environ.get("TF_CPP_MIN_LOG_LEVEL", "2") +os.environ["TF_CPP_MIN_LOG_LEVEL"] = os.environ.get("TF_CPP_MIN_LOG_LEVEL") or "3" os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = os.environ.get("TF_FORCE_GPU_ALLOW_GROWTH", "true") -import tensorflow as tf +# import submodules to register keras objects +import glob +from os.path import basename, dirname, isdir, isfile, join -logger = tf.get_logger() -logger.setLevel(os.environ.get("LOG_LEVEL", "info").upper()) -logger.propagate = False -warnings.simplefilter("ignore") +import keras +import tensorflow as tf # for reference -from tensorflow_asr.models import * -from tensorflow_asr.optimizers import * +from tensorflow_asr.utils import env_util # import here fist to apply logging + +for fd in glob.glob(join(dirname(__file__), "*")): + if not isfile(fd) and not isdir(fd): + continue + if isfile(fd) and not fd.endswith(".py"): + continue + fd = fd if isdir(fd) else fd[:-3] + fd = basename(fd) + if fd.startswith("__"): + continue + __import__(f"{__name__}.{fd}") + + +__all__ = ["keras", "tf", "env_util"] diff --git a/tensorflow_asr/abstracts.py b/tensorflow_asr/abstracts.py new file mode 100644 index 0000000000..eb51282b58 --- /dev/null +++ b/tensorflow_asr/abstracts.py @@ -0,0 +1,41 @@ +import typing +from abc import ABC, abstractmethod + +from tensorflow_asr import tf + + +class AbstractTokenizer(ABC): + initialized: bool + + @abstractmethod + def make(self): + pass + + @abstractmethod + def tokenize(self, text: str) -> tf.Tensor: + pass + + @abstractmethod + def detokenize(self, indices: tf.Tensor) -> tf.Tensor: + pass + + @abstractmethod + def prepand_blank(self, text: tf.Tensor) -> tf.Tensor: + pass + + +class AbstractDataset(ABC): + name: str + num_entries: int + + @abstractmethod + def read_entries(self): + pass + + @abstractmethod + def generator(self) -> typing.Generator: + pass + + @abstractmethod + def vocab_generator(self) -> typing.Generator: + pass diff --git a/tensorflow_asr/augmentations/augmentation.py b/tensorflow_asr/augmentations/augmentation.py index 09cab27efd..289e4c2ec4 100644 --- a/tensorflow_asr/augmentations/augmentation.py +++ b/tensorflow_asr/augmentations/augmentation.py @@ -14,8 +14,7 @@ from typing import List -import tensorflow as tf - +from tensorflow_asr import tf from tensorflow_asr.augmentations.methods import gaussnoise, specaugment from tensorflow_asr.augmentations.methods.base_method import AugmentationMethod @@ -28,30 +27,72 @@ class Augmentation: def __init__(self, config: dict = None): - if not config: - config = {} - self.signal_augmentations = self.parse(config.pop("signal_augment", {})) - self.feature_augmentations = self.parse(config.pop("feature_augment", {})) + _config = config or {} + self.signal_augmentations = self.parse(_config.pop("signal_augment", {})) + self.feature_augmentations = self.parse(_config.pop("feature_augment", {})) def _augment(self, inputs, augmentations: List[AugmentationMethod]): outputs = inputs for au in augmentations: - p = tf.random.uniform([]) - outputs = tf.where(tf.less(p, au.prob), au.augment(outputs), outputs) + outputs = au.augment(outputs) + # p = tf.random.uniform(shape=[], dtype=tf.float32) + # outputs = tf.cond(tf.less(p, au.prob), lambda: au.augment(outputs), lambda: outputs) return outputs - @tf.function - def signal_augment(self, inputs): - return self._augment(inputs, self.signal_augmentations) + def signal_augment(self, inputs, inputs_length): + """ + Augment audio signals + + Parameters + ---------- + inputs : tf.Tensor, shape [B, None] + Original audio signals + inputs_length : tf.Tensor, shape [B] + Original audio signals length + + Returns + ------- + tf.Tensor, shape [B, None] + Augmented audio signals + """ + return tf.map_fn( + fn=lambda x: self._augment(x, self.signal_augmentations), + elems=(inputs, inputs_length), + fn_output_signature=( + tf.TensorSpec.from_tensor(inputs[0]), + tf.TensorSpec.from_tensor(inputs_length[0]), + ), + ) + + def feature_augment(self, inputs, inputs_length): + """ + Augment audio features + + Parameters + ---------- + inputs : tf.Tensor, shape [B, T, F] + Original audio features + inputs_length : tf.Tensor, shape [B] + Original audio features length - @tf.function - def feature_augment(self, inputs): - return self._augment(inputs, self.feature_augmentations) + Returns + ------- + tf.Tensor, shape [B, T, F] + Augmented audio features + """ + return tf.map_fn( + fn=lambda x: self._augment(x, self.feature_augmentations), + elems=(inputs, inputs_length), + fn_output_signature=( + tf.TensorSpec.from_tensor(inputs[0]), + tf.TensorSpec.from_tensor(inputs_length[0]), + ), + ) @staticmethod def parse(config: dict) -> list: augmentations = [] - for key, value in config.items(): + for key, value in sorted(config.items(), key=lambda x: x[0]): au = AUGMENTATIONS.get(key, None) if au is None: raise KeyError(f"No tf augmentation named: {key}\n" f"Available tf augmentations: {AUGMENTATIONS.keys()}") diff --git a/tensorflow_asr/augmentations/methods/base_method.py b/tensorflow_asr/augmentations/methods/base_method.py index 64d141af21..6b1208b612 100644 --- a/tensorflow_asr/augmentations/methods/base_method.py +++ b/tensorflow_asr/augmentations/methods/base_method.py @@ -12,13 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf - class AugmentationMethod: def __init__(self, prob: float = 0.5): self.prob = prob - @tf.function def augment(self, *args, **kwargs): raise NotImplementedError() diff --git a/tensorflow_asr/augmentations/methods/gaussnoise.py b/tensorflow_asr/augmentations/methods/gaussnoise.py index 41083e49ed..6395d7261e 100644 --- a/tensorflow_asr/augmentations/methods/gaussnoise.py +++ b/tensorflow_asr/augmentations/methods/gaussnoise.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf - +from tensorflow_asr import tf from tensorflow_asr.augmentations.methods.base_method import AugmentationMethod @@ -28,7 +27,11 @@ def __init__( self.mean = mean self.stddev = stddev - @tf.function - def augment(self, inputs: tf.Tensor): + def augment(self, args): + inputs, inputs_length = args + prob = tf.random.uniform(shape=[], minval=0, maxval=1, dtype=tf.float32) + do_apply = tf.where(tf.less_equal(prob, self.prob), tf.constant(1, inputs.dtype), tf.constant(0, inputs.dtype)) noise = tf.random.normal(shape=tf.shape(inputs), mean=self.mean, stddev=self.stddev, dtype=inputs.dtype) - return tf.add(inputs, noise) + noise *= tf.sequence_mask(inputs_length, inputs.shape[1], dtype=inputs.dtype) + noise *= do_apply + return tf.add(inputs, noise), inputs_length diff --git a/tensorflow_asr/augmentations/methods/specaugment.py b/tensorflow_asr/augmentations/methods/specaugment.py index f1ec30c2f5..22f008f27a 100644 --- a/tensorflow_asr/augmentations/methods/specaugment.py +++ b/tensorflow_asr/augmentations/methods/specaugment.py @@ -12,12 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf +from dataclasses import asdict, dataclass +from tensorflow_asr import tf from tensorflow_asr.augmentations.methods.base_method import AugmentationMethod from tensorflow_asr.utils import shape_util -MASK_VALUES = ["mean", "min", "max", "zero"] + +@dataclass +class MASK_VALUES: + MEAN: str = "mean" + MIN: str = "min" + MAX: str = "max" + ZERO: str = "zero" + + +def get_mask_value(inputs: tf.Tensor, mask_value=MASK_VALUES.ZERO): + if isinstance(mask_value, (int, float)): + return tf.constant(mask_value, dtype=inputs.dtype) + if mask_value == MASK_VALUES.MEAN: + return tf.reduce_mean(inputs) + if mask_value == MASK_VALUES.MIN: + return tf.reduce_min(inputs) + if mask_value == MASK_VALUES.MAX: + return tf.reduce_max(inputs) + return tf.constant(0, dtype=inputs.dtype) # default zero class FreqMasking(AugmentationMethod): @@ -26,41 +45,46 @@ def __init__( num_masks: int = 1, mask_factor: float = 27, prob: float = 1.0, - mask_value: str = "zero", + mask_value="zero", ): super().__init__(prob=prob) self.num_masks = num_masks self.mask_factor = mask_factor self.mask_value = mask_value - if self.mask_value not in MASK_VALUES: - raise ValueError(f"mask_value must in {MASK_VALUES}") + if self.mask_value not in asdict(MASK_VALUES()).values(): + if not isinstance(self.mask_value, (int, float)): + raise ValueError(f"mask_value must in {asdict(MASK_VALUES()).values()} or a number") - @tf.function - def augment(self, spectrogram: tf.Tensor): + def augment(self, args): """ Masking the frequency channels (shape[1]) - Args: - spectrogram: shape (T, num_feature_bins, V) - Returns: - frequency masked spectrogram + + Parameters + ---------- + spectrogram : tf.Tensor, shape [T, num_feature_bins] or [T, num_feature_bins, 1] + Audio features + + Returns + ------- + tf.Tensor, shape [T, num_feature_bins] or [T, num_feature_bins, 1] + Masked frequency dim of audio features """ - _, F, _ = shape_util.shape_list(spectrogram, out_type=tf.int32) - if self.mask_value == "mean": - mval = tf.reduce_mean(spectrogram) - elif self.mask_value == "min": - mval = tf.reduce_min(spectrogram) - elif self.mask_value == "max": - mval = tf.reduce_max(spectrogram) - elif self.mask_value == "zero": - mval = tf.constant(0, dtype=spectrogram.dtype) - for _ in range(self.num_masks): - f = tf.random.uniform([], minval=0, maxval=self.mask_factor, dtype=tf.dtypes.int32) - f = tf.minimum(f, F) - f0 = tf.random.uniform([], minval=0, maxval=F - f, dtype=tf.dtypes.int32) - indices = tf.reshape(tf.range(F), (1, -1, 1)) - condition = tf.math.logical_and(tf.math.greater_equal(indices, f0), tf.math.less(indices, f0 + f)) - spectrogram = tf.where(condition, mval, spectrogram) - return spectrogram + with tf.name_scope("freq_masking_specaugment"): + spectrogram, spectrogram_length = args + _, frequency_length, *rest = shape_util.shape_list(spectrogram, out_type=tf.int32) + indices_shape = (1, -1) + (1,) * len(rest) + mval = get_mask_value(spectrogram, mask_value=self.mask_value) + F = tf.convert_to_tensor(self.mask_factor, dtype=tf.int32) + for _ in range(self.num_masks): + prob = tf.random.uniform(shape=[], minval=0, maxval=1, dtype=spectrogram.dtype) + do_apply = tf.where(tf.less_equal(prob, self.prob), tf.constant(1, tf.int32), tf.constant(0, tf.int32)) + f = tf.random.uniform(shape=[], minval=0, maxval=F, dtype=tf.int32) + f = do_apply * tf.minimum(f, frequency_length) + f0 = do_apply * tf.random.uniform(shape=[], minval=0, maxval=(frequency_length - f), dtype=tf.int32) + indices = tf.reshape(tf.range(frequency_length), indices_shape) + condition = tf.math.logical_and(tf.math.greater_equal(indices, f0), tf.math.less(indices, f0 + f)) + spectrogram = tf.where(condition, mval, spectrogram) + return spectrogram, spectrogram_length class TimeMasking(AugmentationMethod): @@ -77,32 +101,37 @@ def __init__( self.mask_factor = mask_factor self.p_upperbound = p_upperbound self.mask_value = mask_value - if self.mask_value not in MASK_VALUES: - raise ValueError(f"mask_value must in {MASK_VALUES}") + if self.mask_value not in asdict(MASK_VALUES()).values(): + if not isinstance(self.mask_value, (int, float)): + raise ValueError(f"mask_value must in {asdict(MASK_VALUES()).values()} or a number") - @tf.function - def augment(self, spectrogram: tf.Tensor): + def augment(self, args): """ Masking the time channel (shape[0]) - Args: - spectrogram: shape (T, num_feature_bins, V) - Returns: - frequency masked spectrogram + + Parameters + ---------- + spectrogram : tf.Tensor, shape [T, num_feature_bins] or [T, num_feature_bins, 1] + Audio features + + Returns + ------- + tf.Tensor, shape [T, num_feature_bins] or [T, num_feature_bins, 1] + Masked time dim of audio features """ - T, _, _ = shape_util.shape_list(spectrogram, out_type=tf.int32) - if self.mask_value == "mean": - mval = tf.reduce_mean(spectrogram) - elif self.mask_value == "min": - mval = tf.reduce_min(spectrogram) - elif self.mask_value == "max": - mval = tf.reduce_max(spectrogram) - elif self.mask_value == "zero": - mval = tf.constant(0, dtype=spectrogram.dtype) - for _ in range(self.num_masks): - t = tf.random.uniform([], minval=0, maxval=self.mask_factor, dtype=tf.int32) - t = tf.minimum(t, tf.cast(tf.cast(T, dtype=tf.float32) * self.p_upperbound, dtype=tf.int32)) - t0 = tf.random.uniform([], minval=0, maxval=(T - t), dtype=tf.int32) - indices = tf.reshape(tf.range(T), (-1, 1, 1)) - condition = tf.math.logical_and(tf.math.greater_equal(indices, t0), tf.math.less(indices, t0 + t)) - spectrogram = tf.where(condition, mval, spectrogram) - return spectrogram + with tf.name_scope("time_masking_specaugment"): + spectrogram, spectrogram_length = args + max_length, *rest = shape_util.shape_list(spectrogram, out_type=tf.int32) + indices_shape = (-1,) + (1,) * len(rest) + mval = get_mask_value(spectrogram, mask_value=self.mask_value) + T = tf.cast(tf.floor(tf.cast(spectrogram_length, dtype=spectrogram.dtype) * self.p_upperbound), dtype=tf.int32) + for _ in range(self.num_masks): + prob = tf.random.uniform(shape=[], minval=0, maxval=1, dtype=spectrogram.dtype) + do_apply = tf.where(tf.less_equal(prob, self.prob), tf.constant(1, tf.int32), tf.constant(0, tf.int32)) + t = tf.random.uniform(shape=[], minval=0, maxval=T, dtype=tf.int32) + t = do_apply * tf.minimum(t, spectrogram_length) + t0 = do_apply * tf.random.uniform(shape=[], minval=0, maxval=(spectrogram_length - t), dtype=tf.int32) + indices = tf.reshape(tf.range(max_length), indices_shape) + condition = tf.math.logical_and(tf.math.greater_equal(indices, t0), tf.math.less(indices, t0 + t)) + spectrogram = tf.where(condition, mval, spectrogram) + return spectrogram, spectrogram_length diff --git a/tensorflow_asr/callbacks.py b/tensorflow_asr/callbacks.py new file mode 100644 index 0000000000..9e16aab766 --- /dev/null +++ b/tensorflow_asr/callbacks.py @@ -0,0 +1,414 @@ +# Copyright 2023 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import logging +import os +import shutil +from http import HTTPStatus + +import numpy as np +from keras.src.saving import serialization_lib + +from tensorflow_asr import keras, tf +from tensorflow_asr.datasets import ASRDataset +from tensorflow_asr.utils import file_util + +logger = logging.getLogger(__name__) + + +@keras.utils.register_keras_serializable(package=__name__) +class TestLogger(keras.callbacks.Callback): + def __init__(self): + super().__init__() + self.wer = {"numer": 0, "denom": 0} + self.cer = {"numer": 0, "denom": 0} + + @staticmethod + def compute_wer(decode, target, dtype=tf.float32): + decode = tf.strings.split(decode) + target = tf.strings.split(target) + distances = tf.cast(tf.edit_distance(decode.to_sparse(), target.to_sparse(), normalize=False), dtype) # [B] + lengths = tf.cast(target.row_lengths(axis=1), dtype) # [B] + return distances, lengths + + @staticmethod + def compute_cer(decode, target, dtype=tf.float32): + decode = tf.strings.bytes_split(decode) # [B, N] + target = tf.strings.bytes_split(target) # [B, M] + distances = tf.cast(tf.edit_distance(decode.to_sparse(), target.to_sparse(), normalize=False), dtype) # [B] + lengths = tf.cast(target.row_lengths(axis=1), dtype) # [B] + return distances, lengths + + def on_test_batch_end(self, batch, logs=None): + if logs is None: + return + + predictions = logs.pop("predictions") + if predictions is None: + return + + transcripts = self.model.tokenizer.detokenize(predictions.pop("_tokens")) + targets = self.model.tokenizer.detokenize(predictions.pop("_labels")) + + wer_numer, wer_denom = tf.nest.map_structure(tf.reduce_sum, TestLogger.compute_wer(transcripts, targets)) + cer_numer, cer_denom = tf.nest.map_structure(tf.reduce_sum, TestLogger.compute_cer(transcripts, targets)) + + self.wer["numer"] += wer_numer.numpy() + self.wer["denom"] += wer_denom.numpy() + self.cer["numer"] += cer_numer.numpy() + self.cer["denom"] += cer_denom.numpy() + + def on_test_end(self, logs=None): + logs = logs or {} + logs["wer"] = np.divide(self.wer["numer"], self.wer["denom"]) # handled nan + logs["cer"] = np.divide(self.cer["numer"], self.cer["denom"]) + return logs + + def get_config(self): + return {} + + @classmethod + def from_config(cls, config): + return cls(**config) + + +@keras.utils.register_keras_serializable(package=__name__) +class PredictLogger(keras.callbacks.Callback): + def __init__(self, test_dataset: ASRDataset, output_file_path: str): + super().__init__() + self.test_dataset = test_dataset + self.output_file_path = output_file_path + + def on_predict_begin(self, logs=None): + self.index = 0 + self.output_file = tf.io.gfile.GFile(self.output_file_path, mode="w") + self.output_file.write("\t".join(("PATH", "GROUND_TRUTH", "GREEDY", "BEAM_SEARCH")) + "\n") # header + + def on_predict_batch_end(self, batch, logs=None): + if logs is None: + return + + transcripts = self.model.tokenizer.detokenize(logs.pop("tokens")) + beam_transcripts = self.model.tokenizer.detokenize(logs.pop("beam_tokens")) + targets = self.model.tokenizer.detokenize(logs.pop("labels")) + + for i, item in enumerate(zip(targets.numpy(), transcripts.numpy(), beam_transcripts.numpy()), start=self.index): + groundtruth, greedy, beam = [x.decode("utf-8") for x in item] + path = self.test_dataset.entries[i][0] + line = "\t".join((path, groundtruth, greedy, beam)) + "\n" + self.output_file.write(line) + self.index += 1 + + def on_predict_end(self, logs=None): + self.index = 0 + self.output_file.close() + + def get_config(self): + return {} + + @classmethod + def from_config(cls, config): + return cls(**config) + + +@keras.utils.register_keras_serializable(package=__name__) +class TensorBoard(keras.callbacks.TensorBoard): + def __init__( + self, + log_dir="logs", + histogram_freq=0, + write_graph=True, + write_images=False, + write_steps_per_second=False, + update_freq="epoch", + profile_batch=0, + embeddings_freq=0, + embeddings_metadata=None, + **kwargs, + ): + log_dir = file_util.preprocess_paths(log_dir, isdir=True) + super().__init__( + log_dir, + histogram_freq, + write_graph, + write_images, + write_steps_per_second, + update_freq, + profile_batch, + embeddings_freq, + embeddings_metadata, + **kwargs, + ) + self._profile_batch = profile_batch + + def on_train_batch_end(self, batch, logs=None): + train_logs = dict((logs or {}).items()) + train_logs = self._collect_learning_rate(train_logs) + return super().on_train_batch_end(batch, train_logs) + + def get_config(self): + return { + "log_dir": self.log_dir, + "histogram_freq": self.histogram_freq, + "write_graph": self.write_graph, + "write_images": self.write_images, + "write_steps_per_second": self.write_steps_per_second, + "update_freq": self.update_freq, + "profile_batch": self._profile_batch, + "embeddings_freq": self.embeddings_freq, + "embeddings_metadata": self.embeddings_metadata, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + +@keras.utils.register_keras_serializable(package=__name__) +class TerminateOnNaN(keras.callbacks.TerminateOnNaN): + def get_config(self): + return {} + + @classmethod + def from_config(cls, config): + return cls(**config) + + +@keras.utils.register_keras_serializable(package=__name__) +class ModelCheckpoint(keras.callbacks.ModelCheckpoint): + def __init__( + self, + filepath, + monitor="val_loss", + verbose=0, + save_best_only=False, + save_weights_only=False, + mode="auto", + save_freq="epoch", + initial_value_threshold=None, + keep_checkpoints=5, + ): + filepath = file_util.preprocess_paths(filepath) + self._mode = mode + self._keep_checkpoints = keep_checkpoints + super().__init__(filepath, monitor, verbose, save_best_only, save_weights_only, mode, save_freq, initial_value_threshold) + + def _delete_obsolete_checkpoint(self, epoch, batch=None, logs=None): + for ep in range(int(epoch) - self._keep_checkpoints): + filepath = self._get_file_path(epoch=ep, batch=batch, logs=logs) + if tf.io.gfile.exists(filepath): + tf.io.gfile.remove(filepath) + + def on_train_batch_end(self, batch, logs=None): + super().on_train_batch_end(batch, logs) + if self._should_save_on_batch(batch): + self._delete_obsolete_checkpoint(self._current_epoch, batch, logs) + + def on_epoch_end(self, epoch, logs=None): + super().on_epoch_end(epoch, logs) + if self.save_freq == "epoch": + self._delete_obsolete_checkpoint(epoch, None, logs) + + def get_config(self): + return { + "filepath": self.filepath, + "monitor": self.monitor, + "verbose": self.verbose, + "save_best_only": self.save_best_only, + "save_weights_only": self.save_weights_only, + "mode": self._mode, + "save_freq": self.save_freq, + "initial_value_threshold": self.best, + "keep_checkpoints": self._keep_checkpoints, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + +@keras.utils.register_keras_serializable(package=__name__) +class BackupAndRestore(keras.callbacks.BackupAndRestore): + def __init__( + self, + backup_dir, + save_freq="epoch", + double_checkpoint=True, + delete_checkpoint=False, + ): + backup_dir = file_util.preprocess_paths(backup_dir, isdir=True) + super().__init__(backup_dir=backup_dir, save_freq=save_freq, double_checkpoint=double_checkpoint, delete_checkpoint=delete_checkpoint) + + def get_config(self): + return { + "backup_dir": self.backup_dir, + "save_freq": self.save_freq, + "delete_checkpoint": self.delete_checkpoint, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + +@keras.utils.register_keras_serializable(package=__name__) +class EarlyStopping(keras.callbacks.EarlyStopping): + def __init__( + self, + monitor="val_loss", + min_delta=0, + patience=0, + verbose=0, + mode="auto", + baseline=None, + restore_best_weights=False, + start_from_epoch=0, + ): + super().__init__(monitor, min_delta, patience, verbose, mode, baseline, restore_best_weights, start_from_epoch) + self._mode = mode + + def get_config(self): + return { + "monitor": self.monitor, + "min_delta": self.min_delta, + "patience": self.patience, + "verbose": self.verbose, + "mode": self._mode, + "baseline": self.baseline, + "restore_best_weights": self.restore_best_weights, + "start_from_epoch": self.start_from_epoch, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + +@keras.utils.register_keras_serializable(package=__name__) +class KaggleModelBackupAndRestore(BackupAndRestore): + def __init__( + self, + model_dir: str, + model_handle: str = None, + save_freq="epoch", + ): + backup_dir = os.path.join(model_dir, "states") + super().__init__(backup_dir, save_freq=save_freq, double_checkpoint=True, delete_checkpoint=False) + + try: + # use option 2,3 to authenticate kaggle: https://github.com/Kaggle/kagglehub?tab=readme-ov-file#option-2-read-credentials-from-environment-variables pylint: disable=line-too-long + self._api = importlib.import_module("kagglehub") + + logging.getLogger("kagglehub").disabled = True + logging.getLogger("kagglehub").handlers.clear() + + except ImportError as e: + raise ImportError("Kaggle library is not installed. Please install it via `pip install '.[kaggle]'`.") from e + + self._model_handle = model_handle + self._model_dir = file_util.preprocess_paths(model_dir, isdir=True) + if file_util.is_cloud_path(model_dir): + raise ValueError(f"Model dir must be local path for Kaggle backup and restore. Received: {model_dir}") + self.save_freq = save_freq + if save_freq != "epoch" and not isinstance(save_freq, int): + raise ValueError( + "Invalid value for argument `save_freq`. " f"Received: save_freq={save_freq}. " "Expected either 'epoch' or an integer value." + ) + + self._batches_seen_since_last_saving = 0 + self._last_batch_seen = 0 + self._current_epoch = 0 + + def _restore_kaggle(self): + if not self._model_handle: + return + + if os.path.exists(self._weights_path) and os.path.exists(self._training_metadata_path): + logger.info(f"Backup and restore from '{self.backup_dir}'...") + return + + from kagglehub.exceptions import KaggleApiHTTPError # pylint: disable=import-outside-toplevel + + try: + cached_path = self._api.model_download(handle=self._model_handle, force_download=True) + logger.info(f"Restoring model from '{cached_path}'...") + has_version = False + try: + has_version = int(os.path.basename(cached_path)) + except: # pylint: disable=bare-except + pass + if not has_version: + latest_version = None + for x in os.listdir(cached_path): + try: + latest_version = max(filter(None, (latest_version, int(x)))) + except: # pylint: disable=bare-except + pass + if not latest_version: + logger.info(f"Model '{self._model_handle}' does not have any version. Skipping restore...") + return + cached_path = os.path.join(cached_path, str(latest_version)) + shutil.copytree(cached_path, self._model_dir, ignore_dangling_symlinks=True, dirs_exist_ok=True) + shutil.rmtree(cached_path) + logger.info(f"Model restored to '{self._model_dir}'") + + except KaggleApiHTTPError as e: + if e.response is not None and (e.response.status_code in (HTTPStatus.NOT_FOUND, HTTPStatus.FORBIDDEN)): + logger.info( + f"Model '{self._model_handle}' does not exist or access is forbidden. It will be auto-create on saving. Skipping restore..." + ) + + def _backup_kaggle(self, logs, notes: str): + if not self._model_handle: + return + logs = logs or {} + loss = logs.get("loss") + if loss is not None: + if np.isnan(loss) or np.isinf(loss): + return # Don't save this epoch if loss is NaN or Inf + self._api.model_upload(handle=self._model_handle, local_model_dir=self._model_dir, version_notes=notes, ignore_patterns=[".DS_Store"]) + + def on_train_begin(self, logs=None): + self._restore_kaggle() + super().on_train_begin(logs) + + def on_epoch_end(self, epoch, logs=None): + self._current_epoch = epoch + 1 + self._last_batch_seen = 0 + if self.save_freq == "epoch": + self._save_model() + self._backup_kaggle(logs, notes=f"Backed up model at epoch {self._current_epoch}") + + def on_train_batch_end(self, batch, logs=None): + if self._should_save_on_batch(batch): + self._save_model() + self._backup_kaggle(logs, notes=f"Backed up model at batch {batch}") + + def get_config(self): + return { + "model_handle": self._model_handle, + "model_dir": self._model_dir, + "save_freq": self.save_freq, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + +def deserialize(callback_config): + if isinstance(callback_config, list): + return [serialization_lib.deserialize_keras_object(c) for c in callback_config] + return serialization_lib.deserialize_keras_object(callback_config) diff --git a/tensorflow_asr/configs.py b/tensorflow_asr/configs.py new file mode 100644 index 0000000000..759261d617 --- /dev/null +++ b/tensorflow_asr/configs.py @@ -0,0 +1,129 @@ +# Copyright 2020 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +from typing import Union + +from tensorflow_asr.utils import file_util + +logger = logging.getLogger(__name__) + + +class DecoderConfig: + def __init__(self, config: dict = None): + if not config: + config = {} + self.type: str = config.pop("type", "wordpiece") + + self.blank_index: int = config.pop("blank_index", 0) + self.pad_token: str = config.pop("pad_token", "") + self.pad_index: int = config.pop("pad_index", -1) + self.unknown_token: str = config.pop("unknown_token", "") + self.unknown_index: int = config.pop("unknown_index", 0) + self.bos_token: str = config.pop("bos_token", "") + self.bos_index: int = config.pop("bos_index", -1) + self.eos_token: str = config.pop("eos_token", "") + self.eos_index: int = config.pop("eos_index", -1) + + self.beam_width: int = config.pop("beam_width", 0) + self.norm_score: bool = config.pop("norm_score", True) + self.lm_config: dict = config.pop("lm_config", {}) + + self.model_type: str = config.pop("model_type", "unigram") + self.vocabulary: str = config.pop("vocabulary", None) + self.vocab_size: int = config.pop("vocab_size", 1000) + self.max_token_length: int = config.pop("max_token_length", 50) + self.max_unique_chars: int = config.pop("max_unique_chars", None) + self.num_iterations: int = config.pop("num_iterations", 4) + self.reserved_tokens: list = config.pop("reserved_tokens", None) + self.normalization_form: str = config.pop("normalization_form", "NFKC") + self.keep_whitespace: bool = config.pop("keep_whitespace", False) + self.max_sentence_length: int = config.pop("max_sentence_length", 1048576) # bytes + self.max_sentencepiece_length: int = config.pop("max_sentencepiece_length", 16) # bytes + self.character_coverage: float = config.pop("character_coverage", 1.0) # 0.9995 for languages with rich character, else 1.0 + + for k, v in config.items(): + setattr(self, k, v) + + +class DatasetConfig: + def __init__(self, config: dict = None): + if not config: + config = {} + self.name: str = config.pop("name", "") + self.enabled: bool = config.pop("enabled", True) + self.stage: str = config.pop("stage", None) + self.data_paths = config.pop("data_paths", None) + self.tfrecords_dir: str = config.pop("tfrecords_dir", None) + self.tfrecords_shards: int = config.pop("tfrecords_shards", 16) + self.tfrecords_buffer_size: int = config.pop("tfrecords_buffer_size", 32 * 1024 * 1024) + self.shuffle: bool = config.pop("shuffle", False) + self.cache: bool = config.pop("cache", False) + self.drop_remainder: bool = config.pop("drop_remainder", True) + self.buffer_size: int = config.pop("buffer_size", 1000) + self.metadata: str = config.pop("metadata", None) + self.sample_rate: int = config.pop("sample_rate", 16000) + for k, v in config.items(): + setattr(self, k, v) + + +class DataConfig: + def __init__(self, config: dict = None): + if not config: + config = {} + self.train_dataset_config = DatasetConfig(config.pop("train_dataset_config", {})) + self.eval_dataset_config = DatasetConfig(config.pop("eval_dataset_config", {})) + self.test_dataset_configs = [DatasetConfig(conf) for conf in config.pop("test_dataset_configs", [])] + _test_dataset_config = config.pop("test_dataset_config", None) + if _test_dataset_config: + self.test_dataset_configs.append(_test_dataset_config) + + +class LearningConfig: + def __init__(self, config: dict = None): + if not config: + config = {} + self.pretrained = config.pop("pretrained", None) + self.optimizer_config: dict = config.pop("optimizer_config", {}) + self.gwn_config = config.pop("gwn_config", None) + self.gradn_config = config.pop("gradn_config", None) + self.batch_size: int = config.pop("batch_size", 2) + self.ga_steps: int = config.pop("ga_steps", None) + self.num_epochs: int = config.pop("num_epochs", 300) + self.callbacks: list = config.pop("callbacks", []) + for k, v in config.items(): + setattr(self, k, v) + + +class Config: + """User config class for training, testing or infering""" + + def __init__(self, data: Union[str, dict], training=True, **kwargs): + config = data if isinstance(data, dict) else file_util.load_yaml(file_util.preprocess_paths(data), **kwargs) + self.decoder_config = DecoderConfig(config.pop("decoder_config", {})) + self.model_config: dict = config.pop("model_config", {}) + self.data_config = DataConfig(config.pop("data_config", {})) + self.learning_config = LearningConfig(config.pop("learning_config", {})) if training else None + for k, v in config.items(): + setattr(self, k, v) + + def __str__(self) -> str: + def default(x): + try: + return {k: v for k, v in vars(x).items() if not str(k).startswith("_")} + except: # pylint: disable=bare-except + return str(x) + + return json.dumps(vars(self), indent=2, default=default) diff --git a/tensorflow_asr/configs/config.py b/tensorflow_asr/configs/config.py deleted file mode 100644 index 8f247e8bf3..0000000000 --- a/tensorflow_asr/configs/config.py +++ /dev/null @@ -1,174 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Union - -import tensorflow as tf - -from tensorflow_asr.augmentations.augmentation import Augmentation -from tensorflow_asr.utils import file_util - - -class SpeechConfig: - def __init__(self, config: dict = None): - if not config: - config = {} - - # Sample rate in Hz - self.sample_rate: int = config.pop("sample_rate", 16000) - # Amount of data grabbed for each frame during analysis - self.frame_ms: int = config.pop("frame_ms", 25) - self.frame_length = int(round(self.sample_rate * self.frame_ms / 1000.0)) - # Number of ms to jump between frames - self.stride_ms: int = config.pop("stride_ms", 10) - self.frame_step = int(round(self.sample_rate * self.stride_ms / 1000.0)) - # Number of bins in the feature output - self.num_feature_bins: int = config.pop("num_feature_bins", 80) - # Type of feature extraction - self.feature_type: str = config.pop("feature_type", "log_mel_spectrogram") - - # The first-order filter coefficient used for preemphasis. When it is 0.0, preemphasis is turned off. - self.preemphasis: float = config.pop("preemphasis", 0.97) - # Whether to pad the end of `signals` with zeros when framing produces a frame that lies partially past its end. - self.pad_end: bool = config.pop("pad_end", False) - # Use librosa like stft - self.use_librosa_like_stft: bool = config.pop("use_librosa_like_stft", False) - # Whether to use twice the minimum fft resolution. - self.fft_overdrive: bool = config.pop("fft_overdrive", True) - # Whether to compute filterbank output on the energy of spectrum rather than just the magnitude. - self.compute_energy: bool = config.pop("compute_energy", False) - # Minimum output of filterbank output prior to taking logarithm. - self.output_floor: float = config.pop("output_floor", 1e-10) - # Use natural log - self.use_natural_log: bool = config.pop("use_natural_log", True) - # The lowest frequency of the feature analysis - self.lower_edge_hertz: float = config.pop("lower_edge_hertz", 125.0) - # The highest frequency of the feature analysis - self.upper_edge_hertz: float = config.pop("upper_edge_hertz", self.sample_rate / 2) - - self.normalize_signal: bool = config.pop("normalize_signal", False) - self.normalize_feature: bool = config.pop("normalize_feature", False) - self.normalize_per_frame: bool = config.pop("normalize_per_frame", False) - - for k, v in config.items(): - setattr(self, k, v) - - -class DecoderConfig: - def __init__(self, config: dict = None): - if not config: - config = {} - self.type: str = config.pop("type", "wordpiece") - - self.blank_index: int = config.pop("blank_index", 0) - self.pad_token: str = config.pop("pad_token", "") - self.pad_index: int = config.pop("pad_index", 0) - self.unknown_token: str = config.pop("unknown_token", "") - self.unknown_index: int = config.pop("unknown_index", 1) - self.bos_token: str = config.pop("bos_token", "") - self.bos_index: int = config.pop("bos_index", 2) - self.eos_token: str = config.pop("eos_token", "") - self.eos_index: int = config.pop("eos_index", 3) - - self.beam_width: int = config.pop("beam_width", 0) - self.norm_score: bool = config.pop("norm_score", True) - self.lm_config: dict = config.pop("lm_config", {}) - - self.model_type: str = config.pop("model_type", "unigram") - self.vocabulary: str = file_util.preprocess_paths(config.pop("vocabulary", None)) - self.vocab_size: int = config.pop("vocab_size", 1000) - self.max_token_length: int = config.pop("max_token_length", 50) - self.max_unique_chars: int = config.pop("max_unique_chars", None) - self.num_iterations: int = config.pop("num_iterations", 4) - self.reserved_tokens: list = config.pop("reserved_tokens", None) - self.normalization_form: str = config.pop("normalization_form", "NFKC") - - self.corpus_files = file_util.preprocess_paths(config.pop("corpus_files", [])) - - for k, v in config.items(): - setattr(self, k, v) - - -class DatasetConfig: - def __init__(self, config: dict = None): - if not config: - config = {} - self.enabled: bool = config.pop("enabled", True) - self.stage: str = config.pop("stage", None) - self.data_paths = file_util.preprocess_paths(config.pop("data_paths", None), enabled=self.enabled) - self.tfrecords_dir: str = file_util.preprocess_paths(config.pop("tfrecords_dir", None), isdir=True, enabled=self.enabled) - self.tfrecords_shards: int = config.pop("tfrecords_shards", 16) - self.shuffle: bool = config.pop("shuffle", False) - self.cache: bool = config.pop("cache", False) - self.drop_remainder: bool = config.pop("drop_remainder", True) - self.buffer_size: int = config.pop("buffer_size", 1000) - self.use_tf: bool = config.pop("use_tf", False) - self.augmentations = Augmentation(config.pop("augmentation_config", {})) - self.metadata: str = config.pop("metadata", None) - for k, v in config.items(): - setattr(self, k, v) - - -class RunningConfig: - def __init__(self, config: dict = None): - if not config: - config = {} - self.batch_size: int = config.pop("batch_size", 2) - self.ga_steps: int = config.pop("ga_steps", None) - self.num_epochs: int = config.pop("num_epochs", 100) - self.checkpoint: dict = {} - self.backup_and_restore: dict = {} - self.tensorboard: dict = {} - self.early_stopping: dict = {} - for k, v in config.items(): - setattr(self, k, v) - if k == "checkpoint": - if v and v.get("filepath"): - file_util.preprocess_paths(v.get("filepath")) - if v and v.get("options"): - self.checkpoint["options"] = tf.train.CheckpointOptions(**v.get("options")) - elif k == "backup_and_restore" and v: - if v and v.get("backup_dir"): - file_util.preprocess_paths(v.get("backup_dir"), isdir=True) - elif k == "tensorboard": - if v and v.get("log_dir"): - file_util.preprocess_paths(v.get("log_dir"), isdir=True) - - -class LearningConfig: - def __init__(self, config: dict = None): - if not config: - config = {} - self.pretrained = file_util.preprocess_paths(config.pop("pretrained", None)) - self.train_dataset_config = DatasetConfig(config.pop("train_dataset_config", {})) - self.eval_dataset_config = DatasetConfig(config.pop("eval_dataset_config", {})) - self.test_dataset_config = DatasetConfig(config.pop("test_dataset_config", {})) - self.optimizer_config: dict = config.pop("optimizer_config", {}) - self.running_config = RunningConfig(config.pop("running_config", {})) - self.apply_gwn_config = config.pop("apply_gwn_config", None) - for k, v in config.items(): - setattr(self, k, v) - - -class Config: - """User config class for training, testing or infering""" - - def __init__(self, data: Union[str, dict]): - config = data if isinstance(data, dict) else file_util.load_yaml(file_util.preprocess_paths(data)) - self.speech_config = SpeechConfig(config.pop("speech_config", {})) - self.decoder_config = DecoderConfig(config.pop("decoder_config", {})) - self.model_config: dict = config.pop("model_config", {}) - self.learning_config = LearningConfig(config.pop("learning_config", {})) - for k, v in config.items(): - setattr(self, k, v) diff --git a/tensorflow_asr/datasets.py b/tensorflow_asr/datasets.py new file mode 100755 index 0000000000..294af243d2 --- /dev/null +++ b/tensorflow_asr/datasets.py @@ -0,0 +1,502 @@ +# Copyright 2020 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Dataset Structures :kissing: + +# To make a custom dataset, inherit the `BaseDataset` class and override following methods: + +# 1. `create` to create `tf.data.Dataset` instance. +# 2. `parse` for transforming `tf.data.Dataset` during creation by applyting `tf.data.Dataset.map` function. + +# _Note_: To create transcripts for **librispeech**, see [create_librispeech_trans.py](../../scripts/create_librispeech_trans.py) + +# ## ASR Datasets + +# An ASR dataset is some `.tsv` files in format: `PATH\tDURATION\tTRANSCRIPT`. You must create those files by your own with your own data and methods. + +# **Note**: Each `.tsv` file must include a header `PATH\tDURATION\tTRANSCRIPT` +# because it will remove these headers when loading dataset, otherwise you will lose 1 data file :sob: + +# **For transcript**, if you want to include characters such as dots, commas, double quote, etc.. you must create your own `.txt` vocabulary file. +# Default is [English](../featurizers/english.txt) + +# **Inputs** + +# ```python +# class ASRTFRecordDataset(ASRDataset): +# """ Dataset for ASR using TFRecords """ + +# class ASRSliceDataset(ASRDataset): +# """ Dataset for ASR using Slice """ +# ``` + +# **Outputs when iterating dataset** + +# ```python +# ( +# { +# "inputs": ..., +# "inputs_length": ..., +# "predictions": ..., +# "predictions_length": ..., +# }, +# { +# "labels": ..., +# "labels_length": ... +# } +# ) +# ``` + +# Where `predictions` and `predictions_length` are the label prepanded by blank and its length for training *Transducer* + +import json +import logging +import os +from dataclasses import asdict, dataclass + +import numpy as np + +from tensorflow_asr import schemas, tf +from tensorflow_asr.abstracts import AbstractDataset, AbstractTokenizer +from tensorflow_asr.configs import Config, DatasetConfig +from tensorflow_asr.utils import data_util, feature_util, file_util, math_util + +logger = logging.getLogger(__name__) + + +@dataclass +class ASR_DATASER_TYPES: + TFRECORD: str = "tfrecord" + SLICE: str = "slice" + GENERATOR: str = "generator" + HUGGINGFACE: str = "huggingface" + + +def get( + tokenizer: AbstractTokenizer, + dataset_config: DatasetConfig, + dataset_type: str, + dataset_cache: bool = False, +): + dataset_config.cache = dataset_cache + if dataset_type == ASR_DATASER_TYPES.TFRECORD: + return ASRTFRecordDataset(tokenizer=tokenizer, **vars(dataset_config)) + if dataset_type == ASR_DATASER_TYPES.SLICE: + return ASRSliceDataset(tokenizer=tokenizer, **vars(dataset_config)) + if dataset_type == ASR_DATASER_TYPES.GENERATOR: + return ASRDataset(tokenizer=tokenizer, **vars(dataset_config)) + raise ValueError(f"dataset_type must in {asdict(ASR_DATASER_TYPES()).values()}") + + +def get_global_shape( + config: Config, + strategy: tf.distribute.Strategy, + *datasets: "ASRDataset", + batch_size: int = None, +): + batch_size = (batch_size or config.learning_config.running_config.batch_size) * strategy.num_replicas_in_sync + + max_input_length, max_label_length = 0, 0 + for dset in datasets: + max_input_length = max(max_input_length, dset.max_input_length or 0) + max_label_length = max(max_label_length, dset.max_label_length or 0) + max_input_length = None if max_input_length == 0 else max_input_length + max_label_length = None if max_label_length == 0 else max_label_length + + input_shape = [max_input_length] + prediction_shape = [max_label_length + 1] if max_label_length else [None] + label_shape = [max_label_length] + padded_shapes = schemas.TrainData( + inputs=schemas.TrainInput( + inputs=tf.TensorShape(input_shape), + inputs_length=tf.TensorShape([]), + predictions=tf.TensorShape(prediction_shape), + predictions_length=tf.TensorShape([]), + ), + labels=schemas.TrainLabel( + labels=tf.TensorShape(label_shape), + labels_length=tf.TensorShape([]), + ), + ) + + model_shapes = dict( + batch_size=batch_size, + input_shape=input_shape, + prediction_shape=prediction_shape, + ) + return model_shapes, batch_size, padded_shapes + + +BUFFER_SIZE = 100 +TFRECORD_BUFFER_SIZE = 32 * 1024 * 1024 +TFRECORD_SHARDS = 16 +AUTOTUNE = int(os.environ.get("AUTOTUNE") or tf.data.AUTOTUNE) + + +class ASRDataset(AbstractDataset): + def __init__( + self, + stage: str, + tokenizer: AbstractTokenizer, + data_paths: list, + tfrecords_dir: str = None, + tfrecords_shards: int = TFRECORD_SHARDS, + tfrecords_buffer_size: int = TFRECORD_BUFFER_SIZE, + tfrecords_compression_type: str = "GZIP", + item_mapping: dict = None, + cache: bool = False, + shuffle: bool = False, + indefinite: bool = True, + drop_remainder: bool = True, + enabled: bool = True, + metadata: str = None, + buffer_size: int = BUFFER_SIZE, + sample_rate: int = 16000, + name: str = "", + **kwargs, + ): + self.tokenizer = tokenizer + self.data_paths = data_paths or [] + if not isinstance(self.data_paths, list): + raise ValueError("data_paths must be a list of string paths") + self.cache = cache # whether to cache transformed dataset to memory + self.shuffle = shuffle # whether to shuffle tf.data.Dataset + self.buffer_size = buffer_size # shuffle buffer size + self.stage = stage # for defining tfrecords files + self.enabled = enabled + self.drop_remainder = drop_remainder # whether to drop remainder for multi gpu training + self.indefinite = indefinite # Whether to make dataset repeat indefinitely -> avoid the potential last partial batch + self.total_steps = None # for better training visualization + self.metadata = metadata + self.sample_rate = sample_rate + self.use_ga = False + self.name = name or stage + self.tfrecords_dir = tfrecords_dir + if tfrecords_shards <= 0: + raise ValueError("tfrecords_shards must be positive") + self.tfrecords_shards = tfrecords_shards + self.tfrecords_buffer_size = tfrecords_buffer_size + self.tfrecords_compression_type = tfrecords_compression_type + self.item_mapping = item_mapping or {} + + for key, value in kwargs.items(): + setattr(self, key, value) + + self.entries = [] + self.max_input_length = None + self.max_label_length = None + self.load_metadata() + + # -------------------------------- metadata ------------------------------------- + + def compute_metadata(self): + if not self.tokenizer.initialized: + raise ValueError("Tokenizer must be initialized before computing metadata") + + from tqdm import tqdm # pylint: disable=import-outside-toplevel + + self.max_input_length = 0 if self.max_input_length is None else self.max_input_length + self.max_label_length = 0 if self.max_label_length is None else self.max_label_length + self.read_entries() + for _, duration, transcript in tqdm(self.entries, desc=f"Computing metadata for entries in {self.stage} dataset", disable=False): + input_length = math_util.get_nsamples(duration, self.sample_rate) + label = self.tokenizer.tokenize(transcript).numpy() + label_length = len(label) + self.max_input_length = max(self.max_input_length, input_length) + self.max_label_length = max(self.max_label_length, label_length) + + def save_metadata(self): + if self.metadata is None: + return + self.metadata = file_util.preprocess_paths(self.metadata) + if tf.io.gfile.exists(self.metadata): + with tf.io.gfile.GFile(self.metadata, "r") as f: + try: + content = json.loads(f.read()) + except json.JSONDecodeError as e: + raise ValueError(f"File {self.metadata} is currently not in json format. Please update the file") from e + else: + content = {} + content[self.stage] = dict( + max_input_length=self.max_input_length, + max_label_length=self.max_label_length, + num_entries=self.total_steps, + ) + with tf.io.gfile.GFile(self.metadata, "w") as f: + f.write(json.dumps(content, indent=2)) + logger.info(f"Metadata written to {self.metadata}") + + def load_metadata(self): + if self.metadata is None: + return + if not self.enabled: + return + content = None + self.metadata = file_util.preprocess_paths(self.metadata) + if tf.io.gfile.exists(self.metadata): + logger.info(f"Loading metadata from {self.metadata} ...") + with tf.io.gfile.GFile(self.metadata, "r") as f: + try: + content = json.loads(f.read()).get(self.stage, {}) + except json.JSONDecodeError as e: + raise ValueError(f"File {self.metadata} must be in json format") from e + if not content: + return + self.max_input_length = content.get("max_input_length") + self.max_label_length = content.get("max_label_length") + self.total_steps = int(content.get("num_entries", 0)) + self.num_entries = self.total_steps + + def update_metadata(self): + self.load_metadata() + self.compute_metadata() + self.save_metadata() + + # -------------------------------- ENTRIES ------------------------------------- + + def read_entries(self): + if hasattr(self, "entries") and len(self.entries) > 0: + return + self.data_paths = file_util.preprocess_paths(self.data_paths, enabled=self.enabled, check_exists=True) + for file_path in self.data_paths: + logger.info(f"Reading {file_path} ...") + with tf.io.gfile.GFile(file_path, "r") as f: + for line in f.read().splitlines()[1:]: # Skip the header of tsv file + self.entries.append(line.split("\t", 2)) # The files is "\t" seperated + self.entries = np.array(self.entries) + if self.shuffle: + np.random.shuffle(self.entries) # Mix transcripts.tsv + self.total_steps = len(self.entries) + self.num_entries = self.total_steps + + def vocab_generator(self): + for *_, transcript in self.entries: + yield transcript + + # -------------------------------- LOAD AND PREPROCESS ------------------------------------- + + def generator(self): + for path, _, transcript in self.entries: + audio = data_util.load_and_convert_to_wav(path, sample_rate=self.sample_rate).numpy() + yield bytes(path, "utf-8"), audio, bytes(transcript, "utf-8") + + def _process_item(self, path: tf.Tensor, audio: tf.Tensor, transcript: tf.Tensor): + with tf.device("/CPU:0"): + inputs = data_util.read_raw_audio(audio) + inputs_length = tf.shape(inputs, out_type=tf.int32)[0] + + labels = self.tokenizer.tokenize(transcript) + labels_length = tf.shape(labels, out_type=tf.int32)[0] + + predictions = self.tokenizer.prepand_blank(labels) + predictions_length = tf.shape(predictions, out_type=tf.int32)[0] + + return path, inputs, inputs_length, labels, labels_length, predictions, predictions_length + + def parse(self, path: tf.Tensor, audio: tf.Tensor, transcript: tf.Tensor) -> schemas.TrainData: + ( + _, + inputs, + inputs_length, + labels, + labels_length, + predictions, + predictions_length, + ) = self._process_item(path=path, audio=audio, transcript=transcript) + return schemas.TrainData( + inputs=schemas.TrainInput(inputs=inputs, inputs_length=inputs_length, predictions=predictions, predictions_length=predictions_length), + labels=schemas.TrainLabel(labels=labels, labels_length=labels_length), + ) + + # -------------------------------- CREATION ------------------------------------- + + def process( + self, + dataset: tf.data.Dataset, + batch_size: int, + ga_steps: int = 1, + padded_shapes=None, + ): + dataset = dataset.map(self.parse, num_parallel_calls=AUTOTUNE, deterministic=False) + + if self.cache: + dataset = dataset.cache() # cache original (unchanged data) + + if self.shuffle: + dataset = dataset.shuffle(max(self.buffer_size or self.num_entries, batch_size * 2), reshuffle_each_iteration=True) + + if self.indefinite and hasattr(self, "total_steps") and self.total_steps: + dataset = dataset.repeat() + + if padded_shapes is None: + padded_shapes = schemas.TrainData( + inputs=schemas.TrainInput( + inputs=tf.TensorShape([self.max_input_length]), + inputs_length=tf.TensorShape([]), + predictions=tf.TensorShape([self.max_label_length + 1 if self.max_label_length else None]), + predictions_length=tf.TensorShape([]), + ), + labels=schemas.TrainLabel( + labels=tf.TensorShape([self.max_label_length]), + labels_length=tf.TensorShape([]), + ), + ) + + # PADDED BATCH the dataset + dataset = dataset.padded_batch( + batch_size=batch_size, + padded_shapes=padded_shapes, + padding_values=schemas.TrainData( + inputs=schemas.TrainInput(inputs=0.0, inputs_length=0, predictions=self.tokenizer.blank, predictions_length=0), + labels=schemas.TrainLabel(labels=self.tokenizer.blank, labels_length=0), + ), + drop_remainder=self.drop_remainder, + ) + + # only apply for training dataset, eval and test dataset should not use GA + if ga_steps > 1 and self.stage == "train": + self.use_ga = True + + # PREFETCH to improve speed of input length + dataset = dataset.prefetch(AUTOTUNE) + + # Update metadata + if hasattr(self, "num_entries") and self.num_entries > 0: + self.total_steps = math_util.get_num_batches(self.num_entries, batch_size, drop_remainders=self.drop_remainder) + if self.use_ga: + self.total_steps = math_util.get_num_batches(self.total_steps, ga_steps, drop_remainders=False) + + return dataset + + def create(self, batch_size: int, ga_steps: int = 1, padded_shapes=None): + if not self.enabled: + return None + if not self.tokenizer.initialized: + return None + self.read_entries() + if not self.total_steps or self.total_steps == 0: + return None + dataset = tf.data.Dataset.from_generator( + self.generator, + output_types=(tf.string, tf.string, tf.string), + output_shapes=(tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([])), + ) + return self.process(dataset, batch_size, ga_steps=ga_steps, padded_shapes=padded_shapes) + + +class ASRTFRecordDataset(ASRDataset): + """Dataset for ASR using TFRecords""" + + def write_tfrecord_file(self, splitted_entries: tuple): + shard_path, entries = splitted_entries + logger.info(f"Processing {shard_path} ...") + with tf.io.TFRecordWriter(shard_path, options=tf.io.TFRecordOptions(compression_type=self.tfrecords_compression_type)) as writer: + for path, _, transcript in entries: + audio = data_util.load_and_convert_to_wav(path, sample_rate=self.sample_rate).numpy() + feature = dict( + path=feature_util.bytestring_feature([path.encode("utf-8")]), + audio=feature_util.bytestring_feature([audio]), + transcript=feature_util.bytestring_feature([transcript.encode("utf-8")]), + ) + example = tf.train.Example(features=tf.train.Features(feature=feature)) + writer.write(example.SerializeToString()) + logger.info(f"Created {shard_path}") + + def create_tfrecords(self): + if not self.tfrecords_dir: + return False + self.tfrecords_dir = file_util.preprocess_paths(self.tfrecords_dir, isdir=True, enabled=self.enabled) + + if tf.io.gfile.glob(os.path.join(self.tfrecords_dir, f"{self.stage}*.tfrecord")): + logger.info(f"TFRecords're already existed: {self.stage}") + return True + + logger.info(f"Creating {self.stage}.tfrecord ...") + + self.read_entries() + if not self.total_steps or self.total_steps == 0: + return False + + def get_shard_path(shard_id: int): + return os.path.join(self.tfrecords_dir, f"{self.stage}_{shard_id}.tfrecord") + + shards = [get_shard_path(idx) for idx in range(1, self.tfrecords_shards + 1)] + + splitted_entries = np.array_split(self.entries, self.tfrecords_shards) + for entries in zip(shards, splitted_entries): + self.write_tfrecord_file(entries) + + return True + + def parse(self, record: tf.Tensor, **kwargs): + feature_description = dict( + path=tf.io.FixedLenFeature([], tf.string), + audio=tf.io.FixedLenFeature([], tf.string), + transcript=tf.io.FixedLenFeature([], tf.string), + ) + example = tf.io.parse_single_example(record, feature_description) + return super().parse(**example) + + def create(self, batch_size: int, ga_steps: int = 1, padded_shapes=None): + if not self.enabled: + return None + if not self.tokenizer.initialized: + return None + have_data = self.create_tfrecords() + if not have_data: + return None + + pattern = os.path.join(self.tfrecords_dir, f"{self.stage}*.tfrecord") + files_ds = tf.data.Dataset.list_files(pattern, shuffle=self.shuffle) + ignore_order = tf.data.Options() + ignore_order.deterministic = False + files_ds = files_ds.with_options(ignore_order) + dataset = tf.data.TFRecordDataset( + files_ds, + compression_type=self.tfrecords_compression_type, + buffer_size=self.tfrecords_buffer_size, + num_parallel_reads=AUTOTUNE, + ) + + return self.process(dataset, batch_size, ga_steps=ga_steps, padded_shapes=padded_shapes) + + +class ASRSliceDataset(ASRDataset): + """Dataset for ASR using Slice""" + + def load(self, record): + audio = tf.numpy_function( + lambda path: data_util.load_and_convert_to_wav(path.decode("utf-8"), sample_rate=self.sample_rate).numpy(), + inp=[record[0]], + Tout=tf.string, + ) + return record[0], audio, record[2] + + def create(self, batch_size: int, ga_steps: int = 1, padded_shapes=None): + if not self.enabled: + return None + if not self.tokenizer.initialized: + return None + self.read_entries() + if not self.total_steps or self.total_steps == 0: + return None + + dataset = tf.data.Dataset.from_tensor_slices(self.entries) + options = tf.data.Options() + options.deterministic = False + options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA + dataset = dataset.with_options(options) + dataset = dataset.map(self.load, num_parallel_calls=AUTOTUNE, deterministic=False) + + return self.process(dataset, batch_size, ga_steps=ga_steps, padded_shapes=padded_shapes) diff --git a/tensorflow_asr/datasets/README.md b/tensorflow_asr/datasets/README.md deleted file mode 100644 index 2115b2e36c..0000000000 --- a/tensorflow_asr/datasets/README.md +++ /dev/null @@ -1,45 +0,0 @@ -# Dataset Structures :kissing: - -To make a custom dataset, inherit the `BaseDataset` class and override following methods: - -1. `create` to create `tf.data.Dataset` instance. -2. `parse` for transforming `tf.data.Dataset` during creation by applyting `tf.data.Dataset.map` function. - -_Note_: To create transcripts for **librispeech**, see [create_librispeech_trans.py](../../scripts/create_librispeech_trans.py) - -## ASR Datasets - -An ASR dataset is some `.tsv` files in format: `PATH\tDURATION\tTRANSCRIPT`. You must create those files by your own with your own data and methods. - -**Note**: Each `.tsv` file must include a header `PATH\tDURATION\tTRANSCRIPT` because it will remove these headers when loading dataset, otherwise you will lose 1 data file :sob: - -**For transcript**, if you want to include characters such as dots, commas, double quote, etc.. you must create your own `.txt` vocabulary file. Default is [English](../featurizers/english.txt) - -**Inputs** - -```python -class ASRTFRecordDataset(ASRDataset): - """ Dataset for ASR using TFRecords """ - -class ASRSliceDataset(ASRDataset): - """ Dataset for ASR using Slice """ -``` - -**Outputs when iterating dataset** - -```python -( - { - "inputs": ..., - "inputs_length": ..., - "predictions": ..., - "predictions_length": ..., - }, - { - "labels": ..., - "labels_length": ... - } -) -``` - -Where `predictions` and `predictions_length` are the label prepanded by blank and its length for training *Transducer* \ No newline at end of file diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py deleted file mode 100755 index 5287773cda..0000000000 --- a/tensorflow_asr/datasets/asr_dataset.py +++ /dev/null @@ -1,441 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -from typing import Union - -import numpy as np -import tensorflow as tf -import tqdm - -from tensorflow_asr.augmentations.augmentation import Augmentation -from tensorflow_asr.datasets.base_dataset import AUTOTUNE, BUFFER_SIZE, TFRECORD_SHARDS, BaseDataset -from tensorflow_asr.featurizers.speech_featurizers import ( - SpeechFeaturizer, - load_and_convert_to_wav, - read_raw_audio, - tf_read_raw_audio, -) -from tensorflow_asr.featurizers.text_featurizers import TextFeaturizer -from tensorflow_asr.utils import data_util, feature_util, file_util, math_util - -logger = tf.get_logger() - - -class ASRDataset(BaseDataset): - """Dataset for ASR using Generator""" - - def __init__( - self, - stage: str, - speech_featurizer: SpeechFeaturizer, - text_featurizer: TextFeaturizer, - data_paths: list, - augmentations: Augmentation = Augmentation(None), - cache: bool = False, - shuffle: bool = False, - indefinite: bool = False, - drop_remainder: bool = True, - use_tf: bool = False, - enabled: bool = True, - metadata: str = None, - buffer_size: int = BUFFER_SIZE, - **kwargs, - ): - super().__init__( - data_paths=data_paths, - augmentations=augmentations, - cache=cache, - shuffle=shuffle, - stage=stage, - buffer_size=buffer_size, - drop_remainder=drop_remainder, - use_tf=use_tf, - enabled=enabled, - metadata=metadata, - indefinite=indefinite, - ) - self.entries = [] - self.speech_featurizer = speech_featurizer - self.text_featurizer = text_featurizer - if self.metadata: - self.load_metadata(metadata=metadata) - - # -------------------------------- metadata ------------------------------------- - - def compute_metadata(self): - self.read_entries() - for _, duration, transcript in tqdm.tqdm(self.entries, desc=f"Computing metadata for entries in {self.stage} dataset"): - input_length = self.speech_featurizer.get_length_from_duration(duration) - label = self.text_featurizer.extract(transcript).numpy() - label_length = len(label) - self.speech_featurizer.update_length(input_length) - self.text_featurizer.update_length(label_length) - - def save_metadata( - self, - metadata: str = None, - ): - if metadata is None: - return - metadata = file_util.preprocess_paths(metadata) - if tf.io.gfile.exists(metadata): - with tf.io.gfile.GFile(metadata, "r") as f: - try: - content = json.loads(f.read()) - except json.JSONDecodeError as e: - raise ValueError(f"File {metadata} is currently not in json format. Please update the file") from e - else: - content = {} - content[self.stage] = { - "max_input_length": self.speech_featurizer.max_length, - "max_label_length": self.text_featurizer.max_length, - "num_entries": self.total_steps, - } - with tf.io.gfile.GFile(metadata, "w") as f: - f.write(json.dumps(content, indent=2)) - logger.info(f"Metadata written to {metadata}") - - def load_metadata( - self, - metadata: Union[str, dict] = None, - ): - if metadata is None: - return - if not self.enabled: - return - content = None - if isinstance(metadata, dict): - content = metadata - else: - metadata = file_util.preprocess_paths(metadata) - if tf.io.gfile.exists(metadata): - logger.info(f"Loading metadata from {metadata} ...") - with tf.io.gfile.GFile(metadata, "r") as f: - try: - content = json.loads(f.read()).get(self.stage, {}) - except json.JSONDecodeError as e: - raise ValueError(f"File {metadata} must be in json format") from e - if not content: - return - self.speech_featurizer.update_length(int(content.get("max_input_length", 0))) - self.text_featurizer.update_length(int(content.get("max_label_length", 0))) - self.total_steps = int(content.get("num_entries", 0)) - - def update_metadata( - self, - metadata: str = None, - ): - self.load_metadata(metadata) - self.compute_metadata() - self.save_metadata(metadata) - - # -------------------------------- ENTRIES ------------------------------------- - - def read_entries(self): - if hasattr(self, "entries") and len(self.entries) > 0: - return - for file_path in self.data_paths: - logger.info(f"Reading {file_path} ...") - with tf.io.gfile.GFile(file_path, "r") as f: - for line in f.read().splitlines()[1:]: # Skip the header of tsv file - self.entries.append(line.split("\t", 2)) # The files is "\t" seperated - self.entries = np.array(self.entries) - if self.shuffle: - np.random.shuffle(self.entries) # Mix transcripts.tsv - self.total_steps = len(self.entries) - - # -------------------------------- LOAD AND PREPROCESS ------------------------------------- - - def generator(self): - for path, _, transcript in self.entries: - audio = load_and_convert_to_wav(path).numpy() - yield bytes(path, "utf-8"), audio, bytes(transcript, "utf-8") - - def preprocess( - self, - path: tf.Tensor, - audio: tf.Tensor, - transcript: tf.Tensor, - ): - with tf.device("/CPU:0"): - - def fn(_path: bytes, _audio: bytes, _transcript: bytes): - signal = read_raw_audio(_audio, sample_rate=self.speech_featurizer.speech_config.sample_rate) - signal = self.augmentations.signal_augment(signal) - features = self.speech_featurizer.extract(signal.numpy()) - features = self.augmentations.feature_augment(features) - features = tf.convert_to_tensor(features, tf.float32) - input_length = tf.shape(features, out_type=tf.int32)[0] - - label = self.text_featurizer.extract(_transcript) - label_length = tf.shape(label, out_type=tf.int32)[0] - - prediction = self.text_featurizer.prepand_blank(label) - prediction_length = tf.shape(prediction, out_type=tf.int32)[0] - - return _path, features, input_length, label, label_length, prediction, prediction_length - - return tf.numpy_function( - fn, - inp=[path, audio, transcript], - Tout=[tf.string, tf.float32, tf.int32, tf.int32, tf.int32, tf.int32, tf.int32], - ) - - def tf_preprocess( - self, - path: tf.Tensor, - audio: tf.Tensor, - transcript: tf.Tensor, - ): - with tf.device("/CPU:0"): - signal = tf_read_raw_audio(audio, self.speech_featurizer.speech_config.sample_rate) - signal = self.augmentations.signal_augment(signal) - features = self.speech_featurizer.tf_extract(signal) - features = self.augmentations.feature_augment(features) - input_length = tf.shape(features, out_type=tf.int32)[0] - - label = self.text_featurizer.tf_extract(transcript) - label_length = tf.shape(label, out_type=tf.int32)[0] - - prediction = self.text_featurizer.prepand_blank(label) - prediction_length = tf.shape(prediction, out_type=tf.int32)[0] - - return path, features, input_length, label, label_length, prediction, prediction_length - - def parse( - self, - path: tf.Tensor, - audio: tf.Tensor, - transcript: tf.Tensor, - ): - """ - Returns: - path, features, input_lengths, labels, label_lengths, pred_inp - """ - data = self.tf_preprocess(path, audio, transcript) if self.use_tf else self.preprocess(path, audio, transcript) - _, features, input_length, label, label_length, prediction, prediction_length = data - return ( - data_util.create_inputs(inputs=features, inputs_length=input_length, predictions=prediction, predictions_length=prediction_length), - data_util.create_labels(labels=label, labels_length=label_length), - ) - - # -------------------------------- CREATION ------------------------------------- - - def process( - self, - dataset: tf.data.Dataset, - batch_size: int, - ): - if self.cache: - dataset = dataset.cache() # cache original (unchanged data) - - dataset = dataset.map(self.parse, num_parallel_calls=AUTOTUNE, deterministic=False) - self.total_steps = math_util.get_num_batches(self.total_steps, batch_size, drop_remainders=self.drop_remainder) - - if self.shuffle: - dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=True) - - if self.indefinite and self.total_steps: - dataset = dataset.repeat() - - # PADDED BATCH the dataset - dataset = dataset.padded_batch( - batch_size=batch_size, - padded_shapes=( - data_util.create_inputs( - inputs=tf.TensorShape(self.speech_featurizer.shape), - inputs_length=tf.TensorShape([]), - predictions=tf.TensorShape(self.text_featurizer.prepand_shape), - predictions_length=tf.TensorShape([]), - ), - data_util.create_labels(labels=tf.TensorShape(self.text_featurizer.shape), labels_length=tf.TensorShape([])), - ), - padding_values=( - data_util.create_inputs(inputs=0.0, inputs_length=0, predictions=self.text_featurizer.blank, predictions_length=0), - data_util.create_labels(labels=self.text_featurizer.blank, labels_length=0), - ), - drop_remainder=self.drop_remainder, - ) - - # PREFETCH to improve speed of input length - dataset = dataset.prefetch(AUTOTUNE) - return dataset - - def create( - self, - batch_size: int, - ): - if not self.enabled: - return None - self.read_entries() - if not self.total_steps or self.total_steps == 0: - return None - dataset = tf.data.Dataset.from_generator( - self.generator, - output_types=(tf.string, tf.string, tf.string), - output_shapes=(tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([])), - ) - return self.process(dataset, batch_size) - - -class ASRTFRecordDataset(ASRDataset): - """Dataset for ASR using TFRecords""" - - def __init__( - self, - data_paths: list, - tfrecords_dir: str, - speech_featurizer: SpeechFeaturizer, - text_featurizer: TextFeaturizer, - stage: str, - augmentations: Augmentation = Augmentation(None), - tfrecords_shards: int = TFRECORD_SHARDS, - cache: bool = False, - shuffle: bool = False, - use_tf: bool = False, - enabled: bool = True, - metadata: str = None, - indefinite: bool = False, - drop_remainder: bool = True, - buffer_size: int = BUFFER_SIZE, - compression_type: str = "GZIP", - **kwargs, - ): - super().__init__( - stage=stage, - speech_featurizer=speech_featurizer, - text_featurizer=text_featurizer, - data_paths=data_paths, - augmentations=augmentations, - cache=cache, - shuffle=shuffle, - buffer_size=buffer_size, - drop_remainder=drop_remainder, - use_tf=use_tf, - enabled=enabled, - metadata=metadata, - indefinite=indefinite, - ) - if not self.stage: - raise ValueError("stage must be defined, either 'train', 'eval' or 'test'") - self.tfrecords_dir = tfrecords_dir - if tfrecords_shards <= 0: - raise ValueError("tfrecords_shards must be positive") - self.tfrecords_shards = tfrecords_shards - self.compression_type = compression_type - - def write_tfrecord_file( - self, - splitted_entries: tuple, - ): - shard_path, entries = splitted_entries - logger.info(f"Processing {shard_path} ...") - with tf.io.TFRecordWriter(shard_path, options=tf.io.TFRecordOptions(compression_type=self.compression_type)) as writer: - for path, _, transcript in entries: - audio = load_and_convert_to_wav(path).numpy() - feature = { - "path": feature_util.bytestring_feature([path.encode("utf-8")]), - "audio": feature_util.bytestring_feature([audio]), - "transcript": feature_util.bytestring_feature([transcript.encode("utf-8")]), - } - example = tf.train.Example(features=tf.train.Features(feature=feature)) - writer.write(example.SerializeToString()) - logger.info(f"Created {shard_path}") - - def create_tfrecords(self): - if not self.tfrecords_dir: - return False - - if tf.io.gfile.glob(os.path.join(self.tfrecords_dir, f"{self.stage}*.tfrecord")): - logger.info(f"TFRecords're already existed: {self.stage}") - return True - - logger.info(f"Creating {self.stage}.tfrecord ...") - - self.read_entries() - if not self.total_steps or self.total_steps == 0: - return False - - def get_shard_path(shard_id: int): - return os.path.join(self.tfrecords_dir, f"{self.stage}_{shard_id}.tfrecord") - - shards = [get_shard_path(idx) for idx in range(1, self.tfrecords_shards + 1)] - - splitted_entries = np.array_split(self.entries, self.tfrecords_shards) - for entries in zip(shards, splitted_entries): - self.write_tfrecord_file(entries) - - return True - - def parse( - self, - record: tf.Tensor, - **kwargs, - ): - feature_description = { - "path": tf.io.FixedLenFeature([], tf.string), - "audio": tf.io.FixedLenFeature([], tf.string), - "transcript": tf.io.FixedLenFeature([], tf.string), - } - example = tf.io.parse_single_example(record, feature_description) - return super().parse(**example) - - def create( - self, - batch_size: int, - ): - if not self.enabled: - return None - have_data = self.create_tfrecords() - if not have_data: - return None - - pattern = os.path.join(self.tfrecords_dir, f"{self.stage}*.tfrecord") - files_ds = tf.data.Dataset.list_files(pattern, shuffle=self.shuffle) - ignore_order = tf.data.Options() - ignore_order.deterministic = False - files_ds = files_ds.with_options(ignore_order) - dataset = tf.data.TFRecordDataset(files_ds, compression_type=self.compression_type, num_parallel_reads=AUTOTUNE) - - return self.process(dataset, batch_size) - - -class ASRSliceDataset(ASRDataset): - """Dataset for ASR using Slice""" - - @staticmethod - def load(record): - audio = tf.numpy_function(lambda path: load_and_convert_to_wav(path.decode("utf-8")).numpy(), inp=[record[0]], Tout=tf.string) - return record[0], audio, record[2] - - def create( - self, - batch_size: int, - ): - if not self.enabled: - return None - self.read_entries() - if not self.total_steps or self.total_steps == 0: - return None - - dataset = tf.data.Dataset.from_tensor_slices(self.entries) - options = tf.data.Options() - options.deterministic = False - options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA - dataset = dataset.with_options(options) - dataset = dataset.map(self.load, num_parallel_calls=AUTOTUNE, deterministic=False) - - return self.process(dataset, batch_size) diff --git a/tensorflow_asr/datasets/base_dataset.py b/tensorflow_asr/datasets/base_dataset.py deleted file mode 100644 index f3bd489cd1..0000000000 --- a/tensorflow_asr/datasets/base_dataset.py +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import tensorflow as tf - -from tensorflow_asr.augmentations.augmentation import Augmentation - -BUFFER_SIZE = 100 -TFRECORD_SHARDS = 16 -AUTOTUNE = tf.data.experimental.AUTOTUNE - - -class BaseDataset: - """Based dataset for all models""" - - def __init__( - self, - data_paths: list, - augmentations: Augmentation = Augmentation(None), - cache: bool = False, - shuffle: bool = False, - buffer_size: int = BUFFER_SIZE, - indefinite: bool = False, - drop_remainder: bool = True, - use_tf: bool = False, - enabled: bool = True, - metadata: str = None, - stage: str = "train", - **kwargs - ): - self.data_paths = data_paths or [] - if not isinstance(self.data_paths, list): - raise ValueError("data_paths must be a list of string paths") - self.augmentations = augmentations # apply augmentation - self.cache = cache # whether to cache transformed dataset to memory - self.shuffle = shuffle # whether to shuffle tf.data.Dataset - if buffer_size <= 0 and shuffle: - raise ValueError("buffer_size must be positive when shuffle is on") - self.buffer_size = buffer_size # shuffle buffer size - self.stage = stage # for defining tfrecords files - self.use_tf = use_tf - self.enabled = enabled - self.drop_remainder = drop_remainder # whether to drop remainder for multi gpu training - self.indefinite = indefinite # Whether to make dataset repeat indefinitely -> avoid the potential last partial batch - self.total_steps = None # for better training visualization - self.metadata = metadata - - def parse(self, *args, **kwargs): - raise NotImplementedError() - - def create(self, batch_size): - raise NotImplementedError() diff --git a/tensorflow_asr/featurizers/README.md b/tensorflow_asr/features/README.md similarity index 100% rename from tensorflow_asr/featurizers/README.md rename to tensorflow_asr/features/README.md diff --git a/tensorflow_asr/featurizers/__init__.py b/tensorflow_asr/features/__init__.py similarity index 100% rename from tensorflow_asr/featurizers/__init__.py rename to tensorflow_asr/features/__init__.py diff --git a/tensorflow_asr/featurizers/methods/gammatone.py b/tensorflow_asr/features/gammatone.py similarity index 97% rename from tensorflow_asr/featurizers/methods/gammatone.py rename to tensorflow_asr/features/gammatone.py index f4bae289d3..57ec79729d 100644 --- a/tensorflow_asr/featurizers/methods/gammatone.py +++ b/tensorflow_asr/features/gammatone.py @@ -14,8 +14,8 @@ """ This code is inspired from https://github.com/detly/gammatone """ import numpy as np -import tensorflow as tf +from tensorflow_asr import tf from tensorflow_asr.utils.shape_util import shape_list pi = tf.constant(np.pi, dtype=tf.complex64) @@ -120,11 +120,11 @@ def erb_point( # All of the following expressions are derived in Apple TR #35, "An # Efficient Implementation of the Patterson-Holdsworth Cochlear Filter # Bank." See pages 33-34. - erb_point = -ear_q * min_bw + tf.exp(fraction * (-tf.math.log(high_freq + ear_q * min_bw) + tf.math.log(low_freq + ear_q * min_bw))) * ( - high_freq + ear_q * min_bw + erbp = (-ear_q * min_bw) + ( + tf.exp(fraction * ((-1 * tf.math.log(high_freq + ear_q * min_bw)) + tf.math.log(low_freq + ear_q * min_bw))) * (high_freq + ear_q * min_bw) ) - return tf.cast(erb_point, tf.complex64) + return tf.cast(erbp, tf.complex64) def erb_space( diff --git a/tensorflow_asr/featurizers/figs/log_gammatone_spectrogram.png b/tensorflow_asr/featurizers/figs/log_gammatone_spectrogram.png deleted file mode 100644 index 4639b52a05..0000000000 Binary files a/tensorflow_asr/featurizers/figs/log_gammatone_spectrogram.png and /dev/null differ diff --git a/tensorflow_asr/featurizers/figs/log_mel_spectrogram.png b/tensorflow_asr/featurizers/figs/log_mel_spectrogram.png deleted file mode 100644 index c94cc6d946..0000000000 Binary files a/tensorflow_asr/featurizers/figs/log_mel_spectrogram.png and /dev/null differ diff --git a/tensorflow_asr/featurizers/figs/mfcc.png b/tensorflow_asr/featurizers/figs/mfcc.png deleted file mode 100644 index 65872169eb..0000000000 Binary files a/tensorflow_asr/featurizers/figs/mfcc.png and /dev/null differ diff --git a/tensorflow_asr/featurizers/figs/spectrogram.png b/tensorflow_asr/featurizers/figs/spectrogram.png deleted file mode 100644 index f7ff8231f8..0000000000 Binary files a/tensorflow_asr/featurizers/figs/spectrogram.png and /dev/null differ diff --git a/tensorflow_asr/featurizers/speech_featurizers.py b/tensorflow_asr/featurizers/speech_featurizers.py deleted file mode 100755 index 71c64dafeb..0000000000 --- a/tensorflow_asr/featurizers/speech_featurizers.py +++ /dev/null @@ -1,309 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) and Huy Phan (@pquochuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import io -import math -import os -from typing import Union - -import librosa -import numpy as np -import soundfile as sf -import tensorflow as tf -import tensorflow_io as tfio - -from tensorflow_asr.configs.config import SpeechConfig -from tensorflow_asr.featurizers.methods import gammatone -from tensorflow_asr.utils import env_util, math_util - - -def load_and_convert_to_wav( - path: str, -) -> tf.Tensor: - wave, rate = librosa.load(os.path.expanduser(path), sr=None, mono=True) - return tf.audio.encode_wav(tf.expand_dims(wave, axis=-1), sample_rate=rate) - - -def read_raw_audio( - audio: Union[str, bytes, np.ndarray], - sample_rate=16000, -) -> np.ndarray: - if isinstance(audio, str): - wave, _ = librosa.load(os.path.expanduser(audio), sr=sample_rate, mono=True) - elif isinstance(audio, bytes): - wave, sr = sf.read(io.BytesIO(audio)) - if wave.ndim > 1: - wave = np.mean(wave, axis=-1) - wave = np.asfortranarray(wave) - if sr != sample_rate: - wave = librosa.resample(wave, orig_sr=sr, target_sr=sample_rate) - elif isinstance(audio, np.ndarray): - if audio.ndim > 1: - ValueError("input audio must be single channel") - return audio - else: - raise ValueError("input audio must be either a path or bytes") - return wave - - -def tf_read_raw_audio( - audio: tf.Tensor, - sample_rate=16000, -) -> tf.Tensor: - wave, rate = tf.audio.decode_wav(audio, desired_channels=1, desired_samples=-1) - if not env_util.has_devices("TPU"): - resampled = tfio.audio.resample(wave, rate_in=tf.cast(rate, dtype=tf.int64), rate_out=sample_rate) - return tf.reshape(resampled, shape=[-1]) # reshape for using tf.signal - return tf.reshape(wave, shape=[-1]) # reshape for using tf.signal - - -def slice_signal( - signal, - window_size, - stride=0.5, -) -> np.ndarray: - """Return windows of the given signal by sweeping in stride fractions of window""" - assert signal.ndim == 1, signal.ndim - n_samples = signal.shape[0] - offset = int(window_size * stride) - slices = [] - for beg_i, end_i in zip(range(0, n_samples, offset), range(window_size, n_samples + offset, offset)): - slice_ = signal[beg_i:end_i] - if slice_.shape[0] < window_size: - slice_ = np.pad(slice_, (0, window_size - slice_.shape[0]), "constant", constant_values=0.0) - if slice_.shape[0] == window_size: - slices.append(slice_) - return np.array(slices, dtype=np.float32) - - -def tf_merge_slices( - slices: tf.Tensor, -) -> tf.Tensor: - # slices shape = [batch, window_size] - return tf.keras.backend.flatten(slices) # return shape = [-1, ] - - -def tf_normalize_audio_features( - audio_feature: tf.Tensor, - per_frame=False, -) -> tf.Tensor: - """ - TF z-score features normalization - Args: - audio_feature: tf.Tensor with shape [T, F] - per_frame: - - Returns: - normalized audio features with shape [T, F] - """ - axis = -1 if per_frame else 0 - mean = tf.reduce_mean(audio_feature, axis=axis, keepdims=True) - stddev = tf.sqrt(tf.math.reduce_variance(audio_feature, axis=axis, keepdims=True) + 1e-9) - return tf.divide(tf.subtract(audio_feature, mean), stddev) - - -def tf_normalize_signal( - signal: tf.Tensor, -) -> tf.Tensor: - """ - TF Normailize signal to [-1, 1] range - Args: - signal: tf.Tensor with shape [None] - - Returns: - normalized signal with shape [None] - """ - gain = 1.0 / (tf.reduce_max(tf.abs(signal), axis=-1) + 1e-9) - return signal * gain - - -def tf_preemphasis( - signal: tf.Tensor, - coeff=0.97, -): - """ - TF Pre-emphasis - Args: - signal: tf.Tensor with shape [None] - coeff: Float that indicates the preemphasis coefficient - - Returns: - pre-emphasized signal with shape [None] - """ - if not coeff or coeff <= 0.0: - return signal - s0 = tf.expand_dims(signal[0], axis=-1) - s1 = signal[1:] - coeff * signal[:-1] - return tf.concat([s0, s1], -1) - - -def tf_depreemphasis( - signal: tf.Tensor, - coeff=0.97, -) -> tf.Tensor: - """ - TF Depreemphasis - Args: - signal: tf.Tensor with shape [B, None] - coeff: Float that indicates the preemphasis coefficient - - Returns: - depre-emphasized signal with shape [B, None] - """ - if not coeff or coeff <= 0.0: - return signal - - def map_fn(elem): - x = tf.expand_dims(elem[0], axis=-1) - for n in range(1, elem.shape[0], 1): - current = coeff * x[n - 1] + elem[n] - x = tf.concat([x, [current]], 0) - return x - - return tf.map_fn(map_fn, signal) - - -class SpeechFeaturizer: - def __init__(self, speech_config: SpeechConfig): - self.speech_config = speech_config - self.max_length = 0 - - @property - def nfft(self) -> int: - """Number of FFT""" - fft_length = int(max(512, math.pow(2, math.ceil(math.log(self.speech_config.frame_length, 2))))) - if self.speech_config.fft_overdrive: - fft_length *= 2 - return fft_length - - @property - def shape(self) -> list: - length = self.max_length if self.max_length > 0 else None - return [length, self.speech_config.num_feature_bins, 1] - - def get_length_from_duration(self, duration): - nsamples = math.ceil(float(duration) * self.speech_config.sample_rate) - # https://www.tensorflow.org/api_docs/python/tf/signal/frame - if self.speech_config.use_librosa_like_stft: - return 1 + (nsamples - self.nfft) // self.speech_config.frame_step - if self.speech_config.pad_end: - return -(-nsamples // self.speech_config.frame_step) - return 1 + (nsamples - self.speech_config.frame_length) // self.speech_config.frame_step - - def update_length(self, length: int): - self.max_length = max(self.max_length, length) - - def reset_length(self): - self.max_length = 0 - - def stft(self, signal): - if self.speech_config.use_librosa_like_stft: - # signal = tf.pad(signal, [[self.nfft // 2, self.nfft // 2]], mode="REFLECT") - window = tf.signal.hann_window(self.speech_config.frame_length, periodic=True) - left_pad = (self.nfft - self.speech_config.frame_length) // 2 - right_pad = self.nfft - self.speech_config.frame_length - left_pad - window = tf.pad(window, [[left_pad, right_pad]]) - framed_signals = tf.signal.frame(signal, frame_length=self.nfft, frame_step=self.speech_config.frame_step) - framed_signals *= window - fft_features = tf.abs(tf.signal.rfft(framed_signals, [self.nfft])) - else: - fft_features = tf.abs( - tf.signal.stft( - signal, - frame_length=self.speech_config.frame_length, - frame_step=self.speech_config.frame_step, - fft_length=self.nfft, - pad_end=self.speech_config.pad_end, - ) - ) - if self.speech_config.compute_energy: - fft_features = tf.square(fft_features) - return fft_features - - def logarithm(self, S): - if self.speech_config.use_natural_log: - return tf.math.log(tf.maximum(float(self.speech_config.output_floor), S)) - log_spec = 10.0 * math_util.log10(tf.maximum(self.speech_config.output_floor, S)) - log_spec -= 10.0 * math_util.log10(tf.maximum(self.speech_config.output_floor, 1.0)) - return log_spec - - def extract(self, signal: np.ndarray) -> np.ndarray: - signal = np.asfortranarray(signal) - features = self.tf_extract(tf.convert_to_tensor(signal, dtype=tf.float32)) - return features.numpy() - - def tf_extract(self, signal: tf.Tensor) -> tf.Tensor: - """ - Extract speech features from signals (for using in tflite) - Args: - signal: tf.Tensor with shape [None] - - Returns: - features: tf.Tensor with shape [T, F, 1] - """ - if self.speech_config.normalize_signal: - signal = tf_normalize_signal(signal) - signal = tf_preemphasis(signal, self.speech_config.preemphasis) - - if self.speech_config.feature_type == "spectrogram": - features = self.compute_spectrogram(signal) - elif self.speech_config.feature_type == "log_mel_spectrogram": - features = self.compute_log_mel_spectrogram(signal) - elif self.speech_config.feature_type == "mfcc": - features = self.compute_mfcc(signal) - elif self.speech_config.feature_type == "log_gammatone_spectrogram": - features = self.compute_log_gammatone_spectrogram(signal) - else: - raise ValueError("feature_type must be either 'mfcc', 'log_mel_spectrogram' or 'spectrogram'") - - if self.speech_config.normalize_feature: - features = tf_normalize_audio_features(features, per_frame=self.speech_config.normalize_per_frame) - - features = tf.expand_dims(features, axis=-1) - return features - - def compute_log_mel_spectrogram(self, signal): - spectrogram = self.stft(signal) - linear_to_weight_matrix = tf.signal.linear_to_mel_weight_matrix( - num_mel_bins=self.speech_config.num_feature_bins, - num_spectrogram_bins=spectrogram.shape[-1], - sample_rate=self.speech_config.sample_rate, - lower_edge_hertz=self.speech_config.lower_edge_hertz, - upper_edge_hertz=self.speech_config.upper_edge_hertz, - ) - mel_spectrogram = tf.matmul(spectrogram, linear_to_weight_matrix) - return self.logarithm(mel_spectrogram) - - def compute_spectrogram(self, signal): - S = self.stft(signal) - spectrogram = self.logarithm(S) - return spectrogram[:, : self.speech_config.num_feature_bins] - - def compute_mfcc(self, signal): - log_mel_spectrogram = self.compute_log_mel_spectrogram(signal) - return tf.signal.mfccs_from_log_mel_spectrograms(log_mel_spectrogram) - - def compute_log_gammatone_spectrogram(self, signal: np.ndarray) -> np.ndarray: - S = self.stft(signal) - gtone = gammatone.fft_weights( - self.nfft, - self.speech_config.sample_rate, - self.speech_config.num_feature_bins, - width=1.0, - fmin=int(self.speech_config.lower_edge_hertz), - fmax=int(self.speech_config.upper_edge_hertz), - maxlen=(self.nfft / 2 + 1), - ) - gtone_spectrogram = tf.matmul(S, gtone) - return self.logarithm(gtone_spectrogram) diff --git a/tensorflow_asr/featurizers/text_featurizers.py b/tensorflow_asr/featurizers/text_featurizers.py deleted file mode 100755 index b087bf6c28..0000000000 --- a/tensorflow_asr/featurizers/text_featurizers.py +++ /dev/null @@ -1,518 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import codecs -import os -import unicodedata -from multiprocessing import cpu_count - -import numpy as np -import sentencepiece as sp -import tensorflow as tf -import tensorflow_datasets as tds -import tensorflow_text as tft -from tensorflow_text.tools.wordpiece_vocab import bert_vocab_from_dataset as bert_vocab - -from tensorflow_asr.configs.config import DecoderConfig -from tensorflow_asr.utils import file_util - -logger = tf.get_logger() - -TEXT_FEATURIZER_TYPES = ["characters", "wordpiece", "subwords", "sentencepiece"] - -ENGLISH_CHARACTERS = [ - "", - " ", - "a", - "b", - "c", - "d", - "e", - "f", - "g", - "h", - "i", - "j", - "k", - "l", - "m", - "n", - "o", - "p", - "q", - "r", - "s", - "t", - "u", - "v", - "w", - "x", - "y", - "z", - "'", -] - - -class TextFeaturizer: - def __init__(self, decoder_config: DecoderConfig): - self.scorer = None - self.decoder_config = decoder_config - self.blank = None - self.tokens2indices = {} - self.tokens = [] - self.num_classes = None - self.max_length = 0 - - @property - def shape(self) -> list: - return [self.max_length if self.max_length > 0 else None] - - @property - def prepand_shape(self) -> list: - return [self.max_length + 1 if self.max_length > 0 else None] - - def update_length( - self, - length: int, - ): - self.max_length = max(self.max_length, length) - - def reset_length(self): - self.max_length = 0 - - def preprocess_text(self, text): - text = unicodedata.normalize(self.decoder_config.normalization_form, text.lower()) - return text.strip("\n").strip() # remove trailing newline - - def tf_preprocess_text(self, text: tf.Tensor): - text = tft.normalize_utf8(text, self.decoder_config.normalization_form) - text = tf.strings.regex_replace(text, r"\p{Cc}|\p{Cf}", " ") - text = tf.strings.lower(text, encoding="utf-8") - text = tf.strings.strip(text) # remove trailing whitespace - return text - - def add_scorer(self, scorer: any = None): - """Add scorer to this instance""" - self.scorer = scorer - - def normalize_indices(self, indices: tf.Tensor) -> tf.Tensor: - """ - Remove -1 in indices by replacing them with blanks - Args: - indices (tf.Tensor): shape any - - Returns: - tf.Tensor: normalized indices with shape same as indices - """ - with tf.name_scope("normalize_indices"): - minus_one = -1 * tf.ones_like(indices, dtype=tf.int32) - blank_like = self.blank * tf.ones_like(indices, dtype=tf.int32) - return tf.where(indices == minus_one, blank_like, indices) - - def prepand_blank(self, text: tf.Tensor) -> tf.Tensor: - """Prepand blank index for transducer models""" - return tf.concat([[self.blank], text], 0) - - def extract(self, text: str) -> tf.Tensor: - raise NotImplementedError() - - def tf_extract(self, text: tf.Tensor) -> tf.Tensor: - raise NotImplementedError() - - def iextract(self, indices: tf.Tensor) -> tf.Tensor: - raise NotImplementedError() - - def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor: - raise NotImplementedError() - - -class CharFeaturizer(TextFeaturizer): - """ - Extract text feature based on char-level granularity. - By looking up the vocabulary table, each line of transcript will be - converted to a sequence of integer indexes. - """ - - def __init__(self, decoder_config: DecoderConfig): - super().__init__(decoder_config) - lines = [] - if self.decoder_config.vocabulary is not None: - with codecs.open(self.decoder_config.vocabulary, "r", "utf-8") as fin: - lines.extend(fin.readlines()) - else: - lines = ENGLISH_CHARACTERS - self.blank = self.decoder_config.blank_index - self.tokens = [] - for line in lines: - line = unicodedata.normalize(self.decoder_config.normalization_form, line.lower()).strip("\n") - if line.startswith("#") or not line: - continue - self.tokens.append(line) - if self.blank is None: - self.blank = len(self.tokens) # blank not at zero - self.num_classes = len(self.tokens) - self.indices = tf.range(self.num_classes, dtype=tf.int32) - self.tokenizer = tf.lookup.StaticHashTable( - tf.lookup.KeyValueTensorInitializer(keys=self.tokens, values=self.indices, key_dtype=tf.string, value_dtype=tf.int32), - default_value=self.blank, - ) - self.detokenizer = tf.lookup.StaticHashTable( - tf.lookup.KeyValueTensorInitializer(keys=self.indices, values=self.tokens, key_dtype=tf.int32, value_dtype=tf.string), - default_value=self.tokens[self.blank], - ) - self.upoints = tf.strings.unicode_decode(self.tokens, "UTF-8").to_tensor(shape=[None, 1]) - - def extract(self, text: str): - return self.tf_extract(tf.convert_to_tensor(text)) - - def tf_extract(self, text): - text = self.tf_preprocess_text(text) - text = tf.strings.unicode_split(text, "UTF-8") - return self.tokenizer.lookup(text) - - def iextract(self, indices: tf.Tensor) -> tf.Tensor: - """ - Convert list of indices to string - Args: - indices: tf.Tensor with dim [B, None] - - Returns: - transcripts: tf.Tensor of dtype tf.string with dim [B] - """ - indices = self.normalize_indices(indices) - indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.blank)) - tokens = self.detokenizer.lookup(indices) - tokens = tf.strings.reduce_join(tokens, axis=-1) - return tokens - - @tf.function(input_signature=[tf.TensorSpec([None], dtype=tf.int32)]) - def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor: - """ - Transform Predicted Indices to Unicode Code Points (for using tflite) - Args: - indices: tf.Tensor of Classes in shape [None] - - Returns: - unicode code points transcript with dtype tf.int32 and shape [None] - """ - with tf.name_scope("indices2upoints"): - indices = self.normalize_indices(indices) - upoints = tf.gather_nd(self.upoints, tf.expand_dims(indices, axis=-1)) - return tf.gather_nd(upoints, tf.where(tf.not_equal(upoints, 0))) - - -class SubwordFeaturizer(TextFeaturizer): - """ - Extract text feature based on char-level granularity. - By looking up the vocabulary table, each line of transcript will be - converted to a sequence of integer indexes. - """ - - def __init__(self, decoder_config: DecoderConfig, subwords=None): - super().__init__(decoder_config) - self.subwords = self.__load_subwords() if subwords is None else subwords - self.blank = 0 # subword treats blank as 0 - self.num_classes = self.subwords.vocab_size - # create upoints - self.__init_vocabulary() - - def __init_vocabulary(self): - self.tokens = [] - for idx in np.arange(1, self.num_classes, dtype=np.int32): - self.tokens.append(self.subwords.decode([idx])) - self.non_blank_tokens = self.tokens.copy() - self.tokens.insert(0, "") - self.upoints = tf.strings.unicode_decode(self.tokens, "UTF-8") - self.upoints = self.upoints.to_tensor() # [num_classes, max_subword_length] - - def __load_subwords(self): - filename_prefix = os.path.splitext(self.decoder_config.vocabulary)[0] - return tds.deprecated.text.SubwordTextEncoder.load_from_file(filename_prefix) - - @classmethod - def build_from_corpus(cls, decoder_config: DecoderConfig): - def corpus_generator(): - for file in decoder_config.corpus_files: - logger.info(f"Reading {file} ...") - with open(file, "r", encoding="utf-8") as f: - lines = f.read().splitlines() - lines = lines[1:] - for line in lines: - line = line.split("\t") - yield line[-1] - - def write_vocab_file(filepath, subwords): - filename_prefix = os.path.splitext(filepath)[0] - return subwords.save_to_file(filename_prefix) - - subwords = tds.deprecated.text.SubwordTextEncoder.build_from_corpus( - corpus_generator(), - decoder_config.vocab_size, - decoder_config.max_token_length, - decoder_config.max_unique_chars, - decoder_config.reserved_tokens, - ) - write_vocab_file(decoder_config.vocabulary, subwords) - - return cls(decoder_config, subwords) - - def extract(self, text: str) -> tf.Tensor: - """ - Convert string to a list of integers - Args: - text: string (sequence of characters) - - Returns: - sequence of ints in tf.Tensor - """ - text = self.preprocess_text(text) - indices = self.subwords.encode(text) - return tf.convert_to_tensor(indices, dtype=tf.int32) - - def tf_extract(self, text: tf.Tensor) -> tf.Tensor: - return self.extract(text) - - def iextract(self, indices: tf.Tensor) -> tf.Tensor: - """ - Convert list of indices to string - Args: - indices: tf.Tensor with dim [B, None] - - Returns: - transcripts: tf.Tensor of dtype tf.string with dim [B] - """ - with tf.device("/CPU:0"): # string data is not supported on GPU - total = tf.shape(indices)[0] - batch = tf.constant(0, dtype=tf.int32) - transcripts = tf.TensorArray( - dtype=tf.string, - size=total, - dynamic_size=False, - infer_shape=False, - clear_after_read=False, - element_shape=tf.TensorShape([]), - ) - - def cond(_batch, _total, _): - return tf.less(_batch, _total) - - def body(_batch, _total, _transcripts): - norm_indices = self.normalize_indices(indices[_batch]) - norm_indices = tf.gather_nd(norm_indices, tf.where(tf.not_equal(norm_indices, 0))) - decoded = tf.numpy_function(self.subwords.decode, inp=[norm_indices], Tout=tf.string) - _transcripts = _transcripts.write(_batch, decoded) - return _batch + 1, _total, _transcripts - - _, _, transcripts = tf.while_loop(cond, body, loop_vars=[batch, total, transcripts]) - - return transcripts.stack() - - @tf.function(input_signature=[tf.TensorSpec([None], dtype=tf.int32)]) - def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor: - """ - Transform Predicted Indices to Unicode Code Points (for using tflite) - Args: - indices: tf.Tensor of Classes in shape [None] - - Returns: - unicode code points transcript with dtype tf.int32 and shape [None] - """ - with tf.name_scope("indices2upoints"): - indices = self.normalize_indices(indices) - upoints = tf.gather_nd(self.upoints, tf.expand_dims(indices, axis=-1)) - return tf.gather_nd(upoints, tf.where(tf.not_equal(upoints, 0))) - - -class SentencePieceFeaturizer(TextFeaturizer): - def __init__(self, decoder_config: DecoderConfig): - super().__init__(decoder_config) - self.blank = self.decoder_config.blank_index - self.tokenizer = tft.FastSentencepieceTokenizer(self.__load_model()) - self.num_classes = int(self.tokenizer.vocab_size()) - - def __load_model(self): - with file_util.read_file(self.decoder_config.vocabulary) as path: - with open(path, "rb") as f: - return f.read() - - @classmethod - def build_from_corpus(cls, decoder_config: DecoderConfig): - output_path_prefix = os.path.splitext(decoder_config.vocabulary)[0] - - def corpus_iterator(): - for file in decoder_config.corpus_files: - with open(file, "r", encoding="utf-8") as f: - lines = f.read().splitlines() - lines = lines[1:] - for line in lines: - line = line.split("\t") - yield line[-1] - - sp.SentencePieceTrainer.Train( - sentence_iterator=corpus_iterator(), - model_prefix=output_path_prefix, - model_type=decoder_config.model_type, - vocab_size=decoder_config.vocab_size, - num_threads=cpu_count(), - unk_id=decoder_config.unknown_index, - bos_id=decoder_config.bos_index, - eos_id=decoder_config.eos_index, - pad_id=decoder_config.pad_index, - unk_surface="__UNKNOWN__", # change default unk surface U+2047("⁇") by "__UNKNOWN__" - ) - - return cls(decoder_config) - - def extract(self, text: str) -> tf.Tensor: - return self.tf_extract(text) - - def tf_extract(self, text: tf.Tensor) -> tf.Tensor: - text = self.tf_preprocess_text(text) - text = tf.strings.split(text) - indices = self.tokenizer.tokenize(text).merge_dims(0, 1) - indices = tf.cast(indices, tf.int32) - return indices - - def iextract(self, indices: tf.Tensor) -> tf.Tensor: - """ - Convert list of indices to string - Args: - indices: tf.Tensor with dim [B, None] - - Returns: - transcripts: tf.Tensor of dtype tf.string with dim [B] - """ - indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.blank)) - indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.decoder_config.unknown_index)) - indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.decoder_config.bos_index)) - indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.decoder_config.eos_index)) - transcripts = self.tokenizer.detokenize(indices) - return transcripts - - @tf.function(input_signature=[tf.TensorSpec([None], dtype=tf.int32)]) - def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor: - """ - Transform Predicted Indices to Unicode Code Points (for using tflite) - Args: - indices: tf.Tensor of Classes in shape [None] - - Returns: - unicode code points transcript with dtype tf.int32 and shape [None] - """ - with tf.name_scope("indices2upoints"): - transcripts = self.iextract(tf.reshape(indices, [1, -1])) - upoints = tf.strings.unicode_decode(transcripts, "UTF-8").to_tensor() - return tf.reshape(upoints, [-1]) - - -class WordPieceFeaturizer(TextFeaturizer): - def __init__(self, decoder_config: DecoderConfig): - super().__init__(decoder_config) - self.blank = self.decoder_config.blank_index # treat [PAD] as blank - self.vocab = None - with tf.io.gfile.GFile(self.decoder_config.vocabulary, "r") as voc: - self.vocab = voc.read().splitlines() - if not self.vocab: - raise ValueError("Unable to read vocabulary") - self.tokenizer = tft.FastWordpieceTokenizer( - vocab=self.vocab, - token_out_type=tf.int32, - unknown_token=self.decoder_config.unknown_token, - no_pretokenization=True, # False is limited, so we manually do pretokenization - support_detokenization=True, - ) - self.num_classes = len(self.vocab) + 1 - - @classmethod - def build_from_corpus(cls, decoder_config: DecoderConfig): - def corpus_generator(): - for file_path in decoder_config.corpus_files: - logger.info(f"Reading {file_path} ...") - with tf.io.gfile.GFile(file_path, "r") as f: - temp_lines = f.read().splitlines() - for line in temp_lines[1:]: # Skip the header of tsv file - data = line.split("\t", 2)[-1] # get only transcript - yield data - - def write_vocab_file(filepath, vocab): - with tf.io.gfile.GFile(filepath, "w") as f: - for token in vocab: - print(token, file=f) - - dataset = tf.data.Dataset.from_generator(corpus_generator, output_signature=tf.TensorSpec(shape=(), dtype=tf.string)) - vocab = bert_vocab.bert_vocab_from_dataset( - dataset.batch(1000).prefetch(2), - vocab_size=decoder_config.vocab_size, - reserved_tokens=decoder_config.reserved_tokens, - bert_tokenizer_params=dict( - lower_case=False, # keep original from dataset - keep_whitespace=False, - normalization_form=decoder_config.normalization_form, - preserve_unused_token=False, - ), - learn_params=dict( - max_token_length=decoder_config.max_token_length, - max_unique_chars=decoder_config.max_unique_chars, - num_iterations=decoder_config.num_iterations, - ), - ) - write_vocab_file(decoder_config.vocabulary, vocab) - - return cls(decoder_config) - - def extract(self, text: str) -> tf.Tensor: - """ - Convert string to a list of integers - Args: - text: string (sequence of characters) - - Returns: - sequence of ints in tf.Tensor - """ - return self.tf_extract(text) - - def tf_extract(self, text: tf.Tensor) -> tf.Tensor: - text = self.tf_preprocess_text(text) - text = tf.strings.split(text) - indices = self.tokenizer.tokenize(text).merge_dims(0, 1) - return indices - - def iextract(self, indices: tf.Tensor) -> tf.Tensor: - """ - Convert list of indices to string - Args: - indices: tf.Tensor with dim [B, None] - - Returns: - transcripts: tf.Tensor of dtype tf.string with dim [B] - """ - indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.blank)) - indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.decoder_config.unknown_index)) - transcripts = self.tokenizer.detokenize(indices) - return transcripts - - @tf.function(input_signature=[tf.TensorSpec([None], dtype=tf.int32)]) - def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor: - """ - Transform Predicted Indices to Unicode Code Points (for using tflite) - Args: - indices: tf.Tensor of Classes in shape [None] - - Returns: - unicode code points transcript with dtype tf.int32 and shape [None] - """ - with tf.name_scope("indices2upoints"): - transcripts = self.iextract(tf.reshape(indices, [1, -1])) - upoints = tf.strings.unicode_decode(transcripts, "UTF-8").to_tensor() - return tf.reshape(upoints, [-1]) diff --git a/tensorflow_asr/helpers/dataset_helpers.py b/tensorflow_asr/helpers/dataset_helpers.py deleted file mode 100644 index 604b5fb821..0000000000 --- a/tensorflow_asr/helpers/dataset_helpers.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2022 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from tensorflow_asr.configs.config import Config -from tensorflow_asr.datasets import asr_dataset -from tensorflow_asr.featurizers.speech_featurizers import SpeechFeaturizer -from tensorflow_asr.featurizers.text_featurizers import TextFeaturizer - - -def prepare_training_datasets( - config: Config, - speech_featurizer: SpeechFeaturizer, - text_featurizer: TextFeaturizer, - tfrecords: bool = False, -): - if tfrecords: - train_dataset = asr_dataset.ASRTFRecordDataset( - speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.train_dataset_config), indefinite=True - ) - eval_dataset = asr_dataset.ASRTFRecordDataset( - speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.eval_dataset_config), indefinite=True - ) - else: - train_dataset = asr_dataset.ASRSliceDataset( - speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.train_dataset_config), indefinite=True - ) - eval_dataset = asr_dataset.ASRSliceDataset( - speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.eval_dataset_config), indefinite=True - ) - return train_dataset, eval_dataset - - -def prepare_testing_datasets( - config: Config, - speech_featurizer: SpeechFeaturizer, - text_featurizer: TextFeaturizer, -): - test_dataset = asr_dataset.ASRSliceDataset( - speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, **vars(config.learning_config.test_dataset_config) - ) - return test_dataset - - -def prepare_training_data_loaders( - config: Config, - train_dataset: asr_dataset.ASRDataset, - eval_dataset: asr_dataset.ASRDataset, - strategy, - batch_size: int = None, -): - global_batch_size = batch_size or config.learning_config.running_config.batch_size - global_batch_size *= strategy.num_replicas_in_sync - train_data_loader = train_dataset.create(global_batch_size) - eval_data_loader = eval_dataset.create(global_batch_size) - return train_data_loader, eval_data_loader, global_batch_size diff --git a/tensorflow_asr/helpers/exec_helpers.py b/tensorflow_asr/helpers/exec_helpers.py deleted file mode 100644 index de1b537c33..0000000000 --- a/tensorflow_asr/helpers/exec_helpers.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright 2022 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tensorflow as tf -from tqdm import tqdm - -from tensorflow_asr.datasets.asr_dataset import ASRSliceDataset -from tensorflow_asr.models.base_model import BaseModel -from tensorflow_asr.utils import app_util, file_util - -logger = tf.get_logger() - - -def run_testing( - model: BaseModel, - test_dataset: ASRSliceDataset, - test_data_loader: tf.data.Dataset, - output: str, -): - with file_util.save_file(file_util.preprocess_paths(output)) as filepath: - overwrite = True - if tf.io.gfile.exists(filepath): - overwrite = input(f"Overwrite existing result file {filepath} ? (y/n): ").lower() == "y" - if overwrite: - results = model.predict(test_data_loader, verbose=1) - logger.info(f"Saving result to {output} ...") - with tf.io.gfile.GFile(filepath, "w") as openfile: - openfile.write("PATH\tDURATION\tGROUNDTRUTH\tGREEDY\tBEAMSEARCH\n") - progbar = tqdm(total=test_dataset.total_steps, unit="batch") - for i, pred in enumerate(results): - groundtruth, greedy, beamsearch = [x.decode("utf-8") for x in pred] - path, duration, _ = test_dataset.entries[i] - openfile.write(f"{path}\t{duration}\t{groundtruth}\t{greedy}\t{beamsearch}\n") - progbar.update(1) - progbar.close() - app_util.evaluate_results(filepath) - - -def convert_tflite( - model: BaseModel, - output: str, -): - concrete_func = model.make_tflite_function().get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - converter.target_spec.supported_ops = [ - tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. - tf.lite.OpsSet.SELECT_TF_OPS, # enable TensorFlow ops. - ] - tflite_model = converter.convert() - - output = file_util.preprocess_paths(output) - with open(output, "wb") as tflite_out: - tflite_out.write(tflite_model) diff --git a/tensorflow_asr/helpers/featurizer_helpers.py b/tensorflow_asr/helpers/featurizer_helpers.py deleted file mode 100644 index 7f8a1fe696..0000000000 --- a/tensorflow_asr/helpers/featurizer_helpers.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright 2022 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tensorflow as tf - -from tensorflow_asr.configs.config import Config -from tensorflow_asr.featurizers import speech_featurizers, text_featurizers - -logger = tf.get_logger() - - -def prepare_featurizers( - config: Config, -): - speech_featurizer = speech_featurizers.SpeechFeaturizer(config.speech_config) - if config.decoder_config.type == "sentencepiece": - logger.info("Loading SentencePiece model ...") - text_featurizer = text_featurizers.SentencePieceFeaturizer(config.decoder_config) - elif config.decoder_config.type == "subwords": - logger.info("Loading subwords ...") - text_featurizer = text_featurizers.SubwordFeaturizer(config.decoder_config) - elif config.decoder_config.type == "wordpiece": - logger.info("Loading wordpiece ...") - text_featurizer = text_featurizers.WordPieceFeaturizer(config.decoder_config) - elif config.decoder_config.type == "characters": - logger.info("Use characters ...") - text_featurizer = text_featurizers.CharFeaturizer(config.decoder_config) - else: - raise ValueError(f"type must be in {text_featurizers.TEXT_FEATURIZER_TYPES}, received {config.decoder_config.type}") - return speech_featurizer, text_featurizer diff --git a/tensorflow_asr/losses/base_loss.py b/tensorflow_asr/losses/base_loss.py new file mode 100644 index 0000000000..1276d847bb --- /dev/null +++ b/tensorflow_asr/losses/base_loss.py @@ -0,0 +1,42 @@ +# Copyright 2020 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tensorflow_asr import keras, schemas, tf +from tensorflow_asr.utils import env_util + +logger = tf.get_logger() + + +class BaseLoss(keras.losses.Loss): + def __init__(self, blank=0, reduction="sum_over_batch_size", name=None): + super().__init__(reduction=reduction, name=name) + assert blank == 0, "Only support blank=0" + self.blank = blank + self.use_tpu = env_util.has_devices("TPU") + + def call( + self, + y_true: schemas.TrainLabel, + y_pred: schemas.TrainOutput, + ): + logit_length = tf.cast(y_pred.logits_length, tf.int32) + labels = tf.cast(y_true.labels, tf.int32) + label_length = tf.cast(y_true.labels_length, tf.int32) + logit_length = tf.where(tf.less(logit_length, label_length), label_length, logit_length) # pad logit_length to label_length + return y_pred.logits, logit_length, labels, label_length + + def get_config(self): + config = super().get_config() + config.update({"blank": self.blank}) + return config diff --git a/tensorflow_asr/losses/ctc_loss.py b/tensorflow_asr/losses/ctc_loss.py index ce9fd2da7e..0e58e47c45 100644 --- a/tensorflow_asr/losses/ctc_loss.py +++ b/tensorflow_asr/losses/ctc_loss.py @@ -12,29 +12,55 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf +# Copyright 2021 Alexey Tochin +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import logging +import os + +from tensorflow_asr import tf +from tensorflow_asr.losses.base_loss import BaseLoss +from tensorflow_asr.losses.impl.ctc_tpu import ctc_loss_tpu + +logger = logging.getLogger(__name__) -logger = tf.get_logger() +TFASR_USE_TF_CTC = os.getenv("TFASR_USE_TF_CTC", "False") in ("true", "True", "1") -class CtcLoss(tf.keras.losses.Loss): - def __init__( - self, - blank=0, - name=None, - ): - super().__init__(reduction=tf.keras.losses.Reduction.NONE, name=name) - self.blank = blank - logger.info("Use CTC loss") +class CtcLoss(BaseLoss): + def __init__(self, blank=0, reduction="sum_over_batch_size", name=None): + super().__init__(blank=blank, reduction=reduction, name=name) + logger.info("Use CTC loss TPU implementation" if self.use_tpu and not TFASR_USE_TF_CTC else "Use CTC loss") def call(self, y_true, y_pred): + logits, logit_length, labels, label_length = super().call(y_true, y_pred) + if self.use_tpu and not TFASR_USE_TF_CTC: + return ctc_loss_tpu( + labels=labels, + logits=logits, + label_length=label_length, + logit_length=logit_length, + blank_index=self.blank, + ) return tf.nn.ctc_loss( - logits=y_pred["logits"], - logit_length=y_pred["logits_length"], - labels=y_true["labels"], - label_length=y_true["labels_length"], + logits=logits, + logit_length=logit_length, + labels=labels, + label_length=label_length, logits_time_major=False, - unique=tf.nn.ctc_unique_labels(y_true["labels"]), # enable a faster, memory efficient implementation on TPU. + unique=tf.nn.ctc_unique_labels(labels) if self.use_tpu else None, blank_index=self.blank, name=self.name, ) diff --git a/tensorflow_asr/featurizers/methods/__init__.py b/tensorflow_asr/losses/impl/__init__.py similarity index 100% rename from tensorflow_asr/featurizers/methods/__init__.py rename to tensorflow_asr/losses/impl/__init__.py diff --git a/tensorflow_asr/losses/impl/ctc_tpu.py b/tensorflow_asr/losses/impl/ctc_tpu.py new file mode 100644 index 0000000000..4667b1d617 --- /dev/null +++ b/tensorflow_asr/losses/impl/ctc_tpu.py @@ -0,0 +1,1314 @@ +# pylint: disable=redefined-builtin,method-hidden,invalid-overridden-method +# -*- coding: utf-8 -*- +""" +Created on Tue Jul 18 20:29:39 2023 +""" + +# Copyright 2021 Alexey Tochin +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from abc import ABC, abstractmethod +from typing import Callable, List, Optional, Type, Union + +import numpy as np +import tensorflow as tf +from cached_property import cached_property + +inf = tf.constant(np.inf) + + +def logit_to_logproba(logit: tf.Tensor, axis: int) -> tf.Tensor: + """Converts logits to logarithmic probabilities: + logit_to_logproba(x) = x - log (sum along axis (exp(x)) + + Args: + logit: tf.Tensor, dtype = tf.float32 + axis: integer, like for tf.reduce_logsumexp + + Returns: tf.Tensor, of the same shape and size as input logit + """ + log_probas = logit - tf.reduce_logsumexp(input_tensor=logit, axis=axis, keepdims=True) + return log_probas + + +def apply_logarithmic_mask(tensor: tf.Tensor, mask: tf.Tensor) -> tf.Tensor: + """Masks a logarithmic representation of a tensor, namely + 1. Keeps the value of tensor unchanged for True values of mask + 2. Replace the value of tensor by -tf.inf for False values of mask + + Args: + tensor: tf.Tensor, dtype = tf.float32 of the same shape as mask or broadcastable + mask: tf.Tensor, dbool = tf.float32 of the same shape as tensor or broadcastable + + Returns: tf.Tensor, dtype = tf.float32 of the same shape as tensor + """ + return tensor + tf.math.log(tf.cast(mask, dtype=tf.float32)) + + +def logsumexp(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor: + """A numerically stable version of elementwise function + logsumexp(x, y) = log (e ** x + e ** y) + + Args: + x: tf.Tensor of the shape and size as y or broadcastable + y: tf.Tensor of the shape and size as x or broadcastable + + Returns: tf.Tensor of the shape and size as x and y + """ + return tf.where( + condition=x < y, + x=y + tf.math.softplus(x - y), + y=tf.where(condition=x > y, x=x + tf.math.softplus(y - x), y=x + np.log(2.0)), + ) + + +def subexp(x: tf.Tensor, y: tf.Tensor) -> tf.Tensor: + """A numerically stable version of elementwise function + subexp(x,y) := exp x - exp y + + Args: + x: tf.Tensor, shape broadcastable to y + y: tf.Tensor, shape broadcastable to x + + Returns: tf.Tensor, shape, the same as x and y + """ + return tf.where( + condition=x > y, + x=-tf.exp(x) * tf.math.expm1(y - x), + y=tf.where( + condition=x < y, + x=tf.exp(y) * tf.math.expm1(x - y), + y=tf.zeros_like(x), + ), + ) + + +def unsorted_segment_logsumexp(data: tf.Tensor, segment_ids: tf.Tensor, num_segments: Union[int, tf.Tensor]) -> tf.Tensor: + """Computes the logarithmic sum of exponents along segments of a tensor + like other operators from tf.math.unsorted_segment_* family. + + Args: + data: tf.Tensor, shape = [...] + data_dims, + segment_ids: tf.Tensor, shape = [...], dtype = tf.int32 + num_segments: tf.Tensor, shape = [], dtype = tf.int32 + + Returns: tf.Tensor, shape = [num_segments] + data_dims, for the same type as data + """ + data_max = tf.math.unsorted_segment_max(data=data, segment_ids=segment_ids, num_segments=num_segments) + data_normed = data - tf.gather(params=data_max, indices=segment_ids) + output = data_max + tf.math.log( + tf.math.unsorted_segment_sum( + data=tf.exp(data_normed), + segment_ids=segment_ids, + num_segments=num_segments, + ) + ) + return output + + +def pad_until(tensor: tf.Tensor, desired_size: Union[tf.Tensor, int], axis: int, pad_value: Union[tf.Tensor, int, float, bool] = 0) -> tf.Tensor: + """Pads tensor until desired dimension from right, + + Args: + tensor: tf.Tensor, of any shape and type + desired_size: tf.Tensor or pythonic static integer + axis: pythonic static integer for pad axes + pad_value: tf.Tensor or pythonic numerical for padding + + Returns: tf.Tensor, the same shape as tensor except axis that equals to desired_size. + """ + rank = len(tensor.shape) + if axis >= rank: + raise ValueError() + + current_size = tf.shape(tensor)[axis] + paddings = [[0, 0]] * axis + [[0, tf.maximum(desired_size - current_size, 0)]] + [[0, 0]] * (rank - axis - 1) + return tf.pad(tensor=tensor, paddings=paddings, constant_values=pad_value) + + +def insert_zeros(tensor: tf.Tensor, mask: tf.Tensor) -> tf.Tensor: + """Inserts zeros into tensor before each masked element. + For example: + ```python + output = insert_zeros( + tensor = tf.constant([[1, 2, 3, 4, 5], [10, 20, 30, 40, 50]], dtype = tf.int32), + mask = tf.constant([[False, True, False, False, True], [False, True, True, True, False]]), + ) + # -> [[1, 0, 2, 3, 4, 0, 5, 0], [10, 0, 20, 0, 30, 0, 40, 50]] + # We insert 0s 2, 5, 20, 30, and 40 because their positions in input tensor corresponds to True value + in mask. + ``` + + Args: + tensor: tf.Tensor, shape = [batch, length], any type and the same shape as mask + mask: tf.Tensor, shape = [batch, length], dtype = tf.bool and the same shape as tensor + + Returns: tf.Tensor, shape = [batch, length + max_num_insertions], + where max_num_insertions is the maximal number of True values along the 0 batch dimension of mask. + dtype = same as input tensor + """ + batch_size = tf.shape(tensor)[0] + length = tf.shape(mask)[1] + + delta = tf.cumsum(tf.cast(mask, dtype=tf.int32), exclusive=False, axis=1) + max_num_insertions = tf.reduce_max(delta[:, -1]) + + y, x = tf.meshgrid(tf.range(length), tf.range(batch_size)) + y = y + delta + indices = tf.reshape(tf.stack([x, y], 2), [-1, 2]) + + output = tf.scatter_nd(indices=indices, updates=tf.reshape(tensor, shape=[-1]), shape=tf.stack([batch_size, length + max_num_insertions])) + + return output + + +def unfold( + init_tensor: tf.Tensor, + iterfunc: Callable[[tf.Tensor, tf.Tensor], tf.Tensor], + num_iters: Union[int, tf.Tensor], + d_i: int, + element_shape: tf.TensorShape, + swap_memory: bool = False, + name: str = "unfold", +) -> tf.Tensor: + """Calculates a tensor by iterations over i that is the concatenation + for d_i = +1: + init_tensor + iterfunc(init_tensor, 0) + iterfunc(iterfunc(init_tensor, 0), 1) + ... + ..., num_iters - 1) + ..., num_iters - 1), num_iters) + for d_i = -1: + ..., 2), 1), 0) + ..., 2), 1) + ... + iterfunc(iterfunc(init_tensor, num_iters - 1), num_iters - 2) + iterfunc(init_tensor, num_iters - 1) + init_tensor + For example: + ```python + unfold( + init_tensor=tf.constant(0), + iterfunc=lambda x, i: x + i, + num_iters=5, + d_i=1, + element_shape=tf.TensorShape([]), + ) + # -> [0, 0, 1, 3, 6, 10] + ``` + + Args: + init_tensor: tf.Tensor, of any shape that is the initial value of the iterations. + iterfunc: tf.Tensor, tf.Tensor -> tf.Tensor, that is the iteration function + from and onto the same shape as init_tensor + num_iters: tf.Tensor or static integer that is the number of iterations + d_i: either +1 or -1, where + +1 corresponds for the iterations from 0 to num_iters inclusive + -1 corresponds for the iterations from num_iters to 0 inclusive + element_shape: tf.TensorShape([]) that is the shape of init_tensor + swap_memory: the same as for tf.while_loop, argument + name: str, local tensor names scope + + Returns: tf.Tensor, shape = [num_iters + 1] + init_tensor.shape + dtype the same as init_tensor + """ + assert d_i in {-1, 1} + positive_direction = d_i == 1 + + with tf.name_scope(name): + num_iters = tf.convert_to_tensor(num_iters) + + tensor_array = tf.TensorArray( + dtype=init_tensor.dtype, + size=num_iters + 1, + element_shape=element_shape, + clear_after_read=False, + infer_shape=True, + dynamic_size=False, + ) + tensor_array = tensor_array.write(0 if positive_direction else num_iters, init_tensor) + + def body(i, tensor_slice): + last_value = tensor_slice.read(i if positive_direction else i + 1) + new_value = iterfunc(last_value, i) + tensor_slice = tensor_slice.write(i + 1 if positive_direction else i, new_value) + return i + d_i, tensor_slice + + n = tf.constant(0, dtype=tf.int32) if positive_direction else num_iters - 1 + _, array_out = tf.while_loop( + cond=lambda i, _: tf.constant(True), + body=body, + loop_vars=(n, tensor_array), + maximum_iterations=num_iters, + swap_memory=swap_memory, + name="unfold_while_loop", + ) + return array_out.stack() + + +def reduce_max_with_default(input_tensor: tf.Tensor, default: tf.Tensor) -> tf.Tensor: + """A version of tf.reduce_max function that supports default values for zero size input. + Support axis=None case only that corresponds to scalar output + + Args: + input_tensor: tf.Tensor, of any shape and numerical type + default: tf.Tensor, shape = [], dtype the same as input_tensor + + Returns: tf.Tensor, shape = [], dtype the same as input_tensor + """ + total_size = tf.shape(tf.reshape(input_tensor, [-1]))[0] + return tf.where(condition=total_size > 0, x=tf.reduce_max(input_tensor), y=default) + + +def expand_many_dims(input: tf.Tensor, axes: List[int]) -> tf.Tensor: + """Analogous of tf.expand_dims for multiple new dimensions. + Like for tf.expand_dims no new memory allocated for the output tensor. + + For example: + expand_many_dims(tf.zeros(shape=[5, 1, 3]), axes=[0, 4, 5]).shape + # -> [1, 5, 1, 3, 1, 1] + + Args: + input: tf.Tensor of any rank shape and type + axes: list of integer that are supposed to be the indexes of new dimensions. + + Returns: tf.Tensor of the same type an input and rank = rank(input) + len(axes) + """ + tensor = input + for axis in axes: + tensor = tf.expand_dims(input=tensor, axis=axis) + + return tensor + + +def smart_transpose(a: tf.Tensor, perm=List[int]) -> tf.Tensor: + """Extension of tf.transpose. + Parameter perm may be shorter list than rank on input tensor a. + This case all dimensions that are beyond the list perm remain unchanged. + + For example: + smart_transpose(tf.zeros(shape=[2, 3, 4, 5, 6]), [2, 1, 0]).shape + # -> [4, 3, 2, 5, 6] + + Args: + a: tf.Tensor of any rank shape and type + perm: list of integers like for tf.transpose but in may be shorter than the shape of a. + + Returns: tf.Tensor of the same type and rank as th input tensor a. + """ + if len(perm) > len(a.shape): + raise ValueError(f"Tensor with shape '{a.shape}' cannot be reshaped to '{perm}'") + + perm_rest = list(range(len(perm), len(a.shape))) + + return tf.transpose(a=a, perm=perm + perm_rest) + + +def smart_reshape(tensor: tf.Tensor, shape: List[Optional[Union[int, tf.Tensor]]]) -> tf.Tensor: + """A version of tf.reshape. + 1. The ouput tensor is always of the same rank as input tensor. + 2. The parameter shape is supposed to be a list that is smaller or equal + than the tensor shape. + 3. The list shape may contain None, that means "keep this dimension unchanged". + 4. The list shape is appended with None value to be of the same length as the input tensor shape. + 5. Like for tf.reshape output tensor does not requre new memory for allocation. + + For example: + ```python + smart_reshape( + tensor=tf.zeros(shape=[2, 3, 4, 5]), + shape=[8, None, 1] + ) + # -> tf.Tensor([8, 3, 1, 5]) + ``` + + Args: + tensor: tf.Tensor of any shape and type + shape: list of optional static of dynamic integrates + + Returns: tf.Tensor of the same typey and rank as the input tensor + """ + if len(shape) > len(tensor.shape): + raise ValueError(f"Tensor with shape {tensor.shape} cannot be reshaped to {shape}.") + + shape = shape + [None] * (len(tensor.shape) - len(shape)) + + original_shape = tf.shape(tensor) + new_shape = [] + for index, dim in enumerate(shape): + if dim is None: + new_shape.append(original_shape[index]) + else: + new_shape.append(dim) + + return tf.reshape(tensor=tensor, shape=new_shape) + + +def ctc_loss( + labels: tf.Tensor, + logits: tf.Tensor, + label_length: tf.Tensor, + logit_length: tf.Tensor, + blank_index: Union[int, tf.Tensor], + ctc_loss_data_cls: "Type[BaseCtcLossData]", +) -> tf.Tensor: + """Computes a version of CTC loss from + http://www.cs.toronto.edu/~graves/icml_2006.pdf. + + Args: + labels: tf.Tensor, shape = [batch, max_label_length], dtype = tf.int32 + logits: tf.Tensor, shape = [batch, max_length, mum_tokens], dtype = tf.float32 + label_length: tf.Tensor, shape = [batch], dtype = tf.int32 + logit_length: tf.Tensor, shape = [batch], dtype = tf.int32 + blank_index: static integer >= 0 + ctc_loss_data_cls: BaseCtcLossData class + + Returns: tf.Tensor, shape = [batch, max_length, mum_tokens], dtype = tf.float32 + """ + log_probas = logit_to_logproba(logit=logits, axis=2) + loss = ctc_loss_from_logproba( + labels=labels, + logprobas=log_probas, + label_length=label_length, + logit_length=logit_length, + blank_index=blank_index, + ctc_loss_data_cls=ctc_loss_data_cls, + ) + return loss + + +def ctc_loss_from_logproba( + labels: tf.Tensor, + logprobas: tf.Tensor, + label_length: tf.Tensor, + logit_length: tf.Tensor, + blank_index: Union[int, tf.Tensor], + ctc_loss_data_cls: "Type[BaseCtcLossData]", +) -> tf.Tensor: + """Computes a version of CTC loss from logarothmic probabilities considered as independent parameters. + + Args: + labels: tf.Tensor, shape = [batch, max_label_length], dtype = tf.int32 + logprobas: tf.Tensor, shape = [batch, max_length, mum_tokens], dtype = tf.float32 + label_length: tf.Tensor, shape = [batch], dtype = tf.int32 + logit_length: tf.Tensor, shape = [batch], dtype = tf.int32 + blank_index: static integer >= 0 + ctc_loss_data_cls: BaseCtcLossData class + + Returns: tf.Tensor, shape = [batch, max_length, mum_tokens], dtype = tf.float32 + """ + loss_data = ctc_loss_data_cls( + labels=labels, + logprobas=tf.stop_gradient(logprobas), + label_length=label_length, + logit_length=logit_length, + blank_index=blank_index, + ) + + return loss_data.forward_fn(logprobas) + + +class BaseCtcLossData(ABC): + """Base class for CTC loss data.""" + + def __init__( + self, + labels: tf.Tensor, + logprobas: tf.Tensor, + label_length: tf.Tensor, + logit_length: tf.Tensor, + blank_index: Union[int, tf.Tensor], + swap_memory: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self._logprobas = logprobas + self._original_label = labels + self._logit_length = logit_length + self._original_label_length = label_length + self.max_label_length_plus_one = tf.shape(labels)[1] + self._verify_inputs() + + if isinstance(blank_index, (tf.Tensor, tf.Variable)): + self._blank_index = blank_index + else: + self._blank_index = tf.constant(blank_index, dtype=tf.int32) + + self._swap_memory = swap_memory + + def _verify_inputs(self) -> None: + assert len(self._logprobas.shape) == 3 + assert self._logprobas.dtype == tf.float32 + assert len(self._original_label.shape) == 2 + assert len(self._logit_length.shape) == 1 + assert len(self._original_label_length.shape) == 1 + + assert self._logprobas.shape[0] == self._original_label.shape[0] + assert self._logprobas.shape[0] == self._logit_length.shape[0] + assert self._logprobas.shape[0] == self._original_label_length.shape[0] + + @tf.custom_gradient + def forward_fn(self, unused_logprobas: tf.Tensor) -> tf.Tensor: + def backprop(d_loss): + return expand_many_dims(d_loss, axes=[1, 2]) * self.gradient_fn(unused_logprobas) + + return self.loss, backprop + + @tf.custom_gradient + def gradient_fn(self, unused_logprobas: tf.Tensor) -> tf.Tensor: + def backprop(d_gradient): + output = tf.reduce_sum(input_tensor=expand_many_dims(d_gradient, axes=[1, 2]) * self.hessian_fn(unused_logprobas), axis=[3, 4]) + return output + + return self.gradient, backprop + + @tf.custom_gradient + def hessian_fn(self, unused_logprobas: tf.Tensor) -> tf.Tensor: + def backprop(d_hessian): + raise NotImplementedError("Third order derivative over the ctc loss function is not implemented.") + + return self.hessian, backprop + + @cached_property + def hessian(self) -> tf.Tensor: + """Calculates Hessian of loss w.r.t. input logits. + + Returns: tf.Tensor, shape = [batch_size, max_logit_length, num_tokens, max_logit_length, num_tokens] + """ + alpha_gamma_term = self.combine_transition_probabilities(a=self.alpha[:, :-1], b=self.gamma[:, 1:]) + # shape = [batch_size, max_logit_length, num_tokens, max_logit_length + 1, max_label_length + 1] + alpha_gamma_beta_term = self.combine_transition_probabilities(a=alpha_gamma_term[:, :, :, :-1], b=self.beta[:, 1:]) + # shape = [batch_size, max_logit_length, num_tokens, max_logit_length, num_tokens] + alpha_gamma_beta_loss_term = expand_many_dims(self.loss, axes=[1, 2, 3, 4]) + alpha_gamma_beta_term + # shape = [batch_size, max_logit_length, num_tokens] + logit_length_x_num_tokens = self.max_logit_length * self.num_tokens + first_term = tf.reshape( + tf.linalg.set_diag( + input=tf.reshape(tensor=alpha_gamma_beta_loss_term, shape=[self.batch_size, logit_length_x_num_tokens, logit_length_x_num_tokens]), + diagonal=tf.reshape(tensor=self.logarithmic_logproba_gradient, shape=[self.batch_size, logit_length_x_num_tokens]), + ), + shape=tf.shape(alpha_gamma_beta_term), + ) + + mask = expand_many_dims(input=tf.linalg.band_part(tf.ones(shape=[self.max_logit_length] * 2, dtype=tf.bool), 0, -1), axes=[0, 2, 4]) + symmetrized_first_term = tf.where( + condition=mask, + x=first_term, + y=tf.transpose(first_term, [0, 3, 4, 1, 2]), + ) + # shape = [batch_size, max_logit_length, num_tokens, max_logit_length, num_tokens] + hessian = -tf.exp(symmetrized_first_term) + expand_many_dims(self.gradient, [3, 4]) * expand_many_dims(self.gradient, [1, 2]) + # shape = [batch_size, max_logit_length, num_tokens, max_logit_length, num_tokens] + + # Filter out samples with infinite loss + hessian = tf.where( + condition=expand_many_dims(self.loss == inf, [1, 2, 3, 4]), + x=tf.zeros(shape=[1, 1, 1, 1, 1]), + y=hessian, + ) + # shape = [batch_size, max_logit_length, num_tokens, max_logit_length, num_tokens] + + # Filter out logits that beyond logits length + hessian = tf.where(condition=expand_many_dims(self.logit_length_mask, axes=[2, 3, 4]), x=hessian, y=0.0) + hessian = tf.where(condition=expand_many_dims(self.logit_length_mask, axes=[1, 2, 4]), x=hessian, y=0.0) + + return hessian + + @cached_property + def gradient(self) -> tf.Tensor: + # shape = [batch_size, max_logit_length, num_tokens] + return -tf.exp(self.logarithmic_logproba_gradient) + + @cached_property + def logarithmic_logproba_gradient(self) -> tf.Tensor: + """Calculates logarithmic gradient of log loss w.r.t. input logarithmic probabilities. + + Returns: tf.Tensor, shape = [batch_size, max_logit_length, num_tokens] + """ + logarithmic_logproba_gradient = tf.reshape(self.loss, [-1, 1, 1]) + self.combine_transition_probabilities( + a=self.alpha[:, :-1], b=self.beta[:, 1:] + ) + # shape = [batch_size, max_logit_length, num_tokens] + + # Filter out samples infinite loss + logarithmic_logproba_gradient = tf.where( + condition=expand_many_dims(self.loss == inf, [1, 2]), + x=-inf, + y=logarithmic_logproba_gradient, + ) + # shape = [batch_size, max_logit_length, num_tokens] + + # Filter out logits that beyond logits length + logarithmic_logproba_gradient = apply_logarithmic_mask( + tensor=logarithmic_logproba_gradient, + mask=tf.expand_dims(self.logit_length_mask, axis=2), + ) + # shape = [batch_size, max_logit_length, num_tokens] + + return logarithmic_logproba_gradient + + @property + def alpha(self) -> tf.Tensor: + # shape = [batch_size, max_logit_length + 1, max_label_length + 1, ...] + raise NotImplementedError() + + @property + def beta(self) -> tf.Tensor: + # shape = [batch_size, max_logit_length + 1, max_label_length + 1, ...] + raise NotImplementedError() + + @property + def gamma(self) -> tf.Tensor: + # shape = [batch_size, max_logit_length + 1, max_label_length + 1, ..., + # max_logit_length + 1, max_label_length + 1, ...] + raise NotImplementedError() + + @cached_property + def expected_token_logproba(self) -> tf.Tensor: + """Logarithmic probability to predict label token. + + Returns:shape = [batch_size, max_logit_length, max_label_length + 1] + """ + label_logproba = tf.gather( + params=self.logproba, + indices=self.label, + axis=2, + batch_dims=1, + ) + expected_token_logproba = apply_logarithmic_mask(label_logproba, tf.expand_dims(self.label_length_mask, axis=1)) + # shape = [batch_size, max_logit_length, max_label_length + 1] + return expected_token_logproba + + @property + @abstractmethod + def loss(self) -> tf.Tensor: + """Samplewise loss function value that is minus logarithmic probability to predict label sequence. + + Returns: tf.Tensor, shape = [batch_size] + """ + raise NotImplementedError() + + @cached_property + def label_token_logproba(self) -> tf.Tensor: + """shape = [batch_size, max_logit_length, max_label_length + 1]""" + return tf.gather( + params=self.logproba, + indices=self.label, + axis=2, + batch_dims=1, + ) + + @cached_property + def blank_logproba(self): + """Calculates logarithmic probability to predict blank token for given logit. + + Returns: tf.Tensor, shape = [batch_size, max_logit_length] + """ + return self.logproba[:, :, self.blank_token_index] + + @cached_property + def input_proba(self) -> tf.Tensor: + """shape = [batch_size, input_logit_tensor_length, num_tokens], dtype = tf.float32""" + return tf.exp(self.logproba) + + @cached_property + def logproba(self) -> tf.Tensor: + mask = tf.expand_dims(tf.sequence_mask(lengths=self._logit_length, maxlen=self.max_logit_length), 2) + blank_logprobas = tf.reshape(tf.math.log(tf.one_hot(self.blank_token_index, self.num_tokens)), shape=[1, 1, -1]) + logprobas = tf.where( + condition=mask, + x=self._logprobas, + y=blank_logprobas, + ) + return logprobas + + ''' + def cleaned_label(self) -> tf.Tensor: + """ shape = [batch, max_label_length + 1] """ + _ = self.max_label_length_plus_one + ''' + + @cached_property + def cleaned_label(self): + # Repair padding- apparently, TPU/ GPU jit cannot handle the padding here; I'm not sure why. Anyway, it does not seem necessary in our case. + # labels = self._original_label[:, : self.max_label_length_plus_one] + """ + labels = tf.cond( + pred=tf.shape(self._original_label)[1] > self.max_label_length, + true_fn=lambda: self._original_label[:, :self.max_label_length_plus_one], + false_fn=lambda: pad_until( + tensor=self._original_label, + desired_size=self.max_label_length_plus_one, + pad_value=self.pad_token_index, + axis=1 + ) + ) + """ + # mask = tf.sequence_mask(lengths=self._original_label_length, maxlen=tf.shape(labels)[1]) + # blank_label = tf.ones_like(labels) * self.pad_token_index + # cleaned_label = tf.where( + # condition=mask, + # x=labels, + # y=blank_label, + # ) + # return cleaned_label + cleaned_label = pad_until( + tensor=self._original_label, + desired_size=self.max_label_length_plus_one, + pad_value=self.pad_token_index, + axis=1, + ) + cleaned_label = cleaned_label[:, : self.max_label_length_plus_one] + return cleaned_label + + def select_from_act(self, act: tf.Tensor, label: tf.Tensor) -> tf.Tensor: + """Takes tensor of acts act_{b, a, t, u, ...} and labels label_{b,u}, + where b is the batch index, t is the logit index, and u is the label index, + and returns for each token index k the tensor + + output_{b,a,t,k,...} = logsumexp_u act_{b,a,t,u_k,...} * kroneker_delta(u_k = label_{b,u}) + + that is logarithmic sum of exponents of acts for all u_k = label_{b,u}, given b, t and k. + + Args: + act: tf.Tensor, shape = [batch_size, dim_a, max_logit_length, max_label_length + 1, ...] + label: tf.Tensor, shape = [batch_size, max_label_length + 1] + + Returns: tf.Tensor, shape = [batch_size, max_label_length + 1, num_tokens, ...] + """ + data = smart_transpose(a=act, perm=[0, 3, 2, 1]) + # shape = [batch_size, max_label_length + 1, max_logit_length, dim_a, ...] + data = tf.squeeze( + input=smart_reshape(tensor=data, shape=[1, self.batch_size * self.max_label_length_plus_one, self.max_logit_length]), axis=0 + ) + # shape = [batch_size * (max_label_length + 1), max_logit_length, dim_a, ...] + + segment_ids = tf.reshape(label + tf.expand_dims(tf.range(self.batch_size), 1) * self.num_tokens, shape=[-1]) + # shape = [batch_size * (max_label_length + 1)] + num_segments = self.batch_size * self.num_tokens + + output = unsorted_segment_logsumexp(data=data, segment_ids=segment_ids, num_segments=num_segments) + # shape = [batch_size * num_tokens, max_logit_length, dim_a, ...] + output = smart_reshape(tf.expand_dims(output, 0), [self.batch_size, self.num_tokens, self.max_logit_length]) + # shape = [batch_size, num_tokens, max_logit_length, dim_a, ...] + output = smart_transpose(output, [0, 3, 2, 1]) + # shape = [batch_size, dim_a, max_logit_length, num_tokens, ...] + return output + + @cached_property + def max_logit_length_plus_one(self) -> tf.Tensor: + return self.max_logit_length + tf.constant(1, dtype=tf.int32) + + @cached_property + def max_logit_length(self) -> tf.Tensor: + return tf.shape(self._logprobas)[1] + + @cached_property + def max_label_length_plus_one(self) -> tf.Tensor: + return self.max_label_length + tf.constant(1, dtype=tf.int32) + + @cached_property + def max_label_length(self) -> tf.Tensor: + return reduce_max_with_default(self._original_label_length, default=tf.constant(0, dtype=tf.int32)) + + @cached_property + def pad_token_index(self) -> tf.Tensor: + return self.blank_token_index + + @cached_property + def num_tokens(self) -> tf.Tensor: + return tf.shape(self._logprobas)[2] + + @cached_property + def blank_token_index(self) -> tf.Tensor: + return self._blank_index + + @cached_property + def logit_length_mask(self) -> tf.Tensor: + """shape = [batch_size, max_logit_length]""" + return tf.sequence_mask( + lengths=self._logit_length, + maxlen=self.max_logit_length, + ) + + @cached_property + def label_length_mask(self) -> tf.Tensor: + """shape = [batch_size, max_label_length + 1], dtype = tf.bool""" + return tf.sequence_mask(lengths=self.label_length, maxlen=self.max_label_length_plus_one) + + @property + def label_length(self) -> tf.Tensor: + return self._original_label_length + + @cached_property + def preceded_label(self) -> tf.Tensor: + """Preceded label. For example, for label "abc_" the sequence "_abc" is returned. + + Returns: tf.Tensor, shape = [batch_size, max_label_length + 1] + """ + return tf.roll(self.label, shift=1, axis=1) + + @cached_property + def label(self) -> tf.Tensor: + """shape = [batch, max_label_length + 1]""" + return self.cleaned_label + + @cached_property + def batch_size(self) -> tf.Tensor: + return tf.shape(self._logprobas)[0] + + @abstractmethod + def combine_transition_probabilities(self, a: tf.Tensor, b: tf.Tensor) -> tf.Tensor: + """Given logarithmic probabilities a and b are merges like + a, b -> log( exp a exp p exp b ) + """ + raise NotImplementedError() + + +def classic_ctc_loss( + labels: tf.Tensor, + logits: tf.Tensor, + label_length: tf.Tensor, + logit_length: tf.Tensor, + blank_index: Union[int, tf.Tensor] = 0, +) -> tf.Tensor: + """Computes CTC (Connectionist Temporal Classification) loss from + http://www.cs.toronto.edu/~graves/icml_2006.pdf. + + Repeated non-blank labels will be merged. + For example, predicted sequence + a_bb_ccc_cc + corresponds to label + abcc + where "_" is the blank token. + + If label length is longer then the logit length the output loss for the corresponding sample in the batch + is +tf.inf and the gradient is 0. For example, for label "abb" at least 4 tokens are needed. + It is because the output sequence must be at least "ab_b" in order to handle the repeated token. + + Args: + labels: tf.Tensor, shape = [batch, max_label_length], dtype = tf.int32 + logits: tf.Tensor, shape = [batch, max_length, mum_tokens], dtype = tf.float32 + label_length: tf.Tensor, shape = [batch], dtype = tf.int32 + logit_length: tf.Tensor, shape = [batch], dtype = tf.int32 + blank_index: tf.Tensor or pythonic static integer between 0 <= blank_index < mum_tokens + + Returns: tf.Tensor, shape = [batch, max_length, mum_tokens], dtype = tf.float32 + """ + return ctc_loss( + labels=labels, + logits=logits, + label_length=label_length, + logit_length=logit_length, + blank_index=blank_index, + ctc_loss_data_cls=ClassicCtcLossData, + ) + + +class ClassicCtcLossData(BaseCtcLossData): + """Calculate loss data for CTC (Connectionist Temporal Classification) loss from + http://www.cs.toronto.edu/~graves/icml_2006.pdf. + + This loss is actually the logarithmic likelihood for the classification task with multiple expected class. + All predicated sequences consist of tokens (denoted like "a", "b", ... below) and the blank "_". + The classic CTC decoding merges all repeated non-blank labels and removes the blank. + For example, predicted sequence + a_bb_ccc_c is decoded as "abcc". + All predicated sequences that coincided with the label after the decoding are the expected classes + in the logarithmic likelihood loss function. + + Implementation: + + We calculate alpha_{b,t,l,s} and beta_{b,t,l,s} that are the logarithmic probabilities similar to + this the ones from the sited paper and defined precisely below. + Here, b corresponds to batch, t to logit position, l to label index, and s=0,1 to state (see below for details). + + During the decoding procedure, after handling of a part of the logit sequence, + we predict only a part of the target label tokens. We call this subsequence the in the target space as "state". + For example, two decode label "abc" we have to decode "a" first then add "b" and move tot the state "ab" and + then to the state "abc". + + In order to handle the token duplication swap in the classic CTC loss we extend the set of all possible labels. + For each token sequence we define two sequences called "closed" and "open". + For example, for label "abc" we consider its two states denoted "abc>" (closed) and "abc<" (open). + The difference between them is in their behaviour with respect to the token appending. The rules are: + "...a>" + "_" -> "...a>", + "...a<" + "_" -> "...a>", + "...a>" + "a" -> "...aa<", + "...a<" + "a" -> "...a<", + "...a>" + "b" -> "...ab<", + "...a<" + "b" -> "...ab<", + for any different tokens "a" and "b" and any token sequence denoted by "...". + Namely, appending a token the is equal to the last one to an open state does not change this state. + Appending a blank to a state always males this state closed. + + This is why alpha_{b,t,l,s} and beta_{b,t,l,s} in the code below are equipped with an additional index s=0,1. + Closed states corresponds s=0 and open ones to s=1. + + In particular, the flowing identity is satisfied + sum_s sum_l exp alpha_{b,t,l,s} * exp beta_{b,t,l,s} = loss_{b}, for any b and t + """ + + @cached_property + def diagonal_non_blank_grad_term(self) -> tf.Tensor: + """shape = [batch_size, max_logit_length, num_tokens]""" + input_tensor = self.alpha[:, :-1] + self.any_to_open_diagonal_step_log_proba + tf.roll(self.beta[:, 1:, :, 1:], shift=-1, axis=2) + # shape = [batch_size, max_logit_length, max_label_length + 1, states] + act = tf.reduce_logsumexp( + input_tensor=input_tensor, + axis=3, + ) + # shape = [batch_size, max_logit_length, max_label_length + 1] + diagonal_non_blank_grad_term = self.select_from_act(act=act, label=self.label) + # shape = [batch_size, max_logit_length, num_tokens] + return diagonal_non_blank_grad_term + + @cached_property + def horizontal_non_blank_grad_term(self) -> tf.Tensor: + """Horizontal steps from repeated token: open alpha state to open beta state. + + Returns: shape = [batch_size, max_logit_length, num_tokens] + """ + act = self.alpha[:, :-1, :, 1] + self.previous_label_token_log_proba + self.beta[:, 1:, :, 1] + # shape = [batch_size, max_logit_length, max_label_length + 1] + horizontal_non_blank_grad_term = self.select_from_act(act, self.preceded_label) + return horizontal_non_blank_grad_term + + @cached_property + def loss(self) -> tf.Tensor: + """shape = [batch_size]""" + params = tf.reduce_logsumexp(self.alpha[:, -1], -1) + # shape = [batch_size, max_label_length + 1] + loss = -tf.gather( + params=params, # shape = [batch_size, max_label_length + 1] + indices=self.label_length, # shape = [batch_size] + batch_dims=1, + ) + return loss + + @cached_property + def gamma(self) -> tf.Tensor: + """shape = [ + batch_size, + max_logit_length + 1, + max_label_length + 1, + state, + max_logit_length + 1, + max_label_length + 1, + state, + ], + """ + # This is to avoid InaccessibleTensorError in graph mode + _, _, _ = self.horizontal_step_log_proba, self.any_to_open_diagonal_step_log_proba, self.diagonal_gamma + + gamma_forward_transposed = unfold( + init_tensor=self.diagonal_gamma, + # init_tensor=tf.tile(self.diagonal_gamma, [self.batch_size, self.max_logit_length_plus_one, 1, 1, 1, 1]), + iterfunc=self.gamma_step, + d_i=1, + num_iters=self.max_logit_length, + element_shape=tf.TensorShape([None, None, None, None, None, None]), + name="gamma_1", + ) + # shape = [max_logit_length + 1, batch_size, max_logit_length + 1, max_label_length + 1, state, + # max_label_length + 1, state] + + gamma_forward = tf.transpose(gamma_forward_transposed, [1, 2, 3, 4, 0, 5, 6]) + # shape = [batch_size, max_logit_length + 1, max_label_length + 1, state, + # max_logit_length + 1, max_label_length + 1, state] + + mask = expand_many_dims( + input=tf.linalg.band_part(tf.ones(shape=[self.max_logit_length_plus_one] * 2, dtype=tf.bool), 0, -1), axes=[0, 2, 3, 5, 6] + ) + # shape = [1, max_logit_length + 1, 1, 1, max_logit_length + 1, 1, 1] + gamma = apply_logarithmic_mask(gamma_forward, mask) + # shape = [batch_size, max_logit_length + 1, max_label_length + 1, state, + # max_logit_length + 1, max_label_length + 1, state] + + return gamma + + def gamma_step( + self, + previous_slice: tf.Tensor, + i: tf.Tensor, + ) -> tf.Tensor: + """Args: + previous_slice: tf.Tensor, + shape = [batch_size, max_logit_length + 1, max_label_length + 1, state, + max_label_length + 1, state] + i: tf.Tensor, + shape = [], 0 <= i < max_logit_length + 1 + + Returns: tf.Tensor, + shape = [batch_size, max_logit_length + 1, max_label_length + 1, state, + max_label_length + 1, state] + """ + horizontal_step_states = expand_many_dims(self.horizontal_step_log_proba[:, i], axes=[1, 2, 3]) + tf.expand_dims(previous_slice, 5) + # shape = [batch_size, max_logit_length + 1, max_label_length + 1, state, + # max_label_length + 1, next_state, previous_state] + horizontal_step = tf.reduce_logsumexp(horizontal_step_states, axis=6) + # shape = [batch_size, max_logit_length + 1, max_label_length + 1, state, max_label_length + 1, state] + + diagonal_step_log_proba = tf.reduce_logsumexp( + expand_many_dims(self.any_to_open_diagonal_step_log_proba[:, i], axes=[1, 2, 3]) + previous_slice, axis=5 + ) + # shape = [batch_size, max_logit_length + 1, max_label_length + 1, state, max_label_length + 1] + + # We move by one token because it is a diagonal step + moved_diagonal_step_log_proba = tf.roll(diagonal_step_log_proba, shift=1, axis=4) + # shape = [batch_size, max_logit_length + 1, max_label_length + 1, state, max_label_length + 1] + + # Out state is always open: + diagonal_step = tf.pad( + tensor=tf.expand_dims(moved_diagonal_step_log_proba, 5), + paddings=[[0, 0], [0, 0], [0, 0], [0, 0], [0, 0], [1, 0]], + constant_values=-np.inf, + ) + # shape = [batch_size, max_logit_length + 1, max_label_length + 1, state, max_label_length + 1, state] + new_gamma_slice = logsumexp( + x=horizontal_step, + y=diagonal_step, + ) + # shape = [batch_size, max_logit_length + 1, max_label_length + 1, state, max_label_length + 1, state] + + condition = tf.reshape(tf.range(self.max_logit_length_plus_one) <= i, shape=[1, -1, 1, 1, 1, 1]) + # shape = [1, max_logit_length + 1, 1, 1, 1, 1, 1] + output_slice = tf.where( + condition=condition, + x=new_gamma_slice, + y=self.diagonal_gamma, + ) + # shape = [batch_size, max_logit_length + 1, max_label_length + 1, state, max_label_length + 1, state] + + return output_slice + + @cached_property + def diagonal_gamma(self) -> tf.Tensor: + """shape = [batch_size, max_logit_length_plus_one, max_label_length + 1, state, + max_label_length + 1, state] + """ + diagonal_gamma = tf.math.log( + tf.reshape( + tensor=tf.eye(self.max_label_length_plus_one * 2, dtype=tf.float32), + shape=[1, 1, self.max_label_length_plus_one, 2, self.max_label_length_plus_one, 2], + ) + ) + diagonal_gamma = tf.tile(diagonal_gamma, [self.batch_size, self.max_logit_length_plus_one, 1, 1, 1, 1]) + return diagonal_gamma + + @cached_property + def beta(self) -> tf.Tensor: + """Calculates the beta_{b,t,l,s} that is logarithmic probability of sample 0 <= b < batch_size - 1 in the batch + with logit subsequence from + t, t + 1, ... max_logit_length - 2, max_logit_length - 1, + for t < max_logit_length + to predict the sequence of tokens + w_max_label_length, w_{max_label_length + 1}, ... w_{max_label_length - 2}, w_{max_label_length - 1} + for l < max_label_length + that is either closed s=0 or open s=1. + from label_b = [w_0, w_1, ... w_{max_label_length - 2}, w_{max_label_length - 1}]. + + This logarithmic probability is calculated by iterations + exp beta_{t-1,l} = p_horizontal_step_{t-1,l} * exp beta_{t,l} + p_diagonal_step_{t-1,l} * exp beta_{t,l+1}, + for 0 <= t < max_logit_length, + where p_diagonal_step_{t,l} is the probability to predict label token w_l with logit l + and p_horizontal_step_{t,l} is the probability to skip token w_l prediction with logit l, for example, with + the blank prediction. + + Returns: tf.Tensor, shape = [batch_size, max_logit_length + 1, max_label_length + 1, state], + dtype = tf.float32 + """ + # This is to avoid InaccessibleTensorError in graph mode + _, _ = self.horizontal_step_log_proba, self.any_to_open_diagonal_step_log_proba + + beta = unfold( + init_tensor=self.last_beta_slice, + iterfunc=self.beta_step, + d_i=-1, + num_iters=self.max_logit_length, + element_shape=tf.TensorShape([None, None, 2]), + name="beta", + ) + # shape = [logit_length + 1, batch, label_length + 1, state] + return tf.transpose(beta, [1, 0, 2, 3]) + + def beta_step(self, previous_slice: tf.Tensor, i: tf.Tensor) -> tf.Tensor: + """shape = [batch_size, max_label_length + 1, state]""" + horizontal_step = tf.reduce_logsumexp(self.horizontal_step_log_proba[:, i] + tf.expand_dims(previous_slice, 3), 2) + # shape = [batch_size, max_label_length + 1, state] + diagonal_step = self.any_to_open_diagonal_step_log_proba[:, i] + tf.roll(previous_slice[:, :, 1:], shift=-1, axis=1) + # shape = [batch_size, max_label_length + 1, state] + new_beta_slice = logsumexp( + x=horizontal_step, # shape = [batch_size, max_label_length + 1, state] + y=diagonal_step, # shape = [batch_size, max_label_length + 1, state] + ) + # shape = [batch_size, max_label_length + 1, state] + return new_beta_slice + + @cached_property + def last_beta_slice(self) -> tf.Tensor: + """shape = [batch_size, max_label_length + 1, state]""" + beta_last = tf.math.log(tf.one_hot(indices=self.label_length, depth=self.max_label_length_plus_one)) + beta_last = tf.tile(input=tf.expand_dims(beta_last, axis=2), multiples=[1, 1, 2]) + return beta_last + + @cached_property + def alpha(self) -> tf.Tensor: + """Calculates the alpha_{b,t,l,s} that is + the logarithmic probability of sample 0 <= b < batch_size - 1 in the batch + with logits subsequence from 0, 1, 2, ... t - 2, t - 1, for t < max_logit_length + to predict the sequence of tokens w_0, w_1, w_2, ... w_{l-2}, w_{l-1} for l < max_label_length + 1 + that is either closed s=0 or open s=1. + from label_b = [w_0, w_1, ... w_{max_label_length - 2}, w_{max_label_length - 1}]. + + This logarithmic probability is calculated by iterations + exp alpha_{t + 1,l} = p_horizontal_step_{t,l} * exp alpha_{t,l} + p_diagonal_step_{t,l} * exp alpha_{t,l-1}, + for 0 <= t < max_logit_length, + where p_diagonal_step_{t,l} is the probability to predict label token w_l with logit l + and p_horizontal_step_{t,l} is the probability to skip token w_l prediction with logit l, for example, with + the blank prediction. + + Returns: tf.Tensor, shape = [batch_size, max_logit_length + 1, max_label_length + 1, state], + dtype = tf.float32 + """ + # This is to avoid InaccessibleTensorError in graph mode + _, _ = self.horizontal_step_log_proba, self.any_to_open_diagonal_step_log_proba + + alpha = unfold( + init_tensor=self.first_alpha_slice, + iterfunc=self.alpha_step, + d_i=1, + num_iters=self.max_logit_length, + element_shape=tf.TensorShape([None, None, 2]), + name="alpha", + ) + # shape = [logit_length + 1, batch_size, label_length + 1, state] + return tf.transpose(alpha, [1, 0, 2, 3]) + + def alpha_step(self, previous_slice: tf.Tensor, i: tf.Tensor) -> tf.Tensor: + """Args: + previous_slice: shape = [batch_size, max_label_length + 1, state] + i: + + Returns: shape = [batch_size, max_label_length + 1, state] + """ + temp = self.horizontal_step_log_proba[:, i] + tf.expand_dims(previous_slice, 2) + # shape = [batch_size, max_label_length + 1, next_state, previous_state] + horizontal_step = tf.reduce_logsumexp(temp, 3) + # shape = [batch_size, max_label_length + 1, state] + diagonal_step_log_proba = tf.reduce_logsumexp(self.any_to_open_diagonal_step_log_proba[:, i] + previous_slice, 2) + # shape = [batch_size, max_label_length + 1] + + # We move by one token because it is a diagonal step + moved_diagonal_step_log_proba = tf.roll(diagonal_step_log_proba, shift=1, axis=1) + # shape = [batch_size, max_label_length + 1] + + # Out state is always open: + diagonal_step = tf.pad(tensor=tf.expand_dims(moved_diagonal_step_log_proba, 2), paddings=[[0, 0], [0, 0], [1, 0]], constant_values=-np.inf) + # shape = [batch_size, max_label_length + 1, state] + new_alpha_slice = logsumexp( + x=horizontal_step, + y=diagonal_step, + ) + # shape = [batch_size, max_label_length + 1, state] + return new_alpha_slice + + @cached_property + def first_alpha_slice(self) -> tf.Tensor: + """shape = [batch_size, max_label_length + 1, state]""" + alpha_0 = tf.math.log(tf.one_hot(indices=0, depth=self.max_label_length_plus_one * 2)) + alpha_0 = tf.tile(input=tf.reshape(alpha_0, [1, -1, 2]), multiples=[self.batch_size, 1, 1]) + return alpha_0 + + @cached_property + def any_to_open_diagonal_step_log_proba(self) -> tf.Tensor: + """Logarithmic probability to make a diagonal step from given state to an open state + + Returns:shape = [batch_size, max_logit_length, max_label_length + 1, state] + """ + return tf.stack(values=[self.closed_to_open_diagonal_step_log_proba, self.open_to_open_diagonal_step_log_proba], axis=3) + + @cached_property + def open_to_open_diagonal_step_log_proba(self) -> tf.Tensor: + """Logarithmic probability to make a diagonal step from an open state to an open state + with expected token prediction that is different from the previous one. + + Returns:shape = [batch_size, max_logit_length, max_label_length + 1] + """ + # We check that the predicting token does not equal to previous one + token_repetition_mask = self.label != tf.roll(self.label, shift=1, axis=1) + # shape = [batch_size, max_label_length + 1] + open_diagonal_step_log_proba = apply_logarithmic_mask( + self.closed_to_open_diagonal_step_log_proba, tf.expand_dims(token_repetition_mask, axis=1) + ) + return open_diagonal_step_log_proba + + @cached_property + def closed_to_open_diagonal_step_log_proba(self) -> tf.Tensor: + """Logarithmic probability to make a diagonal step from a closed state to an open state + with expected token prediction. + + Returns:shape = [batch_size, max_logit_length, max_label_length + 1] + """ + return self.expected_token_logproba + + @cached_property + def horizontal_step_log_proba(self) -> tf.Tensor: + """Calculates logarithmic probability of the horizontal step for given logit x label position. + + This is possible in two alternative cases: + 1. Blank + 2. Not blank token from previous label position. + + Returns: tf.Tensor, shape = [batch_size, max_logit_length, max_label_length + 1, next_state, previous_state] + """ + # We map closed and open states to closed states + blank_term = tf.tile(input=tf.expand_dims(tf.expand_dims(self.blank_logproba, 2), 3), multiples=[1, 1, self.max_label_length_plus_one, 2]) + # shape = [batch_size, max_logit_length, max_label_length + 1, 2] + non_blank_term = tf.pad( + tf.expand_dims(self.not_blank_horizontal_step_log_proba, 3), + paddings=[[0, 0], [0, 0], [0, 0], [1, 0]], + constant_values=tf.constant(-np.inf), + ) + # shape = [batch_size, max_logit_length, max_label_length + 1, 2] + horizontal_step_log_proba = tf.stack([blank_term, non_blank_term], axis=3) + return horizontal_step_log_proba + + @cached_property + def not_blank_horizontal_step_log_proba(self) -> tf.Tensor: + """shape = [batch_size, max_logit_length, max_label_length + 1]""" + mask = tf.reshape(1 - tf.one_hot(self.blank_token_index, depth=self.num_tokens), shape=[1, 1, -1]) + not_blank_log_proba = apply_logarithmic_mask(self.logproba, mask) + not_blank_horizontal_step_log_proba = tf.gather( + params=not_blank_log_proba, + indices=tf.roll(self.label, shift=1, axis=1), + axis=2, + batch_dims=1, + ) + # shape = [batch_size, max_logit_length, max_label_length + 1] + return not_blank_horizontal_step_log_proba + + @cached_property + def previous_label_token_log_proba(self) -> tf.Tensor: + """Calculates the probability to predict token that preceded to label token. + + Returns: tf.Tensor, shape = [batch_size, max_logit_length, max_label_length + 1] + """ + previous_label_token_log_proba = tf.gather( + params=self.logproba, + indices=self.preceded_label, + axis=2, + batch_dims=1, + ) + # shape = [batch_size, max_logit_length, max_label_length + 1] + return previous_label_token_log_proba + + @cached_property + def blank_logproba(self) -> tf.Tensor: + """shape = [batch_size, max_logit_length]""" + return self.logproba[:, :, self.blank_token_index] + + def combine_transition_probabilities(self, a: tf.Tensor, b: tf.Tensor) -> tf.Tensor: + """Transforms logarithmic transition probabilities a and b. + + Args: + a: shape = [batch, DIMS_A, max_logit_length, max_label_length + 1, state] + b: shape = [batch, max_logit_length, max_label_length + 1, state, DIMS_B] + + Returns: shape = [batch, DIMS_A, max_logit_length, num_tokens, DIMS_B] + """ + assert len(a.shape) >= 4 + assert len(b.shape) >= 4 + assert a.shape[-1] == 2 + assert b.shape[3] == 2 + + dims_a = tf.shape(a)[1:-3] + dims_b = tf.shape(b)[4:] + a = tf.reshape(a, shape=[self.batch_size, -1, self.max_logit_length, self.max_label_length_plus_one, 2, 1]) + # shape = [batch_size, dims_a, max_logit_length, max_label_length + 1, state, 1] + b = tf.reshape(b, shape=[self.batch_size, 1, self.max_logit_length, self.max_label_length_plus_one, 2, -1]) + # shape = [batch_size, 1, max_logit_length, max_label_length + 1, state, dims_b] + + # Either open or closed state from alpha and only closed state from beta + ab_term = tf.reduce_logsumexp(a, 4) + b[:, :, :, :, 0] + # shape = [batch_size, dims_a, max_logit_length, max_label_length + 1, dims_b] + + horizontal_blank_grad_term = expand_many_dims(self.blank_logproba, axes=[1, 3]) + tf.reduce_logsumexp(ab_term, axis=3) + # shape = [batch_size, dims_a, max_logit_length, dims_b] + + act = a[:, :, :, :, 1] + expand_many_dims(self.previous_label_token_log_proba, axes=[1, 4]) + b[:, :, :, :, 1] + # shape = [batch_size, dim_a, max_logit_length, max_label_length + 1, dim_b] + + horizontal_non_blank_grad_term = self.select_from_act(act, self.preceded_label) + # shape = [batch_size, dim_a, max_logit_length, num_tokens, dim_b] + + input_tensor = a + expand_many_dims(self.any_to_open_diagonal_step_log_proba, axes=[1, 5]) + tf.roll(b[:, :, :, :, 1:], shift=-1, axis=3) + # shape = [batch_size, dim_a, max_logit_length, max_label_length + 1, states, dim_b] + + act = tf.reduce_logsumexp(input_tensor=input_tensor, axis=4) + # shape = [batch_size, dim_a, max_logit_length, max_label_length + 1, dim_b] + + diagonal_non_blank_grad_term = self.select_from_act(act=act, label=self.label) + # shape = [batch_size, dim_a, max_logit_length, num_tokens, dim_b] + + non_blank_grad_term = logsumexp(horizontal_non_blank_grad_term, diagonal_non_blank_grad_term) + # shape = [batch_size, dim_a, max_logit_length, num_tokens, dim_b] + + blank_mask = self.blank_token_index == tf.range(self.num_tokens) + # shape = [num_tokens] + + output = tf.where( + condition=expand_many_dims(blank_mask, axes=[0, 1, 2, 4]), + x=tf.expand_dims(horizontal_blank_grad_term, 3), + y=non_blank_grad_term, + ) + # shape = [batch, dim_a, max_logit_length, num_tokens, dim_b] + output_shape = tf.concat( + [ + tf.expand_dims(self.batch_size, axis=0), + dims_a, + tf.expand_dims(self.max_logit_length, axis=0), + tf.expand_dims(self.num_tokens, axis=0), + dims_b, + ], + axis=0, + ) + output_reshaped = tf.reshape(output, shape=output_shape) + # shape = [batch, DIMS_A, max_logit_length, num_tokens, DIMS_B] + + return output_reshaped + + +def ctc_loss_tpu( + labels: tf.Tensor, + logits: tf.Tensor, + label_length: tf.Tensor, + logit_length: tf.Tensor, + blank_index: Union[int, tf.Tensor] = 0, +) -> tf.Tensor: + orig_dtype = logits.dtype + if orig_dtype in (tf.float16, tf.bfloat16): + logits = tf.cast(logits, tf.float32) + loss = classic_ctc_loss( + labels=labels, + logits=logits, + label_length=label_length, + logit_length=logit_length, + blank_index=blank_index, + ) + if orig_dtype in (tf.float16, tf.bfloat16): + loss = tf.cast(loss, orig_dtype) + return loss diff --git a/tensorflow_asr/losses/impl/rnnt.py b/tensorflow_asr/losses/impl/rnnt.py new file mode 100644 index 0000000000..fde9d97f62 --- /dev/null +++ b/tensorflow_asr/losses/impl/rnnt.py @@ -0,0 +1,331 @@ +import importlib.util + +import numpy as np + +from tensorflow_asr import schemas, tf +from tensorflow_asr.utils import shape_util + +warp_rnnt_loss = importlib.import_module("warprnnt_tensorflow").rnnt_loss if importlib.util.find_spec("warprnnt_tensorflow") is not None else None + + +def rnnt_loss( + logits, + logits_length, + labels, + labels_length, + blank=0, + name=None, + use_cpu=False, + output_shapes=None, +): + kwargs = dict( + logits=logits, + labels=labels, + label_length=labels_length, + logit_length=logits_length, + blank=blank, + use_cpu=use_cpu, + output_shapes=output_shapes, + name=name, + ) + loss_fn = rnnt_loss_tf if warp_rnnt_loss is None else rnnt_loss_warprnnt + if use_cpu: + with tf.device("/CPU:0"): + return loss_fn(**kwargs) + return loss_fn(**kwargs) + + +# ------------------------ RNNT LOSS IN WARP TRANDUCER ----------------------- # + + +def rnnt_loss_warprnnt( + logits, + labels, + label_length, + logit_length, + blank=0, + use_cpu=False, + **kwargs, +): + orig_dtype = logits.dtype + if orig_dtype in (tf.float16, tf.bfloat16): + logits = tf.cast(logits, tf.float32) + if use_cpu: + logits = tf.nn.log_softmax(logits) + loss = warp_rnnt_loss(acts=logits, label_lengths=label_length, labels=labels, input_lengths=logit_length, blank_label=blank) + if orig_dtype in (tf.float16, tf.bfloat16): + loss = tf.cast(loss, orig_dtype) + return loss + + +# ------------------------------ RNNT LOSS IN TF ----------------------------- # + +LOG_0 = -np.inf + + +def nan_to_zero( + input_tensor, +): + return tf.where(tf.math.is_nan(input_tensor), tf.zeros_like(input_tensor), input_tensor) + + +def reduce_logsumexp( + input_tensor, + axis, +): + maximum = tf.reduce_max(input_tensor, axis=axis) + input_tensor = nan_to_zero(input_tensor - maximum) + return tf.math.log(tf.reduce_sum(tf.exp(input_tensor), axis=axis)) + maximum + + +def extract_diagonals( + log_probs, +): + time_steps = tf.shape(log_probs)[1] # T + output_steps = tf.shape(log_probs)[2] # U + 1 + reverse_log_probs = tf.reverse(log_probs, axis=[-1]) + paddings = [[0, 0], [0, 0], [time_steps - 1, 0]] + padded_reverse_log_probs = tf.pad(reverse_log_probs, paddings, "CONSTANT", constant_values=LOG_0) + diagonals = tf.raw_ops.MatrixDiagPartV2(input=padded_reverse_log_probs, k=(0, time_steps + output_steps - 2), padding_value=LOG_0) + + return tf.transpose(diagonals, perm=[1, 0, 2]) + + +def transition_probs( + one_hot_labels, + log_probs, +): + """ + :return: blank_probs with shape batch_size x input_max_len x target_max_len + truth_probs with shape batch_size x input_max_len x (target_max_len-1) + """ + blank_probs = log_probs[:, :, :, 0] + truth_probs = tf.reduce_sum(tf.multiply(log_probs[:, :, :-1, :], one_hot_labels), axis=-1) + + return blank_probs, truth_probs + + +def forward_dp( + bp_diags, + tp_diags, + batch_size, + input_max_len, + target_max_len, +): + """ + :return: forward variable alpha with shape batch_size x input_max_len x target_max_len + """ + + def next_state(x, trans_probs): + blank_probs = trans_probs[0] + truth_probs = trans_probs[1] + + x_b = tf.concat([LOG_0 * tf.ones(shape=[batch_size, 1]), x[:, :-1] + blank_probs], axis=1) + x_t = x + truth_probs + + x = tf.math.reduce_logsumexp(tf.stack([x_b, x_t], axis=0), axis=0) + return x + + initial_alpha = tf.concat([tf.zeros(shape=[batch_size, 1]), tf.ones(shape=[batch_size, input_max_len - 1]) * LOG_0], axis=1) + + fwd = tf.scan(next_state, (bp_diags[:-1, :, :-1], tp_diags), initializer=initial_alpha) + + alpha = tf.transpose(tf.concat([tf.expand_dims(initial_alpha, axis=0), fwd], axis=0), perm=[1, 2, 0]) + alpha = tf.raw_ops.MatrixDiagPartV2(input=alpha, k=(0, target_max_len - 1), padding_value=LOG_0) + alpha = tf.transpose(tf.reverse(alpha, axis=[1]), perm=[0, 2, 1]) + + return alpha + + +def backward_dp( + bp_diags, + tp_diags, + batch_size, + input_max_len, + target_max_len, + label_length, + logit_length, + blank_sl, +): + """ + :return: backward variable beta with shape batch_size x input_max_len x target_max_len + """ + + def next_state(x, mask_and_trans_probs): + mask_s, blank_probs_s, truth_probs = mask_and_trans_probs + + beta_b = tf.concat([x[:, 1:] + blank_probs_s, LOG_0 * tf.ones(shape=[batch_size, 1])], axis=1) + beta_t = tf.concat([x[:, :-1] + truth_probs, LOG_0 * tf.ones(shape=[batch_size, 1])], axis=1) + + beta_next = reduce_logsumexp(tf.stack([beta_b, beta_t], axis=0), axis=0) + masked_beta_next = nan_to_zero(beta_next * tf.expand_dims(mask_s, axis=1)) + nan_to_zero(x * tf.expand_dims((1.0 - mask_s), axis=1)) + return tf.ensure_shape(masked_beta_next, x.shape) + + # Initial beta for batches. + initial_beta_mask = tf.one_hot(logit_length - 1, depth=input_max_len + 1) + initial_beta = tf.expand_dims(blank_sl, axis=1) * initial_beta_mask + nan_to_zero(LOG_0 * (1.0 - initial_beta_mask)) + + # Mask for scan iterations. + mask = tf.sequence_mask(logit_length + label_length - 1, input_max_len + target_max_len - 2, dtype=tf.dtypes.float32) + mask = tf.transpose(mask, perm=[1, 0]) + + bwd = tf.scan(next_state, (mask, bp_diags[:-1, :, :], tp_diags), initializer=initial_beta, reverse=True) + + beta = tf.transpose(tf.concat([bwd, tf.expand_dims(initial_beta, axis=0)], axis=0), perm=[1, 2, 0])[:, :-1, :] + beta = tf.raw_ops.MatrixDiagPartV2(input=beta, k=(0, target_max_len - 1), padding_value=LOG_0) + beta = tf.transpose(tf.reverse(beta, axis=[1]), perm=[0, 2, 1]) + + return beta + + +def compute_rnnt_loss_and_grad_helper( + logits, + labels, + label_length, + logit_length, + use_cpu=False, + output_shapes: schemas.TrainOutput = None, +): + if output_shapes is None: # dynamic shape + batch_size = shape_util.get_dim(logits, 0) + input_max_len = shape_util.get_dim(logits, 1) + target_max_len = shape_util.get_dim(logits, 2) + vocab_size = shape_util.get_dim(logits, 3) + else: + batch_size = output_shapes.logits[0] + input_max_len = output_shapes.logits[1] + target_max_len = output_shapes.logits[2] + vocab_size = output_shapes.logits[3] + + if batch_size is None: + batch_size = shape_util.get_dim(logits, 0) + if input_max_len is None: + input_max_len = shape_util.get_dim(logits, 1) + if target_max_len is None: + target_max_len = shape_util.get_dim(logits, 2) + if vocab_size is None: + vocab_size = shape_util.get_dim(logits, 3) + + one_hot_labels = tf.one_hot(tf.tile(tf.expand_dims(labels, axis=1), multiples=[1, input_max_len, 1]), depth=vocab_size) + + log_probs = tf.nn.log_softmax(logits) + blank_probs, truth_probs = transition_probs(one_hot_labels, log_probs) + bp_diags = extract_diagonals(blank_probs) + tp_diags = extract_diagonals(truth_probs) + + label_mask = tf.expand_dims(tf.sequence_mask(label_length + 1, maxlen=target_max_len, dtype=tf.float32), axis=1) + small_label_mask = tf.expand_dims(tf.sequence_mask(label_length, maxlen=target_max_len, dtype=tf.float32), axis=1) + input_mask = tf.expand_dims(tf.sequence_mask(logit_length, maxlen=input_max_len, dtype=tf.float32), axis=2) + small_input_mask = tf.expand_dims(tf.sequence_mask(logit_length - 1, maxlen=input_max_len, dtype=tf.float32), axis=2) + mask = label_mask * input_mask + grad_blank_mask = (label_mask * small_input_mask)[:, :-1, :] + grad_truth_mask = (small_label_mask * input_mask)[:, :, :-1] + + alpha = forward_dp(bp_diags, tp_diags, batch_size, input_max_len, target_max_len) * mask + + indices = tf.stack([logit_length - 1, label_length], axis=1) + blank_sl = tf.gather_nd(blank_probs, indices, batch_dims=1) + + beta = backward_dp(bp_diags, tp_diags, batch_size, input_max_len, target_max_len, label_length, logit_length, blank_sl) * mask + final_state_probs = beta[:, 0, 0] + beta = nan_to_zero(beta) + + # Compute gradients of loss w.r.t. blank log-probabilities. + grads_blank = ( + -tf.exp( + (alpha[:, :-1, :] + beta[:, 1:, :] - tf.reshape(final_state_probs, shape=[batch_size, 1, 1]) + blank_probs[:, :-1, :]) * grad_blank_mask + ) + * grad_blank_mask + ) + grads_blank = tf.concat([grads_blank, tf.zeros(shape=(batch_size, 1, target_max_len))], axis=1) + last_grads_blank = -1 * tf.scatter_nd( + tf.concat([tf.reshape(tf.range(batch_size, dtype=tf.int64), shape=[batch_size, 1]), tf.cast(indices, dtype=tf.int64)], axis=1), + tf.ones(batch_size, dtype=tf.float32), + [batch_size, input_max_len, target_max_len], + name="last_grads_blank_scatter", + ) + grads_blank = grads_blank + last_grads_blank + + # Compute gradients of loss w.r.t. truth log-probabilities. + grads_truth = ( + -tf.exp((alpha[:, :, :-1] + beta[:, :, 1:] - tf.reshape(final_state_probs, shape=[batch_size, 1, 1]) + truth_probs) * grad_truth_mask) + * grad_truth_mask + ) + + # Compute gradients of loss w.r.t. activations. + a = tf.tile(tf.reshape(tf.range(target_max_len - 1, dtype=tf.int64), shape=(1, 1, target_max_len - 1, 1)), multiples=[batch_size, 1, 1, 1]) + b = tf.cast(tf.reshape(labels - 1, shape=(batch_size, 1, target_max_len - 1, 1)), dtype=tf.int64) + if use_cpu: + b = tf.where(tf.equal(b, -1), tf.zeros_like(b), b) # for cpu testing (index -1 on cpu will raise errors) + c = tf.concat([a, b], axis=3) + d = tf.tile(c, multiples=(1, input_max_len, 1, 1)) + e = tf.tile(tf.reshape(tf.range(input_max_len, dtype=tf.int64), shape=(1, input_max_len, 1, 1)), multiples=(batch_size, 1, target_max_len - 1, 1)) + f = tf.concat([e, d], axis=3) + g = tf.tile(tf.reshape(tf.range(batch_size, dtype=tf.int64), shape=(batch_size, 1, 1, 1)), multiples=[1, input_max_len, target_max_len - 1, 1]) + scatter_idx = tf.concat([g, f], axis=3) + # TODO - improve the part of code for scatter_idx computation. + probs = tf.exp(log_probs) + grads_truth_scatter = tf.scatter_nd( + scatter_idx, + grads_truth, + [batch_size, input_max_len, target_max_len, vocab_size - 1], + name="grads_truth_scatter", + ) + grads = tf.concat([tf.reshape(grads_blank, shape=(batch_size, input_max_len, target_max_len, -1)), grads_truth_scatter], axis=3) + grads_logits = grads - probs * (tf.reduce_sum(grads, axis=3, keepdims=True)) + + loss = -final_state_probs + return loss, grads_logits + + +def rnnt_loss_tf( + logits, + labels, + label_length, + logit_length, + name=None, + use_cpu=False, + output_shapes=None, + **kwargs, +): + name = "rnnt_loss" if name is None else name + with tf.name_scope(name): + logits = tf.convert_to_tensor(logits, name="logits") + labels = tf.convert_to_tensor(labels, name="labels") + label_length = tf.convert_to_tensor(label_length, name="label_length") + logit_length = tf.convert_to_tensor(logit_length, name="logit_length") + + orig_dtype = logits.dtype + if orig_dtype in (tf.float16, tf.bfloat16): + logits = tf.cast(logits, tf.float32) + + args = [logits, labels, label_length, logit_length] + + @tf.custom_gradient + def compute_rnnt_loss_and_grad(logits_t, labels_t, label_length_t, logit_length_t): + """Compute RNN-T loss and gradients.""" + logits_t.set_shape(logits.shape) + labels_t.set_shape(labels.shape) + label_length_t.set_shape(label_length.shape) + logit_length_t.set_shape(logit_length.shape) + kwargs = dict( + logits=logits_t, + labels=labels_t, + label_length=label_length_t, + logit_length=logit_length_t, + use_cpu=use_cpu, + output_shapes=output_shapes, + ) + result = compute_rnnt_loss_and_grad_helper(**kwargs) + + def grad(grad_loss): + grads = [tf.reshape(grad_loss, [-1, 1, 1, 1]) * result[1]] + grads += [None] * (len(args) - len(grads)) + return grads + + return result[0], grad + + loss = compute_rnnt_loss_and_grad(*args) + if orig_dtype in (tf.float16, tf.bfloat16): + loss = tf.cast(loss, orig_dtype) + return loss diff --git a/tensorflow_asr/losses/rnnt_loss.py b/tensorflow_asr/losses/rnnt_loss.py index fe751bfa4f..3d2978c45e 100644 --- a/tensorflow_asr/losses/rnnt_loss.py +++ b/tensorflow_asr/losses/rnnt_loss.py @@ -14,350 +14,53 @@ # limitations under the License. # RNNT loss implementation in pure TensorFlow is borrowed from [iamjanvijay's repo](https://github.com/iamjanvijay/rnnt) -import os -import numpy as np -import tensorflow as tf +import logging +import os +from tensorflow_asr.losses.base_loss import BaseLoss +from tensorflow_asr.losses.impl.rnnt import rnnt_loss, warp_rnnt_loss from tensorflow_asr.utils import env_util -logger = tf.get_logger() - -USE_CPU_LOSS = os.getenv("USE_CPU_LOSS", "False") == "True" - -try: - from warprnnt_tensorflow import rnnt_loss as warp_rnnt_loss +TFASR_USE_CPU_LOSS = os.getenv("TFASR_USE_CPU_LOSS", "False") in ("true", "True", "1") - use_warprnnt = True - logger.info("Use RNNT loss in WarpRnnt") -except ImportError: - logger.info("Use RNNT loss in TensorFlow") - use_warprnnt = False +logger = logging.getLogger(__name__) -class RnntLoss(tf.keras.losses.Loss): +class RnntLoss(BaseLoss): def __init__( self, blank, + reduction="sum_over_batch_size", + output_shapes=None, name=None, ): - if blank != 0 and not use_warprnnt: # restrict blank index - raise ValueError("rnnt_loss in tensorflow must use blank = 0") - super().__init__(reduction=tf.keras.losses.Reduction.NONE, name=name) - self.blank = blank - self.use_cpu = USE_CPU_LOSS if USE_CPU_LOSS else (not env_util.has_devices("GPU") and not env_util.has_devices("TPU")) - if self.use_cpu: - logger.info("Use CPU implementation for RNNT loss") - else: - logger.info("Use GPU/TPU implementation for RNNT loss") + super().__init__(blank=blank, reduction=reduction, name=name) + self.use_cpu = TFASR_USE_CPU_LOSS or (not env_util.has_devices("GPU") and not env_util.has_devices("TPU")) + self.output_shapes = output_shapes + # fmt: off + logger.info(f"[RNNT loss] Use {'CPU' if self.use_cpu else 'GPU/TPU'} implementation in {'Tensorflow' if warp_rnnt_loss is None else 'WarpRNNT'}") # pylint: disable=line-too-long + # fmt: on + if self.output_shapes: + logger.info(f"[RNNT loss] Use model's output shapes: {self.output_shapes}") + if not all(self.output_shapes): + logger.info("[RNNT loss] Detected dynamic shape") + self.output_shapes = None def call(self, y_true, y_pred): - logits, logit_length, labels, label_length = y_pred["logits"], y_pred["logits_length"], y_true["labels"], y_true["labels_length"] - if self.use_cpu: - with tf.device("/CPU:0"): - return rnnt_loss( - logits=logits, - logit_length=logit_length, - labels=labels, - label_length=label_length, - blank=self.blank, - name=self.name, - use_cpu=self.use_cpu, - ) - else: - return rnnt_loss( - logits=logits, - logit_length=logit_length, - labels=labels, - label_length=label_length, - blank=self.blank, - name=self.name, - use_cpu=self.use_cpu, - ) - - -def rnnt_loss( - logits, - labels, - label_length, - logit_length, - blank=0, - name=None, - use_cpu=False, -): - if use_warprnnt: - return rnnt_loss_warprnnt( + logits, logit_length, labels, label_length = super().call(y_true, y_pred) + return rnnt_loss( logits=logits, + logits_length=logit_length, labels=labels, - label_length=label_length, - logit_length=logit_length, - blank=blank, - use_cpu=use_cpu, - ) - return rnnt_loss_tf( - logits=logits, - labels=labels, - label_length=label_length, - logit_length=logit_length, - name=name, - use_cpu=use_cpu, - ) - - -# ------------------------ RNNT LOSS IN WARP TRANDUCER ----------------------- # - - -def rnnt_loss_warprnnt( - logits, - labels, - label_length, - logit_length, - blank=0, - use_cpu=False, -): - if use_cpu: - logits = tf.nn.log_softmax(logits) - loss = warp_rnnt_loss(acts=logits, label_lengths=label_length, labels=labels, input_lengths=logit_length, blank_label=blank) - return loss - - -# ------------------------------ RNNT LOSS IN TF ----------------------------- # - -LOG_0 = -np.inf - - -def nan_to_zero( - input_tensor, -): - return tf.where(tf.math.is_nan(input_tensor), tf.zeros_like(input_tensor), input_tensor) - - -def reduce_logsumexp( - input_tensor, - axis, -): - maximum = tf.reduce_max(input_tensor, axis=axis) - input_tensor = nan_to_zero(input_tensor - maximum) - return tf.math.log(tf.reduce_sum(tf.exp(input_tensor), axis=axis)) + maximum - - -def extract_diagonals( - log_probs, -): - time_steps = tf.shape(log_probs)[1] # T - output_steps = tf.shape(log_probs)[2] # U + 1 - reverse_log_probs = tf.reverse(log_probs, axis=[-1]) - paddings = [[0, 0], [0, 0], [time_steps - 1, 0]] - padded_reverse_log_probs = tf.pad(reverse_log_probs, paddings, "CONSTANT", constant_values=LOG_0) - diagonals = tf.raw_ops.MatrixDiagPartV2(input=padded_reverse_log_probs, k=(0, time_steps + output_steps - 2), padding_value=LOG_0) - - return tf.transpose(diagonals, perm=[1, 0, 2]) - - -def transition_probs( - one_hot_labels, - log_probs, -): - """ - :return: blank_probs with shape batch_size x input_max_len x target_max_len - truth_probs with shape batch_size x input_max_len x (target_max_len-1) - """ - blank_probs = log_probs[:, :, :, 0] - truth_probs = tf.reduce_sum(tf.multiply(log_probs[:, :, :-1, :], one_hot_labels), axis=-1) - - return blank_probs, truth_probs - - -def forward_dp( - bp_diags, - tp_diags, - batch_size, - input_max_len, - target_max_len, -): - """ - :return: forward variable alpha with shape batch_size x input_max_len x target_max_len - """ - - def next_state(x, trans_probs): - blank_probs = trans_probs[0] - truth_probs = trans_probs[1] - - x_b = tf.concat([LOG_0 * tf.ones(shape=[batch_size, 1]), x[:, :-1] + blank_probs], axis=1) - x_t = x + truth_probs - - x = tf.math.reduce_logsumexp(tf.stack([x_b, x_t], axis=0), axis=0) - return x - - initial_alpha = tf.concat([tf.zeros(shape=[batch_size, 1]), tf.ones(shape=[batch_size, input_max_len - 1]) * LOG_0], axis=1) - - fwd = tf.scan(next_state, (bp_diags[:-1, :, :-1], tp_diags), initializer=initial_alpha) - - alpha = tf.transpose(tf.concat([tf.expand_dims(initial_alpha, axis=0), fwd], axis=0), perm=[1, 2, 0]) - alpha = tf.raw_ops.MatrixDiagPartV2(input=alpha, k=(0, target_max_len - 1), padding_value=LOG_0) - alpha = tf.transpose(tf.reverse(alpha, axis=[1]), perm=[0, 2, 1]) - - return alpha - - -def backward_dp( - bp_diags, - tp_diags, - batch_size, - input_max_len, - target_max_len, - label_length, - logit_length, - blank_sl, -): - """ - :return: backward variable beta with shape batch_size x input_max_len x target_max_len - """ - - def next_state(x, mask_and_trans_probs): - mask_s, blank_probs_s, truth_probs = mask_and_trans_probs - - beta_b = tf.concat([x[:, 1:] + blank_probs_s, LOG_0 * tf.ones(shape=[batch_size, 1])], axis=1) - beta_t = tf.concat([x[:, :-1] + truth_probs, LOG_0 * tf.ones(shape=[batch_size, 1])], axis=1) - - beta_next = reduce_logsumexp(tf.stack([beta_b, beta_t], axis=0), axis=0) - masked_beta_next = nan_to_zero(beta_next * tf.expand_dims(mask_s, axis=1)) + nan_to_zero(x * tf.expand_dims((1.0 - mask_s), axis=1)) - return masked_beta_next - - # Initial beta for batches. - initial_beta_mask = tf.one_hot(logit_length - 1, depth=input_max_len + 1) - initial_beta = tf.expand_dims(blank_sl, axis=1) * initial_beta_mask + nan_to_zero(LOG_0 * (1.0 - initial_beta_mask)) - - # Mask for scan iterations. - mask = tf.sequence_mask(logit_length + label_length - 1, input_max_len + target_max_len - 2, dtype=tf.dtypes.float32) - mask = tf.transpose(mask, perm=[1, 0]) - - bwd = tf.scan(next_state, (mask, bp_diags[:-1, :, :], tp_diags), initializer=initial_beta, reverse=True) - - beta = tf.transpose(tf.concat([bwd, tf.expand_dims(initial_beta, axis=0)], axis=0), perm=[1, 2, 0])[:, :-1, :] - beta = tf.raw_ops.MatrixDiagPartV2(input=beta, k=(0, target_max_len - 1), padding_value=LOG_0) - beta = tf.transpose(tf.reverse(beta, axis=[1]), perm=[0, 2, 1]) - - return beta - - -def compute_rnnt_loss_and_grad_helper(logits, labels, label_length, logit_length, use_cpu=False): - batch_size = tf.shape(logits)[0] - input_max_len = tf.shape(logits)[1] - target_max_len = tf.shape(logits)[2] - vocab_size = tf.shape(logits)[3] - - one_hot_labels = tf.one_hot(tf.tile(tf.expand_dims(labels, axis=1), multiples=[1, input_max_len, 1]), depth=vocab_size) - - log_probs = tf.nn.log_softmax(logits) - blank_probs, truth_probs = transition_probs(one_hot_labels, log_probs) - bp_diags = extract_diagonals(blank_probs) - tp_diags = extract_diagonals(truth_probs) - - label_mask = tf.expand_dims(tf.sequence_mask(label_length + 1, maxlen=target_max_len, dtype=tf.float32), axis=1) - small_label_mask = tf.expand_dims(tf.sequence_mask(label_length, maxlen=target_max_len, dtype=tf.float32), axis=1) - input_mask = tf.expand_dims(tf.sequence_mask(logit_length, maxlen=input_max_len, dtype=tf.float32), axis=2) - small_input_mask = tf.expand_dims(tf.sequence_mask(logit_length - 1, maxlen=input_max_len, dtype=tf.float32), axis=2) - mask = label_mask * input_mask - grad_blank_mask = (label_mask * small_input_mask)[:, :-1, :] - grad_truth_mask = (small_label_mask * input_mask)[:, :, :-1] - - alpha = forward_dp(bp_diags, tp_diags, batch_size, input_max_len, target_max_len) * mask - - indices = tf.stack([logit_length - 1, label_length], axis=1) - blank_sl = tf.gather_nd(blank_probs, indices, batch_dims=1) - - beta = backward_dp(bp_diags, tp_diags, batch_size, input_max_len, target_max_len, label_length, logit_length, blank_sl) * mask - final_state_probs = beta[:, 0, 0] - beta = nan_to_zero(beta) - - # Compute gradients of loss w.r.t. blank log-probabilities. - grads_blank = ( - -tf.exp( - (alpha[:, :-1, :] + beta[:, 1:, :] - tf.reshape(final_state_probs, shape=[batch_size, 1, 1]) + blank_probs[:, :-1, :]) * grad_blank_mask + labels_length=label_length, + blank=self.blank, + name=self.name, + use_cpu=self.use_cpu, + output_shapes=self.output_shapes, ) - * grad_blank_mask - ) - grads_blank = tf.concat([grads_blank, tf.zeros(shape=(batch_size, 1, target_max_len))], axis=1) - last_grads_blank = -1 * tf.scatter_nd( - tf.concat([tf.reshape(tf.range(batch_size, dtype=tf.int64), shape=[batch_size, 1]), tf.cast(indices, dtype=tf.int64)], axis=1), - tf.ones(batch_size, dtype=tf.float32), - [batch_size, input_max_len, target_max_len], - ) - grads_blank = grads_blank + last_grads_blank - - # Compute gradients of loss w.r.t. truth log-probabilities. - grads_truth = ( - -tf.exp((alpha[:, :, :-1] + beta[:, :, 1:] - tf.reshape(final_state_probs, shape=[batch_size, 1, 1]) + truth_probs) * grad_truth_mask) - * grad_truth_mask - ) - - # Compute gradients of loss w.r.t. activations. - a = tf.tile(tf.reshape(tf.range(target_max_len - 1, dtype=tf.int64), shape=(1, 1, target_max_len - 1, 1)), multiples=[batch_size, 1, 1, 1]) - b = tf.cast(tf.reshape(labels - 1, shape=(batch_size, 1, target_max_len - 1, 1)), dtype=tf.int64) - if use_cpu: - b = tf.where(tf.equal(b, -1), tf.zeros_like(b), b) # for cpu testing (index -1 on cpu will raise errors) - c = tf.concat([a, b], axis=3) - d = tf.tile(c, multiples=(1, input_max_len, 1, 1)) - e = tf.tile(tf.reshape(tf.range(input_max_len, dtype=tf.int64), shape=(1, input_max_len, 1, 1)), multiples=(batch_size, 1, target_max_len - 1, 1)) - f = tf.concat([e, d], axis=3) - g = tf.tile(tf.reshape(tf.range(batch_size, dtype=tf.int64), shape=(batch_size, 1, 1, 1)), multiples=[1, input_max_len, target_max_len - 1, 1]) - scatter_idx = tf.concat([g, f], axis=3) - # TODO - improve the part of code for scatter_idx computation. - probs = tf.exp(log_probs) - grads_truth_scatter = tf.scatter_nd(scatter_idx, grads_truth, [batch_size, input_max_len, target_max_len, vocab_size - 1]) - grads = tf.concat([tf.reshape(grads_blank, shape=(batch_size, input_max_len, target_max_len, -1)), grads_truth_scatter], axis=3) - grads_logits = grads - probs * (tf.reduce_sum(grads, axis=3, keepdims=True)) - - loss = -final_state_probs - return loss, grads_logits - - -def rnnt_loss_tf( - logits, - labels, - label_length, - logit_length, - name=None, - use_cpu=False, -): - name = "rnnt_loss" if name is None else name - with tf.name_scope(name): - logits = tf.convert_to_tensor(logits, name="logits") - labels = tf.convert_to_tensor(labels, name="labels") - label_length = tf.convert_to_tensor(label_length, name="label_length") - logit_length = tf.convert_to_tensor(logit_length, name="logit_length") - - orig_dtype = logits.dtype - if orig_dtype in (tf.float16, tf.bfloat16): - logits = tf.cast(logits, tf.float32) - - args = [logits, labels, label_length, logit_length] - - @tf.custom_gradient - def compute_rnnt_loss_and_grad(logits_t, labels_t, label_length_t, logit_length_t): - """Compute RNN-T loss and gradients.""" - logits_t.set_shape(logits.shape) - labels_t.set_shape(labels.shape) - label_length_t.set_shape(label_length.shape) - logit_length_t.set_shape(logit_length.shape) - kwargs = dict( - logits=logits_t, - labels=labels_t, - label_length=label_length_t, - logit_length=logit_length_t, - use_cpu=use_cpu, - ) - result = compute_rnnt_loss_and_grad_helper(**kwargs) - - def grad(grad_loss): - grads = [tf.reshape(grad_loss, [-1, 1, 1, 1]) * result[1]] - grads += [None] * (len(args) - len(grads)) - return grads - - return result[0], grad - loss = compute_rnnt_loss_and_grad(*args) - if orig_dtype in (tf.float16, tf.bfloat16): - loss = tf.cast(loss, orig_dtype) - return loss + def get_config(self): + conf = super().get_config() + conf.update({"output_shapes": self.output_shapes}) + return conf diff --git a/tensorflow_asr/metrics/error_rates.py b/tensorflow_asr/metrics/error_rates.py index d9f7480a56..988e8c9391 100644 --- a/tensorflow_asr/metrics/error_rates.py +++ b/tensorflow_asr/metrics/error_rates.py @@ -12,31 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf +from tensorflow_asr import keras, tf -class ErrorRate(tf.keras.metrics.Metric): +class ErrorRate(keras.metrics.Metric): """Metric for WER or CER""" - def __init__( - self, - func, - name="error_rate", - **kwargs, - ): - super(ErrorRate, self).__init__(name=name, **kwargs) + def __init__(self, name="error_rate", **kwargs): + super().__init__(name=name, **kwargs) self.numerator = self.add_weight(name="numerator", initializer="zeros") self.denominator = self.add_weight(name="denominator", initializer="zeros") - self.func = func - def update_state( - self, - decode: tf.Tensor, - target: tf.Tensor, - ): - n, d = self.func(decode, target) - self.numerator.assign_add(n) - self.denominator.assign_add(d) + def update_state(self, data): + numer, denom = data + self.numerator.assign_add(tf.reduce_sum(numer)) + self.denominator.assign_add(tf.reduce_sum(denom)) def result(self): - return tf.math.divide_no_nan(self.numerator, self.denominator) + return tf.math.divide(self.numerator, self.denominator) diff --git a/tensorflow_asr/models/__init__.py b/tensorflow_asr/models/__init__.py index ef7fcc7f55..9139bde684 100644 --- a/tensorflow_asr/models/__init__.py +++ b/tensorflow_asr/models/__init__.py @@ -1,2 +1,13 @@ -from tensorflow_asr.models.ctc import * -from tensorflow_asr.models.transducer import * +import glob +from os.path import basename, dirname, isdir, isfile, join + +for fd in glob.glob(join(dirname(__file__), "*")): + if not isfile(fd) and not isdir(fd): + continue + if isfile(fd) and not fd.endswith(".py"): + continue + fd = fd if isdir(fd) else fd[:-3] + fd = basename(fd) + if fd.startswith("__"): + continue + __import__(f"{__name__}.{fd}") diff --git a/tensorflow_asr/models/activations/__init__.py b/tensorflow_asr/models/activations/__init__.py index e69de29bb2..9139bde684 100644 --- a/tensorflow_asr/models/activations/__init__.py +++ b/tensorflow_asr/models/activations/__init__.py @@ -0,0 +1,13 @@ +import glob +from os.path import basename, dirname, isdir, isfile, join + +for fd in glob.glob(join(dirname(__file__), "*")): + if not isfile(fd) and not isdir(fd): + continue + if isfile(fd) and not fd.endswith(".py"): + continue + fd = fd if isdir(fd) else fd[:-3] + fd = basename(fd) + if fd.startswith("__"): + continue + __import__(f"{__name__}.{fd}") diff --git a/tensorflow_asr/models/activations/glu.py b/tensorflow_asr/models/activations/glu.py index 2497125acb..8d1d249176 100644 --- a/tensorflow_asr/models/activations/glu.py +++ b/tensorflow_asr/models/activations/glu.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf - +from tensorflow_asr import keras, tf from tensorflow_asr.models.base_layer import Layer +@keras.utils.register_keras_serializable(package=__name__) class GLU(Layer): - def __init__(self, axis=-1, name="glu_activation", **kwargs): + def __init__(self, axis=-1, name="glu", **kwargs): super().__init__(name=name, **kwargs) self.axis = axis diff --git a/tensorflow_asr/models/base_layer.py b/tensorflow_asr/models/base_layer.py index e0517dddd2..15751d56b1 100644 --- a/tensorflow_asr/models/base_layer.py +++ b/tensorflow_asr/models/base_layer.py @@ -1,4 +1,4 @@ -# Copyright 2022 Huy Le Nguyen (@nglehuy) +# Copyright 2023 Huy Le Nguyen (@nglehuy) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,43 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -import keras.layers -from keras.utils import tf_utils - +from tensorflow_asr import keras from tensorflow_asr.utils import math_util +@keras.utils.register_keras_serializable(package=__name__) class Layer(keras.layers.Layer): def __init__( self, trainable=True, name=None, dtype=None, - dynamic=False, **kwargs, ): - super().__init__(trainable, name, dtype, dynamic, **kwargs) - self._output_shape = None + super().__init__(trainable=trainable, name=name, dtype=dtype, **kwargs) self.supports_masking = True - @property - def output_shape(self): - if self._output_shape is None: - raise AttributeError(f"The layer {self.name} has never been called and thus has no defined output shape.") - return self._output_shape - - def build(self, input_shape): - self._output_shape = tf_utils.convert_shapes(self.compute_output_shape(input_shape), to_tuples=True) - super().build(input_shape) - def compute_output_shape(self, input_shape): return input_shape +@keras.utils.register_keras_serializable(package=__name__) class Reshape(Layer): def call(self, inputs): - return math_util.merge_two_last_dims(inputs) + outputs, outputs_length = inputs + outputs = math_util.merge_two_last_dims(outputs) + return outputs, outputs_length def compute_output_shape(self, input_shape): - b, h, w, d = input_shape - return (b, h, w * d) + output_shape, output_length_shape = input_shape + output_shape = output_shape[:2] + (output_shape[2] * output_shape[3],) + return output_shape, output_length_shape diff --git a/tensorflow_asr/models/base_model.py b/tensorflow_asr/models/base_model.py index 70f3a34a07..fe6a753771 100644 --- a/tensorflow_asr/models/base_model.py +++ b/tensorflow_asr/models/base_model.py @@ -13,125 +13,119 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf +# import importlib +import logging +import typing -from tensorflow_asr.featurizers.speech_featurizers import SpeechFeaturizer -from tensorflow_asr.featurizers.text_featurizers import TextFeaturizer +from keras.src import tree +from keras.src.backend.tensorflow.trainer import TensorFlowTrainer, reduce_per_replica +from keras.src.losses import loss as loss_module + +from tensorflow_asr import keras, schemas, tf +from tensorflow_asr.models.layers.feature_extraction import FeatureExtraction from tensorflow_asr.optimizers.accumulation import GradientAccumulator -from tensorflow_asr.utils import env_util, file_util +from tensorflow_asr.tokenizers import Tokenizer +from tensorflow_asr.utils import file_util, math_util, shape_util -logger = tf.get_logger() +logger = logging.getLogger(__name__) -class BaseModel(tf.keras.Model): - def summary( - self, - line_length=127, - expand_nested=True, - show_trainable=True, - **kwargs, - ): +class BaseModel(keras.Model, TensorFlowTrainer): + optimizer: typing.Union[keras.optimizers.Optimizer, keras.optimizers.LossScaleOptimizer] + + def __init__(self, speech_config: dict, *args, **kwargs): + super().__init__(*args, **kwargs) + self.feature_extraction = FeatureExtraction(**speech_config) + + @property + def tokenizer(self): + return self._tokenizer + + @tokenizer.setter + def tokenizer(self, tokenizer: Tokenizer): + self._tokenizer = tokenizer + + def summary(self, line_length=120, expand_nested=True, show_trainable=True, **kwargs): super().summary(line_length=line_length, expand_nested=expand_nested, show_trainable=show_trainable, **kwargs) - def save( - self, - filepath, - overwrite=True, - include_optimizer=True, - save_format=None, - signatures=None, - options=None, - save_traces=True, - ): + def save(self, filepath, overwrite=True, zipped=None, **kwargs): with file_util.save_file(filepath) as path: - super().save( - filepath=path, - overwrite=overwrite, - include_optimizer=include_optimizer, - save_format=save_format, - signatures=signatures, - options=options, - save_traces=save_traces, - ) + super().save(filepath=path, overwrite=overwrite, zipped=zipped, **kwargs) - def save_weights( - self, - filepath, - overwrite=True, - save_format=None, - options=None, - ): + def save_weights(self, filepath, overwrite=True): with file_util.save_file(filepath) as path: - super().save_weights(filepath=path, overwrite=overwrite, save_format=save_format, options=options) + super().save_weights(filepath=path, overwrite=overwrite) - def load_weights( - self, - filepath, - by_name=False, - skip_mismatch=False, - options=None, - ): + def load_weights(self, filepath, skip_mismatch=False, **kwargs): with file_util.read_file(filepath) as path: - super().load_weights(filepath=path, by_name=by_name, skip_mismatch=skip_mismatch, options=options) + super().load_weights(filepath=path, skip_mismatch=skip_mismatch, **kwargs) - @property - def metrics(self): - if not hasattr(self, "_tfasr_metrics"): - self._tfasr_metrics = {} - return list(self._tfasr_metrics.values()) - - def reset_metrics(self): - super().reset_metrics() - self.reset_states() # reset all stateful states also - - def add_custom_metric(self, metric: tf.keras.metrics.Metric): + def add_custom_metric(self, metric: keras.metrics.Metric): if not hasattr(self, "_tfasr_metrics"): self._tfasr_metrics = {} self._tfasr_metrics[metric.name] = metric - def make(self, *args, **kwargs): - """Custom function for building model (uses self.build so cannot overwrite that function)""" - raise NotImplementedError() + def make(self, input_shape=[None], prediction_shape=[None], batch_size=None, **kwargs) -> schemas.TrainOutput: + """ + Custom function for building model (uses self.build so cannot overwrite that function) + + Parameters + ---------- + input_shape : list, optional + The shape of signal, by default [None] + prediction_shape : list, optional + The shape of prediction, by default [None] + batch_size : int, optional + Batch size, by default None + """ + assert batch_size is not None and batch_size > 0 + signals = keras.Input(shape=input_shape, batch_size=batch_size, dtype=tf.float32) + signals_length = keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) + predictions = keras.Input(shape=prediction_shape, batch_size=batch_size, dtype=tf.int32) + predictions_length = keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) + self._per_replica_batch_size = int(batch_size / self.distribute_strategy.num_replicas_in_sync) + self._batch_size = batch_size + outputs: schemas.TrainOutput = self( + schemas.TrainInput( + inputs=signals, + inputs_length=signals_length, + predictions=predictions, + predictions_length=predictions_length, + ), + training=False, + ) + return tf.nest.map_structure( + lambda x: shape_util.shape_list_per_replica(x, per_replica_batch_size=self._per_replica_batch_size), + outputs, + ) # compute output shape def compile( self, loss, - optimizer, - run_eagerly=None, - mxp="none", + optimizer=None, + run_eagerly=False, ga_steps=None, - apply_gwn_config=None, + gwn_config=None, + gradn_config=None, **kwargs, ): - optimizer = tf.keras.optimizers.get(optimizer) - if env_util.has_devices("TPU"): - self.use_loss_scale = False - else: - self.use_loss_scale = mxp != "none" - if self.use_loss_scale: - optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer) - logger.info("Using loss scale") + optimizer = keras.optimizers.get(optimizer) if isinstance(ga_steps, int) and ga_steps > 1: self.use_ga = True - self.ga = GradientAccumulator(ga_steps=ga_steps, trainable_variables=self.trainable_variables) + self.ga = GradientAccumulator(ga_steps=ga_steps, optimizer=optimizer) + self.ga.build(self.trainable_weights) + kwargs["steps_per_execution"] = 1 logger.info(f"Using gradient accumulation with accumulate steps = {ga_steps}") else: self.use_ga = False - self.apply_gwn_config = apply_gwn_config - self.add_custom_metric(metric=tf.keras.metrics.Mean(name="loss")) - self.distribute_reduction_method = "sum" - super().compile(optimizer=optimizer, loss=loss, run_eagerly=run_eagerly, **kwargs) + self.gwn_config = gwn_config + self.gradn_config = gradn_config + self.distribute_reduction_method = "auto" + self.tfasr_loss = loss + super().compile(optimizer=optimizer, run_eagerly=run_eagerly, **kwargs) - def add_featurizers(self, speech_featurizer: SpeechFeaturizer, text_featurizer: TextFeaturizer): - """ - Function to add featurizer to model to convert to end2end tflite - Args: - speech_featurizer: SpeechFeaturizer instance - text_featurizer: TextFeaturizer instance - scorer: external language model scorer - """ - self.speech_featurizer = speech_featurizer - self.text_featurizer = text_featurizer + def call(self, inputs: schemas.TrainInput, training=False): + raise NotImplementedError() # -------------------------------- STEP FUNCTIONS ------------------------------------- def apply_gwn(self) -> list: @@ -140,94 +134,231 @@ def apply_gwn(self) -> list: def remove_gwn(self, original_weights): pass - def _get_global_batch_size(self, y_pred): - global_batch_size = tf.shape(y_pred["logits"])[0] * self.distribute_strategy.num_replicas_in_sync - return global_batch_size - - def train_step(self, batch): - """ - Args: - batch ([tf.Tensor]): a batch of training data - - Returns: - Dict[tf.Tensor]: a dict of validation metrics with keys are the name of metric + def tfasr_compute_loss( + self, + x=None, + y=None, + y_pred=None, + sample_weight=None, + training=True, + ): + loss = self.tfasr_loss(y, y_pred) + self.add_loss(loss) + return super()._compute_loss(x, y, y_pred, sample_weight, training) - """ - inputs, y_true = batch + def _train_step(self, data: schemas.TrainData): + x, y = data + sample_weight = None with tf.GradientTape() as tape: + tape.watch(x.inputs) original_weights = self.apply_gwn() - y_pred = self(inputs, training=True) + y_pred: schemas.TrainOutput = self(x, training=True) + tape.watch(y_pred.logits) self.remove_gwn(original_weights) - tape.watch(y_pred["logits"]) - per_sample_loss = self.loss(y_true=y_true, y_pred=y_pred) - global_batch_size = self._get_global_batch_size(y_pred) - loss = tf.nn.compute_average_loss(per_sample_loss, global_batch_size=global_batch_size) - if self.use_loss_scale: - scaled_loss = self.optimizer.get_scaled_loss(loss) - - if self.use_loss_scale: - gradients = tape.gradient(scaled_loss, self.trainable_weights, unconnected_gradients=tf.UnconnectedGradients.ZERO) - gradients = self.optimizer.get_unscaled_gradients(gradients) - else: - gradients = tape.gradient(loss, self.trainable_weights, unconnected_gradients=tf.UnconnectedGradients.ZERO) + loss = self.tfasr_compute_loss( + x=x, + y=y, + y_pred=y_pred, + sample_weight=sample_weight, + training=True, + ) + # loss is in shape [B] + # reduce_mean on all replicas = (sum_loss1 / B + ... + sum_lossN / B) / N = (sum_loss1 + ... + sum_lossN) / (B * N) + # (B * N) also total count of samples across all replicas of current batch + # (sum_loss1 + ... + sum_lossN) is the total loss summed over all replicas of current batch, so the total number of loss = (B * N) + # B = mini_batch_size * ga_steps + # reduce_first = sum_loss1 / B + # => reduce_mean has the same effect as reduce_first + # the loss already divided by num_replicas for gradients reduce_sum when using _compute_loss, so unscale it + self._loss_tracker.update_state( + loss_module.unscale_loss_for_distribution(loss), + sample_weight=tf.shape(tree.flatten(x)[0])[0], # this is the count, which = B + ) + + if self.optimizer is not None: + loss = self.optimizer.scale_loss(loss) + + gradients = tape.gradient(loss, self.trainable_weights) + return gradients - if self.use_ga: # perform gradient accumulation - self.ga.accumulate(gradients=gradients) - self.optimizer.apply_gradients(zip(self.ga.gradients, self.trainable_variables)) - tf.cond(self.ga.is_apply_step, self.ga.reset, lambda: None) + def _apply_gradients(self, gradients): + if self.gradn_config is not None: + gradients = tf.cond( + tf.greater_equal(self.optimizer.iterations, self.gradn_config["step"]), + lambda: math_util.add_gauss_noise(gradients, stddev=self.gradn_config["stddev"]), + lambda: gradients, + ) + self.optimizer.apply(gradients, self.trainable_weights) + + def train_step(self, data): + gradients = self._train_step(data) + self._apply_gradients(gradients) + metrics = self.get_metrics_result() + return metrics + + def train_step_ga(self, data, do_apply=None): # avoid merge_call error as "Such behaviors are not yet supported" + gradients = self._train_step(data) + if do_apply is None: + self.ga.accumulate(gradients, self.trainable_weights) else: - self.optimizer.apply_gradients(zip(gradients, self.trainable_variables)) + gradients = self.ga.gradients(gradients, self.trainable_weights) + self._apply_gradients(gradients) + self.ga.reset() + metrics = self.get_metrics_result() + return metrics + + def _test_step(self, data: schemas.TrainData): + x, y = data + sample_weight = None + y_pred = self(x, training=False) + loss = self.tfasr_compute_loss( + x=x, + y=y, + y_pred=y_pred, + sample_weight=sample_weight, + training=False, + ) + self._loss_tracker.update_state( + loss_module.unscale_loss_for_distribution(loss), + sample_weight=tf.shape(tree.flatten(x)[0])[0], + ) + + def test_step(self, data: schemas.TrainData): + self._test_step(data) + metrics = self.get_metrics_result() + return metrics + + def predict_step(self, data: schemas.TrainData): + x, y_true = data + batch_size, *_ = shape_util.shape_list(x.inputs) + inputs = schemas.PredictInput( + inputs=x.inputs, + inputs_length=x.inputs_length, + previous_tokens=self.get_initial_tokens(batch_size=batch_size), + previous_encoder_states=self.get_initial_encoder_states(batch_size=batch_size), + previous_decoder_states=self.get_initial_decoder_states(batch_size=batch_size), + ) + _tokens = self.recognize(inputs=inputs).tokens + _beam_tokens = self.recognize_beam(inputs=inputs).tokens + return { + "tokens": _tokens, + "beam_tokens": _beam_tokens, + "labels": y_true.labels, + } + + # ------------------------------------ FIT ----------------------------------- # + + def _make_function(self, step_function): + @tf.autograph.experimental.do_not_convert + def one_step_on_data(data): + """Runs a single training step on a batch of data.""" + outputs = self.distribute_strategy.run(step_function, args=(data,)) + outputs = reduce_per_replica( + outputs, + self.distribute_strategy, + reduction=self.distribute_reduction_method, + ) + return outputs - self._tfasr_metrics["loss"].update_state(per_sample_loss) - result = {m.name: m.result() / tf.distribute.get_strategy().num_replicas_in_sync for m in self.metrics} - return result + if not self.run_eagerly: + one_step_on_data = tf.function( + one_step_on_data, + reduce_retracing=True, + jit_compile=self.jit_compile, + ) - def test_step(self, batch): - """ - Args: - batch ([tf.Tensor]: a batch of validation data + def function(iterator): + for step, data in zip(range(self.steps_per_execution), iterator): + outputs = one_step_on_data(data) + return outputs - Returns: - Dict[tf.Tensor]: a dict of validation metrics with keys are the name of metric prefixed with "val_" + return function - """ - inputs, y_true = batch - y_pred = self(inputs, training=False) - per_sample_loss = self.loss(y_true=y_true, y_pred=y_pred) - # global_batch_size = self._get_global_batch_size(y_pred) - # loss = tf.nn.compute_average_loss(per_sample_loss, global_batch_size=global_batch_size) - self._tfasr_metrics["loss"].update_state(per_sample_loss) - return {m.name: m.result() / tf.distribute.get_strategy().num_replicas_in_sync for m in self.metrics} - - def predict_step(self, batch): - """ - Args: - batch ([tf.Tensor]): a batch of testing data + def make_train_function(self, force=False): + if self.train_function is not None and not force: + return self.train_function - Returns: - [tf.Tensor]: stacked tensor of shape [B, 3] with each row is the text [truth, greedy, beam_search] - """ - inputs, y_true = batch - labels = self.text_featurizer.iextract(y_true["labels"]) - greedy_decoding = self.recognize(inputs) - if self.text_featurizer.decoder_config.beam_width == 0: - beam_search_decoding = tf.tile(tf.expand_dims(tf.convert_to_tensor("", tf.string), 0), [tf.shape(labels)[0]]) - else: - beam_search_decoding = self.recognize_beam(inputs) - return tf.stack([labels, greedy_decoding, beam_search_decoding], axis=-1) + if not self.use_ga: + self.train_function = self._make_function(self.train_step) + return self.train_function + + @tf.autograph.experimental.do_not_convert + def one_ga_step_on_data(data, do_apply=None): + outputs = self.distribute_strategy.run(self.train_step_ga, args=(data, do_apply)) + outputs = reduce_per_replica( + outputs, + self.distribute_strategy, + reduction=self.distribute_reduction_method, + ) + return outputs + + if not self.run_eagerly: + one_ga_step_on_data = tf.function( + one_ga_step_on_data, + reduce_retracing=True, + jit_compile=self.jit_compile, + ) + + def function(iterator): + for step, data in zip(range(self.ga.total_steps), iterator): + if step >= self.ga.total_steps - 1: + outputs = one_ga_step_on_data(data, True) + else: + outputs = one_ga_step_on_data(data) + return outputs + + self.train_function = function + return self.train_function # -------------------------------- INFERENCE FUNCTIONS ------------------------------------- - def recognize(self, *args, **kwargs): + def get_initial_tokens(self, batch_size=1): + return tf.ones([batch_size, 1], dtype=tf.int32) * self.tokenizer.blank + + def get_initial_encoder_states(self, batch_size=1): + return [] + + def get_initial_decoder_states(self, batch_size=1): + return [] + + def recognize(self, inputs: schemas.PredictInput, **kwargs) -> schemas.PredictOutput: """Greedy decoding function that used in self.predict_step""" raise NotImplementedError() - def recognize_beam(self, *args, **kwargs): + def recognize_beam(self, inputs: schemas.PredictInput, beam_width: int = 10, **kwargs) -> schemas.PredictOutput: """Beam search decoding function that used in self.predict_step""" raise NotImplementedError() # ---------------------------------- TFLITE ---------------------------------- # - def make_tflite_function(self, *args, **kwargs): - pass + def make_tflite_function(self, batch_size: int = 1, beam_width: int = 0): + + def tflite_func(inputs: schemas.PredictInput): + if beam_width > 0: + outputs = self.recognize_beam(inputs, beam_width=beam_width) + else: + outputs = self.recognize(inputs) + return schemas.PredictOutputWithTranscript( + transcript=self.tokenizer.detokenize(outputs.tokens), + tokens=outputs.tokens, + next_tokens=outputs.next_tokens, + next_encoder_states=outputs.next_encoder_states, + next_decoder_states=outputs.next_decoder_states, + ) + + input_signature = schemas.PredictInput( + inputs=tf.TensorSpec([batch_size, None], dtype=tf.float32), + inputs_length=tf.TensorSpec([batch_size], dtype=tf.int32), + previous_tokens=tf.TensorSpec.from_tensor(self.get_initial_tokens(batch_size)), + previous_encoder_states=tf.nest.map_structure(tf.TensorSpec.from_tensor, self.get_initial_encoder_states(batch_size)), + previous_decoder_states=tf.nest.map_structure(tf.TensorSpec.from_tensor, self.get_initial_decoder_states(batch_size)), + ) + + return tf.function( + tflite_func, + input_signature=[input_signature], + jit_compile=True, + reduce_retracing=True, + autograph=True, + ) diff --git a/tensorflow_asr/models/ctc/__init__.py b/tensorflow_asr/models/ctc/__init__.py index 1fab1570e5..9139bde684 100644 --- a/tensorflow_asr/models/ctc/__init__.py +++ b/tensorflow_asr/models/ctc/__init__.py @@ -1,4 +1,13 @@ -import tensorflow_asr.models.ctc.conformer -import tensorflow_asr.models.ctc.deepspeech2 -import tensorflow_asr.models.ctc.jasper -import tensorflow_asr.models.ctc.transformer +import glob +from os.path import basename, dirname, isdir, isfile, join + +for fd in glob.glob(join(dirname(__file__), "*")): + if not isfile(fd) and not isdir(fd): + continue + if isfile(fd) and not fd.endswith(".py"): + continue + fd = fd if isdir(fd) else fd[:-3] + fd = basename(fd) + if fd.startswith("__"): + continue + __import__(f"{__name__}.{fd}") diff --git a/tensorflow_asr/models/ctc/base_ctc.py b/tensorflow_asr/models/ctc/base_ctc.py index 7893425025..8a6ab83ae6 100644 --- a/tensorflow_asr/models/ctc/base_ctc.py +++ b/tensorflow_asr/models/ctc/base_ctc.py @@ -12,169 +12,138 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict - -import numpy as np -import tensorflow as tf +from tensorflow_asr import keras, schemas, tf from tensorflow_asr.losses.ctc_loss import CtcLoss from tensorflow_asr.models.base_model import BaseModel -from tensorflow_asr.utils import data_util, layer_util, math_util, shape_util +from tensorflow_asr.utils import layer_util class CtcModel(BaseModel): def __init__( self, - encoder: tf.keras.layers.Layer, - decoder: tf.keras.layers.Layer, + blank: int, + speech_config: dict, + encoder: keras.layers.Layer, + decoder: keras.layers.Layer, **kwargs, ): - super().__init__(**kwargs) + super().__init__(speech_config=speech_config, **kwargs) + self.blank = blank self.encoder = encoder self.decoder = decoder self.time_reduction_factor = 1 - def make(self, input_shape, batch_size=None, **kwargs): - inputs = tf.keras.Input(input_shape, batch_size=batch_size, dtype=tf.float32) - inputs_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) - self( - data_util.create_inputs( - inputs=inputs, - inputs_length=inputs_length, - ), - training=False, - ) - - def compile( - self, - optimizer, - blank=0, - run_eagerly=None, - mxp="none", - ga_steps=None, - **kwargs, - ): - loss = CtcLoss(blank=blank) - super().compile(loss=loss, optimizer=optimizer, run_eagerly=run_eagerly, mxp=mxp, ga_steps=ga_steps, **kwargs) + def compile(self, optimizer, output_shapes=None, **kwargs): + loss = CtcLoss(blank=self.blank, name="ctc_loss") + return super().compile(loss=loss, optimizer=optimizer, **kwargs) def apply_gwn(self): - if self.apply_gwn_config: + if self.gwn_config: original_weights = {} - if self.apply_gwn_config.get("encoder_step") is not None and self.apply_gwn_config.get("encoder_stddev") is not None: + if self.gwn_config.get("encoder_step") is not None and self.gwn_config.get("encoder_stddev") is not None: original_weights["encoder"] = tf.cond( - tf.greater_equal((self.optimizer.iterations), self.apply_gwn_config["encoder_step"]), - lambda: layer_util.add_gwn(self.encoder.trainable_weights, stddev=self.apply_gwn_config["encoder_stddev"]), + tf.greater_equal(self.optimizer.iterations, self.gwn_config["encoder_step"]), + lambda: layer_util.add_gwn(self.encoder.trainable_weights, stddev=self.gwn_config["encoder_stddev"]), lambda: self.encoder.trainable_weights, ) - if self.apply_gwn_config.get("decoder_step") is not None and self.apply_gwn_config.get("decoder_stddev") is not None: + if self.gwn_config.get("decoder_step") is not None and self.gwn_config.get("decoder_stddev") is not None: original_weights["decoder"] = tf.cond( - tf.greater_equal((self.optimizer.iterations), self.apply_gwn_config["decoder_step"]), - lambda: layer_util.add_gwn(self.decoder.trainable_weights, stddev=self.apply_gwn_config["decoder_stddev"]), + tf.greater_equal(self.optimizer.iterations, self.gwn_config["decoder_step"]), + lambda: layer_util.add_gwn(self.decoder.trainable_weights, stddev=self.gwn_config["decoder_stddev"]), lambda: self.decoder.trainable_weights, ) return original_weights return {} def remove_gwn(self, original_weights): - if self.apply_gwn_config: + if self.gwn_config: if original_weights.get("encoder") is not None: tf.cond( - tf.greater_equal((self.optimizer.iterations), self.apply_gwn_config["encoder_step"]), + tf.greater_equal(self.optimizer.iterations, self.gwn_config["encoder_step"]), lambda: layer_util.sub_gwn(original_weights["encoder"], self.encoder.trainable_weights), lambda: None, ) if original_weights.get("decoder") is not None: tf.cond( - tf.greater_equal((self.optimizer.iterations), self.apply_gwn_config["decoder_step"]), + tf.greater_equal(self.optimizer.iterations, self.gwn_config["decoder_step"]), lambda: layer_util.sub_gwn(original_weights["decoder"], self.decoder.trainable_weights), lambda: None, ) - def call(self, inputs, training=False): - logits, logits_length = self.encoder([inputs["inputs"], inputs["inputs_length"]], training=training) - logits, logits_length = self.decoder([logits, logits_length], training=training) - return data_util.create_logits(logits=logits, logits_length=logits_length) - - # -------------------------------- GREEDY ------------------------------------- - - def recognize(self, inputs: Dict[str, tf.Tensor]): - outputs = self(inputs, training=False) - decoded = self._perform_greedy(encoded=outputs["logits"], encoded_length=outputs["logits_length"]) - return self.text_featurizer.iextract(decoded) - - def _perform_greedy(self, encoded, encoded_length): - decoded, _ = tf.nn.ctc_greedy_decoder( - inputs=tf.transpose(encoded, perm=[1, 0, 2]), - sequence_length=encoded_length, - merge_repeated=True, - blank_index=self.text_featurizer.blank, + def call(self, inputs: schemas.TrainInput, training=False): + features, features_length = self.feature_extraction((inputs.inputs, inputs.inputs_length), training=training) + logits, logits_length, *_ = self.encoder((features, features_length), training=training) + logits, logits_length, *_ = self.decoder((logits, logits_length), training=training) + return schemas.TrainOutput( + logits=logits, + logits_length=logits_length, ) - decoded = tf.reshape(decoded[0].values, decoded[0].dense_shape) - return tf.cast(decoded, dtype=tf.int32) - - def recognize_tflite(self, signal): - """ - Function to convert to tflite using greedy decoding - Args: - signal: tf.Tensor with shape [None] indicating a single audio signal - Return: - transcript: tf.Tensor of Unicode Code Points with shape [None] and dtype tf.int32 - """ - inputs = self.speech_featurizer.tf_extract(signal) - inputs = tf.expand_dims(inputs, axis=0) - inputs_length = shape_util.shape_list(inputs)[1] - inputs_length = math_util.get_reduced_length(inputs_length, self.time_reduction_factor) - inputs_length = tf.expand_dims(inputs_length, axis=0) - outputs = self(data_util.create_inputs(inputs=inputs, inputs_length=inputs_length)) - decoded = self._perform_greedy(encoded=outputs["logits"], encoded_length=outputs["logits_length"]) - transcript = self.text_featurizer.indices2upoints(decoded) - return transcript - - # -------------------------------- BEAM SEARCH ------------------------------------- + def call_next( + self, + features, + features_length, + previous_encoder_states=None, + previous_decoder_states=None, + ): + outputs, outputs_length, next_encoder_states = self.encoder.call_next(features, features_length, previous_encoder_states) + outputs, outputs_length, next_decoder_states = self.decoder.call_next(outputs, outputs_length, previous_decoder_states) + return outputs, outputs_length, next_encoder_states, next_decoder_states - def recognize_beam(self, inputs: Dict[str, tf.Tensor], lm: bool = False): - logits = self(inputs, training=False) - decoded = self._perform_beam_search(encoded=logits["logits"], encoded_length=logits["logits_length"]) - return self.text_featurizer.iextract(decoded) + def get_initial_encoder_states(self, batch_size=1): + return [] - def _perform_beam_search(self, encoded: np.ndarray, encoded_length): - decoded, _ = tf.nn.ctc_beam_search_decoder( - inputs=tf.transpose(encoded, perm=[1, 0, 2]), - sequence_length=encoded_length, - beam_width=self.text_featurizer.decoder_config.beam_width, - ) - decoded = tf.reshape(decoded[0].values, decoded[0].dense_shape) - return tf.cast(decoded, dtype=tf.int32) + def get_initial_decoder_states(self, batch_size=1): + return [] - def recognize_beam_tflite(self, signal): - """ - Function to convert to tflite using beam search decoding - Args: - signal: tf.Tensor with shape [None] indicating a single audio signal + # -------------------------------- GREEDY ------------------------------------- - Return: - transcript: tf.Tensor of Unicode Code Points with shape [None] and dtype tf.int32 - """ - inputs = self.speech_featurizer.tf_extract(signal) - inputs = tf.expand_dims(inputs, axis=0) - inputs_length = shape_util.shape_list(inputs)[1] - inputs_length = math_util.get_reduced_length(inputs_length, self.time_reduction_factor) - inputs_length = tf.expand_dims(inputs_length, axis=0) - outputs = self(data_util.create_inputs(inputs=inputs, inputs_length=inputs_length)) - decoded = self._perform_beam_search(encoded=outputs["logits"], encoded_length=outputs["logits_length"]) - transcript = self.text_featurizer.indices2upoints(decoded) - return transcript + def recognize(self, inputs: schemas.PredictInput, **kwargs): + with tf.name_scope(f"{self.name}_recognize"): + features, features_length = self.feature_extraction((inputs.inputs, inputs.inputs_length), training=False) + ( + outputs, + outputs_length, + next_encoder_states, + next_decoder_states, + ) = self.call_next(features, features_length, inputs.previous_encoder_states, inputs.previous_decoder_states) + tokens, _ = tf.nn.ctc_greedy_decoder( + inputs=tf.transpose(outputs, perm=[1, 0, 2]), + sequence_length=outputs_length, + merge_repeated=True, + blank_index=self.blank, + ) + tokens = tf.sparse.to_dense(tokens[0]) + tokens = tf.cast(tokens, dtype=tf.int32) + return schemas.PredictOutput( + tokens=tokens, + next_tokens=None, + next_encoder_states=next_encoder_states, + next_decoder_states=next_decoder_states, + ) - # -------------------------------- TFLITE ------------------------------------- + # -------------------------------- BEAM SEARCH ------------------------------------- - def make_tflite_function(self, greedy: bool = False): - if greedy: - return tf.function( - self.recognize_tflite, - input_signature=[tf.TensorSpec([None], dtype=tf.float32)], + def recognize_beam(self, inputs: schemas.PredictInput, beam_width: int = 10, **kwargs): + with tf.name_scope(f"{self.name}_recognize_beam"): + features, features_length = self.feature_extraction((inputs.inputs, inputs.inputs_length), training=False) + ( + outputs, + outputs_length, + next_encoder_states, + next_decoder_states, + ) = self.call_next(features, features_length, inputs.previous_encoder_states, inputs.previous_decoder_states) + tokens, _ = tf.nn.ctc_beam_search_decoder( + inputs=tf.transpose(outputs, perm=[1, 0, 2]), + sequence_length=outputs_length, + beam_width=beam_width, + ) + tokens = tf.sparse.to_dense(tokens[0]) + tokens = tf.cast(tokens, dtype=tf.int32) + return schemas.PredictOutput( + tokens=tokens, + next_tokens=None, + next_encoder_states=next_encoder_states, + next_decoder_states=next_decoder_states, ) - return tf.function( - self.recognize_beam_tflite, - input_signature=[tf.TensorSpec([None], dtype=tf.float32)], - ) diff --git a/tensorflow_asr/models/ctc/conformer.py b/tensorflow_asr/models/ctc/conformer.py index c8aab736e0..63ee3d86a5 100644 --- a/tensorflow_asr/models/ctc/conformer.py +++ b/tensorflow_asr/models/ctc/conformer.py @@ -12,36 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf - +from tensorflow_asr import keras from tensorflow_asr.models.base_layer import Layer from tensorflow_asr.models.ctc.base_ctc import CtcModel from tensorflow_asr.models.encoders.conformer import L2, ConformerEncoder +@keras.utils.register_keras_serializable(package=__name__) class ConformerDecoder(Layer): def __init__( self, vocab_size: int, kernel_regularizer=L2, - bias_regularizer=L2, + bias_regularizer=None, + activity_regularizer=None, **kwargs, ): super().__init__(**kwargs) self._vocab_size = vocab_size - self.vocab = tf.keras.layers.Conv1D( - filters=vocab_size, - kernel_size=1, - strides=1, + self.vocab = keras.layers.Dense( + units=vocab_size, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, name="logits", + dtype=self.dtype, ) def call(self, inputs, training=False): - logits, logits_length = inputs + logits, logits_length, *_ = inputs logits = self.vocab(logits, training=training) - return logits, logits_length + return logits, logits_length, None + + def call_next(self, logits, logits_length, *args, **kwargs): + return self((logits, logits_length), training=False) def compute_output_shape(self, input_shape): logits_shape, logits_length_shape = input_shape @@ -49,11 +53,13 @@ def compute_output_shape(self, input_shape): return tuple(outputs_shape), tuple(logits_length_shape) -@tf.keras.utils.register_keras_serializable("tensorflow_asr.models.ctc") +@keras.utils.register_keras_serializable(package=__name__) class Conformer(CtcModel): def __init__( self, + blank: int, vocab_size: int, + speech_config: dict, encoder_subsampling: dict, encoder_dmodel: int = 144, encoder_num_blocks: int = 16, @@ -63,26 +69,35 @@ def __init__( encoder_interleave_relpe: bool = True, encoder_use_attention_causal_mask: bool = False, encoder_use_attention_auto_mask: bool = True, - encoder_kernel_size: int = 32, + encoder_kernel_size: int = 31, encoder_padding: str = "causal", encoder_ffm_scale_factor: int = 4, encoder_ffm_residual_factor: float = 0.5, encoder_mhsam_residual_factor: float = 1.0, + encoder_mhsam_use_attention_bias: bool = False, + encoder_mhsam_causal: bool = False, + encoder_mhsam_flash_attention: bool = False, encoder_convm_scale_factor: int = 2, encoder_convm_residual_factor: float = 1.0, + encoder_convm_use_group_conv: bool = False, + encoder_convm_dw_norm_type: str = "batch", encoder_dropout: float = 0.1, encoder_module_norm_position: str = "pre", encoder_block_norm_position: str = "post", encoder_memory_length: int = None, - encoder_mhsam_before_convm: bool = True, + encoder_history_size: int = None, + encoder_chunk_size: int = None, encoder_trainable: bool = True, decoder_trainable: bool = True, kernel_regularizer=L2, - bias_regularizer=L2, + bias_regularizer=None, + activity_regularizer=None, name: str = "conformer", **kwargs, ): super().__init__( + blank=blank, + speech_config=speech_config, encoder=ConformerEncoder( subsampling=encoder_subsampling, dmodel=encoder_dmodel, @@ -98,15 +113,22 @@ def __init__( ffm_scale_factor=encoder_ffm_scale_factor, ffm_residual_factor=encoder_ffm_residual_factor, mhsam_residual_factor=encoder_mhsam_residual_factor, + mhsam_use_attention_bias=encoder_mhsam_use_attention_bias, + mhsam_causal=encoder_mhsam_causal, + mhsam_flash_attention=encoder_mhsam_flash_attention, convm_scale_factor=encoder_convm_scale_factor, convm_residual_factor=encoder_convm_residual_factor, + convm_use_group_conv=encoder_convm_use_group_conv, + convm_dw_norm_type=encoder_convm_dw_norm_type, dropout=encoder_dropout, module_norm_position=encoder_module_norm_position, block_norm_position=encoder_block_norm_position, memory_length=encoder_memory_length, - mhsam_before_convm=encoder_mhsam_before_convm, + history_size=encoder_history_size, + chunk_size=encoder_chunk_size, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, trainable=encoder_trainable, name="encoder", ), @@ -122,3 +144,6 @@ def __init__( ) self.dmodel = encoder_dmodel self.time_reduction_factor = self.encoder.conv_subsampling.time_reduction_factor + + def get_initial_encoder_states(self, batch_size=1): + return self.encoder.get_initial_state(batch_size) diff --git a/tensorflow_asr/models/ctc/deepspeech2.py b/tensorflow_asr/models/ctc/deepspeech2.py index 7e158b5eee..3204c6c777 100644 --- a/tensorflow_asr/models/ctc/deepspeech2.py +++ b/tensorflow_asr/models/ctc/deepspeech2.py @@ -12,362 +12,119 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf - +from tensorflow_asr import keras, tf from tensorflow_asr.models.base_layer import Layer from tensorflow_asr.models.ctc.base_ctc import CtcModel -from tensorflow_asr.models.layers.row_conv_1d import RowConv1D -from tensorflow_asr.models.layers.sequence_wise_bn import SequenceBatchNorm -from tensorflow_asr.utils import layer_util, math_util - - -class Reshape(tf.keras.layers.Layer): - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.supports_masking = True - - def call(self, inputs): - return math_util.merge_two_last_dims(inputs) - - -class ConvBlock(tf.keras.layers.Layer): - def __init__( - self, - conv_type: str = "conv2d", - kernels: list = [11, 41], - strides: list = [2, 2], - filters: int = 32, - padding: str = "same", - dropout: float = 0.1, - **kwargs, - ): - super().__init__(**kwargs) - self.supports_masking = True - CnnClass = layer_util.get_conv(conv_type) - self.conv = CnnClass(filters=filters, kernel_size=kernels, strides=strides, padding=padding, name=conv_type) - self.bn = tf.keras.layers.BatchNormalization(name="bn") - self.relu = tf.keras.layers.ReLU(name="relu") - self.do = tf.keras.layers.Dropout(dropout, name="dropout") - - def call(self, inputs, training=False): - outputs = self.conv(inputs, training=training) - outputs = self.bn(outputs, training=training) - outputs = self.relu(outputs, training=training) - outputs = self.do(outputs, training=training) - return outputs - - -class ConvModule(tf.keras.layers.Layer): - def __init__( - self, - conv_type: str = "conv2d", - kernels: list = [[11, 41], [11, 21], [11, 21]], - strides: list = [[2, 2], [1, 2], [1, 2]], - filters: list = [32, 32, 96], - padding: str = "same", - dropout: float = 0.1, - **kwargs, - ): - super().__init__(**kwargs) - self.supports_masking = True - - assert len(kernels) == len(strides) == len(filters) - assert dropout >= 0.0 - - self.preprocess = None # reshape from [B, T, F, C] to [B, T, F * C] - if conv_type == "conv1d": - self.preprocess = Reshape(name="preprocess") - - self.blocks = [ - ConvBlock( - conv_type=conv_type, - kernels=kernels[i], - strides=strides[i], - filters=filters[i], - dropout=dropout, - padding=padding, - name=f"block_{i}", - ) - for i in range(len(filters)) - ] - - self.postprocess = None # reshape from [B, T, F, C] to [B, T, F * C] - if conv_type == "conv2d": - self.postprocess = Reshape(name="postprocess") - - self.reduction_factor = 1 - for s in strides: - self.reduction_factor *= s[0] - - def call(self, inputs, training=False): - outputs = inputs - if self.preprocess is not None: - outputs = self.preprocess(outputs) - for block in self.blocks: - outputs = block(outputs, training=training) - if self.postprocess is not None: - outputs = self.postprocess(outputs) - return outputs - - -class RnnBlock(tf.keras.layers.Layer): - def __init__( - self, - rnn_type: str = "lstm", - bn_type: str = "sbn", - units: int = 1024, - bidirectional: bool = True, - unroll: bool = False, - rowconv: int = 0, - dropout: float = 0.1, - **kwargs, - ): - super().__init__(**kwargs) - self.supports_masking = True - RnnClass = layer_util.get_rnn(rnn_type) - self.rnn = RnnClass( - units, - dropout=dropout, - unroll=unroll, - return_sequences=True, - use_bias=True, - name=rnn_type, - zero_output_for_mask=True, - ) - if bidirectional: - self.rnn = tf.keras.layers.Bidirectional(self.rnn, name=f"b{rnn_type}") - if bn_type not in ("bn", "sbn"): - raise ValueError(f"bn_type must be in {('bn', 'sbn')}") - self.bn = SequenceBatchNorm(time_major=False, name="bn") if bn_type == "sbn" else tf.keras.layers.BatchNormalization(name="bn") - self.rowconv = None - if not bidirectional and rowconv > 0: - self.rowconv = RowConv1D(filters=units, future_context=rowconv, name="rowconv") - - def call(self, inputs, training=False): - outputs = inputs - outputs = self.rnn(outputs, training=training, mask=getattr(outputs, "_keras_mask", None)) - outputs = self.bn(outputs, training=training) - if self.rowconv is not None: - outputs = self.rowconv(outputs, training=training) - return outputs - - -class RnnModule(tf.keras.layers.Layer): - def __init__( - self, - nlayers: int = 5, - rnn_type: str = "lstm", - bn_type: str = "sbn", - units: int = 1024, - bidirectional: bool = True, - unroll: bool = False, - rowconv: int = 0, - dropout: float = 0.1, - **kwargs, - ): - super().__init__(**kwargs) - self.supports_masking = True - self.blocks = [ - RnnBlock( - rnn_type=rnn_type, - bn_type=bn_type, - units=units, - bidirectional=bidirectional, - unroll=unroll, - rowconv=rowconv, - dropout=dropout, - name=f"block_{i}", - ) - for i in range(nlayers) - ] +from tensorflow_asr.models.encoders.deepspeech2 import DeepSpeech2Encoder - def call(self, inputs, training=False): - outputs = inputs - for block in self.blocks: - outputs = block(outputs, training=training) - return outputs - - -class FcBlock(tf.keras.layers.Layer): - def __init__( - self, - units: int = 1024, - dropout: float = 0.1, - **kwargs, - ): - super().__init__(**kwargs) - self.supports_masking = True - self.fc = tf.keras.layers.Dense(units, name="fc") - self.bn = tf.keras.layers.BatchNormalization(name="bn") - self.relu = tf.keras.layers.ReLU(name="relu") - self.do = tf.keras.layers.Dropout(dropout, name="dropout") - - def call(self, inputs, training=False): - outputs = self.fc(inputs, training=training) - outputs = self.bn(outputs, training=training) - outputs = self.relu(outputs, training=training) - outputs = self.do(outputs, training=training) - return outputs - -class FcModule(tf.keras.layers.Layer): - def __init__( - self, - nlayers: int = 0, - units: int = 1024, - dropout: float = 0.1, - **kwargs, - ): - super().__init__(**kwargs) - self.supports_masking = True - self.blocks = [FcBlock(units=units, dropout=dropout, name=f"block_{i}") for i in range(nlayers)] - - def call(self, inputs, training=False): - outputs = inputs - for block in self.blocks: - outputs = block(outputs, training=training) - return outputs - - -class DeepSpeech2Encoder(Layer): +@keras.utils.register_keras_serializable(package=__name__) +class DeepSpeech2Decoder(Layer): def __init__( self, - conv_type: str = "conv2d", - conv_kernels: list = [[11, 41], [11, 21], [11, 21]], - conv_strides: list = [[2, 2], [1, 2], [1, 2]], - conv_filters: list = [32, 32, 96], - conv_padding: str = "same", - conv_dropout: float = 0.1, - rnn_nlayers: int = 5, - rnn_type: str = "lstm", - rnn_bn_type: str = "sbn", - rnn_units: int = 1024, - rnn_bidirectional: bool = True, - rnn_unroll: bool = False, - rnn_rowconv: int = 0, - rnn_dropout: float = 0.1, - fc_nlayers: int = 0, - fc_units: int = 1024, - fc_dropout: float = 0.1, + vocab_size: int, + kernel_regularizer=None, + bias_regularizer=None, + initializer="glorot_uniform", **kwargs, ): - super().__init__(**kwargs) - self.conv_module = ConvModule( - conv_type=conv_type, - kernels=conv_kernels, - strides=conv_strides, - filters=conv_filters, - padding=conv_padding, - dropout=conv_dropout, - name="conv_module", - ) - self.rnn_module = RnnModule( - nlayers=rnn_nlayers, - rnn_type=rnn_type, - bn_type=rnn_bn_type, - units=rnn_units, - bidirectional=rnn_bidirectional, - unroll=rnn_unroll, - rowconv=rnn_rowconv, - dropout=rnn_dropout, - name="rnn_module", + super().__init__(dtype=tf.float32, **kwargs) + self.vocab = keras.layers.Dense( + vocab_size, + name="logits", + kernel_regularizer=kernel_regularizer, + kernel_initializer=initializer, + bias_regularizer=bias_regularizer, + dtype=self.dtype, ) - self._rnn_units = rnn_units - self.fc_module = FcModule( - nlayers=fc_nlayers, - units=fc_units, - dropout=fc_dropout, - name="fc_module", - ) - self._fc_nlayers = fc_nlayers - self._fc_units = fc_units - self.time_reduction_factor = self.conv_module.reduction_factor def call(self, inputs, training=False): - outputs, inputs_length = inputs - outputs = self.conv_module(outputs, training=training) - outputs_length = math_util.get_reduced_length(inputs_length, self.time_reduction_factor) - outputs = math_util.apply_mask(outputs, mask=tf.sequence_mask(outputs_length, maxlen=tf.shape(outputs)[1], dtype=tf.bool)) - outputs = self.rnn_module(outputs, training=training) - outputs = self.fc_module(outputs, training=training) - return outputs, outputs_length - - def compute_output_shape(self, input_shape): - inputs_shape, inputs_length_shape = input_shape - outputs_time = None if inputs_shape[1] is None else math_util.legacy_get_reduced_length(inputs_shape[1], self.time_reduction_factor) - outputs_batch = inputs_shape[0] - outputs_size = self._fc_units if self._fc_nlayers > 0 else self._rnn_units - outputs_shape = (outputs_batch, outputs_time, outputs_size) - return outputs_shape, inputs_length_shape - - -class DeepSpeech2Decoder(Layer): - def __init__(self, vocab_size: int, **kwargs): - super().__init__(**kwargs) - self.vocab = tf.keras.layers.Dense(vocab_size, name="logits") - self.bn = tf.keras.layers.BatchNormalization(name="bn") - self._vocab_size = vocab_size - - def call(self, inputs, training=False): - logits, logits_length = inputs + logits, logits_length, *_ = inputs logits = self.vocab(logits, training=training) - logits = self.bn(logits, training=training) return logits, logits_length + def call_next(self, logits, logits_length, *args, **kwargs): + outputs, outputs_length = self((logits, logits_length), training=False) + return outputs, outputs_length, None + def compute_output_shape(self, input_shape): - logits_shape, logits_length_shape = input_shape - outputs_shape = logits_shape[:-1] + (self._vocab_size,) - return tuple(outputs_shape), tuple(logits_length_shape) + output_shape, output_length_shape = input_shape + output_shape = self.vocab.compute_output_shape(output_shape) + return output_shape, output_length_shape -@tf.keras.utils.register_keras_serializable("tensorflow_asr.models.ctc") +@keras.utils.register_keras_serializable(package=__name__) class DeepSpeech2(CtcModel): def __init__( self, + blank: int, vocab_size: int, + speech_config: dict, conv_type: str = "conv2d", conv_kernels: list = [[11, 41], [11, 21], [11, 21]], - conv_strides: list = [[2, 2], [1, 2], [1, 2]], + conv_strides: list = [[3, 2], [1, 2], [1, 2]], conv_filters: list = [32, 32, 96], conv_padding: str = "same", - conv_dropout: float = 0.1, + conv_activation: str = "relu", + conv_initializer: str = None, rnn_nlayers: int = 5, rnn_type: str = "lstm", - rnn_bn_type: str = "sbn", rnn_units: int = 1024, rnn_bidirectional: bool = True, rnn_unroll: bool = False, rnn_rowconv: int = 0, + rnn_rowconv_activation: str = "relu", rnn_dropout: float = 0.1, + rnn_initializer: str = None, fc_nlayers: int = 0, fc_units: int = 1024, + fc_activation: str = "relu", fc_dropout: float = 0.1, + fc_initializer: str = None, name: str = "deepspeech2", + kernel_regularizer=None, + bias_regularizer=None, + initializer="glorot_uniform", **kwargs, ): super().__init__( + blank=blank, + speech_config=speech_config, encoder=DeepSpeech2Encoder( conv_type=conv_type, conv_kernels=conv_kernels, conv_strides=conv_strides, conv_filters=conv_filters, conv_padding=conv_padding, - conv_dropout=conv_dropout, + conv_activation=conv_activation, + conv_initializer=conv_initializer, rnn_nlayers=rnn_nlayers, rnn_type=rnn_type, - rnn_bn_type=rnn_bn_type, rnn_units=rnn_units, rnn_bidirectional=rnn_bidirectional, rnn_unroll=rnn_unroll, rnn_rowconv=rnn_rowconv, + rnn_rowconv_activation=rnn_rowconv_activation, rnn_dropout=rnn_dropout, + rnn_initializer=rnn_initializer, fc_nlayers=fc_nlayers, fc_units=fc_units, + fc_activation=fc_activation, fc_dropout=fc_dropout, + fc_initializer=fc_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + initializer=initializer, name="encoder", ), decoder=DeepSpeech2Decoder(vocab_size=vocab_size, name="decoder"), name=name, **kwargs, ) - self.time_reduction_factor = self.encoder.conv_module.reduction_factor + self.time_reduction_factor = self.encoder.time_reduction_factor + + def get_initial_encoder_states(self, batch_size=1): + return self.encoder.get_initial_state(batch_size) + + def get_initial_decoder_states(self, batch_size=1): + return None diff --git a/tensorflow_asr/models/ctc/jasper.py b/tensorflow_asr/models/ctc/jasper.py index 040c0dd1be..1e70c4c56e 100644 --- a/tensorflow_asr/models/ctc/jasper.py +++ b/tensorflow_asr/models/ctc/jasper.py @@ -12,303 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf - +from tensorflow_asr import keras from tensorflow_asr.models.base_layer import Layer from tensorflow_asr.models.ctc.base_ctc import CtcModel +from tensorflow_asr.models.encoders.jasper import JasperEncoder from tensorflow_asr.models.layers.convolution import Conv1D -from tensorflow_asr.utils import math_util - - -class Reshape(tf.keras.layers.Layer): - def call(self, inputs): - return math_util.merge_two_last_dims(inputs) - - -class JasperSubBlock(tf.keras.layers.Layer): - def __init__( - self, - channels: int = 256, - kernels: int = 11, - strides: int = 1, - dropout: float = 0.1, - padding: str = "causal", - dilation: int = 1, - kernel_regularizer=None, - bias_regularizer=None, - **kwargs, - ): - super().__init__(**kwargs) - self.conv1d = Conv1D( - filters=channels, - kernel_size=kernels, - strides=strides, - dilation_rate=dilation, - padding=padding, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - name="conv1d", - ) - self.bn = tf.keras.layers.BatchNormalization(name="bn") - self.relu = tf.keras.layers.ReLU(name="relu") - self.do = tf.keras.layers.Dropout(dropout, name="dropout") - self.reduction_factor = strides - - def call(self, inputs, training=False): - outputs = inputs - outputs = self.conv1d(outputs, training=training) - outputs = self.bn(outputs, training=training) - outputs = self.relu(outputs, training=training) - outputs = self.do(outputs, training=training) - return outputs - - -class JasperResidual(tf.keras.layers.Layer): - def __init__( - self, - channels: int = 256, - padding: str = "causal", - kernel_regularizer=None, - bias_regularizer=None, - **kwargs, - ): - super().__init__(**kwargs) - self.pointwise_conv1d = Conv1D( - filters=channels, - kernel_size=1, - strides=1, - padding=padding, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - name="pointwise_conv1d", - ) - self.bn = tf.keras.layers.BatchNormalization(name="bn") - - def call(self, inputs, training=False): - outputs = self.pointwise_conv1d(inputs, training=training) - outputs = self.bn(outputs, training=training) - return outputs - - -class JasperSubBlockResidual(JasperSubBlock): - def __init__( - self, - channels: int = 256, - kernels: int = 11, - strides: int = 1, - dropout: float = 0.1, - padding: str = "causal", - dilation: int = 1, - nresiduals: int = 1, - kernel_regularizer=None, - bias_regularizer=None, - **kwargs, - ): - super().__init__( - channels=channels, - kernels=kernels, - strides=strides, - dropout=dropout, - padding=padding, - dilation=dilation, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - **kwargs, - ) - - self.residuals = [ - JasperResidual( - channels=channels, - padding=padding, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - name=f"residual_{i}", - ) - for i in range(nresiduals) - ] - - self.add = tf.keras.layers.Add(name="add") - - def call(self, inputs, training=False): - outputs, residuals = inputs - outputs = self.conv1d(outputs, training=training) - outputs = self.bn(outputs, training=training) - for i, res in enumerate(residuals): - res = self.residuals[i](res, training=training) - outputs = self.add([outputs, res], training=training) - outputs = self.relu(outputs, training=training) - outputs = self.do(outputs, training=training) - return outputs - - -class JasperBlock(tf.keras.layers.Layer): - def __init__( - self, - nsubblocks: int = 3, - channels: int = 256, - kernels: int = 11, - dropout: float = 0.1, - padding: str = "causal", - dense: bool = False, - nresiduals: int = 1, - kernel_regularizer=None, - bias_regularizer=None, - **kwargs, - ): - super().__init__(**kwargs) - - self.dense = dense - - self.subblocks = [ - JasperSubBlock( - channels=channels, - kernels=kernels, - dropout=dropout, - padding=padding, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - name=f"subordinate_{i}", - ) - for i in range(nsubblocks - 1) - ] - - self.subblock_residual = JasperSubBlockResidual( - channels=channels, - kernels=kernels, - dropout=dropout, - nresiduals=nresiduals, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - name=f"subordinate_{nsubblocks - 1}", - ) - - self.reduction_factor = 1 - - def call(self, inputs, training=False): - inputs, residuals = inputs - outputs = inputs - for subblock in self.subblocks: - outputs = subblock(outputs, training=training) - if self.dense: - residuals.append(inputs) - outputs = self.subblock_residual([outputs, residuals], training=training) - else: - outputs = self.subblock_residual([outputs, [inputs]], training=training) - return outputs, residuals - - -class JasperEncoder(Layer): - def __init__( - self, - dense: bool = False, - padding: str = "causal", - first_additional_block_channels: int = 256, - first_additional_block_kernels: int = 11, - first_additional_block_strides: int = 2, - first_additional_block_dilation: int = 1, - first_additional_block_dropout: int = 0.2, - nsubblocks: int = 5, - block_channels: list = [256, 384, 512, 640, 768], - block_kernels: list = [11, 13, 17, 21, 25], - block_dropout: list = [0.2, 0.2, 0.2, 0.3, 0.3], - second_additional_block_channels: int = 896, - second_additional_block_kernels: int = 1, - second_additional_block_strides: int = 1, - second_additional_block_dilation: int = 2, - second_additional_block_dropout: int = 0.4, - third_additional_block_channels: int = 1024, - third_additional_block_kernels: int = 1, - third_additional_block_strides: int = 1, - third_additional_block_dilation: int = 1, - third_additional_block_dropout: int = 0.4, - kernel_regularizer=None, - bias_regularizer=None, - **kwargs, - ): - super().__init__(**kwargs) - - assert len(block_channels) == len(block_kernels) == len(block_dropout) - - self.reshape = Reshape(name="reshape") - - self.first_additional_block = JasperSubBlock( - channels=first_additional_block_channels, - kernels=first_additional_block_kernels, - strides=first_additional_block_strides, - dropout=first_additional_block_dropout, - padding=padding, - dilation=first_additional_block_dilation, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - name="first_block", - ) - - self.blocks = [ - JasperBlock( - nsubblocks=nsubblocks, - channels=block_channels[i], - kernels=block_kernels[i], - dropout=block_dropout[i], - dense=dense, - nresiduals=(i + 1) if dense else 1, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - name=f"block_{i}", - ) - for i in range(len(block_channels)) - ] - - self.second_additional_block = JasperSubBlock( - channels=second_additional_block_channels, - kernels=second_additional_block_kernels, - strides=second_additional_block_strides, - dropout=second_additional_block_dropout, - padding=padding, - dilation=second_additional_block_dilation, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - name="second_block", - ) - - self.third_additional_block = JasperSubBlock( - channels=third_additional_block_channels, - kernels=third_additional_block_kernels, - strides=third_additional_block_strides, - dropout=third_additional_block_dropout, - padding=padding, - dilation=third_additional_block_dilation, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - name="third_block", - ) - self.time_reduction_factor = self.first_additional_block.reduction_factor - self.time_reduction_factor *= self.second_additional_block.reduction_factor - self.time_reduction_factor *= self.third_additional_block.reduction_factor - - def call(self, inputs, training=False): - outputs, inputs_length = inputs - outputs = self.reshape(outputs) - outputs = self.first_additional_block(outputs, training=training) - - residuals = [] - for block in self.blocks: - outputs, residuals = block([outputs, residuals], training=training) - - outputs = self.second_additional_block(outputs, training=training) - outputs = self.third_additional_block(outputs, training=training) - outputs_length = math_util.get_reduced_length(inputs_length, self.time_reduction_factor) - outputs = math_util.apply_mask(outputs, mask=tf.sequence_mask(outputs_length, maxlen=tf.shape(outputs)[1], dtype=tf.bool)) - return outputs, outputs_length - - def compute_output_shape(self, input_shape): - inputs_shape, inputs_length_shape = input_shape - outputs_time = None if inputs_shape[1] is None else math_util.legacy_get_reduced_length(inputs_shape[1], self.time_reduction_factor) - outputs_batch = inputs_shape[0] - outputs_size = self.third_additional_block.conv1d.filters - outputs_shape = [outputs_batch, outputs_time, outputs_size] - return tuple(outputs_shape), tuple(inputs_length_shape) +@keras.utils.register_keras_serializable(package=__name__) class JasperDecoder(Layer): def __init__( self, @@ -327,25 +38,32 @@ def __init__( kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name="logits", + dtype=self.dtype, ) self._vocab_size = vocab_size def call(self, inputs, training=False): - logits, logits_length = inputs + logits, logits_length, *_ = inputs logits = self.vocab(logits, training=training) return logits, logits_length + def call_next(self, logits, logits_length, *args, **kwargs): + outputs, outputs_length = self((logits, logits_length), training=False) + return outputs, outputs_length, None + def compute_output_shape(self, input_shape): logits_shape, logits_length_shape = input_shape outputs_shape = logits_shape[:-1] + (self._vocab_size,) return tuple(outputs_shape), tuple(logits_length_shape) -@tf.keras.utils.register_keras_serializable("tensorflow_asr.models.ctc") +@keras.utils.register_keras_serializable(package=__name__) class Jasper(CtcModel): def __init__( self, + blank: int, vocab_size: int, + speech_config: dict, dense: bool = False, padding: str = "causal", first_additional_block_channels: int = 256, @@ -373,6 +91,8 @@ def __init__( **kwargs, ): super().__init__( + blank=blank, + speech_config=speech_config, encoder=JasperEncoder( dense=dense, padding=padding, @@ -399,13 +119,7 @@ def __init__( bias_regularizer=bias_regularizer, name="encoder", ), - decoder=JasperDecoder( - vocab_size=vocab_size, - padding=padding, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - name="decoder", - ), + decoder=JasperDecoder(vocab_size=vocab_size, padding=padding, name="decoder"), name=name, **kwargs, ) diff --git a/tensorflow_asr/models/ctc/transformer.py b/tensorflow_asr/models/ctc/transformer.py index dfb954568e..5030506b16 100644 --- a/tensorflow_asr/models/ctc/transformer.py +++ b/tensorflow_asr/models/ctc/transformer.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf - +from tensorflow_asr import keras from tensorflow_asr.models.base_layer import Layer from tensorflow_asr.models.ctc.base_ctc import CtcModel from tensorflow_asr.models.encoders.transformer import TransformerEncoder +@keras.utils.register_keras_serializable(package=__name__) class TransformerDecoder(Layer): def __init__( self, @@ -29,29 +29,36 @@ def __init__( ): super().__init__(**kwargs) self._vocab_size = vocab_size - self.vocab = tf.keras.layers.Dense( + self.vocab = keras.layers.Dense( vocab_size, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name="logits", + dtype=self.dtype, ) def call(self, inputs, training=False): - logits, logits_length = inputs + logits, logits_length, *_ = inputs logits = self.vocab(logits, training=training) return logits, logits_length + def call_next(self, logits, logits_length, *args, **kwargs): + outputs, outputs_length = self((logits, logits_length), training=False) + return outputs, outputs_length, None + def compute_output_shape(self, input_shape): logits_shape, logits_length_shape = input_shape outputs_shape = logits_shape[:-1] + (self._vocab_size,) return tuple(outputs_shape), tuple(logits_length_shape) -@tf.keras.utils.register_keras_serializable("tensorflow_asr.models.ctc") +@keras.utils.register_keras_serializable(package=__name__) class Transformer(CtcModel): def __init__( self, + blank: int, vocab_size: int, + speech_config: dict, encoder_subsampling: dict, encoder_dmodel: int = 512, encoder_dff: int = 1024, @@ -67,6 +74,10 @@ def __init__( encoder_pwffn_activation: str = "relu", encoder_dropout: float = 0.1, encoder_memory_length: int = None, + encoder_history_size: int = None, + encoder_chunk_size: int = None, + encoder_mha_causal: bool = False, + encoder_flash_attention: bool = False, encoder_trainable: bool = True, decoder_trainable: bool = True, kernel_regularizer=None, @@ -75,6 +86,8 @@ def __init__( **kwargs, ): super().__init__( + blank=blank, + speech_config=speech_config, encoder=TransformerEncoder( subsampling=encoder_subsampling, num_blocks=encoder_num_blocks, @@ -91,18 +104,16 @@ def __init__( pwffn_activation=encoder_pwffn_activation, dropout=encoder_dropout, memory_length=encoder_memory_length, + history_size=encoder_history_size, + chunk_size=encoder_chunk_size, + relmha_causal=encoder_mha_causal, + flash_attention=encoder_flash_attention, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, trainable=encoder_trainable, name="encoder", ), - decoder=TransformerDecoder( - vocab_size=vocab_size, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - trainable=decoder_trainable, - name="decoder", - ), + decoder=TransformerDecoder(vocab_size=vocab_size, trainable=decoder_trainable, name="decoder"), name=name, **kwargs, ) diff --git a/tensorflow_asr/models/decoders/__init__.py b/tensorflow_asr/models/decoders/__init__.py new file mode 100644 index 0000000000..9139bde684 --- /dev/null +++ b/tensorflow_asr/models/decoders/__init__.py @@ -0,0 +1,13 @@ +import glob +from os.path import basename, dirname, isdir, isfile, join + +for fd in glob.glob(join(dirname(__file__), "*")): + if not isfile(fd) and not isdir(fd): + continue + if isfile(fd) and not fd.endswith(".py"): + continue + fd = fd if isdir(fd) else fd[:-3] + fd = basename(fd) + if fd.startswith("__"): + continue + __import__(f"{__name__}.{fd}") diff --git a/tensorflow_asr/models/encoders/__init__.py b/tensorflow_asr/models/encoders/__init__.py index e69de29bb2..9139bde684 100644 --- a/tensorflow_asr/models/encoders/__init__.py +++ b/tensorflow_asr/models/encoders/__init__.py @@ -0,0 +1,13 @@ +import glob +from os.path import basename, dirname, isdir, isfile, join + +for fd in glob.glob(join(dirname(__file__), "*")): + if not isfile(fd) and not isdir(fd): + continue + if isfile(fd) and not fd.endswith(".py"): + continue + fd = fd if isdir(fd) else fd[:-3] + fd = basename(fd) + if fd.startswith("__"): + continue + __import__(f"{__name__}.{fd}") diff --git a/tensorflow_asr/models/encoders/conformer.py b/tensorflow_asr/models/encoders/conformer.py index 7ecb3245ed..158afd2403 100644 --- a/tensorflow_asr/models/encoders/conformer.py +++ b/tensorflow_asr/models/encoders/conformer.py @@ -12,21 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""http://arxiv.org/abs/2005.08100""" -import tensorflow as tf - +from tensorflow_asr import keras, tf from tensorflow_asr.models.activations.glu import GLU -from tensorflow_asr.models.base_layer import Layer from tensorflow_asr.models.layers.convolution import Conv1D, DepthwiseConv1D +from tensorflow_asr.models.layers.general import Activation, Dropout, Identity from tensorflow_asr.models.layers.multihead_attention import MultiHeadAttention, MultiHeadRelativeAttention -from tensorflow_asr.models.layers.positional_encoding import PositionalEncoding, RelativePositionalEncoding +from tensorflow_asr.models.layers.positional_encoding import RelativeSinusoidalPositionalEncoding, SinusoidalPositionalEncoding from tensorflow_asr.models.layers.residual import Residual -from tensorflow_asr.models.layers.subsampling import Conv1dSubsampling, Conv2dSubsampling, VggSubsampling +from tensorflow_asr.utils import data_util -L2 = tf.keras.regularizers.l2(1e-6) +L2 = keras.regularizers.l2(1e-6) -class FFModule(Layer): +@keras.utils.register_keras_serializable(package=__name__) +class FFModule(keras.Model): r""" architecture:: input @@ -52,43 +53,64 @@ def __init__( residual_factor=0.5, norm_position="pre", kernel_regularizer=L2, - bias_regularizer=L2, + bias_regularizer=None, name="ff_module", **kwargs, ): super().__init__(name=name, **kwargs) assert norm_position in ("pre", "post", "none") - self._norm_position = norm_position - self.norm = ( - None - if norm_position == "none" - else tf.keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer) + self.pre_norm = ( + keras.layers.LayerNormalization( + name="ln", + gamma_regularizer=kernel_regularizer, + beta_regularizer=kernel_regularizer, + dtype=self.dtype, + ) + if norm_position == "pre" + else Identity(name="preiden" if norm_position == "none" else "iden", dtype=self.dtype) ) - self.ffn1 = tf.keras.layers.Dense( - scale_factor * input_dim, + self.ffn1 = keras.layers.Dense( + units=scale_factor * input_dim, name="dense_1", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + activation="swish", + dtype=self.dtype, + ) + self.do1 = Dropout(rate=dropout, name="dropout_1", dtype=self.dtype) + self.ffn2 = keras.layers.Dense( + units=input_dim, + name="dense_2", + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + dtype=self.dtype, ) - self.swish = tf.keras.layers.Activation(tf.nn.swish, name="swish_activation") - self.do1 = tf.keras.layers.Dropout(dropout, name="dropout_1") - self.ffn2 = tf.keras.layers.Dense(input_dim, name="dense_2", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer) - self.do2 = tf.keras.layers.Dropout(dropout, name="dropout_2") - self.residual = Residual(factor=residual_factor, regularizer=bias_regularizer, name="residual") + self.do2 = Dropout(rate=dropout, name="dropout_2", dtype=self.dtype) + self.post_norm = ( + keras.layers.LayerNormalization( + name="ln", + gamma_regularizer=kernel_regularizer, + beta_regularizer=kernel_regularizer, + dtype=self.dtype, + ) + if norm_position == "post" + else Identity(name="postiden" if norm_position == "none" else "iden", dtype=self.dtype) + ) + self.residual = Residual(factor=residual_factor, regularizer=bias_regularizer, name="residual", dtype=self.dtype) def call(self, inputs, training=False): - outputs = self.norm(inputs, training=training) if self._norm_position == "pre" else inputs + outputs = self.pre_norm(inputs, training=training) outputs = self.ffn1(outputs, training=training) - outputs = self.swish(outputs) outputs = self.do1(outputs, training=training) outputs = self.ffn2(outputs, training=training) outputs = self.do2(outputs, training=training) - outputs = self.norm(outputs, training=training) if self._norm_position == "post" else outputs - outputs = self.residual([inputs, outputs], training=training) + outputs = self.post_norm(outputs, training=training) + outputs = self.residual((inputs, outputs), training=training) return outputs -class MHSAModule(Layer): +@keras.utils.register_keras_serializable(package=__name__) +class MHSAModule(keras.Model): r""" architecture:: input @@ -110,94 +132,125 @@ def __init__( residual_factor=1.0, dropout=0.0, mha_type="relmha", + relmha_causal=False, + flash_attention=None, norm_position="pre", memory_length=None, + history_size=None, + chunk_size=None, + use_attention_bias=False, kernel_regularizer=L2, - bias_regularizer=L2, + bias_regularizer=None, + activity_regularizer=None, name="mhsa_module", **kwargs, ): super().__init__(name=name, **kwargs) assert norm_position in ("pre", "post", "none") - self._norm_position = norm_position - self.norm = ( - None - if norm_position == "none" - else tf.keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer) + assert mha_type in ("relmha", "mha") + self.pre_norm = ( + keras.layers.LayerNormalization( + name="ln", + gamma_regularizer=kernel_regularizer, + beta_regularizer=kernel_regularizer, + dtype=self.dtype, + ) + if norm_position == "pre" + else Identity(name="preiden" if norm_position == "none" else "iden", dtype=self.dtype) ) if mha_type == "relmha": self.mha = MultiHeadRelativeAttention( + causal=relmha_causal, num_heads=num_heads, key_dim=head_size, output_shape=dmodel, memory_length=memory_length, + history_size=history_size, + chunk_size=chunk_size, + flash_attention=flash_attention, + use_attention_bias=use_attention_bias, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - dtype=tf.float32, # for stable training + activity_regularizer=activity_regularizer, name="mhsa", + dtype=self.dtype, ) - elif mha_type == "mha": + else: self.mha = MultiHeadAttention( num_heads=num_heads, key_dim=head_size, output_shape=dmodel, memory_length=memory_length, + history_size=history_size, + chunk_size=chunk_size, + flash_attention=flash_attention, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - dtype=tf.float32, # for stable training + activity_regularizer=activity_regularizer, name="mhsa", + dtype=self.dtype, ) - else: - raise ValueError("mha_type must be either 'mha' or 'relmha'") - self.do = tf.keras.layers.Dropout(dropout, name="dropout") - self.residual = Residual(factor=residual_factor, regularizer=bias_regularizer, name="residual") - self.mha_type = mha_type + self.do = Dropout(dropout, name="dropout", dtype=self.dtype) + self.post_norm = ( + keras.layers.LayerNormalization( + name="ln", + gamma_regularizer=kernel_regularizer, + beta_regularizer=kernel_regularizer, + dtype=self.dtype, + ) + if norm_position == "post" + else Identity(name="postiden" if norm_position == "none" else "iden", dtype=self.dtype) + ) + self.residual = Residual(factor=residual_factor, regularizer=bias_regularizer, name="residual", dtype=self.dtype) + + def get_initial_state(self, batch_size: int): + return self.mha.get_initial_state(batch_size) def call( self, inputs, - relative_position_encoding=None, content_attention_bias=None, positional_attention_bias=None, + initial_state=None, training=False, attention_mask=None, use_causal_mask=False, use_auto_mask=True, + return_states=False, ): - outputs = self.norm(inputs, training=training) if self._norm_position == "pre" else inputs - mha_inputs = ( - dict( - inputs=[outputs, outputs, outputs, relative_position_encoding], - content_attention_bias=content_attention_bias, - positional_attention_bias=positional_attention_bias, - ) - if self.mha_type == "relmha" - else dict(inputs=[outputs, outputs, outputs]) - ) - outputs = self.mha( - **mha_inputs, + _inputs, relative_position_encoding = inputs + outputs = self.pre_norm(_inputs, training=training) + outputs, *states = self.mha( + [outputs, outputs, outputs, relative_position_encoding], + content_attention_bias=content_attention_bias, + positional_attention_bias=positional_attention_bias, + initial_state=initial_state, training=training, attention_mask=attention_mask, use_causal_mask=use_causal_mask, use_auto_mask=use_auto_mask, + return_states=return_states, ) outputs = self.do(outputs, training=training) - outputs = self.norm(outputs, training=training) if self._norm_position == "post" else outputs - outputs = self.residual([inputs, outputs], training=training) - return outputs + outputs = self.post_norm(outputs, training=training) + outputs = self.residual((_inputs, outputs), training=training) + if return_states: + return [outputs] + states + return [outputs] -class ConvModule(Layer): +@keras.utils.register_keras_serializable(package=__name__) +class ConvModule(keras.Model): r""" architecture:: input / \ | ln(.) # input_dim - | conv1d(.) # 2 * input_dim + | conv1d(.) # 2 * input_dim | | | glu(.) # input_dim | depthwise_conv_1d(.) - | bnorm(.) + | norm(.) # batch or layer | swish(.) | | | conv1d(.) @@ -217,70 +270,115 @@ def __init__( scale_factor=2, residual_factor=1.0, norm_position="pre", + dw_norm_type="batch", + use_group_conv=False, kernel_regularizer=L2, - bias_regularizer=L2, + bias_regularizer=None, name="conv_module", **kwargs, ): super().__init__(name=name, **kwargs) assert norm_position in ("pre", "post", "none") - self._norm_position = norm_position - self.norm = ( - None - if norm_position == "none" - else tf.keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer) + assert dw_norm_type in ("batch", "layer") + self.pre_norm = ( + keras.layers.LayerNormalization( + name="ln", + gamma_regularizer=kernel_regularizer, + beta_regularizer=kernel_regularizer, + dtype=self.dtype, + ) + if norm_position == "pre" + else Identity(name="preiden" if norm_position == "none" else "iden", dtype=self.dtype) ) self.pw_conv_1 = Conv1D( filters=scale_factor * input_dim, kernel_size=1, strides=1, - padding=padding, + padding="valid", name="pw_conv_1", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + dtype=self.dtype, ) - self.glu = GLU(axis=-1, name="glu_activation") - self.dw_conv = Conv1D( - filters=input_dim, - kernel_size=kernel_size, - strides=1, - groups=input_dim, - padding=padding, - name="dw_conv", - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - ) - self.bn = tf.keras.layers.BatchNormalization( - name="bn", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer, synchronized=True + self.glu = GLU(axis=-1, name="glu", dtype=self.dtype) + if use_group_conv: + self.dw_conv = Conv1D( + filters=input_dim, + kernel_size=kernel_size, + strides=1, + padding=padding, + groups=input_dim, + name="dw_conv", + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + dtype=self.dtype, + ) + else: + self.dw_conv = DepthwiseConv1D( + kernel_size=kernel_size, + strides=1, + padding=padding, + name="dw_conv", + depthwise_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + dtype=self.dtype, + ) + self.dw_norm = ( + keras.layers.BatchNormalization( + name="dw_bn", + gamma_regularizer=kernel_regularizer, + beta_regularizer=kernel_regularizer, + synchronized=True, + dtype=self.dtype, + ) + if dw_norm_type == "batch" + else keras.layers.LayerNormalization( + name="dw_ln", + gamma_regularizer=kernel_regularizer, + beta_regularizer=kernel_regularizer, + dtype=self.dtype, + ) ) - self.swish = tf.keras.layers.Activation(tf.nn.swish, name="swish_activation") + self.swish = Activation(tf.nn.swish, name="swish", dtype=self.dtype) self.pw_conv_2 = Conv1D( filters=input_dim, kernel_size=1, strides=1, - padding=padding, + padding="valid", name="pw_conv_2", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + dtype=self.dtype, ) - self.do = tf.keras.layers.Dropout(dropout, name="dropout") - self.residual = Residual(factor=residual_factor, regularizer=bias_regularizer, name="residual") + self.do = Dropout(rate=dropout, name="dropout", dtype=self.dtype) + self.post_norm = ( + keras.layers.LayerNormalization( + name="ln", + gamma_regularizer=kernel_regularizer, + beta_regularizer=kernel_regularizer, + dtype=self.dtype, + ) + if norm_position == "post" + else Identity(name="postiden" if norm_position == "none" else "iden", dtype=self.dtype) + ) + self.residual = Residual(factor=residual_factor, regularizer=bias_regularizer, name="residual", dtype=self.dtype) def call(self, inputs, training=False): - outputs = self.norm(inputs, training=training) if self._norm_position == "pre" else inputs + outputs = self.pre_norm(inputs, training=training) outputs = self.pw_conv_1(outputs, training=training) - outputs = self.glu(outputs) + outputs = self.glu(outputs, training=training) outputs = self.dw_conv(outputs, training=training) - outputs = self.bn(outputs, training=training) - outputs = self.swish(outputs) + outputs = self.dw_norm(outputs, training=training) + outputs = self.swish(outputs, training=training) outputs = self.pw_conv_2(outputs, training=training) outputs = self.do(outputs, training=training) - outputs = self.norm(outputs, training=training) if self._norm_position == "post" else outputs - outputs = self.residual([inputs, outputs], training=training) + outputs = self.post_norm(outputs, training=training) + outputs = self.residual((inputs, outputs), training=training) return outputs -class ConformerBlock(Layer): +@keras.utils.register_keras_serializable(package=__name__) +class ConformerBlock(keras.Model): r""" architecture:: x = x + 1/2 * FFN(x) @@ -300,27 +398,37 @@ def __init__( num_heads=4, mha_type="relmha", mhsam_residual_factor=1.0, + mhsam_use_attention_bias=False, + mhsam_causal=False, + mhsam_flash_attention=None, kernel_size=32, padding="causal", convm_scale_factor=2, convm_residual_factor=1.0, + convm_use_group_conv=False, + convm_dw_norm_type="batch", module_norm_position="pre", block_norm_position="post", memory_length=None, - mhsam_before_convm=True, + history_size=None, + chunk_size=None, kernel_regularizer=L2, - bias_regularizer=L2, + bias_regularizer=None, + activity_regularizer=None, name="conformer_block", **kwargs, ): super().__init__(name=name, **kwargs) assert block_norm_position in ("pre", "post", "none") - self._norm_position = block_norm_position - self._mhsam_before_convm = mhsam_before_convm - self.norm = ( - None - if block_norm_position == "none" - else tf.keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer) + self.pre_norm = ( + keras.layers.LayerNormalization( + name="ln", + gamma_regularizer=kernel_regularizer, + beta_regularizer=kernel_regularizer, + dtype=self.dtype, + ) + if block_norm_position == "pre" + else Identity(name="preiden" if block_norm_position == "none" else "iden", dtype=self.dtype) ) self.ffm1 = FFModule( input_dim=input_dim, @@ -331,19 +439,27 @@ def __init__( name="ff_module_1", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + dtype=self.dtype, ) self.mhsam = MHSAModule( dmodel=input_dim, head_size=head_size, num_heads=num_heads, residual_factor=mhsam_residual_factor, + use_attention_bias=mhsam_use_attention_bias, dropout=dropout, mha_type=mha_type, + relmha_causal=mhsam_causal, norm_position=module_norm_position, memory_length=memory_length, + history_size=history_size, + chunk_size=chunk_size, + flash_attention=mhsam_flash_attention, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, name="mhsa_module", + dtype=self.dtype, ) self.convm = ConvModule( input_dim=input_dim, @@ -354,8 +470,11 @@ def __init__( scale_factor=convm_scale_factor, residual_factor=convm_residual_factor, norm_position=module_norm_position, + dw_norm_type=convm_dw_norm_type, + use_group_conv=convm_use_group_conv, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + dtype=self.dtype, ) self.ffm2 = FFModule( input_dim=input_dim, @@ -366,51 +485,58 @@ def __init__( name="ff_module_2", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + dtype=self.dtype, + ) + self.post_norm = ( + keras.layers.LayerNormalization( + name="ln", + gamma_regularizer=kernel_regularizer, + beta_regularizer=kernel_regularizer, + dtype=self.dtype, + ) + if block_norm_position == "post" + else Identity(name="postiden" if block_norm_position == "none" else "iden", dtype=self.dtype) ) + def get_initial_state(self, batch_size: int): + return self.mhsam.get_initial_state(batch_size) + def call( self, inputs, - relative_position_encoding=None, content_attention_bias=None, positional_attention_bias=None, + initial_state=None, training=False, attention_mask=None, use_causal_mask=False, use_auto_mask=True, + return_states=False, ): - outputs = self.norm(inputs, training=training) if self._norm_position == "pre" else inputs + _inputs, relative_position_encoding = inputs + outputs = self.pre_norm(_inputs, training=training) outputs = self.ffm1(outputs, training=training) - if self._mhsam_before_convm: - outputs = self.mhsam( - outputs, - relative_position_encoding=relative_position_encoding, - content_attention_bias=content_attention_bias, - positional_attention_bias=positional_attention_bias, - training=training, - attention_mask=attention_mask, - use_causal_mask=use_causal_mask, - use_auto_mask=use_auto_mask, - ) - outputs = self.convm(outputs, training=training) - else: - outputs = self.convm(outputs, training=training) - outputs = self.mhsam( - outputs, - relative_position_encoding=relative_position_encoding, - content_attention_bias=content_attention_bias, - positional_attention_bias=positional_attention_bias, - training=training, - attention_mask=attention_mask, - use_causal_mask=use_causal_mask, - use_auto_mask=use_auto_mask, - ) + outputs, *states = self.mhsam( + [outputs, relative_position_encoding], + content_attention_bias=content_attention_bias, + positional_attention_bias=positional_attention_bias, + initial_state=initial_state, + training=training, + attention_mask=attention_mask, + use_causal_mask=use_causal_mask, + use_auto_mask=use_auto_mask, + return_states=return_states, + ) + outputs = self.convm(outputs, training=training) outputs = self.ffm2(outputs, training=training) - outputs = self.norm(outputs, training=training) if self._norm_position == "post" else outputs - return outputs + outputs = self.post_norm(outputs, training=training) + if return_states: + return [outputs] + states + return [outputs] -class ConformerEncoder(Layer): +@keras.utils.register_keras_serializable(package=__name__) +class ConformerEncoder(keras.Model): def __init__( self, subsampling, @@ -427,64 +553,66 @@ def __init__( ffm_scale_factor=4, ffm_residual_factor=0.5, mhsam_residual_factor=1.0, + mhsam_use_attention_bias=False, + mhsam_causal=False, + mhsam_flash_attention=None, convm_scale_factor=2, convm_residual_factor=1.0, + convm_use_group_conv=False, + convm_dw_norm_type="batch", dropout=0.1, module_norm_position="pre", block_norm_position="post", memory_length=None, - mhsam_before_convm=True, + history_size=None, + chunk_size=None, kernel_regularizer=L2, - bias_regularizer=L2, + bias_regularizer=None, + activity_regularizer=None, name="conformer_encoder", **kwargs, ): super().__init__(name=name, **kwargs) + assert mha_type in ("relmha", "mha") self._dmodel = dmodel self._kernel_regularizer = kernel_regularizer self._bias_regularizer = bias_regularizer self._num_blocks = num_blocks - subsampling_name = subsampling.pop("type", None) - if subsampling_name == "vgg": - subsampling_class = VggSubsampling - elif subsampling_name == "conv2d": - subsampling_class = Conv2dSubsampling - elif subsampling_name == "conv1d": - subsampling_class = Conv1dSubsampling - else: - raise ValueError("subsampling must be either 'vgg', 'conv2d', 'conv1d'") - - self.conv_subsampling = subsampling_class( - **subsampling, + self.conv_subsampling = keras.utils.get_registered_object(name=subsampling["class_name"])( + **subsampling["config"], name="subsampling", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + dtype=self.dtype, ) self.time_reduction_factor = self.conv_subsampling.time_reduction_factor - self.linear = tf.keras.layers.Dense( - dmodel, - name="linear", - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, + self.linear = keras.layers.Dense( + dmodel, name="linear", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, dtype=self.dtype ) - self.do = tf.keras.layers.Dropout(dropout, name="dropout") + self.do = Dropout(dropout, name="dropout", dtype=self.dtype) self._mha_type = mha_type self._num_heads = num_heads self._key_dim = head_size + self._memory_length = memory_length self._use_attention_causal_mask = use_attention_causal_mask self._use_attention_auto_mask = use_attention_auto_mask if self._mha_type == "relmha": - self.relpe = RelativePositionalEncoding(interleave=interleave_relpe, memory_length=memory_length, name="relpe") + self.relpe = RelativeSinusoidalPositionalEncoding( + interleave=interleave_relpe, + memory_length=memory_length, + causal=mhsam_causal, + name="relpe", + dtype=self.dtype, + ) else: - self.relpe = PositionalEncoding(interleave=interleave_relpe, name="pe") + self.relpe = SinusoidalPositionalEncoding(interleave=interleave_relpe, name="pe", dtype=self.dtype) - self.conformer_blocks = [] - for i in range(self._num_blocks): - conformer_block = ConformerBlock( + self.conformer_blocks = [ + ConformerBlock( input_dim=dmodel, dropout=dropout, ffm_scale_factor=ffm_scale_factor, @@ -493,27 +621,37 @@ def __init__( num_heads=num_heads, mha_type=mha_type, mhsam_residual_factor=mhsam_residual_factor, + mhsam_use_attention_bias=mhsam_use_attention_bias, + mhsam_causal=mhsam_causal, + mhsam_flash_attention=mhsam_flash_attention, kernel_size=kernel_size, padding=padding, convm_scale_factor=convm_scale_factor, convm_residual_factor=convm_residual_factor, + convm_use_group_conv=convm_use_group_conv, + convm_dw_norm_type=convm_dw_norm_type, module_norm_position=module_norm_position, block_norm_position=block_norm_position, memory_length=memory_length, - mhsam_before_convm=mhsam_before_convm, + history_size=history_size, + chunk_size=chunk_size, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, name=f"block_{i}", + dtype=self.dtype, ) - self.conformer_blocks.append(conformer_block) + for i in range(self._num_blocks) + ] - if self._mha_type == "relmha": + if self._mha_type == "relmha" and not mhsam_use_attention_bias: self.content_attention_bias = self.add_weight( name="content_attention_bias", shape=[self._num_heads, self._key_dim], trainable=True, initializer="zeros", regularizer=self._bias_regularizer, + dtype=self.variable_dtype, ) self.positional_attention_bias = self.add_weight( name="positional_attention_bias", @@ -521,41 +659,63 @@ def __init__( trainable=True, initializer="zeros", regularizer=self._bias_regularizer, + dtype=self.variable_dtype, ) else: self.content_attention_bias, self.positional_attention_bias = None, None - def get_states(self): - return [block.mhsam.mha.get_states() for block in self.conformer_blocks] - - def reset_states(self, states=None): - if states is None: - states = [(None, None) for _ in range(self._num_blocks)] - for i, memory_states in enumerate(states): - self.conformer_blocks[i].mhsam.mha.reset_states(memory_states) + def get_initial_state(self, batch_size: int): + states = [block.get_initial_state(batch_size) for block in self.conformer_blocks] + states = [s for s in states if s is not None] + return states - def call(self, inputs, training=False): + def call( + self, + inputs, + initial_state=None, + training=False, + return_states=False, + ): outputs, outputs_length = inputs - outputs, outputs_length = self.conv_subsampling([outputs, outputs_length], training=training) + outputs, outputs_length = self.conv_subsampling((outputs, outputs_length), training=training) outputs = self.linear(outputs, training=training) - outputs, relative_position_encoding = self.relpe(outputs, training=training) outputs = self.do(outputs, training=training) - - for _, cblock in enumerate(self.conformer_blocks): - outputs = cblock( - outputs, - relative_position_encoding=relative_position_encoding, + outputs, relative_position_encoding = self.relpe((outputs, outputs_length), training=training) + states = None if self._memory_length is None else [] + for i, cblock in enumerate(self.conformer_blocks): + outputs, *_states = cblock( + (outputs, relative_position_encoding), content_attention_bias=self.content_attention_bias, positional_attention_bias=self.positional_attention_bias, + initial_state=data_util.get(initial_state, i, None), training=training, use_causal_mask=self._use_attention_causal_mask, use_auto_mask=self._use_attention_auto_mask, + return_states=return_states, ) - + if not states: + continue + states.extend(_states) + if return_states: + return outputs, outputs_length, states return outputs, outputs_length - def compute_output_shape(self, input_shape): - outputs_shape, outputs_length_shape = self.conv_subsampling.compute_output_shape(input_shape) - outputs_shape = list(outputs_shape) - outputs_shape[-1] = self._dmodel - return outputs_shape, outputs_length_shape + def call_next(self, features, features_length, previous_encoder_states, *args, **kwargs): + """ + Recognize function for encoder network + + Parameters + ---------- + features : tf.Tensor, shape [B, T, F, C] + features_length : tf.Tensor, shape [B] + + Returns + ------- + Tuple[tf.Tensor, tf.Tensor, tf.Tensor], shape ([B, T, dmodel], [B], None) + Outputs, outputs_length, new_states + """ + with tf.name_scope(f"{self.name}_call_next"): + return self((features, features_length), initial_state=previous_encoder_states, training=False, return_states=True) + + def compute_mask(self, inputs, mask=None): + return self.conv_subsampling.compute_mask(inputs, mask=mask) diff --git a/tensorflow_asr/models/encoders/contextnet.py b/tensorflow_asr/models/encoders/contextnet.py index d352dea6c3..013e5a97cd 100644 --- a/tensorflow_asr/models/encoders/contextnet.py +++ b/tensorflow_asr/models/encoders/contextnet.py @@ -11,15 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +""" http://arxiv.org/abs/2005.03191 """ from typing import List -import tensorflow as tf - -from tensorflow_asr.models.base_layer import Layer +from tensorflow_asr import keras, tf +from tensorflow_asr.models.base_layer import Layer, Reshape +from tensorflow_asr.models.layers.convolution import SeparableConv1D from tensorflow_asr.utils import math_util -L2 = tf.keras.regularizers.l2(1e-6) +L2 = keras.regularizers.l2(1e-6) def get_activation( @@ -31,16 +32,12 @@ def get_activation( if activation == "relu": return tf.nn.relu if activation == "linear": - return tf.keras.activations.linear + return keras.activations.linear raise ValueError("activation must be either 'silu', 'swish', 'relu' or 'linear'") -class Reshape(tf.keras.layers.Layer): - def call(self, inputs): - return math_util.merge_two_last_dims(inputs) - - -class ConvModule(tf.keras.layers.Layer): +@keras.utils.register_keras_serializable(package=__name__) +class ConvModule(Layer): def __init__( self, kernel_size: int = 3, @@ -54,7 +51,7 @@ def __init__( ): super().__init__(**kwargs) self.strides = strides - self.conv = tf.keras.layers.SeparableConv1D( + self.conv = SeparableConv1D( filters=filters, kernel_size=kernel_size, strides=strides, @@ -63,18 +60,55 @@ def __init__( pointwise_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name="conv", + dtype=self.dtype, + ) + self.bn = keras.layers.BatchNormalization( + name="bn", + gamma_regularizer=kernel_regularizer, + beta_regularizer=kernel_regularizer, + synchronized=True, + dtype=self.dtype, ) - self.bn = tf.keras.layers.BatchNormalization(name="bn") self.activation = get_activation(activation) def call(self, inputs, training=False): - outputs = self.conv(inputs, training=training) + outputs, outputs_length = inputs + outputs = self.conv(outputs, training=training) + outputs_length = math_util.conv_output_length( + outputs_length, + filter_size=self.conv.kernel_size[0], + padding=self.conv._padding, + stride=self.conv.strides[0], + dilation=self.conv.dilation_rate[0], + ) outputs = self.bn(outputs, training=training) outputs = self.activation(outputs) - return outputs + return outputs, outputs_length + + def compute_mask(self, inputs, mask=None): + outputs, outputs_length = inputs + maxlen = tf.shape(outputs)[1] + maxlen, outputs_length = ( + math_util.conv_output_length( + length, + filter_size=self.conv.kernel_size[0], + padding=self.conv._padding, + stride=self.conv.strides[0], + dilation=self.conv.dilation_rate[0], + ) + for length in (maxlen, outputs_length) + ) + mask = tf.sequence_mask(outputs_length, maxlen=maxlen, dtype=tf.bool) + return mask, None + + def compute_output_shape(self, input_shape): + output_shape, output_length_shape = input_shape + output_shape = self.conv.compute_output_shape(output_shape) + return output_shape, output_length_shape -class SEModule(tf.keras.layers.Layer): +@keras.utils.register_keras_serializable(package=__name__) +class SEModule(Layer): def __init__( self, kernel_size: int = 3, @@ -96,18 +130,30 @@ def __init__( kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name="conv_module", + dtype=self.dtype, ) - self.global_avg_pool = tf.keras.layers.GlobalAveragePooling1D(keepdims=True, name="global_avg_pool") + self.global_avg_pool = keras.layers.GlobalAveragePooling1D(keepdims=True, name="global_avg_pool", dtype=self.dtype) self.activation = get_activation(activation) - self.fc1 = tf.keras.layers.Dense(filters // 8, name="fc1") - self.fc2 = tf.keras.layers.Dense(filters, name="fc2") + self.fc1 = keras.layers.Dense( + filters // 8, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name="fc1", + dtype=self.dtype, + ) + self.fc2 = keras.layers.Dense( + filters, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name="fc2", + dtype=self.dtype, + ) def call(self, inputs, training=False): - features, inputs_length = inputs - outputs = self.conv(features, training=training) # [B, T, E] + outputs, outputs_length = inputs + outputs, outputs_length = self.conv((outputs, outputs_length), training=training) # [B, T, E] - mask = tf.sequence_mask(inputs_length, maxlen=tf.shape(outputs)[1]) - se = self.global_avg_pool(outputs, mask=mask) # [B, 1, E] + se = self.global_avg_pool(outputs) # [B, 1, E], mask auto populate se = self.fc1(se, training=training) se = self.activation(se) se = self.fc2(se, training=training) @@ -115,10 +161,17 @@ def call(self, inputs, training=False): se = tf.tile(se, [1, tf.shape(outputs)[1], 1]) # [B, 1, E] => [B, T, E] outputs = tf.multiply(outputs, se) # [B, T, E] - return outputs + return outputs, outputs_length + + def compute_mask(self, inputs, mask=None): + return self.conv.compute_mask(inputs, mask=mask) + + def compute_output_shape(self, input_shape): + return self.conv.compute_output_shape(input_shape) -class ConvBlock(tf.keras.layers.Layer): +@keras.utils.register_keras_serializable(package=__name__) +class ConvBlock(keras.Model): def __init__( self, nlayers: int = 3, @@ -151,6 +204,7 @@ def __init__( kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name=f"conv_module_{i}", + dtype=self.dtype, ) ) @@ -163,6 +217,7 @@ def __init__( kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name=f"conv_module_{nlayers - 1}", + dtype=self.dtype, ) self.se = SEModule( @@ -174,6 +229,7 @@ def __init__( kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name="se", + dtype=self.dtype, ) self.residual = None @@ -187,28 +243,38 @@ def __init__( kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name="residual", + dtype=self.dtype, ) self.activation = get_activation(activation) def call(self, inputs, training=False): - features, inputs_length = inputs - outputs = features + _inputs, _inputs_length = inputs + outputs, outputs_length = _inputs, _inputs_length for conv in self.convs: - outputs = conv(outputs, training=training) - outputs = self.last_conv(outputs, training=training) - inputs_length = math_util.conv_output_length( - inputs_length, filter_size=self.last_conv.conv.kernel_size[0], padding=self.last_conv.conv.padding, stride=self.last_conv.strides - ) - outputs = self.se([outputs, inputs_length], training=training) + outputs, outputs_length = conv((outputs, outputs_length), training=training) + outputs, outputs_length = self.last_conv((outputs, outputs_length), training=training) + outputs, outputs_length = self.se((outputs, outputs_length), training=training) if self.residual is not None: - res = self.residual(features, training=training) + res, _ = self.residual((_inputs, _inputs_length), training=training) outputs = tf.add(outputs, res) outputs = self.activation(outputs) - return outputs, inputs_length + return outputs, outputs_length + + def compute_mask(self, inputs, mask=None): + return self.last_conv.compute_mask(inputs, mask=mask) + def compute_output_shape(self, input_shape): + output_shape = input_shape + for conv in self.convs: + output_shape = conv.compute_output_shape(output_shape) + output_shape = self.last_conv.compute_output_shape(output_shape) + output_shape = self.se.compute_output_shape(output_shape) + return output_shape -class ContextNetEncoder(Layer): + +@keras.utils.register_keras_serializable(package=__name__) +class ContextNetEncoder(keras.Model): def __init__( self, blocks: List[dict] = [], @@ -219,36 +285,57 @@ def __init__( ): super().__init__(**kwargs) - self.reshape = Reshape(name="reshape") + self.reshape = Reshape(name="reshape", dtype=self.dtype) self.blocks = [] + self.time_reduction_factor = 1 for i, config in enumerate(blocks): - self.blocks.append( - ConvBlock( - **config, - alpha=alpha, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - name=f"block_{i}", - ) + block = ConvBlock( + **config, + alpha=alpha, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"block_{i}", + dtype=self.dtype, ) + self.blocks.append(block) + self.time_reduction_factor *= block.time_reduction_factor self.dmodel = self.blocks[-1].dmodel - self.time_reduction_factor = 1 - for block in self.blocks: - self.time_reduction_factor *= block.time_reduction_factor def call(self, inputs, training=False): - outputs, inputs_length = inputs - outputs = self.reshape(outputs) + outputs, outputs_length = inputs + outputs, outputs_length = self.reshape((outputs, outputs_length)) for block in self.blocks: - outputs, inputs_length = block([outputs, inputs_length], training=training) - return outputs, inputs_length + outputs, outputs_length = block((outputs, outputs_length), training=training) + return outputs, outputs_length + + def call_next(self, features, features_length, *args, **kwargs): + """ + Recognize function for encoder network + + Parameters + ---------- + features : tf.Tensor, shape [B, T, F, C] + features_length : tf.Tensor, shape [B] + + Returns + ------- + Tuple[tf.Tensor, tf.Tensor, tf.Tensor], shape ([B, T, dmodel], [B], None) + Outputs, outputs_length, new_states + """ + with tf.name_scope(f"{self.name}_call_next"): + return self.call((features, features_length), training=False) + + def compute_mask(self, inputs, mask=None): + outputs, outputs_length = inputs + maxlen = tf.shape(outputs)[1] + maxlen, outputs_length = (math_util.get_reduced_length(length, self.time_reduction_factor) for length in (maxlen, outputs_length)) + mask = tf.sequence_mask(outputs_length, maxlen=maxlen, dtype=tf.bool) + return mask, None def compute_output_shape(self, input_shape): - inputs_shape, inputs_length_shape = input_shape - outputs_size = self.dmodel - outputs_time = None if inputs_shape[1] is None else math_util.legacy_get_reduced_length(inputs_shape[1], self.time_reduction_factor) - outputs_batch = inputs_shape[0] - outputs_shape = [outputs_batch, outputs_time, outputs_size] - return tuple(outputs_shape), tuple(inputs_length_shape) + output_shape = self.reshape.compute_output_shape(input_shape) + for block in self.blocks: + output_shape = block.compute_output_shape(output_shape) + return output_shape diff --git a/tensorflow_asr/models/encoders/deepspeech2.py b/tensorflow_asr/models/encoders/deepspeech2.py new file mode 100644 index 0000000000..78ca228e26 --- /dev/null +++ b/tensorflow_asr/models/encoders/deepspeech2.py @@ -0,0 +1,533 @@ +# Copyright 2020 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tensorflow_asr import keras, tf +from tensorflow_asr.models.base_layer import Layer, Reshape +from tensorflow_asr.models.layers.convolution import DepthwiseConv1D +from tensorflow_asr.models.layers.general import Activation, Dropout, Identity +from tensorflow_asr.utils import env_util, layer_util, math_util + +# ----------------------------------- CONV ----------------------------------- # + + +@keras.utils.register_keras_serializable(package=__name__) +class RowConv1D(Layer): + def __init__( + self, + future_width=2, + activation="relu", + regularizer=None, + initializer="glorot_uniform", + **kwargs, + ): + assert future_width >= 0, "Future context must be positive" + super().__init__(**kwargs) + self.conv = DepthwiseConv1D( + kernel_size=future_width * 2 + 1, + strides=1, + padding="causal", + use_bias=False, + depthwise_regularizer=regularizer, + depthwise_initializer=initializer, + bias_regularizer=regularizer, + name="conv", + dtype=self.dtype, + ) + self.bn = keras.layers.BatchNormalization( + name="bn", + gamma_regularizer=regularizer, + beta_regularizer=regularizer, + synchronized=True, + dtype=self.dtype, + ) + self.activation = keras.activations.get(activation) + + def call(self, inputs, training=False): + outputs = self.conv(inputs, training=training) + outputs = self.bn(outputs, training=training) + outputs = self.activation(outputs) + return outputs + + def compute_output_shape(self, input_shape): + output_shape = self.conv.compute_output_shape(input_shape) + output_shape = self.bn.compute_output_shape(output_shape) + return output_shape + + +@keras.utils.register_keras_serializable(package=__name__) +class ConvBlock(Layer): + def __init__( + self, + conv_type: str = "conv2d", + kernels: list = [11, 41], + strides: list = [2, 2], + filters: int = 32, + padding: str = "causal", + activation: str = "relu", + kernel_regularizer=None, + bias_regularizer=None, + initializer="glorot_uniform", + **kwargs, + ): + super().__init__(**kwargs) + self.conv = layer_util.get_conv(conv_type)( + filters=filters, + kernel_size=kernels, + strides=strides, + padding=padding, + name=conv_type, + kernel_regularizer=kernel_regularizer, + kernel_initializer=initializer, + bias_regularizer=bias_regularizer, + dtype=self.dtype, + ) + self.bn = keras.layers.BatchNormalization( + name="bn", + gamma_regularizer=kernel_regularizer, + beta_regularizer=kernel_regularizer, + synchronized=True, + dtype=self.dtype, + ) + self.act = Activation(activation=activation, name=activation, dtype=self.dtype) + self.time_reduction_factor = self.conv.strides[0] + + def call(self, inputs, training=False): + outputs, outputs_length = inputs + outputs = self.conv(outputs, training=training) + outputs = self.bn(outputs, training=training) + outputs = self.act(outputs, training=training) + outputs_length = math_util.conv_output_length( + outputs_length, + filter_size=self.conv.kernel_size[0], + padding=self.conv._padding, + stride=self.conv.strides[0], + dilation=self.conv.dilation_rate[0], + ) + return outputs, outputs_length + + def compute_mask(self, inputs, mask=None): + outputs, outputs_length = inputs + maxlen = tf.shape(outputs)[1] + maxlen, outputs_length = ( + math_util.conv_output_length( + length, + filter_size=self.conv.kernel_size[0], + padding=self.conv._padding, + stride=self.conv.strides[0], + dilation=self.conv.dilation_rate[0], + ) + for length in (maxlen, outputs_length) + ) + mask = tf.sequence_mask(outputs_length, maxlen=maxlen, dtype=tf.bool) + return mask, None + + def compute_output_shape(self, input_shape): + output_shape, output_length_shape = input_shape + output_shape = self.conv.compute_output_shape(output_shape) + output_shape = self.bn.compute_output_shape(output_shape) + output_shape = self.act.compute_output_shape(output_shape) + return output_shape, output_length_shape + + +@keras.utils.register_keras_serializable(package=__name__) +class ConvModule(keras.Model): + def __init__( + self, + conv_type: str = "conv2d", + kernels: list = [[11, 41], [11, 21], [11, 21]], + strides: list = [[2, 2], [1, 2], [1, 2]], + filters: list = [32, 32, 96], + padding: str = "causal", + activation: str = "relu", + kernel_regularizer=None, + bias_regularizer=None, + initializer=None, + **kwargs, + ): + super().__init__(**kwargs) + assert conv_type in ("conv1d", "conv2d") + assert len(kernels) == len(strides) == len(filters) + + self.pre = Reshape(name="preprocess", dtype=self.dtype) if conv_type == "conv1d" else Identity(name="iden", dtype=self.dtype) + + self.convs = [] + self.time_reduction_factor = 1 + for i in range(len(filters)): + conv_block = ConvBlock( + conv_type=conv_type, + kernels=kernels[i], + strides=strides[i], + filters=filters[i], + padding=padding, + activation=activation, + name=f"block_{i}", + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + initializer=initializer, + dtype=self.dtype, + ) + self.convs.append(conv_block) + self.time_reduction_factor *= conv_block.time_reduction_factor + + self.post = Reshape(name="postprocess", dtype=self.dtype) if conv_type == "conv2d" else Identity(name="iden", dtype=self.dtype) + + def call(self, inputs, training=False): + outputs = self.pre(inputs, training=training) + for conv in self.convs: + outputs = conv(outputs, training=training) + outputs = self.post(outputs, training=training) + return outputs + + +# ------------------------------------ RNN ----------------------------------- # + + +@keras.utils.register_keras_serializable(package=__name__) +class RnnBlock(Layer): + def __init__( + self, + rnn_type: str = "lstm", + units: int = 1024, + bidirectional: bool = True, + unroll: bool = False, + rowconv: int = 0, + rowconv_activation: str = "relu", + dropout: float = 0.1, + kernel_regularizer=None, + bias_regularizer=None, + initializer=None, + **kwargs, + ): + super().__init__(**kwargs) + self.rnn = layer_util.get_rnn(rnn_type)( + units, + unroll=unroll, + return_sequences=True, + return_state=True, + use_bias=True, + name=rnn_type, + zero_output_for_mask=True, + kernel_regularizer=kernel_regularizer, + kernel_initializer=initializer or "glorot_uniform", + bias_regularizer=bias_regularizer, + use_cudnn=env_util.TF_CUDNN, + dtype=self.dtype, + ) + self._bidirectional = bidirectional + if bidirectional: + self.rnn = keras.layers.Bidirectional(self.rnn, name=f"b{rnn_type}", dtype=self.dtype) + self.rowconv = None + if not bidirectional and rowconv > 0: + self.rowconv = RowConv1D( + future_width=rowconv, + name="rowconv", + regularizer=kernel_regularizer, + initializer=initializer, + activation=rowconv_activation, + dtype=self.dtype, + ) + self.do = Dropout(dropout, name="dropout", dtype=self.dtype) + + def get_initial_state(self, batch_size: int): + if self._bidirectional: + states = self.rnn.forward_layer.get_initial_state(batch_size) + states += self.rnn.backward_layer.get_initial_state(batch_size) + else: + states = self.rnn.get_initial_state(batch_size=batch_size) + return states + + def call(self, inputs, training=False): + outputs, outputs_length = inputs + outputs, *_ = self.rnn(outputs, training=training) # mask auto populate + if self.rowconv is not None: + outputs = self.rowconv(outputs, training=training) + outputs = self.do(outputs, training=training) + return outputs, outputs_length + + def call_next(self, inputs, previous_encoder_states): + with tf.name_scope(f"{self.name}_call_next"): + outputs, outputs_length = inputs + outputs, *_states = self.rnn(outputs, training=False, initial_state=tf.unstack(previous_encoder_states, axis=0)) + if self.rowconv is not None: + outputs = self.rowconv(outputs, training=False) + return outputs, outputs_length, tf.stack(_states) + + def compute_output_shape(self, input_shape): + output_shape, output_length_shape = input_shape + output_shape, *_ = self.rnn.compute_output_shape(output_shape) + if self.rowconv is not None: + output_shape = self.rowconv.compute_output_shape(output_shape) + return output_shape, output_length_shape + + +@keras.utils.register_keras_serializable(package=__name__) +class RnnModule(keras.Model): + def __init__( + self, + nlayers: int = 5, + rnn_type: str = "lstm", + units: int = 1024, + bidirectional: bool = True, + unroll: bool = False, + rowconv: int = 0, + rowconv_activation: str = "relu", + dropout: float = 0.1, + kernel_regularizer=None, + bias_regularizer=None, + initializer=None, + **kwargs, + ): + super().__init__(**kwargs) + self.blocks = [ + RnnBlock( + rnn_type=rnn_type, + units=units, + bidirectional=bidirectional, + unroll=unroll, + rowconv=rowconv, + rowconv_activation=rowconv_activation, + dropout=dropout, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + initializer=initializer, + name=f"block_{i}", + dtype=self.dtype, + ) + for i in range(nlayers) + ] + + def get_initial_state(self, batch_size: int): + """ + Get zeros states + + Returns + ------- + tf.Tensor, shape [B, num_rnns, nstates, state_size] + Zero initialized states + """ + states = [] + for block in self.blocks: + states.append(tf.stack(block.get_initial_state(batch_size=batch_size), axis=0)) + return tf.transpose(tf.stack(states, axis=0), perm=[2, 0, 1, 3]) + + def call(self, inputs, training=False): + outputs = inputs + for block in self.blocks: + outputs = block(outputs, training=training) + return outputs + + def call_next(self, inputs, previous_encoder_states): + outputs = inputs + previous_encoder_states = tf.transpose(previous_encoder_states, perm=[1, 2, 0, 3]) + new_states = [] + for i, block in enumerate(self.blocks): + *outputs, _states = block.call_next(outputs, previous_encoder_states=previous_encoder_states[i]) + new_states.append(_states) + return outputs, tf.transpose(tf.stack(new_states, axis=0), perm=[2, 0, 1, 3]) + + +# ------------------------------ FULLY CONNECTED ----------------------------- # + + +@keras.utils.register_keras_serializable(package=__name__) +class FcBlock(Layer): + def __init__( + self, + units: int = 1024, + activation: str = "relu", + dropout: float = 0.1, + kernel_regularizer=None, + bias_regularizer=None, + initializer="glorot_uniform", + **kwargs, + ): + super().__init__(**kwargs) + self.fc = keras.layers.Dense( + units, + kernel_regularizer=kernel_regularizer, + kernel_initializer=initializer, + bias_regularizer=bias_regularizer, + name="fc", + dtype=self.dtype, + ) + self.act = Activation(activation=activation, name=activation, dtype=self.dtype) + self.do = Dropout(dropout, name="dropout", dtype=self.dtype) + + def call(self, inputs, training=False): + outputs, outputs_length = inputs + outputs = self.fc(outputs, training=training) + outputs = self.act(outputs, training=training) + outputs = self.do(outputs, training=training) + return outputs, outputs_length + + def compute_output_shape(self, input_shape): + output_shape, output_length_shape = input_shape + output_shape = self.fc.compute_output_shape(output_shape) + return output_shape, output_length_shape + + +@keras.utils.register_keras_serializable(package=__name__) +class FcModule(keras.Model): + def __init__( + self, + nlayers: int = 0, + units: int = 1024, + activation: str = "relu", + dropout: float = 0.1, + kernel_regularizer=None, + bias_regularizer=None, + initializer=None, + **kwargs, + ): + super().__init__(**kwargs) + self.blocks = [ + FcBlock( + units=units, + activation=activation, + dropout=dropout, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + initializer=initializer, + name=f"block_{i}", + dtype=self.dtype, + ) + for i in range(nlayers) + ] + + def call(self, inputs, training=False): + outputs = inputs + for block in self.blocks: + outputs = block(outputs, training=training) + return outputs + + +@keras.utils.register_keras_serializable(package=__name__) +class DeepSpeech2Encoder(keras.Model): + def __init__( + self, + conv_type: str = "conv2d", + conv_kernels: list = [[11, 41], [11, 21], [11, 21]], + conv_strides: list = [[2, 2], [1, 2], [1, 2]], + conv_filters: list = [32, 32, 96], + conv_padding: str = "same", + conv_activation: str = "relu", + conv_initializer: str = None, + rnn_nlayers: int = 5, + rnn_type: str = "lstm", + rnn_units: int = 1024, + rnn_bidirectional: bool = True, + rnn_unroll: bool = False, + rnn_rowconv: int = 0, + rnn_rowconv_activation: str = "relu", + rnn_dropout: float = 0.1, + rnn_initializer: str = None, + fc_nlayers: int = 0, + fc_units: int = 1024, + fc_activation: str = "relu", + fc_dropout: float = 0.1, + fc_initializer: str = None, + kernel_regularizer=None, + bias_regularizer=None, + initializer=None, + **kwargs, + ): + super().__init__(**kwargs) + self.conv_module = ConvModule( + conv_type=conv_type, + kernels=conv_kernels, + strides=conv_strides, + filters=conv_filters, + padding=conv_padding, + activation=conv_activation, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + initializer=conv_initializer or initializer, + name="conv_module", + dtype=self.dtype, + ) + self.rnn_module = RnnModule( + nlayers=rnn_nlayers, + rnn_type=rnn_type, + units=rnn_units, + bidirectional=rnn_bidirectional, + unroll=rnn_unroll, + rowconv=rnn_rowconv, + rowconv_activation=rnn_rowconv_activation, + dropout=rnn_dropout, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + initializer=rnn_initializer or initializer, + name="rnn_module", + dtype=self.dtype, + ) + self.fc_module = FcModule( + nlayers=fc_nlayers, + units=fc_units, + activation=fc_activation, + dropout=fc_dropout, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + initializer=fc_initializer or initializer, + name="fc_module", + dtype=self.dtype, + ) + self.time_reduction_factor = self.conv_module.time_reduction_factor + + def get_initial_state(self, batch_size: int): + """ + Get zeros states + + Returns + ------- + tf.Tensor, shape [B, num_rnns, nstates, state_size] + Zero initialized states + """ + return self.rnn_module.get_initial_state(batch_size=batch_size) + + def call(self, inputs, training=False): + outputs = inputs + outputs = self.conv_module(outputs, training=training) + outputs = self.rnn_module(outputs, training=training) + outputs = self.fc_module(outputs, training=training) + return outputs + + def call_next(self, features, features_length, previous_encoder_states, *args, **kwargs): + """ + Recognize function for encoder network from previous encoder states + + Parameters + ---------- + features : tf.Tensor, shape [B, T, F, C] + features_length : tf.Tensor, shape [B] + previous_encoder_states : tf.Tensor, shape [B, nlayers, nstates, rnn_units] -> [nlayers, nstates, B, rnn_units] + + Returns + ------- + Tuple[tf.Tensor, tf.Tensor, tf.Tensor], shape ([B, T, dmodel], [B], [nlayers, nstates, B, rnn_units] -> [B, nlayers, nstates, rnn_units]) + """ + with tf.name_scope(f"{self.name}_call_next"): + outputs = (features, features_length) + outputs = self.conv_module(outputs, training=False) + outputs, new_encoder_states = self.rnn_module.call_next(outputs, previous_encoder_states=previous_encoder_states) + outputs, outputs_length = self.fc_module(outputs, training=False) + return outputs, outputs_length, new_encoder_states + + def compute_mask(self, inputs, mask=None): + return self.conv_module.compute_mask(inputs, mask) + + def compute_output_shape(self, input_shape): + output_shape = self.conv_module.compute_output_shape(input_shape) + output_shape = self.rnn_module.compute_output_shape(output_shape) + output_shape = self.fc_module.compute_output_shape(output_shape) + return output_shape diff --git a/tensorflow_asr/models/encoders/jasper.py b/tensorflow_asr/models/encoders/jasper.py new file mode 100644 index 0000000000..07b96dfa8d --- /dev/null +++ b/tensorflow_asr/models/encoders/jasper.py @@ -0,0 +1,359 @@ +# Copyright 2020 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + +from tensorflow_asr import keras +from tensorflow_asr.models.base_layer import Reshape +from tensorflow_asr.models.layers.convolution import Conv1D +from tensorflow_asr.models.layers.general import Dropout +from tensorflow_asr.utils import math_util + + +@keras.utils.register_keras_serializable(package=__name__) +class JasperSubBlock(keras.layers.Layer): + def __init__( + self, + channels: int = 256, + kernels: int = 11, + strides: int = 1, + dropout: float = 0.1, + padding: str = "causal", + dilation: int = 1, + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): + super().__init__(**kwargs) + self.conv1d = Conv1D( + filters=channels, + kernel_size=kernels, + strides=strides, + dilation_rate=dilation, + padding=padding, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name="conv1d", + dtype=self.dtype, + ) + self.bn = keras.layers.BatchNormalization( + name="bn", + gamma_regularizer=kernel_regularizer, + beta_regularizer=kernel_regularizer, + synchronized=True, + dtype=self.dtype, + ) + self.relu = keras.layers.ReLU(name="relu", dtype=self.dtype) + self.do = Dropout(dropout, name="dropout", dtype=self.dtype) + self.reduction_factor = strides + + def call(self, inputs, training=False): + outputs = inputs + outputs = self.conv1d(outputs, training=training) + outputs = self.bn(outputs, training=training) + outputs = self.relu(outputs, training=training) + outputs = self.do(outputs, training=training) + return outputs + + def compute_output_shape(self, input_shape): + return self.conv1d.compute_output_shape(input_shape) + + +@keras.utils.register_keras_serializable(package=__name__) +class JasperResidual(keras.layers.Layer): + def __init__( + self, + channels: int = 256, + padding: str = "causal", + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): + super().__init__(**kwargs) + self.pointwise_conv1d = Conv1D( + filters=channels, + kernel_size=1, + strides=1, + padding=padding, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name="pointwise_conv1d", + dtype=self.dtype, + ) + self.bn = keras.layers.BatchNormalization( + name="bn", + gamma_regularizer=kernel_regularizer, + beta_regularizer=kernel_regularizer, + synchronized=True, + dtype=self.dtype, + ) + + def call(self, inputs, training=False): + outputs = self.pointwise_conv1d(inputs, training=training) + outputs = self.bn(outputs, training=training) + return outputs + + def compute_output_shape(self, input_shape): + return self.pointwise_conv1d.compute_output_shape(input_shape) + + +@keras.utils.register_keras_serializable(package=__name__) +class JasperSubBlockResidual(JasperSubBlock): + def __init__( + self, + channels: int = 256, + kernels: int = 11, + strides: int = 1, + dropout: float = 0.1, + padding: str = "causal", + dilation: int = 1, + nresiduals: int = 1, + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): + super().__init__( + channels=channels, + kernels=kernels, + strides=strides, + dropout=dropout, + padding=padding, + dilation=dilation, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + **kwargs, + ) + + self.residuals = [ + JasperResidual( + channels=channels, + padding=padding, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"residual_{i}", + dtype=self.dtype, + ) + for i in range(nresiduals) + ] + + self.add = keras.layers.Add(name="add") + + def call(self, inputs, training=False): + outputs, residuals = inputs + outputs = self.conv1d(outputs, training=training) + outputs = self.bn(outputs, training=training) + for i, res in enumerate(residuals): + res = self.residuals[i](res, training=training) + outputs = self.add([outputs, res], training=training) + outputs = self.relu(outputs, training=training) + outputs = self.do(outputs, training=training) + return outputs + + +@keras.utils.register_keras_serializable(package=__name__) +class JasperBlock(keras.layers.Layer): + def __init__( + self, + nsubblocks: int = 3, + channels: int = 256, + kernels: int = 11, + dropout: float = 0.1, + padding: str = "causal", + dense: bool = False, + nresiduals: int = 1, + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): + super().__init__(**kwargs) + + self.dense = dense + + self.subblocks = [ + JasperSubBlock( + channels=channels, + kernels=kernels, + dropout=dropout, + padding=padding, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"subordinate_{i}", + dtype=self.dtype, + ) + for i in range(nsubblocks - 1) + ] + + self.subblock_residual = JasperSubBlockResidual( + channels=channels, + kernels=kernels, + dropout=dropout, + nresiduals=nresiduals, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"subordinate_{nsubblocks - 1}", + dtype=self.dtype, + ) + + self.reduction_factor = 1 + + def call(self, inputs, training=False): + inputs, residuals = inputs + outputs = inputs + for subblock in self.subblocks: + outputs = subblock(outputs, training=training) + if self.dense: + residuals.append(inputs) + outputs = self.subblock_residual([outputs, residuals], training=training) + else: + outputs = self.subblock_residual([outputs, [inputs]], training=training) + return outputs, residuals + + def compute_output_shape(self, input_shape): + output_shape, residuals_shape = input_shape + for subblock in self.subblocks: + output_shape = subblock.compute_output_shape(output_shape) + return output_shape, residuals_shape + + +@keras.utils.register_keras_serializable(package=__name__) +class JasperEncoder(keras.Model): + def __init__( + self, + dense: bool = False, + padding: str = "causal", + first_additional_block_channels: int = 256, + first_additional_block_kernels: int = 11, + first_additional_block_strides: int = 2, + first_additional_block_dilation: int = 1, + first_additional_block_dropout: int = 0.2, + nsubblocks: int = 5, + block_channels: list = [256, 384, 512, 640, 768], + block_kernels: list = [11, 13, 17, 21, 25], + block_dropout: list = [0.2, 0.2, 0.2, 0.3, 0.3], + second_additional_block_channels: int = 896, + second_additional_block_kernels: int = 1, + second_additional_block_strides: int = 1, + second_additional_block_dilation: int = 2, + second_additional_block_dropout: int = 0.4, + third_additional_block_channels: int = 1024, + third_additional_block_kernels: int = 1, + third_additional_block_strides: int = 1, + third_additional_block_dilation: int = 1, + third_additional_block_dropout: int = 0.4, + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): + super().__init__(**kwargs) + + assert len(block_channels) == len(block_kernels) == len(block_dropout) + + self.reshape = Reshape(name="reshape") + + self.first_additional_block = JasperSubBlock( + channels=first_additional_block_channels, + kernels=first_additional_block_kernels, + strides=first_additional_block_strides, + dropout=first_additional_block_dropout, + padding=padding, + dilation=first_additional_block_dilation, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name="first_block", + dtype=self.dtype, + ) + + self.blocks = [ + JasperBlock( + nsubblocks=nsubblocks, + channels=block_channels[i], + kernels=block_kernels[i], + dropout=block_dropout[i], + dense=dense, + nresiduals=(i + 1) if dense else 1, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"block_{i}", + dtype=self.dtype, + ) + for i in range(len(block_channels)) + ] + + self.second_additional_block = JasperSubBlock( + channels=second_additional_block_channels, + kernels=second_additional_block_kernels, + strides=second_additional_block_strides, + dropout=second_additional_block_dropout, + padding=padding, + dilation=second_additional_block_dilation, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name="second_block", + dtype=self.dtype, + ) + + self.third_additional_block = JasperSubBlock( + channels=third_additional_block_channels, + kernels=third_additional_block_kernels, + strides=third_additional_block_strides, + dropout=third_additional_block_dropout, + padding=padding, + dilation=third_additional_block_dilation, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name="third_block", + dtype=self.dtype, + ) + self.time_reduction_factor = self.first_additional_block.reduction_factor + self.time_reduction_factor *= self.second_additional_block.reduction_factor + self.time_reduction_factor *= self.third_additional_block.reduction_factor + + def call(self, inputs, training=False): + outputs, outputs_length = inputs + outputs, outputs_length = self.reshape((outputs, outputs_length)) + outputs = self.first_additional_block(outputs, training=training) + + residuals = [] + for block in self.blocks: + outputs, residuals = block([outputs, residuals], training=training) + + outputs = self.second_additional_block(outputs, training=training) + outputs = self.third_additional_block(outputs, training=training) + outputs_length = math_util.get_reduced_length(outputs_length, self.time_reduction_factor) + return outputs, outputs_length + + def call_next(self, features, features_length, previous_encoder_states, *args, **kwargs): + """ + Recognize function for encoder network from previous encoder states + + Parameters + ---------- + features : tf.Tensor, shape [B, T, F, C] + features_length : tf.Tensor, shape [B] + previous_encoder_states : tf.Tensor, shape [B, nlayers, nstates, rnn_units] -> [nlayers, nstates, B, rnn_units] + + Returns + ------- + Tuple[tf.Tensor, tf.Tensor, tf.Tensor], shape ([B, T, dmodel], [B], [nlayers, nstates, B, rnn_units] -> [B, nlayers, nstates, rnn_units]) + """ + with tf.name_scope(f"{self.name}_call_next"): + return self.call((features, features_length), training=False) + + def compute_output_shape(self, input_shape): + inputs_shape, inputs_length_shape = input_shape + outputs_time = None if inputs_shape[1] is None else math_util.legacy_get_reduced_length(inputs_shape[1], self.time_reduction_factor) + outputs_batch = inputs_shape[0] + outputs_size = self.third_additional_block.conv1d.filters + outputs_shape = [outputs_batch, outputs_time, outputs_size] + return tuple(outputs_shape), tuple(inputs_length_shape) diff --git a/tensorflow_asr/models/encoders/rnnt.py b/tensorflow_asr/models/encoders/rnnt.py new file mode 100644 index 0000000000..844737fa25 --- /dev/null +++ b/tensorflow_asr/models/encoders/rnnt.py @@ -0,0 +1,224 @@ +# Copyright 2020 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" http://arxiv.org/abs/1811.06621 """ + +import typing + +from keras.src import backend + +from tensorflow_asr import keras, tf +from tensorflow_asr.models.base_layer import Reshape +from tensorflow_asr.models.layers.subsampling import TimeReduction +from tensorflow_asr.utils import env_util, layer_util, math_util + + +@keras.utils.register_keras_serializable(package=__name__) +class RnnTransducerBlock(keras.Model): + def __init__( + self, + reduction_position: str = "pre", + reduction_factor: int = 0, + dmodel: int = 640, + rnn_type: str = "lstm", + rnn_units: int = 2048, + rnn_unroll: bool = False, + layer_norm: bool = True, + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): + super().__init__(**kwargs) + assert reduction_position in ["post", "pre"], "reduction_position must be 'post' or 'pre'" + self._reduction_position = reduction_position + self.reduction = TimeReduction(reduction_factor, name="reduction", dtype=self.dtype) if reduction_factor > 0 else None + self.rnn: typing.Union[keras.layers.GRU, keras.layers.LSTM, keras.layers.SimpleRNN] = layer_util.get_rnn(rnn_type)( + units=rnn_units, + return_sequences=True, + name=rnn_type, + unroll=rnn_unroll, + return_state=True, + zero_output_for_mask=True, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + use_cudnn=env_util.TF_CUDNN, + dtype=self.dtype, + ) + self.ln = ( + keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype) + if layer_norm + else None + ) + self.projection = keras.layers.Dense( + dmodel, + name="projection", + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + dtype=self.dtype, + ) + + def call(self, inputs, training=False): + outputs, outputs_length = inputs + if self._reduction_position == "pre": + if self.reduction is not None: + outputs, outputs_length = self.reduction((outputs, outputs_length)) + outputs, *_ = self.rnn(outputs, training=training) + if self.ln is not None: + outputs = self.ln(outputs, training=training) + outputs = self.projection(outputs, training=training) + if self._reduction_position == "post": + if self.reduction is not None: + outputs, outputs_length = self.reduction((outputs, outputs_length)) + return outputs, outputs_length + + def compute_mask(self, inputs, mask=None): + if self.reduction is not None: + mask = self.reduction.compute_mask(inputs) + return mask + + def call_next(self, inputs, inputs_length, previous_encoder_states): + """ + Recognize function for encoder network from the previous encoder states + + Parameters + ---------- + inputs : tf.Tensor, shape [B, T, E] + previous_encoder_states : tf.Tensor, shape [nstates, B, rnn_units] + + Returns + ------- + Tuple[tf.Tensor, tf.Tensor, tf.Tensor], shapes ([B, T, dmodel], [B], [nstates, B, rnn_units]) + """ + with tf.name_scope(f"{self.name}_call_next"): + outputs, outputs_length = inputs, inputs_length + if self._reduction_position == "pre": + if self.reduction is not None: + outputs, outputs_length = self.reduction([outputs, outputs_length]) + outputs, *_states = self.rnn( + outputs, + training=False, + initial_state=tf.unstack(previous_encoder_states, axis=0), + mask=backend.get_keras_mask(inputs), + ) + new_states = tf.stack(_states, axis=0) + if self.ln is not None: + outputs = self.ln(outputs, training=False) + outputs = self.projection(outputs, training=False) + if self._reduction_position == "post": + if self.reduction is not None: + outputs, outputs_length = self.reduction([outputs, outputs_length]) + return outputs, outputs_length, new_states + + def compute_output_shape(self, input_shape): + output_shape, output_length_shape = input_shape + if self.reduction is not None: + output_shape, output_length_shape = self.reduction.compute_output_shape((output_shape, output_length_shape)) + output_shape = self.projection.compute_output_shape(output_shape) + return output_shape, output_length_shape + + +@keras.utils.register_keras_serializable(package=__name__) +class RnnTransducerEncoder(keras.Model): + def __init__( + self, + reduction_positions: list = ["pre", "pre", "pre", "pre", "pre", "pre", "pre", "pre"], + reduction_factors: list = [6, 0, 0, 0, 0, 0, 0, 0], + dmodel: int = 640, + nlayers: int = 8, + rnn_type: str = "lstm", + rnn_units: int = 2048, + rnn_unroll: bool = False, + layer_norm: bool = True, + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): + super().__init__(**kwargs) + assert len(reduction_positions) == nlayers, "reduction_positions length must be equal to nlayers" + assert len(reduction_factors) == nlayers, "reduction_factors length must be equal to nlayers" + self.reshape = Reshape(name="reshape", dtype=self.dtype) + + self.time_reduction_factor = 1 + self.blocks: typing.List[RnnTransducerBlock] = [] + for i in range(nlayers): + block = RnnTransducerBlock( + reduction_position=reduction_positions[i], + reduction_factor=reduction_factors[i], + dmodel=dmodel, + rnn_type=rnn_type, + rnn_units=rnn_units, + rnn_unroll=rnn_unroll, + layer_norm=layer_norm, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=f"block_{i}", + dtype=self.dtype, + ) + self.blocks.append(block) + self.time_reduction_factor *= getattr(block.reduction, "time_reduction_factor", 1) + + def get_initial_state(self, batch_size=1): + """Get zeros states + + Returns: + tf.Tensor, shape [B, num_rnns, nstates, state_size] + Zero initialized states + """ + states = [] + for block in self.blocks: + states.append(tf.stack(block.rnn.get_initial_state(batch_size=batch_size), axis=0)) + return tf.transpose(tf.stack(states, axis=0), perm=[2, 0, 1, 3]) + + def call(self, inputs, training=False): + outputs, outputs_length = inputs + outputs, outputs_length = self.reshape((outputs, outputs_length)) + for block in self.blocks: + outputs, outputs_length = block((outputs, outputs_length), training=training) + return outputs, outputs_length + + def call_next(self, features, features_length, previous_encoder_states, *args, **kwargs): + """ + Recognize function for encoder network from previous encoder states + + Parameters + ---------- + features : tf.Tensor, shape [B, T, F, C] + features_length : tf.Tensor, shape [B] + previous_encoder_states : tf.Tensor, shape [B, nlayers, nstates, rnn_units] -> [nlayers, nstates, B, rnn_units] + + Returns + ------- + Tuple[tf.Tensor, tf.Tensor, tf.Tensor], shape ([B, T, dmodel], [B], [nlayers, nstates, B, rnn_units] -> [B, nlayers, nstates, rnn_units]) + """ + with tf.name_scope(f"{self.name}_call_next"): + previous_encoder_states = tf.transpose(previous_encoder_states, perm=[1, 2, 0, 3]) + outputs, outputs_length = self.reshape((features, features_length)) + new_states = [] + for i, block in enumerate(self.blocks): + outputs, outputs_length, block_states = block.call_next(outputs, outputs_length, previous_encoder_states=previous_encoder_states[i]) + new_states.append(block_states) + return outputs, outputs_length, tf.transpose(tf.stack(new_states, axis=0), perm=[2, 0, 1, 3]) + + def compute_mask(self, inputs, mask=None): + outputs, outputs_length = inputs + maxlen = tf.shape(outputs)[1] + maxlen, outputs_length = (math_util.get_reduced_length(length, self.time_reduction_factor) for length in (maxlen, outputs_length)) + mask = tf.sequence_mask(outputs_length, maxlen=maxlen, dtype=tf.bool) + return mask, None + + def compute_output_shape(self, input_shape): + output_shape = input_shape + output_shape = self.reshape.compute_output_shape(output_shape) + for block in self.blocks: + output_shape = block.compute_output_shape(output_shape) + return output_shape diff --git a/tensorflow_asr/models/encoders/transformer.py b/tensorflow_asr/models/encoders/transformer.py index 80cdceaa87..cde9e03daf 100644 --- a/tensorflow_asr/models/encoders/transformer.py +++ b/tensorflow_asr/models/encoders/transformer.py @@ -13,16 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf - +from tensorflow_asr import keras, tf from tensorflow_asr.models.base_layer import Layer +from tensorflow_asr.models.layers.general import Dropout from tensorflow_asr.models.layers.multihead_attention import MultiHeadAttention, MultiHeadRelativeAttention -from tensorflow_asr.models.layers.positional_encoding import PositionalEncoding, RelativePositionalEncoding +from tensorflow_asr.models.layers.positional_encoding import RelativeSinusoidalPositionalEncoding, SinusoidalPositionalEncoding from tensorflow_asr.models.layers.residual import Residual from tensorflow_asr.models.layers.subsampling import Conv1dSubsampling, Conv2dSubsampling, VggSubsampling +from tensorflow_asr.utils import data_util -class Pointwiseffn(Layer): +@keras.utils.register_keras_serializable(package=__name__) +class PointwiseFFN(Layer): def __init__( self, dmodel, @@ -33,18 +35,20 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - self.ffn1 = tf.keras.layers.Dense( + self.ffn1 = keras.layers.Dense( units=dff, activation=activation, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name="ffn_1", + dtype=self.dtype, ) - self.ffn2 = tf.keras.layers.Dense( + self.ffn2 = keras.layers.Dense( units=dmodel, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name="ffn_2", + dtype=self.dtype, ) def call(self, inputs, training=False): @@ -52,8 +56,12 @@ def call(self, inputs, training=False): outputs = self.ffn2(outputs, training=training) return outputs + def compute_output_shape(self, input_shape): + return input_shape[:-1] + (self.ffn2.units,) + -class TransformerBlock(Layer): +@keras.utils.register_keras_serializable(package=__name__) +class TransformerBlock(keras.Model): def __init__( self, dmodel, @@ -61,11 +69,16 @@ def __init__( num_heads, head_size, mha_type="mha", + relmha_causal=False, + flash_attention=None, norm_position="post", residual_factor=1.0, pwffn_activation="relu", dropout=0.1, memory_length=None, + history_size=None, + chunk_size=None, + use_attention_bias=False, kernel_regularizer=None, bias_regularizer=None, **kwargs, @@ -78,7 +91,9 @@ def __init__( self.norm1 = ( None if self._norm_position == "none" - else tf.keras.layers.LayerNormalization(beta_regularizer=kernel_regularizer, gamma_regularizer=bias_regularizer, name="ln_1") + else keras.layers.LayerNormalization( + beta_regularizer=kernel_regularizer, gamma_regularizer=bias_regularizer, name="ln_1", dtype=self.dtype + ) ) self.mha = ( MultiHeadAttention( @@ -86,69 +101,79 @@ def __init__( key_dim=head_size, output_shape=dmodel, memory_length=memory_length, + history_size=history_size, + chunk_size=chunk_size, + flash_attention=flash_attention, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - dtype=tf.float32, # stable training name="mhsa", + dtype=self.dtype, ) if mha_type == "mha" else MultiHeadRelativeAttention( + causal=relmha_causal, num_heads=num_heads, key_dim=head_size, output_shape=dmodel, memory_length=memory_length, + history_size=history_size, + chunk_size=chunk_size, + flash_attention=flash_attention, + use_attention_bias=use_attention_bias, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - dtype=tf.float32, # stable training name="mhsa", + dtype=self.dtype, ) ) - self.do1 = tf.keras.layers.Dropout(dropout, name="do_1") - self.residual1 = Residual(factor=residual_factor, regularizer=bias_regularizer, name="residual_1") + self.do1 = Dropout(dropout, name="do_1", dtype=self.dtype) + self.residual1 = Residual(factor=residual_factor, regularizer=bias_regularizer, name="residual_1", dtype=self.dtype) self.norm2 = ( None if self._norm_position == "none" - else tf.keras.layers.LayerNormalization(beta_regularizer=kernel_regularizer, gamma_regularizer=bias_regularizer, name="ln_2") + else keras.layers.LayerNormalization( + beta_regularizer=kernel_regularizer, gamma_regularizer=bias_regularizer, name="ln_2", dtype=self.dtype + ) ) - self.pwffn = Pointwiseffn( + self.pwffn = PointwiseFFN( dmodel=dmodel, dff=dff, activation=pwffn_activation, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name="pwffn", + dtype=self.dtype, ) - self.do2 = tf.keras.layers.Dropout(dropout, name="do_2") - self.residual2 = Residual(factor=residual_factor, regularizer=bias_regularizer, name="residual_2") + self.do2 = Dropout(dropout, name="do_2", dtype=self.dtype) + self.residual2 = Residual(factor=residual_factor, regularizer=bias_regularizer, name="residual_2", dtype=self.dtype) + + def get_initial_state(self, batch_size): + return self.mha.get_initial_state(batch_size) def call( self, inputs, - relative_position_encoding=None, content_attention_bias=None, positional_attention_bias=None, + initial_state=None, training=False, attention_mask=None, use_causal_mask=False, use_auto_mask=True, + return_states=False, ): - original_outputs = inputs + original_outputs, relative_position_encoding = inputs outputs = self.norm1(original_outputs, training=training) if self._norm_position == "pre" else original_outputs - mha_inputs = ( - dict( - inputs=[outputs, outputs, outputs, relative_position_encoding], - content_attention_bias=content_attention_bias, - positional_attention_bias=positional_attention_bias, - ) - if self._mha_type == "relmha" - else dict(inputs=[outputs, outputs, outputs]) - ) - outputs = self.mha( - **mha_inputs, + outputs, *states = self.mha( + [outputs, outputs, outputs, relative_position_encoding], + content_attention_bias=content_attention_bias, + positional_attention_bias=positional_attention_bias, + initial_state=initial_state, training=training, attention_mask=attention_mask, use_causal_mask=use_causal_mask, use_auto_mask=use_auto_mask, + return_states=return_states, ) outputs = self.do1(outputs, training=training) outputs = self.norm1(outputs, training=training) if self._norm_position == "post" else outputs @@ -158,10 +183,17 @@ def call( outputs = self.do2(outputs, training=training) outputs = self.norm2(outputs, training=training) if self._norm_position == "post" else outputs outputs = self.residual2([original_outputs, outputs], training=training) - return outputs + if return_states: + return (outputs,) + states + return (outputs,) + + def compute_output_shape(self, input_shape): + output_shape, *_ = input_shape + return output_shape -class TransformerEncoder(Layer): +@keras.utils.register_keras_serializable(package=__name__) +class TransformerEncoder(keras.Model): def __init__( self, subsampling, @@ -172,13 +204,18 @@ def __init__( head_size=128, dropout=0.1, mha_type="mha", + relmha_causal=False, norm_position="post", residual_factor=1.0, interleave_relpe=True, use_attention_causal_mask=False, use_attention_auto_mask=True, + use_attention_bias=False, pwffn_activation="relu", memory_length=None, + history_size=None, + chunk_size=None, + flash_attention=None, kernel_regularizer=None, bias_regularizer=None, name="transformer_encoder", @@ -188,6 +225,8 @@ def __init__( self._use_attention_causal_mask = use_attention_causal_mask self._use_attention_auto_mask = use_attention_auto_mask self._num_blocks = num_blocks + self._dmodel = dmodel + self._memory_length = memory_length subsampling_name = subsampling.pop("type", None) if subsampling_name == "vgg": @@ -203,15 +242,28 @@ def __init__( name="subsampling", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + dtype=self.dtype, ) self.time_reduction_factor = self.subsampling.time_reduction_factor - self.linear = tf.keras.layers.Dense(units=dmodel, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name="linear") - self.do = tf.keras.layers.Dropout(dropout, name="dropout") + self.linear = keras.layers.Dense( + units=dmodel, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name="linear", + dtype=self.dtype, + ) + self.do = Dropout(dropout, name="dropout", dtype=self.dtype) if mha_type == "relmha": - self.relpe = RelativePositionalEncoding(interleave=interleave_relpe, memory_length=memory_length, name="relpe") + self.relpe = RelativeSinusoidalPositionalEncoding( + interleave=interleave_relpe, + memory_length=memory_length, + causal=relmha_causal, + name="relpe", + dtype=self.dtype, + ) else: - self.relpe = PositionalEncoding(interleave=interleave_relpe, name="pe") + self.relpe = SinusoidalPositionalEncoding(interleave=interleave_relpe, name="pe", dtype=self.dtype) self.blocks = [ TransformerBlock( @@ -220,25 +272,32 @@ def __init__( num_heads=num_heads, head_size=head_size, mha_type=mha_type, + relmha_causal=relmha_causal, norm_position=norm_position, residual_factor=residual_factor, pwffn_activation=pwffn_activation, dropout=dropout, memory_length=memory_length, + history_size=history_size, + chunk_size=chunk_size, + flash_attention=flash_attention, + use_attention_bias=use_attention_bias, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, name=f"block_{i}", + dtype=self.dtype, ) for i in range(self._num_blocks) ] - if mha_type == "relmha": + if mha_type == "relmha" and not use_attention_bias: self.content_attention_bias = self.add_weight( name="content_attention_bias", shape=[num_heads, head_size], trainable=True, initializer="zeros", regularizer=bias_regularizer, + dtype=self.variable_dtype, ) self.positional_attention_bias = self.add_weight( name="positional_attention_bias", @@ -246,42 +305,71 @@ def __init__( trainable=True, initializer="zeros", regularizer=bias_regularizer, + dtype=self.variable_dtype, ) else: self.content_attention_bias, self.positional_attention_bias = None, None - def get_states(self): - return [block.mha.get_states() for block in self.blocks] - - def reset_states(self, states=None): - if states is None: - states = [(None, None) for _ in range(self._num_blocks)] - for i, memory_states in enumerate(states): - self.blocks[i].mha.reset_states(memory_states) + def get_initial_state(self, batch_size): + return [block.get_initial_state(batch_size) for block in self.blocks] - def call(self, inputs, training=False): + def call( + self, + inputs, + initial_state=None, + training=False, + return_states=False, + ): outputs, outputs_length = inputs outputs, outputs_length = self.subsampling([outputs, outputs_length], training=training) outputs = self.linear(outputs, training=training) - outputs, relative_position_encoding = self.relpe(outputs, training=training) + outputs, relative_position_encoding = self.relpe([outputs, outputs_length], training=training) outputs = self.do(outputs, training=training) - for block in self.blocks: - outputs = block( - outputs, - relative_position_encoding=relative_position_encoding, + states = None if self._memory_length is None else [] + for i, block in enumerate(self.blocks): + outputs, *_states = block( + [outputs, relative_position_encoding], content_attention_bias=self.content_attention_bias, positional_attention_bias=self.positional_attention_bias, + initial_state=data_util.get(initial_state, i, None), training=training, use_causal_mask=self._use_attention_causal_mask, use_auto_mask=self._use_attention_auto_mask, + return_states=return_states, ) + if not _states: + continue + states.extend(_states) + if return_states: + return outputs, outputs_length, states return outputs, outputs_length + def call_next(self, features, features_length, previous_encoder_states, *args, **kwargs): + """ + Recognize function for encoder network + + Parameters + ---------- + features : tf.Tensor, shape [B, T, F, C] + features_length : tf.Tensor, shape [B] + + Returns + ------- + Tuple[tf.Tensor, tf.Tensor, tf.Tensor], shape ([B, T, dmodel], [B], None) + Outputs, outputs_length, new_states + """ + with tf.name_scope(f"{self.name}_call_next"): + return self.call((features, features_length), initial_state=previous_encoder_states, training=False) + + def compute_mask(self, inputs, mask=None): + return self.subsampling.compute_mask(inputs, mask=mask) + def compute_output_shape(self, input_shape): - output_shape, output_length_shape = self.subsampling.compute_output_shape(input_shape) + output_shape, output_length_shape = input_shape + output_shape, output_length_shape = self.subsampling.compute_output_shape((output_shape, output_length_shape)) output_shape = self.linear.compute_output_shape(output_shape) - output_shape, _ = self.relpe.compute_output_shape(output_shape) + output_shape, relative_position_encoding_shape = self.relpe.compute_output_shape((output_shape, output_length_shape)) output_shape = self.do.compute_output_shape(output_shape) for block in self.blocks: - output_shape = block.compute_output_shape(output_shape) + output_shape = block.compute_output_shape((output_shape, relative_position_encoding_shape, None, None)) return output_shape, output_length_shape diff --git a/tensorflow_asr/models/layers/__init__.py b/tensorflow_asr/models/layers/__init__.py index e69de29bb2..9139bde684 100755 --- a/tensorflow_asr/models/layers/__init__.py +++ b/tensorflow_asr/models/layers/__init__.py @@ -0,0 +1,13 @@ +import glob +from os.path import basename, dirname, isdir, isfile, join + +for fd in glob.glob(join(dirname(__file__), "*")): + if not isfile(fd) and not isdir(fd): + continue + if isfile(fd) and not fd.endswith(".py"): + continue + fd = fd if isdir(fd) else fd[:-3] + fd = basename(fd) + if fd.startswith("__"): + continue + __import__(f"{__name__}.{fd}") diff --git a/tensorflow_asr/models/layers/blurpool.py b/tensorflow_asr/models/layers/blurpool.py index 81f0a5c371..372c62c5fc 100644 --- a/tensorflow_asr/models/layers/blurpool.py +++ b/tensorflow_asr/models/layers/blurpool.py @@ -14,11 +14,12 @@ # limitations under the License. import numpy as np -import tensorflow as tf +from tensorflow_asr import keras, tf from tensorflow_asr.models.base_layer import Layer +@keras.utils.register_keras_serializable(package=__name__) class BlurPool2D(Layer): def __init__( self, @@ -29,10 +30,9 @@ def __init__( trainable=True, name="blurpool2d", dtype=None, - dynamic=False, **kwargs, ): - super().__init__(trainable, name, dtype, dynamic, **kwargs) + super().__init__(trainable=trainable, name=name, dtype=dtype, **kwargs) self.filters = filters self.kernel_size = kernel_size self.strides = strides @@ -61,6 +61,8 @@ def __init__( a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0]) elif self.kernel_size == 7: a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0]) + else: + raise ValueError("Kernel size must be in [1, 2, 3, 4, 5, 6, 7]") self.kernel = tf.constant(a[:, None] * a[None, :], dtype=self.compute_dtype) self.kernel = tf.divide(self.kernel, tf.reduce_sum(self.kernel)) @@ -74,6 +76,7 @@ def call(self, inputs): return tf.nn.conv2d(inputs, filters=kernel, strides=self.strides, padding="VALID") +@keras.utils.register_keras_serializable(package=__name__) class BlurPool1D(Layer): def __init__( self, @@ -84,10 +87,9 @@ def __init__( trainable=True, name="blurpool1d", dtype=None, - dynamic=False, **kwargs, ): - super().__init__(trainable, name, dtype, dynamic, **kwargs) + super().__init__(trainable=trainable, name=name, dtype=dtype, **kwargs) self.filters = filters self.kernel_size = kernel_size self.strides = strides @@ -115,6 +117,8 @@ def __init__( a = np.array([1.0, 5.0, 10.0, 10.0, 5.0, 1.0]) elif self.kernel_size == 7: a = np.array([1.0, 6.0, 15.0, 20.0, 15.0, 6.0, 1.0]) + else: + raise ValueError("Kernel size must be in [1, 2, 3, 4, 5, 6, 7]") self.kernel = tf.constant(a, dtype=self.compute_dtype) self.kernel = tf.divide(self.kernel, tf.reduce_sum(self.kernel)) diff --git a/tensorflow_asr/models/layers/convolution.py b/tensorflow_asr/models/layers/convolution.py index dec4c1492b..0006e9ba18 100644 --- a/tensorflow_asr/models/layers/convolution.py +++ b/tensorflow_asr/models/layers/convolution.py @@ -17,45 +17,146 @@ Causal padding supported Conv1D, Conv2D, DepthwiseConv1D, DepthwiseConv2D """ -import tensorflow as tf -from keras.layers.convolutional.base_conv import Conv +from keras.src.ops.operation_utils import compute_conv_output_shape - -def _validate_init(self): # removed check padding causal - if self.filters is not None and self.filters % self.groups != 0: - raise ValueError( - f"The number of filters must be evenly divisible by the number of groups. Received: groups={self.groups}, filters={self.filters}" - ) - if not all(self.kernel_size): - raise ValueError(f"The argument `kernel_size` cannot contain 0(s). Received: {(self.kernel_size,)}") - if not all(self.strides): - raise ValueError(f"The argument `strides` cannot contains 0(s). Received: {(self.strides,)}") +from tensorflow_asr import keras, tf -def _compute_causal_padding(self, inputs): +def _compute_causal_padding(inputs, rank, data_format, dilation_rate, kernel_size): """Calculates padding for 'causal' option for 1-d and 2-d conv layers.""" batch_pad = [[0, 0]] channel_pad = [[0, 0]] - height_pad = [[self.dilation_rate[0] * (self.kernel_size[0] - 1), 0]] - if self.rank == 1: - if self.data_format == "channels_last": + height_pad = [[dilation_rate[0] * (kernel_size[0] - 1), 0]] + if rank == 1: + if data_format == "channels_last": return batch_pad + height_pad + channel_pad return batch_pad + channel_pad + height_pad - width_pad = [[self.dilation_rate[1] * (self.kernel_size[1] - 1), 0]] - if self.data_format == "channels_last": + width_pad = [[dilation_rate[1] * (kernel_size[1] - 1), 0]] + if data_format == "channels_last": return batch_pad + height_pad + width_pad + channel_pad return batch_pad + channel_pad + height_pad + width_pad -# Monkey patch -Conv._validate_init = _validate_init -Conv._compute_causal_padding = _compute_causal_padding +@keras.utils.register_keras_serializable(package=__name__) +class Conv1D(keras.layers.Conv1D): + def __init__( + self, + filters, + kernel_size, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, + groups=1, + activation=None, + use_bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + **kwargs, + ): + super().__init__( + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + groups, + activation, + use_bias, + kernel_initializer, + bias_initializer, + kernel_regularizer, + bias_regularizer, + activity_regularizer, + kernel_constraint, + bias_constraint, + **kwargs, + ) + self._padding = padding -import keras.layers.convolutional -from keras.layers.convolutional import Conv1D, Conv2D # pylint: disable=unused-import +@keras.utils.register_keras_serializable(package=__name__) +class Conv2D(keras.layers.Conv2D): + def __init__( + self, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1), + groups=1, + activation=None, + use_bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + bias_constraint=None, + **kwargs, + ): + self._padding = padding + if padding == "causal": + self._is_causal = True + padding = "valid" + else: + self._is_causal = False + super().__init__( + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + groups, + activation, + use_bias, + kernel_initializer, + bias_initializer, + kernel_regularizer, + bias_regularizer, + activity_regularizer, + kernel_constraint, + bias_constraint, + **kwargs, + ) + + def call(self, inputs): + if self._is_causal: + inputs = tf.pad( + inputs, + _compute_causal_padding( + inputs, + rank=self.rank, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + kernel_size=self.kernel_size, + ), + ) + return super().call(inputs) + + def compute_output_shape(self, input_shape): + return compute_conv_output_shape( + input_shape, + self.filters, + self.kernel_size, + strides=self.strides, + padding="causal" if self._is_causal else self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) -class DepthwiseConv1D(keras.layers.convolutional.DepthwiseConv1D): + +@keras.utils.register_keras_serializable(package=__name__) +class DepthwiseConv1D(keras.layers.DepthwiseConv1D): def __init__( self, kernel_size, @@ -75,6 +176,12 @@ def __init__( bias_constraint=None, **kwargs, ): + self._padding = padding + if padding == "causal": + self._is_causal = True + padding = "valid" + else: + self._is_causal = False super().__init__( kernel_size, strides, @@ -93,24 +200,44 @@ def __init__( bias_constraint, **kwargs, ) - if self._is_causal: - self.padding = "VALID" def call(self, inputs): if self._is_causal: - inputs = tf.pad(inputs, self._compute_causal_padding(inputs)) + inputs = tf.pad( + inputs, + _compute_causal_padding( + inputs, + rank=self.rank, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + kernel_size=self.kernel_size, + ), + ) return super().call(inputs) + def compute_output_shape(self, input_shape): + input_channel = self._get_input_channel(input_shape) + return compute_conv_output_shape( + input_shape, + self.depth_multiplier * input_channel, + self.kernel_size, + strides=self.strides, + padding="causal" if self._is_causal else self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + -class DepthwiseConv2D(keras.layers.convolutional.DepthwiseConv2D): +@keras.utils.register_keras_serializable(package=__name__) +class DepthwiseConv2D(keras.layers.DepthwiseConv2D): def __init__( self, kernel_size, - strides=..., + strides=(1, 1), padding="valid", depth_multiplier=1, data_format=None, - dilation_rate=..., + dilation_rate=(1, 1), activation=None, use_bias=True, depthwise_initializer="glorot_uniform", @@ -122,6 +249,12 @@ def __init__( bias_constraint=None, **kwargs, ): + self._padding = padding + if padding == "causal": + self._is_causal = True + padding = "valid" + else: + self._is_causal = False super().__init__( kernel_size, strides, @@ -140,10 +273,189 @@ def __init__( bias_constraint, **kwargs, ) + + def call(self, inputs): if self._is_causal: - self.padding = "VALID" + inputs = tf.pad( + inputs, + _compute_causal_padding( + inputs, + rank=self.rank, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + kernel_size=self.kernel_size, + ), + ) + return super().call(inputs) + + def compute_output_shape(self, input_shape): + input_channel = self._get_input_channel(input_shape) + return compute_conv_output_shape( + input_shape, + self.depth_multiplier * input_channel, + self.kernel_size, + strides=self.strides, + padding="causal" if self._is_causal else self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + + +@keras.utils.register_keras_serializable(package=__name__) +class SeparableConv1D(keras.layers.SeparableConv1D): + def __init__( + self, + filters, + kernel_size, + strides=1, + padding="valid", + data_format=None, + dilation_rate=1, + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + pointwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + **kwargs, + ): + self._padding = padding + if padding == "causal": + self._is_causal = True + padding = "valid" + else: + self._is_causal = False + super().__init__( + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + depth_multiplier, + activation, + use_bias, + depthwise_initializer, + pointwise_initializer, + bias_initializer, + depthwise_regularizer, + pointwise_regularizer, + bias_regularizer, + activity_regularizer, + depthwise_constraint, + pointwise_constraint, + bias_constraint, + **kwargs, + ) + + def call(self, inputs): + if self._is_causal: + inputs = tf.pad( + inputs, + _compute_causal_padding( + inputs, + rank=self.rank, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + kernel_size=self.kernel_size, + ), + ) + return super().call(inputs) + + def compute_output_shape(self, input_shape): + return compute_conv_output_shape( + input_shape, + self.filters, + self.kernel_size, + strides=self.strides, + padding="causal" if self._is_causal else self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + + +@keras.utils.register_keras_serializable(package=__name__) +class SeparableConv2D(keras.layers.SeparableConv2D): + def __init__( + self, + filters, + kernel_size, + strides=..., + padding="valid", + data_format=None, + dilation_rate=..., + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + pointwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + **kwargs, + ): + self._padding = padding + if padding == "causal": + self._is_causal = True + padding = "valid" + else: + self._is_causal = False + super().__init__( + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + depth_multiplier, + activation, + use_bias, + depthwise_initializer, + pointwise_initializer, + bias_initializer, + depthwise_regularizer, + pointwise_regularizer, + bias_regularizer, + activity_regularizer, + depthwise_constraint, + pointwise_constraint, + bias_constraint, + **kwargs, + ) def call(self, inputs): if self._is_causal: - inputs = tf.pad(inputs, self._compute_causal_padding(inputs)) + inputs = tf.pad( + inputs, + _compute_causal_padding( + inputs, + rank=self.rank, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + kernel_size=self.kernel_size, + ), + ) return super().call(inputs) + + def compute_output_shape(self, input_shape): + return compute_conv_output_shape( + input_shape, + self.filters, + self.kernel_size, + strides=self.strides, + padding="causal" if self._is_causal else self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) diff --git a/tensorflow_asr/models/layers/embedding.py b/tensorflow_asr/models/layers/embedding.py index d9f47fd385..64685f2ac4 100644 --- a/tensorflow_asr/models/layers/embedding.py +++ b/tensorflow_asr/models/layers/embedding.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf +from tensorflow_asr import keras, tf +from tensorflow_asr.models.base_layer import Layer -class Embedding(tf.keras.layers.Embedding): +@keras.utils.register_keras_serializable(package=__name__) +class Embedding(keras.layers.Embedding): def __init__( self, vocab_size, @@ -36,6 +38,56 @@ def __init__( ) self.supports_masking = True - def recognize_tflite(self, inputs): + def call(self, inputs): + outputs, outputs_length = inputs + outputs = super().call(outputs) + return outputs, outputs_length + + def call_next(self, inputs): outputs = tf.cast(tf.expand_dims(inputs, axis=-1), dtype=tf.int32) return tf.gather_nd(self.embeddings, outputs) # https://github.com/tensorflow/tensorflow/issues/42410 + + def compute_mask(self, inputs, mask=None): + outputs, outputs_length = inputs + mask = tf.sequence_mask(outputs_length, maxlen=tf.shape(outputs)[1], dtype=tf.bool) + return mask, None + + def compute_output_shape(self, input_shape): + output_shape, output_length_shape = input_shape + output_shape = super().compute_output_shape(output_shape) + return output_shape, output_length_shape + + +@keras.utils.register_keras_serializable(package=__name__) +class OneHotBlank(Layer): + """ + https://arxiv.org/pdf/1211.3711.pdf + The inputs are encoded as one-hot vectors; + that is, if Y consists of K labels and yu = k, then y^u is a length K vector whose elements are all zero + except the k-th, which is one. ∅ is encoded as a length K vector of zeros + """ + + def __init__(self, blank, depth, name="one_hot_blank", **kwargs): + super().__init__(name=name, **kwargs) + self.blank = blank + self.depth = depth + + def call(self, inputs): + outputs, outputs_length = inputs + minus_one_at_blank = tf.where(tf.equal(outputs, self.blank), -1, outputs) + outputs = tf.one_hot(minus_one_at_blank, depth=self.depth, dtype=self.dtype) + return outputs, outputs_length + + def call_next(self, inputs): + outputs, _ = self.call((inputs, None)) + return outputs + + def compute_mask(self, inputs, mask=None): + outputs, outputs_length = inputs + mask = tf.sequence_mask(outputs_length, maxlen=tf.shape(outputs)[1], dtype=tf.bool) + return mask, None + + def compute_output_shape(self, input_shape): + output_shape, output_length_shape = input_shape + output_shape = output_shape + (self.depth,) + return output_shape, output_length_shape diff --git a/tensorflow_asr/models/layers/feature_extraction.py b/tensorflow_asr/models/layers/feature_extraction.py new file mode 100644 index 0000000000..0a12751be1 --- /dev/null +++ b/tensorflow_asr/models/layers/feature_extraction.py @@ -0,0 +1,331 @@ +# Copyright 2023 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import asdict, dataclass + +from tensorflow_asr import keras, tf +from tensorflow_asr.augmentations.augmentation import Augmentation +from tensorflow_asr.features import gammatone +from tensorflow_asr.models.base_layer import Layer +from tensorflow_asr.utils import math_util + + +@dataclass +class FEATURE_TYPES: + SPECTROGRAM: str = "spectrogram" + LOG_MEL_SPECTROGRAM: str = "log_mel_spectrogram" + MFCC: str = "mfcc" + LOG_GAMMATONE_SPECTROGRAM: str = "log_gammatone_spectrogram" + + +@keras.utils.register_keras_serializable(package=__name__) +class FeatureExtraction(Layer): + def __init__( + self, + sample_rate=16000, + frame_ms=25, + stride_ms=10, + num_feature_bins=80, + feature_type="log_mel_spectrogram", + preemphasis=0.97, + pad_end=True, + use_librosa_like_stft=False, + epsilon=1e-6, + lower_edge_hertz=0.0, + upper_edge_hertz=8000.0, + log_base="e", # "10", "e" + nfft=512, + normalize_signal=False, + normalize_zscore=False, + normalize_min_max=False, + padding=0, + augmentation_config={}, + **kwargs, + ): + """ + Audio Features Extraction Keras Layer + + Parameters + ---------- + sample_rate : int, optional + Sample rate of audio signals in Hz, by default 16000 + frame_ms : int, optional + Amount of data grabbed for each frame during analysis in ms, by default 25 + stride_ms : int, optional + Number of ms to jump between frames, by default 10 + num_feature_bins : int, optional + Number of bins in the feature output, by default 80 + feature_type : str, optional + Type of feature extraction, by default "log_mel_spectrogram" + preemphasis : float, optional + The first-order filter coefficient used for preemphasis, when it is 0.0, preemphasis is turned off, by default 0.0 + pad_end : bool, optional + Whether to pad the end of `signals` with zeros when framing produces a frame that lies partially past its end, by default True + use_librosa_like_stft : bool, optional + Use librosa like stft, by default False + epsilon : float, optional + Epsilon value to avoid log(0.0) causes Inf, by default 1e-6 + lower_edge_hertz : float, optional + The lowest frequency of the feature analysis, by default 125.0 + upper_edge_hertz : float, optional + The highest frequency of the feature analysis, by default 8000.0 + log_base : str, optional + The base of logarithm, by default 'e' + nfft : int, optional + NFFT, if None, equals frame_length derived from frame_ms, by default None + normalize_signal : bool, optional + Normalize signals to [-1,1] range, by default False + normalize_zscore : bool, optional + Normalize features using z-score, by default False + normalize_min_max : bool, optional + Normalize features as (value - min) / (max - min), by default True + padding : int, optional + Number of samples to pad with 0 before feature extraction, by default 0 + augmentation_config : dict, optional + Dictionary of augmentation config for training + """ + assert feature_type in asdict(FEATURE_TYPES()).values(), f"feature_type must be in {asdict(FEATURE_TYPES()).values()}" + + super().__init__(name=feature_type, **kwargs) + self.sample_rate = sample_rate + + self.frame_ms = frame_ms + self.frame_length = int(round(self.sample_rate * self.frame_ms / 1000.0)) + + self.stride_ms = stride_ms + self.frame_step = int(round(self.sample_rate * self.stride_ms / 1000.0)) + + self.num_feature_bins = num_feature_bins + + self.feature_type = feature_type + + self.preemphasis = preemphasis + + self.pad_end = pad_end + + self.use_librosa_like_stft = use_librosa_like_stft + + # fmt: off + self.epsilon = epsilon + assert self.epsilon > 1e-9 and self.epsilon <= 0.001, "epsilon must be in (1e-9, 0.001]" + # fmt: on + + self.lower_edge_hertz = lower_edge_hertz + self.upper_edge_hertz = upper_edge_hertz + + self.log_base = log_base + assert self.log_base in ("10", "e"), "log_base must be '10' or 'e'" + + self._normalize_signal = normalize_signal + self._normalize_zscore = normalize_zscore + self._normalize_min_max = normalize_min_max + + self.padding = padding + self.nfft = self.frame_length if nfft is None else nfft + + self.augmentations = Augmentation(augmentation_config) + + # ---------------------------------- signals --------------------------------- # + + def get_signal_chunk_size_and_step(self, nframes): + """ + This will ensure the "fft of chunked signal" is the same with "fft of whole signal" + The features are extracted by windowing the signal by length and strides + The chunk size is the size of the windowed signal, + which is (nframes - 1) * frame_step + frame_length + The next chunk will start at the position of the next frame, + which is nframes + 1 "steps", so we need to move nframes "steps" to get the next chunk + + Parameters + ---------- + nframes : int + Number of target frames of the chunk signals will result in + + Returns + ------- + (chunk_size, chunk_step) + Size of the chunk signals and the step to move to the next chunk + """ + chunk_size = (nframes - 1) * self.frame_step + self.frame_length + chunk_step = nframes * self.frame_step + return chunk_size, chunk_step + + def normalize_signal(self, signal): + if not self._normalize_signal: + return signal + gain = 1.0 / (tf.reduce_max(tf.abs(signal), axis=1, keepdims=True) + self.epsilon) + return signal * gain + + def preemphasis_signal(self, signal): + if not self.preemphasis or self.preemphasis <= 0.0: + return signal + s0 = tf.expand_dims(signal[:, 0], axis=-1) + s1 = signal[:, 1:] - self.preemphasis * signal[:, :-1] + return tf.concat([s0, s1], -1) + + # --------------------------------- features --------------------------------- # + + def normalize_audio_features(self, audio_feature): + if self._normalize_zscore: + mean = tf.reduce_mean(audio_feature, axis=1, keepdims=True) + stddev = tf.sqrt(tf.math.reduce_variance(audio_feature, axis=1, keepdims=True) + self.epsilon) + return tf.divide(tf.subtract(audio_feature, mean), stddev) + if self._normalize_min_max: + if self.feature_type.startswith("log_") or self.feature_type == FEATURE_TYPES.SPECTROGRAM: + min_value = self.logarithm(self.epsilon) + else: + min_value = tf.reduce_min(audio_feature, axis=1, keepdims=True) + return (audio_feature - min_value) / (tf.reduce_max(audio_feature, axis=1, keepdims=True) - min_value) + return audio_feature + + def stft(self, signal): + orig_dtype = signal.dtype + if orig_dtype in (tf.float16, tf.bfloat16): + signal = tf.cast(signal, tf.float32) + if self.use_librosa_like_stft: + # signal = tf.pad(signal, [[self.nfft // 2, self.nfft // 2]], mode="REFLECT") + window = tf.signal.hann_window(self.frame_length, periodic=True) + left_pad = (self.nfft - self.frame_length) // 2 + right_pad = self.nfft - self.frame_length - left_pad + window = tf.pad(window, [[left_pad, right_pad]]) + framed_signals = tf.signal.frame(signal, frame_length=self.nfft, frame_step=self.frame_step, pad_end=self.pad_end) + framed_signals *= window + fft_features = tf.abs(tf.signal.rfft(framed_signals, [self.nfft])) + else: + fft_features = tf.abs( + tf.signal.stft(signal, frame_length=self.frame_length, frame_step=self.frame_step, fft_length=self.nfft, pad_end=self.pad_end) + ) + fft_features = tf.square(fft_features) + if orig_dtype in (tf.float16, tf.bfloat16): + fft_features = tf.cast(fft_features, orig_dtype) + return fft_features + + def logarithm(self, S): + S += self.epsilon + if self.log_base == "10": + return math_util.log10(S) + return tf.math.log(S) + + def log_mel_spectrogram(self, signal): + S = self.stft(signal) + linear_to_weight_matrix = tf.signal.linear_to_mel_weight_matrix( + num_mel_bins=self.num_feature_bins, + num_spectrogram_bins=tf.shape(S)[-1], + sample_rate=self.sample_rate, + lower_edge_hertz=self.lower_edge_hertz, + upper_edge_hertz=self.upper_edge_hertz, + dtype=S.dtype, + ) + mel_spectrogram = tf.matmul(S, linear_to_weight_matrix) + return self.logarithm(mel_spectrogram) + + def spectrogram(self, signal): + spectrogram = self.logarithm(self.stft(signal)) + return spectrogram[:, :, : self.num_feature_bins] + + def mfcc(self, signal): + log_mel_spectrogram = self.log_mel_spectrogram(signal) + return tf.signal.mfccs_from_log_mel_spectrograms(log_mel_spectrogram) + + def log_gammatone_spectrogram(self, signal): + S = self.stft(signal) + gtone = gammatone.fft_weights( + self.nfft, + self.sample_rate, + self.num_feature_bins, + width=1.0, + fmin=int(self.lower_edge_hertz), + fmax=int(self.upper_edge_hertz), + maxlen=(self.nfft / 2 + 1), + ) + gtone_spectrogram = tf.matmul(S, gtone) + return self.logarithm(gtone_spectrogram) + + def call(self, inputs, training=False): + """ + Compute features of audio signals + + Parameters + ---------- + inputs : tf.Tensor, shape [B, None] + Audio signals that were resampled to sample_rate + + training : bool, optional + Training mode, by default False + + Returns + ------- + tf.Tensor, shape = [B, n_frames, num_feature_bins, 1] if has_channel_dim else [B, n_frames, num_feature_bins] + Features extracted from audio signals + """ + signals, signals_length = inputs + + if training: + signals, signals_length = self.augmentations.signal_augment(signals, signals_length) + + if self.padding > 0: + signals = tf.pad(signals, [[0, 0], [0, self.padding]], mode="CONSTANT", constant_values=0.0) + + signals = self.normalize_signal(signals) + signals = self.preemphasis_signal(signals) + + if self.feature_type == FEATURE_TYPES.SPECTROGRAM: + features = self.spectrogram(signals) + elif self.feature_type == FEATURE_TYPES.MFCC: + features = self.mfcc(signals) # TODO: add option to compute delta features for mfccs + elif self.feature_type == FEATURE_TYPES.LOG_GAMMATONE_SPECTROGRAM: + features = self.log_gammatone_spectrogram(signals) + else: # default as log_mel_spectrogram + features = self.log_mel_spectrogram(signals) + + features = self.normalize_audio_features(features) + features = tf.expand_dims(features, axis=-1) + features_length = tf.map_fn( + fn=self.get_nframes, + elems=tf.cast(signals_length, tf.int32), + fn_output_signature=tf.TensorSpec(shape=(), dtype=tf.int32), + ) + + if training: + features, features_length = self.augmentations.feature_augment(features, features_length) + + return features, features_length + + def get_nframes(self, nsamples): + # https://www.tensorflow.org/api_docs/python/tf/signal/frame + if self.use_librosa_like_stft: + if self.pad_end: + return -(-nsamples // self.frame_step) + return 1 + (nsamples - self.nfft) // self.frame_step + if self.pad_end: + return -(-nsamples // self.frame_step) + return 1 + (nsamples - self.frame_length) // self.frame_step + + def compute_mask(self, inputs, mask=None): + signals, signals_length = inputs + mask = tf.sequence_mask(signals_length, maxlen=(tf.shape(signals)[1] + self.padding), dtype=tf.bool) + nsamples = tf.reduce_sum(tf.cast(mask, tf.int32), axis=1) + # nframes = tf.map_fn(fn=self.get_nframes, elems=nsamples, fn_output_signature=tf.TensorSpec(shape=(), dtype=tf.int32)) + nframes = self.get_nframes(nsamples) + padded_nframes = self.get_nframes(tf.shape(signals, tf.int32)[1] + self.padding) + return tf.sequence_mask(nframes, maxlen=padded_nframes, dtype=tf.bool), None + + def compute_output_shape(self, input_shape): + signal_shape, signal_length_shape = input_shape + B, nsamples = signal_shape + if nsamples is None: + output_shape = [B, None, self.num_feature_bins, 1] + else: + output_shape = [B, self.get_nframes(nsamples + self.padding), self.num_feature_bins, 1] + return tf.TensorShape(output_shape), tf.TensorShape(signal_length_shape) diff --git a/tensorflow_asr/models/layers/general.py b/tensorflow_asr/models/layers/general.py new file mode 100644 index 0000000000..874f3ea836 --- /dev/null +++ b/tensorflow_asr/models/layers/general.py @@ -0,0 +1,41 @@ +import keras +from keras.src import activations, backend + +from tensorflow_asr.utils import math_util + + +class Dropout(keras.layers.Dropout): + def __init__(self, rate, noise_shape=None, seed=None, **kwargs): + super().__init__(rate, noise_shape, seed, **kwargs) + self.built = False + + +class Identity(keras.layers.Identity): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.built = False + + +class Activation(keras.layers.Activation): + def __init__(self, activation, **kwargs): + super().__init__(activation, **kwargs) + self.built = False + + +class Softmax(keras.layers.Softmax): + """ + Softmax activation layer with better numerical stability to avoid Inf or NaN + """ + + def call(self, inputs, mask=None): + if mask is not None: + inputs = math_util.masked_fill( + inputs, + mask=mask, + value=math_util.large_compatible_negative_number(self.dtype), + ) + if isinstance(self.axis, (tuple, list)): + if len(self.axis) > 1: + return backend.numpy.exp(inputs - backend.math.logsumexp(inputs, axis=self.axis, keepdims=True)) + return activations.softmax(inputs, axis=self.axis[0]) + return activations.softmax(inputs, axis=self.axis) diff --git a/tensorflow_asr/models/layers/memory.py b/tensorflow_asr/models/layers/memory.py index 619510c02e..5de627eed1 100644 --- a/tensorflow_asr/models/layers/memory.py +++ b/tensorflow_asr/models/layers/memory.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf +from keras.src import backend +from tensorflow_asr import keras, tf from tensorflow_asr.models.base_layer import Layer from tensorflow_asr.utils import math_util @@ -27,83 +28,58 @@ def _shift(tensor, shift): return shifted_tensor +@keras.utils.register_keras_serializable(package=__name__) class Memory(Layer): - def __init__(self, batch_size, memory_length, dmodel, **kwargs): + """ + Memory Layer + This layer `call` method will do 2 things: + 1. prepend memory hidden states to inputs -> new_inputs + 2. concatenating memory and inputs, then slice to memory length -> new_memory + """ + + def __init__(self, memory_length, dmodel, **kwargs): super().__init__(trainable=False, **kwargs) assert memory_length > 0, "memory_length must be integer" - self.batch_size = batch_size self.memory_length = memory_length self.dmodel = dmodel - self.stateful = True - self.memory = tf.Variable( - initial_value=tf.zeros(shape=(self.batch_size, self.memory_length, self.dmodel), dtype=self.dtype), - trainable=False, - name="memory", - ) - self.memory_mask = tf.Variable( - initial_value=tf.zeros(shape=(self.batch_size, self.memory_length), dtype=tf.bool), - trainable=False, - name="memory_mask", - ) - def _get_inputs(self, inputs): - inputs_mask = getattr(inputs, "_keras_mask", None) - max_length = tf.shape(inputs)[1] + def _get_inputs(self, inputs, default_mask_value=1): + inputs_mask = backend.get_keras_mask(inputs) if inputs_mask is None: - inputs_mask = tf.ones([self.batch_size, max_length], dtype=tf.bool) + batch_size, max_length, *_ = tf.shape(inputs) + inputs_mask = tf.cast(tf.ones((batch_size, max_length), dtype=tf.int32) * default_mask_value, dtype=tf.bool) return inputs, inputs_mask - def attach_memory(self, inputs): - inputs, inputs_mask = self._get_inputs(inputs) - # shift memory and stop grad - memory_shift = _create_num_masked(self.memory_mask) - memory = _shift(self.memory, shift=memory_shift) - memory = tf.stop_gradient(memory) - memory_mask = _shift(self.memory_mask, shift=memory_shift) - memory_mask = tf.stop_gradient(memory_mask) - # prepend memory and inputs - new_inputs = tf.concat([memory, inputs], 1) - new_inputs._keras_mask = tf.concat([memory_mask, inputs_mask], 1) # pylint: disable=protected-access - return new_inputs - - def get_states(self): - return (self.memory, self.memory_mask) - - def reset_states(self, states=(None, None)): - memory, memory_mask = states - if memory is None: - memory = tf.zeros(shape=(self.batch_size, self.memory_length, self.dmodel), dtype=self.dtype) - if memory_mask is None: - memory_mask = tf.zeros(shape=(self.batch_size, self.memory_length), dtype=tf.bool) - self.add_update([tf.keras.backend.update(self.memory, memory), tf.keras.backend.update(self.memory_mask, memory_mask)]) + def get_initial_state(self, batch_size: int): + memory = tf.zeros(shape=(batch_size, self.memory_length, self.dmodel), dtype=self.dtype) + backend.set_keras_mask(memory, tf.zeros(shape=(batch_size, self.memory_length), dtype=tf.bool)) + return memory - def call(self, inputs): + def call(self, inputs, memories=None, training=False): + if memories is None: + return None inputs, inputs_mask = self._get_inputs(inputs) - # shift by memory mask - shift = _create_num_masked(self.memory_mask) - new_memory = _shift(self.memory, shift=shift) - new_memory_mask = _shift(self.memory_mask, shift=shift) - # prepend memory to inputs - new_memory = tf.concat([new_memory, inputs], 1) - new_memory_mask = tf.concat([new_memory_mask, inputs_mask], 1) - # shift by inputs mask - shift = _create_num_masked(inputs_mask) - new_memory = _shift(new_memory, shift=shift) - new_memory_mask = _shift(new_memory_mask, shift=shift) - # slice combination of memory and inputs into memory_length + memory, memory_mask = self._get_inputs(memories) + # create new_inputs by prepending memory to inputs + if training: + memory = tf.stop_gradient(memory) + memory_mask = tf.stop_gradient(memory_mask) + new_inputs = tf.concat([memory, inputs], 1) # prepend memory and inputs + new_inputs_mask = tf.concat([memory_mask, inputs_mask], 1) + new_inputs._keras_mask = new_inputs_mask # pylint: disable=protected-access + # create new_memory by slicing new_inputs to memory length new_memory = tf.slice( - new_memory, - begin=[0, tf.shape(new_memory)[1] - self.memory_length, 0], + new_inputs, + begin=[0, tf.shape(new_inputs)[1] - self.memory_length, 0], size=[-1, self.memory_length, -1], ) new_memory_mask = tf.slice( - new_memory_mask, - begin=[0, tf.shape(new_memory_mask)[1] - self.memory_length], + new_inputs_mask, + begin=[0, tf.shape(new_inputs_mask)[1] - self.memory_length], size=[-1, self.memory_length], ) - self.add_update([tf.keras.backend.update(self.memory, new_memory), tf.keras.backend.update(self.memory_mask, new_memory_mask)]) new_memory._keras_mask = new_memory_mask # pylint: disable=protected-access - return new_memory + return new_inputs, new_memory def compute_output_shape(self, input_shape): - return input_shape[0], self.memory_length, self.dmodel + return input_shape, (input_shape[0], self.memory_length, self.dmodel) diff --git a/tensorflow_asr/models/layers/multihead_attention.py b/tensorflow_asr/models/layers/multihead_attention.py index 7e6787400f..4a752605a2 100644 --- a/tensorflow_asr/models/layers/multihead_attention.py +++ b/tensorflow_asr/models/layers/multihead_attention.py @@ -13,22 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math +import collections -import tensorflow as tf -from keras.layers import EinsumDense -from keras.layers import MultiHeadAttention as KerasMultiHeadAttention -from keras.utils import tf_utils - -try: - from keras.layers.multi_head_attention import _build_attention_equation, _build_proj_equation, _get_output_shape -except ImportError: - from keras.layers.attention.multi_head_attention import _build_attention_equation, _build_proj_equation, _get_output_shape +from keras.src import backend +from keras.src.layers.attention import multi_head_attention as mha_module +from tensorflow_asr import keras, tf +from tensorflow_asr.models.layers.general import Dropout, Softmax from tensorflow_asr.models.layers.memory import Memory +from tensorflow_asr.utils import shape_util -def rel_left_shift(x): +def rel_left_shift(x, causal=False): """ Relative left shift @@ -50,21 +46,40 @@ def rel_left_shift(x): Returns: x: left shifted, shape BNTR """ - x = tf.transpose(x, perm=[2, 3, 0, 1]) # BNTR -> TRBN - x_shape = tf.shape(x) - - x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]]) # shift on position time dimension R - x = tf.reshape(x, [x_shape[1] + 1, x_shape[0], x_shape[2], x_shape[3]]) - x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1]) - x = tf.reshape(x, x_shape) - - x = tf.transpose(x, perm=[2, 3, 0, 1]) # TRBN -> BNTR - x *= tf.linalg.band_part(tf.ones((1, 1, x_shape[0], x_shape[1]), x.dtype), -1, 0) + b, n, t, r = shape_util.shape_list(x) + + # fmt: off + if causal: + x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [1, 0]]) # [B, N, T, Th + T] + x = tf.reshape(x, [b, n, -1]) + x = tf.pad(x, [[0, 0], [0, 0], [r - t, 0]]) + x = tf.reshape(x, [b, n, 1 + t, r]) + x = tf.slice(x, begin=[0, 0, 1, 0], size=[-1, -1, -1, -1]) # [B, N, T, Th + T] + else: + x = tf.pad(x, [[0, 0], [0, 0], [0, 0], [0, 1]]) # [B, N, T, Th + 2*T] where R = Th + 2*T - 1, S = Th + T + x = tf.reshape(x, [b, n, -1]) # [B, N, TTh + 2*TT] + x = tf.pad(x, [[0, 0], [0, 0], [0, r - t]]) # [B, N, TTh + 2*TT + Th + 2*T - 1 - T] = [B, N, TTh + 2*TT + Th + T - 1] + x = tf.reshape(x, [b, n, 1 + t, r]) # TTh + 2*TT + Th + T - 1 = TTh + 2*TT + Th + 2*T - T - 1 = Th(T + 1) + 2*T(T + 1) - (T + 1) = (T + 1)(Th + 2*T - 1) = (T + 1)R # pylint: disable=line-too-long + x = tf.slice(x, begin=[0, 0, 0, (t - 1)], size=[-1, -1, t, -1]) # [B, N, T, Th + T] + # fmt: on + + # x = tf.transpose(x, perm=[2, 3, 0, 1]) # BNTR -> TRBN + # x_shape = tf.shape(x) + + # x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]]) # shift on position time dimension R + # x = tf.reshape(x, [x_shape[1] + 1, x_shape[0], x_shape[2], x_shape[3]]) + # x = tf.slice(x, [1, 0, 0, 0], [-1, -1, -1, -1]) + # x = tf.reshape(x, x_shape) + # if mask_upper_triangle: + # x *= tf.reverse(tf.linalg.band_part(tf.ones((x_shape[0], x_shape[1]), x.dtype), 0, -1), [0, 1])[..., tf.newaxis, tf.newaxis] + + # x = tf.transpose(x, perm=[2, 3, 0, 1]) # TRBN -> BNTR return x def compute_causal_mask(query, value=None): - """Computes a causal mask (e.g., for masked self-attention layers). + """ + Computes a causal mask (e.g., for masked self-attention layers). For example, if query and value both contain sequences of length 4, this function returns a boolean `Tensor` equal to: ``` @@ -86,7 +101,57 @@ def compute_causal_mask(query, value=None): return tf.linalg.band_part(tf.ones((1, q_seq_length, v_seq_length), tf.bool), -1, 0) # creates a lower triangular matrix -def compute_attention_mask(query, value, key=None, attention_mask=None, use_causal_mask=False): +def compute_streaming_mask(chunk_size, history_size, query, value=None): + """ + Computes a streaming mask as in http://arxiv.org/abs/2010.11395 + For example, if query and value both contain sequences of length 8, chunk size 2, history_size 2 + Chunk size = 2 -> it can see < 2 frames in the future because it in the same chunk, the 2nd frame is the last frame in the chunk therefore it does not see future # pylint: disable=line-too-long + History size = 2 -> it can see history_size = 2 frames in the past + The * indicates the current frame + All frames in the same chunk can see each other + this function returns a boolean `Tensor` equal to: + ``` + [[[ 1*, 1, 0, 0, 0, 0, 0, 0 ], + [ 1, 1*, 0, 0, 0, 0, 0, 0 ], + [ 1, 1, 1*, 1, 0, 0, 0, 0 ], + [ 1, 1, 1, 1*, 0, 0, 0, 0 ], + [ 0, 0, 1, 1, 1*, 1, 0, 0 ], + [ 0, 0, 1, 1, 1, 1*, 0, 0 ], + [ 0, 0, 0, 0, 1, 1, 1*, 1 ], + [ 0, 0, 0, 0, 1, 1, 1, 1*]]] + ``` + Args: + chunk_size: chunk size to split + history_size: history size to keep + query: query `Tensor` of shape `(B, T, ...)`. + value: value `Tensor` of shape `(B, S, ...)` (optional, defaults to query). + Returns: + mask: a boolean `Tensor` of shape [1, T, S] + """ + q_seq_length = shape_util.shape_list(query)[1] + v_seq_length = q_seq_length if value is None else shape_util.shape_list(value)[1] + hist_size = tf.where(tf.less(history_size, 0), v_seq_length, tf.constant(history_size, tf.int32)) + + def _fn(x): + index = x * chunk_size + start_index = tf.maximum(0, index - hist_size) + end_index_excluded = tf.minimum(v_seq_length, index + chunk_size) + keep = tf.sequence_mask(end_index_excluded, v_seq_length, dtype=tf.bool) + drop = tf.math.logical_not(tf.sequence_mask(start_index, v_seq_length, dtype=tf.bool)) + return keep & drop + + return tf.expand_dims(tf.map_fn(_fn, tf.math.floordiv(tf.range(q_seq_length, dtype=tf.int32), chunk_size), dtype=tf.bool), axis=0) + + +def compute_attention_mask( + query, + value, + key=None, + attention_mask=None, + use_causal_mask=False, + chunk_size=None, + history_size=None, +): """Computes the attention mask, using the Keras masks of the inputs. * The `query`'s mask is reshaped from [B, T] to [B, T, 1]. @@ -117,9 +182,9 @@ def compute_attention_mask(query, value, key=None, attention_mask=None, use_caus `query`, `key`, `value`, and `attention_mask` tensors, and the causal mask if `use_causal_mask=True`. """ - query_mask = getattr(query, "_keras_mask", None) - value_mask = getattr(value, "_keras_mask", None) - key_mask = getattr(key, "_keras_mask", None) + query_mask = backend.get_keras_mask(query) + value_mask = None + key_mask = None auto_mask = None if query_mask is not None: query_mask = tf.cast(query_mask, tf.bool) # defensive casting @@ -139,13 +204,17 @@ def compute_attention_mask(query, value, key=None, attention_mask=None, use_caus # the shape of the causal mask is [1, T, S] mask = compute_causal_mask(query, value) auto_mask = mask if auto_mask is None else auto_mask & mask + if chunk_size is not None and history_size is not None: + mask = compute_streaming_mask(chunk_size, history_size, query, value) + auto_mask = mask if auto_mask is None else auto_mask & mask if auto_mask is not None: # merge attention_mask & automatic mask, to shape [B, T, S] attention_mask = auto_mask if attention_mask is None else tf.cast(attention_mask, bool) & auto_mask return attention_mask -class MultiHeadAttention(KerasMultiHeadAttention): +@keras.utils.register_keras_serializable(package=__name__) +class MultiHeadAttention(keras.layers.MultiHeadAttention): def __init__( self, num_heads, @@ -155,7 +224,10 @@ def __init__( use_bias=True, output_shape=None, attention_axes=None, + flash_attention=None, memory_length=None, + history_size=None, + chunk_size=None, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None, @@ -163,49 +235,53 @@ def __init__( activity_regularizer=None, kernel_constraint=None, bias_constraint=None, + seed=None, **kwargs, ): + self._memory_length = memory_length + self._chunk_size = chunk_size + self._history_size = history_size + self._memory = None + if output_shape: + if not isinstance(output_shape, collections.abc.Sized): + output_shape = (output_shape,) super().__init__( - num_heads, - key_dim, - value_dim, - dropout, - use_bias, - output_shape, - attention_axes, - kernel_initializer, - bias_initializer, - kernel_regularizer, - bias_regularizer, - activity_regularizer, - kernel_constraint, - bias_constraint, + num_heads=num_heads, + key_dim=key_dim, + value_dim=value_dim, + dropout=dropout, + use_bias=use_bias, + output_shape=output_shape, + attention_axes=attention_axes, + flash_attention=flash_attention, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + seed=seed, **kwargs, ) - if not hasattr(self, "_compute_attention_mask"): - self._compute_attention_mask = compute_attention_mask - if not hasattr(self, "_compute_causal_mask"): - self._compute_causal_mask = compute_causal_mask - self._memory_length = memory_length - self.stateful = self._memory_length is not None - - def _get_common_kwargs_for_sublayer(self): - common_kwargs = dict( - kernel_regularizer=self._kernel_regularizer, - bias_regularizer=self._bias_regularizer, - activity_regularizer=self._activity_regularizer, - kernel_constraint=self._kernel_constraint, - bias_constraint=self._bias_constraint, - dtype=self.dtype, - ) - # Create new clone of kernel/bias initializer, so that we don't reuse - # the initializer instance, which could lead to same init value since - # initializer is stateless. - kernel_initializer = self._kernel_initializer.__class__.from_config(self._kernel_initializer.get_config()) - bias_initializer = self._bias_initializer.__class__.from_config(self._bias_initializer.get_config()) - common_kwargs["kernel_initializer"] = kernel_initializer - common_kwargs["bias_initializer"] = bias_initializer - return common_kwargs + self._precomputed_output_shape = None + + @property + def output_shape(self): + return self._precomputed_output_shape + + def build(self, input_shape): + query_shape, key_shape, value_shape, *_ = input_shape + if self._memory_length is not None: + self._memory = Memory( + batch_size=query_shape[0], + memory_length=self._memory_length, + dmodel=query_shape[-1], + name="memory", + dtype=self.dtype_policy, + ) + self._precomputed_output_shape = self.compute_output_shape(input_shape) + return super().build(query_shape, value_shape, key_shape) def _build_attention(self, rank): """Builds multi-head dot-product attention computations. @@ -225,59 +301,87 @@ def _build_attention(self, rank): self._dot_product_equation, self._combine_equation, attn_scores_rank, - ) = _build_attention_equation(rank, attn_axes=self._attention_axes) + ) = mha_module._build_attention_equation(rank, attn_axes=self._attention_axes) norm_axes = tuple(range(attn_scores_rank - len(self._attention_axes), attn_scores_rank)) - self._softmax = tf.keras.layers.Softmax(axis=norm_axes, dtype=self.dtype) - self._dropout_layer = tf.keras.layers.Dropout(rate=self._dropout, dtype=self.dtype) - - def _build_from_signature(self, query, value, key=None): - super()._build_from_signature(query, value, key) - with tf_utils.maybe_init_scope(self): # pylint: disable=not-context-manager - batch_size, _, dmodel = self._query_shape - if self._memory_length is not None: - self._memory = Memory(batch_size=batch_size, memory_length=self._memory_length, dmodel=dmodel, name="memory", dtype=self.dtype) - else: - self._memory = None - - def _update_with_memory(self, query, key, value): + self._softmax = Softmax(axis=norm_axes, dtype=self.dtype_policy) + self._dropout_layer = Dropout(rate=self._dropout, dtype=self.dtype_policy, seed=self.seed) + + def get_initial_state(self, batch_size: int): if self._memory is None: - return query, key, value + return None + return { + "key": self._memory.get_initial_state(batch_size), + "value": self._memory.get_initial_state(batch_size), + } - key = self._memory.attach_memory(key) - value = self._memory.attach_memory(value) + def _with_memory(self, query, key, value, initial_state=None, training=False): + if self._memory is None or initial_state is None: + return query, key, value, initial_state - self._memory(query) # update memory + new_key, new_key_memory = self._memory(key, memories=initial_state.get("key"), training=training) + new_value, new_value_memory = self._memory(value, memories=initial_state.get("value"), training=training) - return query, key, value + new_states = { + "key": new_key_memory, + "value": new_value_memory, + } - def get_states(self): - if self._memory is None: - return (None, None) - return self._memory.get_states() + return query, new_key, new_value, new_states - def reset_states(self, states=(None, None)): - if self._memory is None: - return - self._memory.reset_states(states) + def _compute_attention_mask( + self, + query, + value, + query_mask=None, + value_mask=None, + key_mask=None, + attention_mask=None, + use_causal_mask=False, + ): + attention_mask = super()._compute_attention_mask(query, value, query_mask, value_mask, key_mask, attention_mask, use_causal_mask) + if self._chunk_size is not None and self._history_size is not None: + mask = compute_streaming_mask(self._chunk_size, self._history_size, query, value) + attention_mask = mask if attention_mask is None else attention_mask & mask + return attention_mask def call( self, inputs, + query_mask=None, + value_mask=None, + key_mask=None, attention_mask=None, + use_auto_mask=True, return_attention_scores=False, training=None, use_causal_mask=False, - use_auto_mask=True, + initial_state=None, + return_states=False, + **kwargs, ): - query, key, value = inputs + query, key, value, *_ = inputs - if not self._built_from_signature: - self._build_from_signature(query=query, value=value, key=key) + self._return_attention_scores = return_attention_scores + if key is None: + key = value - query, key, value = self._update_with_memory(query, key, value) + # Delete the masks because the masks are handled at the level of the + # layer + query_mask = backend.get_keras_mask(query) + backend.set_keras_mask(query, None) + backend.set_keras_mask(value, None) + backend.set_keras_mask(key, None) if use_auto_mask: - attention_mask = self._compute_attention_mask(query, value, key=key, attention_mask=attention_mask, use_causal_mask=use_causal_mask) + attention_mask = self._compute_attention_mask( + query, + value, + query_mask=query_mask, + value_mask=value_mask, + key_mask=key_mask, + attention_mask=attention_mask, + use_causal_mask=use_causal_mask, + ) # N = `num_attention_heads` # H = `size_per_head` @@ -290,14 +394,69 @@ def call( # `value` = [B, S, N, H] value = self._value_dense(value) - attention_output, attention_scores = self._compute_attention(query, key, value, attention_mask, training) + states = None + + if return_states: + query, key, value, states = self._with_memory(query, key, value, initial_state, training) + + attention_output, attention_scores = self._compute_attention( + query, + key, + value, + attention_mask, + training, + return_attention_scores, + ) attention_output = self._output_dense(attention_output) + # Set mask on output if needed + if query_mask is not None: + backend.set_keras_mask(attention_output, query_mask) + if return_attention_scores: + if return_states: + return attention_output, states, attention_scores return attention_output, attention_scores - return attention_output + if return_states and states is not None: + return attention_output, states + return (attention_output,) + def compute_output_shape(self, input_shape): + query_shape, key_shape, value_shape, *_ = input_shape + return super().compute_output_shape(query_shape, value_shape, key_shape) + + def compute_output_spec( + self, + inputs, + query_mask=None, + value_mask=None, + key_mask=None, + attention_mask=None, + use_auto_mask=True, + return_attention_scores=False, + training=None, + use_causal_mask=False, + initial_state=None, + return_states=False, + ): + query, value, key, *_ = inputs + output_spec, *attention_score_spec = super().compute_output_spec( + query, value, key, query_mask, value_mask, key_mask, attention_mask, return_attention_scores, training, use_causal_mask + ) + if not return_states: + return [output_spec] + attention_score_spec + if self._memory_length is None: + return [output_spec, None] + attention_score_spec + states_shape = (query.shape[0], self._memory_length, query.shape[-1]) + states_spec = { + "key": keras.KerasTensor(states_shape, dtype=self.compute_dtype), + "value": keras.KerasTensor(states_shape, dtype=self.compute_dtype), + } + return [output_spec, states_spec] + attention_score_spec + + +@keras.utils.register_keras_serializable(package=__name__) class MultiHeadRelativeAttention(MultiHeadAttention): def __init__( self, @@ -308,53 +467,78 @@ def __init__( use_bias=True, output_shape=None, attention_axes=None, + flash_attention=None, memory_length=None, - kernel_initializer="variance_scaling", + history_size=None, + chunk_size=None, + kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None, + seed=None, + use_attention_bias=False, + causal=False, **kwargs, ): super().__init__( - num_heads, - key_dim, - value_dim, - dropout, - use_bias, - output_shape, - attention_axes, - memory_length, - kernel_initializer, - bias_initializer, - kernel_regularizer, - bias_regularizer, - activity_regularizer, - kernel_constraint, - bias_constraint, + num_heads=num_heads, + key_dim=key_dim, + value_dim=value_dim, + dropout=dropout, + use_bias=use_bias, + output_shape=output_shape, + attention_axes=attention_axes, + flash_attention=flash_attention, + memory_length=memory_length, + history_size=history_size, + chunk_size=chunk_size, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + kernel_constraint=kernel_constraint, + bias_constraint=bias_constraint, + seed=seed, **kwargs, ) - self._relative_position_encoding_shape = None - - def _build_from_signature(self, query, value, relative_position_encoding, key=None): - super()._build_from_signature(query=query, value=value, key=key) - if hasattr(relative_position_encoding, "shape"): - self._relative_position_encoding_shape = tf.TensorShape(relative_position_encoding.shape) - else: - self._relative_position_encoding_shape = tf.TensorShape(relative_position_encoding) - with tf_utils.maybe_init_scope(self): # pylint: disable=not-context-manager - einsum_equation, bias_axes, output_rank = _build_proj_equation( - self._relative_position_encoding_shape.rank - 1, bound_dims=1, output_dims=2 + self._use_attention_bias = use_attention_bias + self._causal = causal + + def build(self, input_shape): + *rest_input_shape, relpe_shape = input_shape + relpe_rank = len(relpe_shape) + einsum_equation, bias_axes, output_rank = mha_module._build_proj_equation(relpe_rank - 1, bound_dims=1, output_dims=2) + self._relpe_dense = keras.layers.EinsumDense( + einsum_equation, + output_shape=mha_module._get_output_shape(output_rank - 1, [self._num_heads, self._key_dim]), + bias_axes=bias_axes if self._use_bias else None, + name="encoding", + **self._get_common_kwargs_for_sublayer(), + ) + if self._use_attention_bias: + self.content_attention_bias = self.add_weight( + name="content_attention_bias", + shape=[self._num_heads, self._key_dim], + trainable=True, + initializer="zeros", + regularizer=self._bias_regularizer, + dtype=self.variable_dtype, ) - self._encoding_dense = EinsumDense( - einsum_equation, - output_shape=_get_output_shape(output_rank - 1, [self._num_heads, self._key_dim]), - bias_axes=bias_axes if self._use_bias else None, - name="encoding", - **self._get_common_kwargs_for_sublayer(), + self.positional_attention_bias = self.add_weight( + name="positional_attention_bias", + shape=[self._num_heads, self._key_dim], + trainable=True, + initializer="zeros", + regularizer=self._bias_regularizer, + dtype=self.variable_dtype, ) + else: + self.content_attention_bias, self.positional_attention_bias = None, None + return super().build(rest_input_shape) def _compute_attention( self, @@ -362,82 +546,122 @@ def _compute_attention( key, value, position, - content_attention_bias, - positional_attention_bias, + content_attention_bias=None, + positional_attention_bias=None, attention_mask=None, training=None, ): - content_attention = tf.einsum( - self._dot_product_equation, - key, - (query + tf.cast(content_attention_bias, query.dtype)), - ) # BSNH,BTNH->BNTS - positional_attention = tf.einsum( - self._dot_product_equation, - position, - (query + tf.cast(positional_attention_bias, query.dtype)), - ) # BRNH,BTNH->BNTR - positional_attention = rel_left_shift(positional_attention) + cbias = self.content_attention_bias if content_attention_bias is None else content_attention_bias + pbias = self.positional_attention_bias if positional_attention_bias is None else positional_attention_bias + + content_query = tf.multiply((query + tf.cast(cbias, query.dtype)), tf.cast(self._inverse_sqrt_key_dim, query.dtype)) + content_attention = tf.einsum(self._dot_product_equation, key, content_query, optimize="optimal") # BSNH,BTNH->BNTS + + positional_query = tf.multiply((query + tf.cast(pbias, query.dtype)), tf.cast(self._inverse_sqrt_key_dim, query.dtype)) + positional_attention = tf.einsum(self._dot_product_equation, position, positional_query, optimize="optimal") # BRNH,BTNH->BNTR + positional_attention = rel_left_shift(positional_attention, causal=self._causal) # BNTR -> BNTS + positional_attention = tf.slice( + positional_attention, + begin=[0, 0, 0, tf.shape(positional_attention)[-1] - tf.shape(content_attention)[-1]], + size=[-1, -1, -1, tf.shape(content_attention)[-1]], + ) - attention_scores = content_attention + tf.slice(positional_attention, begin=[0, 0, 0, 0], size=tf.shape(content_attention)) - attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(float(self._key_dim))) + attention_scores = content_attention + positional_attention attention_scores = self._masked_softmax(attention_scores, attention_mask) - attention_output = self._dropout_layer(attention_scores, training=training) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + if self.dropout: + final_attn_scores = self._dropout_layer(attention_scores, training=training) + else: + final_attn_scores = attention_scores - attention_output = tf.einsum(self._combine_equation, attention_output, value) # BNTS,BSNH->BTNH + # `context_layer` = [B, T, N, H] + attention_output = tf.einsum(self._combine_equation, final_attn_scores, value, optimize="optimal") return attention_output, attention_scores def call( self, inputs, - content_attention_bias, - positional_attention_bias, + content_attention_bias=None, + positional_attention_bias=None, + query_mask=None, + value_mask=None, + key_mask=None, attention_mask=None, - training=None, - use_causal_mask=False, use_auto_mask=True, return_attention_scores=False, + training=None, + use_causal_mask=False, + initial_state=None, + return_states=False, + **kwargs, ): - query, key, value, relative_position_encoding = inputs + query, key, value, relpe = inputs - if not self._built_from_signature: - self._build_from_signature(query, value, relative_position_encoding, key=key) + self._return_attention_scores = return_attention_scores + if key is None: + key = value - query, key, value = self._update_with_memory(query, key, value) + # Delete the masks because the masks are handled at the level of the + # layer + query_mask = backend.get_keras_mask(query) + backend.set_keras_mask(query, None) + backend.set_keras_mask(value, None) + backend.set_keras_mask(key, None) if use_auto_mask: - attention_mask = self._compute_attention_mask(query, value, key=key, attention_mask=attention_mask, use_causal_mask=use_causal_mask) + attention_mask = self._compute_attention_mask( + query, + value, + query_mask=query_mask, + value_mask=value_mask, + key_mask=key_mask, + attention_mask=attention_mask, + use_causal_mask=use_causal_mask, + ) # N = `num_attention_heads` # H = `size_per_head` # `query` = [B, T, N ,H] query = self._query_dense(query) - # `key` = [B, S + M, N, H] + # `key` = [B, S, N, H] key = self._key_dense(key) - # `value` = [B, S + M, N, H] + # `value` = [B, S, N, H] value = self._value_dense(value) # `position` = [B, R, N, H] - position = self._encoding_dense(relative_position_encoding) + position = self._relpe_dense(relpe) + + states = None + + if return_states: + query, key, value, states = self._with_memory(query, key, value, initial_state, training) attention_output, attention_scores = self._compute_attention( - query=query, - key=key, - value=value, - position=position, + query, + key, + value, + position, content_attention_bias=content_attention_bias, positional_attention_bias=positional_attention_bias, attention_mask=attention_mask, training=training, ) - - # `attention_output` = [B, S, N, H] attention_output = self._output_dense(attention_output) + # Set mask on output if needed + if query_mask is not None: + backend.set_keras_mask(attention_output, query_mask) + if return_attention_scores: + if return_states: + return attention_output, states, attention_scores return attention_output, attention_scores - return attention_output + + if return_states and states is not None: + return attention_output, states + return (attention_output,) diff --git a/tensorflow_asr/models/layers/norm.py b/tensorflow_asr/models/layers/norm.py new file mode 100644 index 0000000000..5d35845bc6 --- /dev/null +++ b/tensorflow_asr/models/layers/norm.py @@ -0,0 +1,391 @@ +# import warnings + +# import keras +# import tensorflow as tf + + +# def _running_with_dtensor_strategy(): +# """Check whether running with a `Strategy` that is backed by DTensor. + +# In the DTensor based training, all the tensors are in global context, which +# is different from the local context. Some keras components need to +# behave differently, e.g. BatchNormalization and SyncBatchNormalization, as +# well as optimizers. + +# This check will help those layer to branch the logic and keep the correct +# behavior between different context. +# """ +# if not tf.distribute.has_strategy(): +# return False +# strategy = tf.distribute.get_strategy() +# # TODO(scottzhu): Finalize the strategy API to check if a strategy is backed +# # by DTensor. +# return getattr(strategy, "_mesh", None) is not None + + +# def _raise_for_non_sync_bn_with_renorm_and_dtensor_strategy( +# synchronized, +# training, +# renorm, +# ): +# if _running_with_dtensor_strategy() and not synchronized and training and renorm: +# raise NotImplementedError( +# "Renorm for BatchNormalization under DTensor based distribution " +# "strategy is not supported at the moment. Please file a feature " +# "request if this is blocking your adoption." +# ) + + +# class BatchNormalization(keras.layers.BatchNormalization): +# def call(self, inputs, training=None, mask=None): +# inputs = tf.cast(inputs, self.compute_dtype) +# training = self._get_training_value(training) +# # Determine a boolean value for `training`: could be True, False, or +# # None. +# _raise_for_non_sync_bn_with_renorm_and_dtensor_strategy( +# synchronized=self.synchronized, +# training=training, +# renorm=self.renorm, +# ) + +# if self.virtual_batch_size is not None: +# # Virtual batches (aka ghost batches) can be simulated by reshaping +# # the Tensor and reusing the existing batch norm implementation +# original_shape = tf.shape(inputs) +# original_shape = tf.concat([tf.constant([-1]), original_shape[1:]], axis=0) + +# if tf.__internal__.tf2.enabled(): +# expanded_shape = [self.virtual_batch_size, -1] if training else [-1, 1] +# expanded_shape = tf.concat( +# [ +# tf.constant(expanded_shape), +# original_shape[1:], +# ], +# axis=0, +# ) +# else: +# # Preserve incorrect legacy behavior for backwards compatibility +# expanded_shape = tf.concat( +# [ +# tf.constant([self.virtual_batch_size, -1]), +# original_shape[1:], +# ], +# axis=0, +# ) + +# # Will cause errors if virtual_batch_size does not divide the batch +# # size +# inputs = tf.reshape(inputs, expanded_shape) + +# def undo_virtual_batching(outputs): +# outputs = tf.reshape(outputs, original_shape) +# return outputs + +# if self.fused: +# outputs = self._fused_batch_norm(inputs, mask=mask, training=training) +# if self.virtual_batch_size is not None: +# # Currently never reaches here since fused_batch_norm does not +# # support virtual batching +# outputs = undo_virtual_batching(outputs) +# return outputs + +# inputs_dtype = inputs.dtype.base_dtype +# if inputs_dtype in (tf.float16, tf.bfloat16): +# # Do all math in float32 if given 16-bit inputs for numeric +# # stability. In particular, it's very easy for variance to overflow +# # in float16 and for safety we also choose to cast bfloat16 to +# # float32. +# inputs = tf.cast(inputs, tf.float32) + +# # Compute the axes along which to reduce the mean / variance +# input_shape = inputs.shape +# ndims = len(input_shape) +# reduction_axes = [i for i in range(ndims) if i not in self.axis] +# if self.virtual_batch_size is not None: +# del reduction_axes[1] # Do not reduce along virtual batch dim + +# # Broadcasting only necessary for single-axis batch norm where the axis +# # is not the last dimension +# broadcast_shape = [1] * ndims +# broadcast_shape[self.axis[0]] = input_shape.dims[self.axis[0]].value + +# def _broadcast(v): +# if v is not None and len(v.shape) != ndims and reduction_axes != list(range(ndims - 1)): +# return tf.reshape(v, broadcast_shape) +# return v + +# scale, offset = _broadcast(self.gamma), _broadcast(self.beta) + +# def _compose_transforms(scale, offset, then_scale, then_offset): +# if then_scale is not None: +# scale *= then_scale +# offset *= then_scale +# if then_offset is not None: +# offset += then_offset +# return (scale, offset) + +# if not training: # noqa: E712 +# mean, variance = self.moving_mean, self.moving_variance +# else: +# # The following long block are handling mean/variance update during +# # the training stage in various of different settings. +# if self.adjustment: +# adj_scale, adj_bias = self.adjustment(tf.shape(inputs)) +# scale, offset = _compose_transforms(adj_scale, adj_bias, scale, offset) + +# # Some of the computations here are not necessary when +# # training==False but not a constant. However, this makes the code +# # simpler. +# keep_dims = self.virtual_batch_size is not None or len(self.axis) > 1 +# mean, variance = self._moments( +# tf.cast(inputs, self._param_dtype), +# reduction_axes, +# keep_dims=keep_dims, +# mask=mask, +# ) + +# if self.virtual_batch_size is not None: +# # This isn't strictly correct since in ghost batch norm, you are +# # supposed to sequentially update the moving_mean and +# # moving_variance with each sub-batch. However, since the moving +# # statistics are only used during evaluation, it is more +# # efficient to just update in one step and should not make a +# # significant difference in the result. +# new_mean = tf.reduce_mean(mean, axis=1, keepdims=True) +# new_variance = tf.reduce_mean(variance, axis=1, keepdims=True) +# else: +# if _running_with_dtensor_strategy() and not self.synchronized: +# new_mean = tf.math.reduce_mean(mean, axis=reduction_axes) +# new_variance = tf.math.reduce_mean(variance, axis=reduction_axes) +# else: +# new_mean, new_variance = mean, variance + +# if self._support_zero_size_input(): +# # Keras assumes that batch dimension is the first dimension for +# # Batch Normalization. +# input_batch_size = tf.shape(inputs)[0] +# else: +# input_batch_size = None + +# if self.renorm: +# ( +# r, +# d, +# new_mean, +# new_variance, +# ) = self._renorm_correction_and_moments(new_mean, new_variance, training, input_batch_size) +# # When training, the normalized values (say, x) will be +# # transformed as x * gamma + beta without renorm, and (x * r + +# # d) * gamma + beta = x * (r * gamma) + (d * gamma + beta) with +# # renorm. +# r = _broadcast(tf.stop_gradient(r, name="renorm_r")) +# d = _broadcast(tf.stop_gradient(d, name="renorm_d")) +# scale, offset = _compose_transforms(r, d, scale, offset) + +# def _do_update(var, value): +# """Compute the updates for mean and variance.""" +# return self._assign_moving_average(var, value, self.momentum, input_batch_size) + +# def mean_update(): +# if training: +# return _do_update(self.moving_mean, new_mean) +# return self.moving_mean + +# def variance_update(): +# """Update the moving variance.""" + +# def true_branch_renorm(): +# # We apply epsilon as part of the moving_stddev to mirror +# # the training code path. +# moving_stddev = _do_update(self.moving_stddev, tf.sqrt(new_variance + self.epsilon)) +# return self._assign_new_value( +# self.moving_variance, +# # Apply relu in case floating point rounding causes it +# # to go negative. +# tf.nn.relu(moving_stddev * moving_stddev - self.epsilon), +# ) + +# if not training: +# return self.moving_variance + +# if self.renorm: +# return true_branch_renorm() + +# return _do_update(self.moving_variance, new_variance) + +# self.add_update(mean_update) +# self.add_update(variance_update) +# # End of handling mean/variance calculation and update. + +# mean = tf.cast(mean, inputs.dtype) +# variance = tf.cast(variance, inputs.dtype) +# if offset is not None: +# offset = tf.cast(offset, inputs.dtype) +# if scale is not None: +# scale = tf.cast(scale, inputs.dtype) +# outputs = tf.nn.batch_normalization( +# inputs, +# _broadcast(mean), +# _broadcast(variance), +# offset, +# scale, +# self.epsilon, +# ) +# if inputs_dtype in (tf.float16, tf.bfloat16): +# outputs = tf.cast(outputs, inputs_dtype) + +# # If some components of the shape got lost due to adjustments, fix that. +# outputs.set_shape(input_shape) + +# if self.virtual_batch_size is not None: +# outputs = undo_virtual_batching(outputs) +# return outputs + +# def _fused_batch_norm(self, inputs, mask, training): +# """Returns the output of fused batch norm.""" +# if mask is not None: +# warnings.warn( +# "Masking is not supported with `fused=True`. " +# "You should either turn off fusing " +# "(`fused=False`) or you should not pass a `mask` " +# "argument when calling the layer. " +# "For the moment `mask` will be ignored for the " +# "normalization." +# ) +# if self.center: +# beta = self.beta +# else: +# beta = tf.constant(0.0, dtype=self._param_dtype, shape=self._param_shape) +# if self.scale: +# gamma = self.gamma +# else: +# gamma = tf.constant(1.0, dtype=self._param_dtype, shape=self._param_shape) + +# input_batch_size = tf.shape(inputs)[0] +# use_fused_avg_updates = False +# exponential_avg_factor = None + +# def _maybe_add_or_remove_bessels_correction(variance, remove=True): +# r"""Add or remove Bessel's correction.""" +# # Removes Bessel's correction if remove == True, adds it otherwise. +# # This is to be consistent with non-fused batch norm. Note that the +# # variance computed by fused batch norm is with Bessel's correction. +# # This is only used in legacy V1 batch norm tests. +# if self._bessels_correction_test_only: +# return variance +# sample_size = tf.cast(tf.size(inputs) / tf.size(variance), variance.dtype) +# if remove: +# factor = (sample_size - tf.cast(1.0, variance.dtype)) / sample_size +# else: +# factor = sample_size / (sample_size - tf.cast(1.0, variance.dtype)) +# return variance * factor + +# def _fused_batch_norm_training(): +# return tf.compat.v1.nn.fused_batch_norm( +# inputs, +# gamma, +# beta, +# mean=self.moving_mean, +# variance=_maybe_add_or_remove_bessels_correction(self.moving_variance, remove=False), +# epsilon=self.epsilon, +# is_training=True, +# data_format=self._data_format, +# exponential_avg_factor=exponential_avg_factor, +# ) + +# def _fused_batch_norm_inference(): +# return tf.compat.v1.nn.fused_batch_norm( +# inputs, +# gamma, +# beta, +# mean=self.moving_mean, +# variance=self.moving_variance, +# epsilon=self.epsilon, +# is_training=False, +# data_format=self._data_format, +# ) + +# if training: +# output, mean, variance = _fused_batch_norm_training() +# else: +# output, mean, variance = _fused_batch_norm_inference() + +# variance = _maybe_add_or_remove_bessels_correction(variance, remove=True) + +# if training: +# momentum = tf.convert_to_tensor(self.momentum) + +# def mean_update(): +# """Update self.moving_mean with the most recent data point.""" +# if use_fused_avg_updates: +# return self._assign_new_value(self.moving_mean, mean) +# return self._assign_moving_average(self.moving_mean, mean, momentum, input_batch_size) + +# def variance_update(): +# """Update self.moving_variance with the most recent data +# point.""" +# if use_fused_avg_updates: +# return self._assign_new_value(self.moving_variance, variance) +# return self._assign_moving_average(self.moving_variance, variance, momentum, input_batch_size) + +# self.add_update(mean_update) +# self.add_update(variance_update) + +# return output + +# def _renorm_correction_and_moments(self, mean, variance, training, inputs_size): +# """Returns the correction and update values for renorm.""" +# stddev = tf.sqrt(variance + self.epsilon) +# # Compute the average mean and standard deviation, as if they were +# # initialized with this batch's moments. +# renorm_mean = self.renorm_mean +# # Avoid divide by zero early on in training. +# renorm_stddev = tf.maximum(self.renorm_stddev, tf.sqrt(self.epsilon)) +# # Compute the corrections for batch renorm. +# r = stddev / renorm_stddev +# d = (mean - renorm_mean) / renorm_stddev +# # Ensure the corrections use pre-update moving averages. +# with tf.control_dependencies([r, d]): +# mean = tf.identity(mean) +# stddev = tf.identity(stddev) +# rmin, rmax, dmax = [self.renorm_clipping.get(key) for key in ["rmin", "rmax", "dmax"]] +# if rmin is not None: +# r = tf.maximum(r, rmin) +# if rmax is not None: +# r = tf.minimum(r, rmax) +# if dmax is not None: +# d = tf.maximum(d, -dmax) +# d = tf.minimum(d, dmax) +# # When not training, use r=1, d=0. +# if not training: +# r = tf.ones_like(r) +# d = tf.zeros_like(d) + +# def _update_renorm_variable(var, value, inputs_size): +# """Updates a moving average and weight, returns the unbiased +# value.""" +# value = tf.identity(value) + +# def _do_update(): +# """Updates the var, returns the updated value.""" +# new_var = self._assign_moving_average(var, value, self.renorm_momentum, inputs_size) +# return new_var + +# def _fake_update(): +# return tf.identity(var) + +# if training: +# return _do_update() + +# return _fake_update() + +# # TODO(yuefengz): colocate the operations +# update_new_mean = _update_renorm_variable(self.renorm_mean, mean, inputs_size) +# update_new_stddev = _update_renorm_variable(self.renorm_stddev, stddev, inputs_size) + +# # Update the inference mode moving averages with the batch value. +# with tf.control_dependencies([update_new_mean, update_new_stddev]): +# out_mean = tf.identity(mean) +# out_variance = tf.identity(variance) + +# return (r, d, out_mean, out_variance) diff --git a/tensorflow_asr/models/layers/one_hot_blank.py b/tensorflow_asr/models/layers/one_hot_blank.py deleted file mode 100644 index cad06eaf02..0000000000 --- a/tensorflow_asr/models/layers/one_hot_blank.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright 2022 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tensorflow as tf - -from tensorflow_asr.models.base_layer import Layer - - -class OneHotBlank(Layer): - """ - https://arxiv.org/pdf/1211.3711.pdf - The inputs are encoded as one-hot vectors; - that is, if Y consists of K labels and yu = k, then y^u is a length K vector whose elements are all zero - except the k-th, which is one. ∅ is encoded as a length K vector of zeros - """ - - def __init__(self, blank, depth, name="one_hot_blank", **kwargs): - super().__init__(name=name, **kwargs) - self.blank = blank - self.depth = depth - - def call(self, inputs, training=False): - minus_one_at_blank = tf.where(tf.equal(inputs, self.blank), -1, inputs) - return tf.one_hot(minus_one_at_blank, depth=self.depth, dtype=self.dtype) diff --git a/tensorflow_asr/models/layers/positional_encoding.py b/tensorflow_asr/models/layers/positional_encoding.py index 3e7d378427..e057611b85 100755 --- a/tensorflow_asr/models/layers/positional_encoding.py +++ b/tensorflow_asr/models/layers/positional_encoding.py @@ -13,9 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf - +from tensorflow_asr import keras, tf from tensorflow_asr.models.base_layer import Layer +from tensorflow_asr.models.layers.general import Dropout from tensorflow_asr.utils import shape_util @@ -49,25 +49,25 @@ def compute_sinusoid_position_encoding( angles = tf.einsum("i,d->id", position, timescales) pe = tf.concat([tf.sin(angles), tf.cos(angles)], -1) pe = tf.repeat(pe[None, :, :], repeats=batch_size, axis=0) - pe = tf.stop_gradient(pe) return pe -class PositionalEncoding(Layer): +@keras.utils.register_keras_serializable(package=__name__) +class SinusoidalPositionalEncoding(Layer): def __init__( self, - dropout=0.0, + dropout=0, scale=None, interleave=False, **kwargs, ): - super().__init__(**kwargs) - self.do = tf.keras.layers.Dropout(dropout, name="dropout") + super().__init__(trainable=False, **kwargs) + self.do = Dropout(dropout, dtype=self.dtype, name="dropout") self._scale = scale self._interleave = interleave def call(self, inputs, training=False): - outputs = inputs + outputs, outputs_length = inputs if self._scale is not None: outputs *= self._scale batch_size, length, dmodel = shape_util.shape_list(outputs) @@ -79,34 +79,46 @@ def call(self, inputs, training=False): interleave=self._interleave, dtype=outputs.dtype, ) + pe *= tf.expand_dims(tf.sequence_mask(outputs_length, maxlen=length, dtype=pe.dtype), axis=-1) pe = self.do(pe, training=training) outputs += pe return outputs, pe def compute_output_shape(self, input_shape): - output_shape = input_shape + output_shape, _ = input_shape return output_shape, output_shape -class RelativePositionalEncoding(PositionalEncoding): +@keras.utils.register_keras_serializable(package=__name__) +class RelativeSinusoidalPositionalEncoding(SinusoidalPositionalEncoding): def __init__( self, dropout=0, scale=None, interleave=False, memory_length=None, + causal=False, **kwargs, ): + """ + http://arxiv.org/abs/1901.02860 + Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context + Relative Sinusoidal Positional Encoding + Will be computed with weights as the Q in paper + ==> Define in reversed order + """ super().__init__(dropout, scale, interleave, **kwargs) - self._memory_length = memory_length + self._memory_length = memory_length or 0 + self._causal = causal def call(self, inputs, training=False): - outputs = inputs + outputs, outputs_length = inputs if self._scale is not None: outputs *= self._scale batch_size, length, dmodel = shape_util.shape_list(outputs) - start = tf.constant(0, dtype=tf.int32) if self._memory_length is None else -tf.convert_to_tensor(self._memory_length, dtype=tf.int32) - position = compute_position(start=start, end=length, step=1, dtype=outputs.dtype) + position_left = compute_position(start=length + self._memory_length - 1, end=0, step=-1, dtype=outputs.dtype) + position_right = compute_position(start=0, end=-length, step=-1, dtype=outputs.dtype) + position = tf.concat([position_left, position_right], axis=0) # 2 * length + self._memory_length - 1 pe = compute_sinusoid_position_encoding( position=position, batch_size=batch_size, @@ -114,11 +126,57 @@ def call(self, inputs, training=False): interleave=self._interleave, dtype=outputs.dtype, ) + if self._causal: + pe, _ = tf.map_fn( + fn=lambda x: ( # [B, length + self._memory_length, dmodel] + tf.multiply( + tf.slice( + tf.roll(input=x[0], shift=-(length - x[1]), axis=0), + begin=[0, 0], + size=[(length + self._memory_length), dmodel], + ), + tf.expand_dims( + tf.sequence_mask((x[1] + self._memory_length), maxlen=(length + self._memory_length), dtype=x[0].dtype), + axis=-1, + ), + ), + x[1], + ), + elems=(pe, outputs_length), + # fn_output_signature=( + # tf.TensorSpec(shape=[(length + self._memory_length), dmodel], dtype=pe.dtype), + # tf.TensorSpec(shape=[], dtype=outputs_length.dtype), + # ), + ) + else: + pe, _ = tf.map_fn( + fn=lambda x: ( # [B, 2 * length + self._memory_length - 1, dmodel] + tf.multiply( + tf.slice( + tf.roll(input=x[0], shift=-(length - x[1]), axis=0), + begin=[0, 0], + size=[(2 * length + self._memory_length - 1), dmodel], + ), + tf.expand_dims( + tf.sequence_mask((2 * x[1] + self._memory_length - 1), maxlen=(2 * length + self._memory_length - 1), dtype=x[0].dtype), + axis=-1, + ), + ), + x[1], + ), + elems=(pe, outputs_length), + # fn_output_signature=( + # tf.TensorSpec(shape=[(2 * length + self._memory_length - 1), dmodel], dtype=pe.dtype), + # tf.TensorSpec(shape=[], dtype=outputs_length.dtype), + # ), + ) pe = self.do(pe, training=training) return outputs, pe def compute_output_shape(self, input_shape): - output_shape = input_shape + output_shape, _ = input_shape B, T, V = output_shape - pT = (self._memory_length + T) if (self._memory_length is not None and T is not None) else None + pT = 2 * T - 1 if T is not None else None + if self._memory_length > 0 and T is not None: + pT += self._memory_length return output_shape, (B, pT, V) diff --git a/tensorflow_asr/models/layers/recurrent.py b/tensorflow_asr/models/layers/recurrent.py deleted file mode 100644 index de885b1ad4..0000000000 --- a/tensorflow_asr/models/layers/recurrent.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright 2023 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tensorflow as tf - - -class LSTM(tf.keras.layers.LSTM): - def __init__( - self, - units, - activation="tanh", - recurrent_activation="sigmoid", - use_bias=True, - kernel_initializer="glorot_uniform", - recurrent_initializer="orthogonal", - bias_initializer="zeros", - unit_forget_bias=True, - kernel_regularizer=None, - recurrent_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - recurrent_constraint=None, - bias_constraint=None, - dropout=0, - recurrent_dropout=0, - return_sequences=False, - return_state=False, - go_backwards=False, - stateful=False, - time_major=False, - unroll=False, - **kwargs - ): - super().__init__( - units, - activation, - recurrent_activation, - use_bias, - kernel_initializer, - recurrent_initializer, - bias_initializer, - unit_forget_bias, - kernel_regularizer, - recurrent_regularizer, - bias_regularizer, - activity_regularizer, - kernel_constraint, - recurrent_constraint, - bias_constraint, - dropout, - recurrent_dropout, - return_sequences, - return_state, - go_backwards, - stateful, - time_major, - unroll, - **kwargs - ) - self._could_use_gpu_kernel = self._could_use_gpu_kernel and tf.keras.mixed_precision.global_policy().name != "mixed_bfloat16" - - -class GRU(tf.keras.layers.GRU): - def __init__( - self, - units, - activation="tanh", - recurrent_activation="sigmoid", - use_bias=True, - kernel_initializer="glorot_uniform", - recurrent_initializer="orthogonal", - bias_initializer="zeros", - kernel_regularizer=None, - recurrent_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - kernel_constraint=None, - recurrent_constraint=None, - bias_constraint=None, - dropout=0, - recurrent_dropout=0, - return_sequences=False, - return_state=False, - go_backwards=False, - stateful=False, - unroll=False, - time_major=False, - reset_after=True, - **kwargs - ): - super().__init__( - units, - activation, - recurrent_activation, - use_bias, - kernel_initializer, - recurrent_initializer, - bias_initializer, - kernel_regularizer, - recurrent_regularizer, - bias_regularizer, - activity_regularizer, - kernel_constraint, - recurrent_constraint, - bias_constraint, - dropout, - recurrent_dropout, - return_sequences, - return_state, - go_backwards, - stateful, - unroll, - time_major, - reset_after, - **kwargs - ) - self._could_use_gpu_kernel = self._could_use_gpu_kernel and tf.keras.mixed_precision.global_policy().name != "mixed_bfloat16" diff --git a/tensorflow_asr/models/layers/residual.py b/tensorflow_asr/models/layers/residual.py index e1852b2f8f..96abcac074 100644 --- a/tensorflow_asr/models/layers/residual.py +++ b/tensorflow_asr/models/layers/residual.py @@ -14,11 +14,11 @@ from typing import Optional -import tensorflow as tf - +from tensorflow_asr import keras, tf from tensorflow_asr.models.base_layer import Layer +@keras.utils.register_keras_serializable(package=__name__) class Residual(Layer): """Applying residual addition to layers - Normal addition with constant factor @@ -30,38 +30,36 @@ class Residual(Layer): def __init__( self, factor="rezero", - initializer: tf.keras.initializers.Initializer = "zeros", - regularizer: Optional[tf.keras.regularizers.Regularizer] = None, + initializer: keras.initializers.Initializer = "zeros", + regularizer: Optional[keras.regularizers.Regularizer] = None, name="residual", **kwargs, ): - super().__init__(name=name, **kwargs) + super().__init__(name=name, trainable=False, **kwargs) self._factor = factor self._initializer = initializer self._regularizer = regularizer def build(self, input_shape): if self._factor == "rezero": - self._alpha = self.add_weight(name="alpha", shape=[], initializer=self._initializer, regularizer=self._regularizer, trainable=True) + self._alpha = self.add_weight( + name="alpha", + shape=[], + initializer=self._initializer, + regularizer=self._regularizer, + trainable=True, + dtype=self.variable_dtype, + ) else: assert isinstance(self._factor, (int, float)) - self._alpha = tf.convert_to_tensor(self._factor, dtype=self.compute_dtype) + self._alpha = self._factor return super().build(input_shape) def call(self, inputs): x, residual_x = inputs - alpha = tf.cast(self._alpha, residual_x.dtype) + alpha = tf.cast(tf.convert_to_tensor(self._alpha, dtype=self.dtype), residual_x.dtype) x = x + alpha * residual_x return x - def get_config(self): - config = { - "factor": self._factor, - "initializer": self._initializer, - "regularizer": self._regularizer, - } - base_config = super().get_config() - return dict(list(base_config.items()) + list(config.items())) - def compute_output_shape(self, input_shape): return input_shape[0] diff --git a/tensorflow_asr/models/layers/row_conv_1d.py b/tensorflow_asr/models/layers/row_conv_1d.py deleted file mode 100755 index b20b8990fe..0000000000 --- a/tensorflow_asr/models/layers/row_conv_1d.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import tensorflow as tf -from keras.utils import conv_utils -from tensorflow.python.ops import nn_ops - - -class RowConv1D(tf.keras.layers.Conv1D): - def __init__( - self, - filters, - future_context, - **kwargs, - ): - assert future_context >= 0, "Future context must be positive" - super().__init__(filters=filters, kernel_size=(future_context * 2 + 1), **kwargs) - self.future_context = future_context - - def build( - self, - input_shape, - ): - input_shape = tf.TensorShape(input_shape) - input_channel = self._get_input_channel(input_shape) - kernel_shape = self.kernel_size + (input_channel, self.filters) - - self.kernel = self.add_weight( - name="kernel", - shape=kernel_shape, - initializer=self.kernel_initializer, - regularizer=self.kernel_regularizer, - constraint=self.kernel_constraint, - trainable=True, - ) - # Add mask to remove weights on half of the kernel to the left - # (only keep future - # context) - left_kernel_dims = (self.future_context, input_channel, self.filters) - left_kernel = tf.fill(dims=left_kernel_dims, value=0) - right_kernel_dims = (self.future_context + 1, input_channel, self.filters) - right_kernel = tf.fill(dims=right_kernel_dims, value=1) - mask_kernel = tf.cast(tf.concat([left_kernel, right_kernel], axis=0), dtype=self.dtype) - self.kernel = tf.multiply(self.kernel, mask_kernel) - - if self.use_bias: - self.bias = self.add_weight( - name="bias", - shape=(self.filters,), - initializer=self.bias_initializer, - regularizer=self.bias_regularizer, - constraint=self.bias_constraint, - trainable=True, - ) - else: - self.bias = None - channel_axis = self._get_channel_axis() - self.input_spec = tf.keras.layers.InputSpec(ndim=self.rank + 2, axes={channel_axis: input_channel}) - - self.make_conv_op_input_shape = input_shape - self.make_input_channel = input_channel - self._padding_op = self._get_padding_op() - self._conv_op_data_format = conv_utils.convert_data_format(self.data_format, self.rank + 2) - self._convolution_op = nn_ops.Convolution( - input_shape, - filter_shape=self.kernel.shape, - dilation_rate=self.dilation_rate, - strides=self.strides, - padding=self._padding_op, - data_format=self._conv_op_data_format, - ) - self.built = True diff --git a/tensorflow_asr/models/layers/sequence_wise_bn.py b/tensorflow_asr/models/layers/sequence_wise_bn.py index c01e51eb6b..80d32c9665 100644 --- a/tensorflow_asr/models/layers/sequence_wise_bn.py +++ b/tensorflow_asr/models/layers/sequence_wise_bn.py @@ -12,14 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf + +from tensorflow_asr import keras, tf # https://arxiv.org/abs/1510.01378 -class SequenceBatchNorm(tf.keras.layers.Layer): - def __init__(self, name, time_major=False, **kwargs): - super(SequenceBatchNorm, self).__init__(name=name, **kwargs) +class SequenceBatchNorm(keras.layers.Layer): + def __init__(self, name, time_major=False, gamma_regularizer=None, beta_regularizer=None, **kwargs): + super().__init__(name=name, **kwargs) self.time_major = time_major + self.gamma_regularizer = keras.regularizers.get(gamma_regularizer) + self.beta_regularizer = keras.regularizers.get(beta_regularizer) def build( self, @@ -29,17 +32,19 @@ def build( shape=[input_shape[-1]], name="beta", initializer="zeros", - regularizer=None, + regularizer=self.beta_regularizer, constraint=None, trainable=True, + dtype=self.variable_dtype, ) self.gamma = self.add_weight( shape=[input_shape[-1]], name="gamma", initializer="ones", - regularizer=None, + regularizer=self.gamma_regularizer, constraint=None, trainable=True, + dtype=self.variable_dtype, ) def call( @@ -49,12 +54,12 @@ def call( ): mean, variance = tf.nn.moments(inputs, axes=[0, 1], keepdims=False) if self.time_major: - total_padded_frames = tf.cast(tf.shape(inputs)[0], tf.keras.backend.dtype(mean)) - batch_size = tf.cast(tf.shape(inputs)[1], tf.keras.backend.dtype(mean)) + total_padded_frames = tf.cast(tf.shape(inputs)[0], keras.backend.dtype(mean)) + batch_size = tf.cast(tf.shape(inputs)[1], keras.backend.dtype(mean)) else: - total_padded_frames = tf.cast(tf.shape(inputs)[1], tf.keras.backend.dtype(mean)) - batch_size = tf.cast(tf.shape(inputs)[0], tf.keras.backend.dtype(mean)) - total_unpadded_frames_batch = tf.math.count_nonzero(inputs, axis=[0, 1], keepdims=False, dtype=tf.keras.backend.dtype(mean)) + total_padded_frames = tf.cast(tf.shape(inputs)[1], keras.backend.dtype(mean)) + batch_size = tf.cast(tf.shape(inputs)[0], keras.backend.dtype(mean)) + total_unpadded_frames_batch = tf.math.count_nonzero(inputs, axis=[0, 1], keepdims=False, dtype=keras.backend.dtype(mean)) mean = (mean * total_padded_frames * batch_size) / total_unpadded_frames_batch variance = (variance * total_padded_frames * batch_size) / total_unpadded_frames_batch return tf.nn.batch_normalization( @@ -63,5 +68,5 @@ def call( variance=variance, offset=self.beta, scale=self.gamma, - variance_epsilon=tf.keras.backend.epsilon(), + variance_epsilon=keras.backend.epsilon(), ) diff --git a/tensorflow_asr/models/layers/subsampling.py b/tensorflow_asr/models/layers/subsampling.py index 840948571f..1dd86ef701 100644 --- a/tensorflow_asr/models/layers/subsampling.py +++ b/tensorflow_asr/models/layers/subsampling.py @@ -12,43 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf +import typing +from tensorflow_asr import keras, tf from tensorflow_asr.models.base_layer import Layer from tensorflow_asr.models.layers.convolution import Conv1D, Conv2D +from tensorflow_asr.models.layers.general import Activation from tensorflow_asr.utils import math_util, shape_util -class Subsampling(Layer): - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.time_reduction_factor = 1 - - def call(self, inputs): - outputs, outputs_length = inputs - outputs = self._create_mask(outputs, outputs_length) - outputs, outputs_length = self._update_mask_and_input_length(outputs, outputs_length) - return outputs, outputs_length - - def _create_mask(self, inputs, inputs_length): - mask = getattr(inputs, "_keras_mask", None) - if mask is None: - mask = tf.sequence_mask(inputs_length, maxlen=tf.shape(inputs)[1], dtype=tf.bool) - inputs._keras_mask = mask # pylint: disable=protected-access - return inputs - - def _update_mask_and_input_length(self, inputs, inputs_length): - raise NotImplementedError() - - def compute_output_shape(self, input_shape): - inputs_shape, inputs_length_shape = input_shape - reduced_time = math_util.legacy_get_reduced_length(inputs_shape[1], self.time_reduction_factor) - inputs_shape = list(inputs_shape) - inputs_shape[1] = reduced_time - return inputs_shape, inputs_length_shape - - -class TimeReduction(Subsampling): +@keras.utils.register_keras_serializable(package=__name__) +class TimeReduction(Layer): def __init__(self, factor: int, name: str = "TimeReduction", **kwargs): super().__init__(name=name, **kwargs) self.time_reduction_factor = factor @@ -57,28 +31,36 @@ def padding(self, time): new_time = tf.math.ceil(time / self.time_reduction_factor) * self.time_reduction_factor return tf.cast(new_time, dtype=tf.int32) - time - def _update_mask_and_input_length(self, inputs, inputs_length): - outputs_length = math_util.get_reduced_length(inputs_length, self.time_reduction_factor) - outputs = math_util.apply_mask(inputs, mask=tf.sequence_mask(outputs_length, maxlen=tf.shape(inputs)[1], dtype=tf.bool)) - return outputs, outputs_length - def call(self, inputs): outputs, outputs_length = inputs - outputs = self._create_mask(outputs, outputs_length) shape = shape_util.shape_list(outputs) outputs = tf.pad(outputs, [[0, 0], [0, self.padding(shape[1])], [0, 0]]) outputs = tf.reshape(outputs, [shape[0], -1, shape[-1] * self.time_reduction_factor]) - outputs, outputs_length = super().call([outputs, outputs_length]) + outputs_length = math_util.get_reduced_length(outputs_length, reduction_factor=self.time_reduction_factor) return outputs, outputs_length + def compute_mask(self, inputs, mask=None): + outputs, outputs_length = inputs + maxlen = tf.shape(outputs)[1] + maxlen, outputs_length = (math_util.get_reduced_length(length, self.time_reduction_factor) for length in (maxlen, outputs_length)) + mask = tf.sequence_mask(outputs_length, maxlen=maxlen, dtype=tf.bool) + return mask, None + + def compute_output_shape(self, input_shape): + output_shape, output_length_shape = input_shape + reduced_time = math_util.legacy_get_reduced_length(output_shape[1], self.time_reduction_factor) + output_shape = output_shape[:1] + (reduced_time,) + output_shape[2:] + return output_shape, output_length_shape + -class VggSubsampling(Subsampling): +@keras.utils.register_keras_serializable(package=__name__) +class VggSubsampling(Layer): def __init__( self, - filters: tuple or list = (32, 64), - kernel_size: int or list or tuple = 3, - pool_size: int or list or tuple = 2, - strides: int or list or tuple = 2, + filters: typing.Union[tuple, list] = (32, 64), + kernel_size: typing.Union[int, list, tuple] = 3, + pool_size: typing.Union[int, list, tuple] = 2, + strides: typing.Union[int, list, tuple] = 2, padding: str = "same", activation: str = "relu", kernel_regularizer=None, @@ -96,6 +78,7 @@ def __init__( kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activation=activation, + dtype=self.dtype, ) self.conv2 = Conv2D( filters=filters[0], @@ -106,8 +89,9 @@ def __init__( kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activation=activation, + dtype=self.dtype, ) - self.maxpool1 = tf.keras.layers.MaxPool2D(pool_size=pool_size, strides=strides, padding=padding, name="maxpool_1") + self.maxpool1 = keras.layers.MaxPool2D(pool_size=pool_size, strides=strides, padding="same", dtype=self.dtype, name="maxpool_1") self.conv3 = Conv2D( filters=filters[1], kernel_size=kernel_size, @@ -117,6 +101,7 @@ def __init__( kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activation=activation, + dtype=self.dtype, ) self.conv4 = Conv2D( filters=filters[1], @@ -127,29 +112,13 @@ def __init__( kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activation=activation, + dtype=self.dtype, ) - self.maxpool2 = tf.keras.layers.MaxPool2D(pool_size=pool_size, strides=strides, padding=padding, name="maxpool_2") + self.maxpool2 = keras.layers.MaxPool2D(pool_size=pool_size, strides=strides, padding="same", dtype=self.dtype, name="maxpool_2") self.time_reduction_factor = self.maxpool1.pool_size[0] * self.maxpool2.pool_size[0] - def _update_mask_and_input_length(self, inputs, inputs_length): - outputs_length = math_util.conv_output_length( - inputs_length, - self.maxpool1.pool_size[0], - padding=self.maxpool1.padding, - stride=self.maxpool1.strides[0], - ) - outputs_length = math_util.conv_output_length( - outputs_length, - self.maxpool2.pool_size[0], - padding=self.maxpool2.padding, - stride=self.maxpool2.strides[0], - ) - outputs = math_util.apply_mask(inputs, mask=tf.sequence_mask(outputs_length, maxlen=tf.shape(inputs)[1], dtype=tf.bool)) - return outputs, outputs_length - def call(self, inputs, training=False): - inputs, inputs_length = inputs - outputs = self._create_mask(inputs, inputs_length) + outputs, outputs_length = inputs outputs = self.conv1(outputs, training=training) outputs = self.conv2(outputs, training=training) @@ -160,176 +129,221 @@ def call(self, inputs, training=False): outputs = self.maxpool2(outputs, training=training) outputs = math_util.merge_two_last_dims(outputs) - outputs, outputs_length = super().call([outputs, inputs_length]) return outputs, outputs_length + def compute_mask(self, inputs, mask=None): + outputs, outputs_length = inputs + maxlen = tf.shape(outputs)[1] + for pool in (self.maxpool1, self.maxpool2): + maxlen, outputs_length = ( + math_util.conv_output_length( + length, + pool.pool_size[0], + padding=pool.padding, + stride=pool.strides[0], + ) + for length in (maxlen, outputs_length) + ) + mask = tf.sequence_mask(outputs_length, maxlen=maxlen, dtype=tf.bool) + return mask, None + def compute_output_shape(self, input_shape): - inputs_shape, inputs_length_shape = input_shape - outputs_shape = self.conv1.compute_output_shape(inputs_shape) + output_shape, output_length_shape = input_shape + outputs_shape = self.conv1.compute_output_shape(output_shape) outputs_shape = self.conv2.compute_output_shape(outputs_shape) outputs_shape = self.maxpool1.compute_output_shape(outputs_shape) outputs_shape = self.conv3.compute_output_shape(outputs_shape) outputs_shape = self.conv4.compute_output_shape(outputs_shape) outputs_shape = self.maxpool2.compute_output_shape(outputs_shape) - outputs_shape = list(outputs_shape[:2]) + [outputs_shape[2] * outputs_shape[3]] - return outputs_shape, inputs_length_shape + outputs_shape = outputs_shape[:2] + (outputs_shape[2] * outputs_shape[3],) + return outputs_shape, output_length_shape -class Conv2dSubsampling(Subsampling): +@keras.utils.register_keras_serializable(package=__name__) +class Conv2dSubsampling(Layer): def __init__( self, - nlayers: int, - filters: int, - strides: list or tuple or int = 2, - kernel_size: int or list or tuple = 3, - padding: str = "same", - norm: str = "none", - activation: str = "relu", + filters: list, + strides: list = [[2, 1], [2, 1]], + kernels: list = [[3, 3], [3, 3]], + paddings: list = ["causal", "causal"], + norms: list = ["none", "none"], + activations: list = ["relu", "relu"], kernel_regularizer=None, bias_regularizer=None, name="conv2d_subsampling", **kwargs, ): super().__init__(name=name, **kwargs) + assert len(filters) == len(strides) == len(kernels) == len(paddings) == len(norms) == len(activations) self.convs = [] self.time_reduction_factor = 1 - for i in range(nlayers): - subblock = tf.keras.Sequential(name=f"block_{i}") + for i in range(len(filters)): + subblock = keras.Sequential(name=f"block_{i}") subblock.add( Conv2D( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, + filters=filters[i], + kernel_size=kernels[i], + strides=strides[i], + padding=paddings[i], name=f"conv_{i}", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + dtype=self.dtype, ) ) - if norm == "batch": + if norms[i] == "batch": subblock.add( - tf.keras.layers.BatchNormalization( + keras.layers.BatchNormalization( name=f"bn_{i}", gamma_regularizer=kernel_regularizer, - beta_regularizer=bias_regularizer, + beta_regularizer=kernel_regularizer, + synchronized=True, + dtype=self.dtype, ) ) - elif norm == "layer": + elif norms[i] == "layer": subblock.add( - tf.keras.layers.LayerNormalization( + keras.layers.LayerNormalization( name=f"ln_{i}", gamma_regularizer=kernel_regularizer, - beta_regularizer=bias_regularizer, + beta_regularizer=kernel_regularizer, + dtype=self.dtype, ) ) - subblock.add(tf.keras.layers.Activation(activation, name=f"{activation}_{i}")) + subblock.add(Activation(activations[i], name=f"{activations[i]}_{i}", dtype=self.dtype)) self.convs.append(subblock) self.time_reduction_factor *= subblock.layers[0].strides[0] - def _update_mask_and_input_length(self, inputs, inputs_length): - outputs_length = inputs_length + def call(self, inputs, training=False): + outputs, outputs_length = inputs for block in self.convs: + outputs = block(outputs, training=training) outputs_length = math_util.conv_output_length( outputs_length, filter_size=block.layers[0].kernel_size[0], - padding=block.layers[0].padding, + padding=block.layers[0]._padding, stride=block.layers[0].strides[0], + dilation=block.layers[0].dilation_rate[0], ) - outputs = math_util.apply_mask(inputs, mask=tf.sequence_mask(outputs_length, maxlen=tf.shape(inputs)[1], dtype=tf.bool)) + outputs = math_util.merge_two_last_dims(outputs) return outputs, outputs_length - def call(self, inputs, training=False): - inputs, inputs_length = inputs - outputs = self._create_mask(inputs, inputs_length) + def compute_mask(self, inputs, mask=None): + outputs, outputs_length = inputs + maxlen = tf.shape(outputs)[1] for block in self.convs: - outputs = block(outputs, training=training) - outputs = math_util.merge_two_last_dims(outputs) - outputs, outputs_length = super().call([outputs, inputs_length]) - return outputs, outputs_length + maxlen, outputs_length = ( + math_util.conv_output_length( + length, + filter_size=block.layers[0].kernel_size[0], + padding=block.layers[0]._padding, + stride=block.layers[0].strides[0], + dilation=block.layers[0].dilation_rate[0], + ) + for length in (maxlen, outputs_length) + ) + mask = tf.sequence_mask(outputs_length, maxlen=maxlen, dtype=tf.bool) + return mask, None def compute_output_shape(self, input_shape): - outputs_shape, inputs_length_shape = input_shape + output_shape, output_length_shape = input_shape for block in self.convs: - outputs_shape = block.layers[0].compute_output_shape(outputs_shape) - outputs_shape = list(outputs_shape[:2]) + [outputs_shape[2] * outputs_shape[3]] - return tuple(outputs_shape), inputs_length_shape + output_shape = block.layers[0].compute_output_shape(output_shape) + output_shape = output_shape[:2] + (output_shape[2] * output_shape[3],) + return output_shape, output_length_shape -class Conv1dSubsampling(Subsampling): +@keras.utils.register_keras_serializable(package=__name__) +class Conv1dSubsampling(Layer): def __init__( self, - nlayers: int, - filters: int, - strides: int = 2, - kernel_size: int = 3, - padding: str = "causal", - norm: str = "none", - activation: str = "relu", + filters: list, + strides: list = [2, 2], + kernels: list = [3, 3], + paddings: list = ["causal", "causal"], + norms: list = ["none", "none"], + activations: list = ["relu", "relu"], kernel_regularizer=None, bias_regularizer=None, name="conv1d_subsampling", **kwargs, ): super().__init__(name=name, **kwargs) + assert len(filters) == len(strides) == len(kernels) == len(paddings) == len(norms) == len(activations) self.convs = [] self.time_reduction_factor = 1 - for i in range(nlayers): - subblock = tf.keras.Sequential(name=f"block_{i}") + for i in range(len(filters)): + subblock = keras.Sequential(name=f"block_{i}") subblock.add( Conv1D( - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, + filters=filters[i], + kernel_size=kernels[i], + strides=strides[i], + padding=paddings[i], name=f"conv_{i}", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + dtype=self.dtype, ) ) - if norm == "batch": + if norms[i] == "batch": subblock.add( - tf.keras.layers.BatchNormalization( + keras.layers.BatchNormalization( name=f"bn_{i}", gamma_regularizer=kernel_regularizer, - beta_regularizer=bias_regularizer, + beta_regularizer=kernel_regularizer, + synchronized=True, + dtype=self.dtype, ) ) - elif norm == "layer": + elif norms[i] == "layer": subblock.add( - tf.keras.layers.LayerNormalization( + keras.layers.LayerNormalization( name=f"ln_{i}", gamma_regularizer=kernel_regularizer, - beta_regularizer=bias_regularizer, + beta_regularizer=kernel_regularizer, + dtype=self.dtype, ) ) - subblock.add(tf.keras.layers.Activation(activation, name=f"{activation}_{i}")) + subblock.add(Activation(activations[i], name=f"{activations[i]}_{i}", dtype=self.dtype)) self.convs.append(subblock) self.time_reduction_factor *= subblock.layers[0].strides[0] - def _update_mask_and_input_length(self, inputs, inputs_length): - outputs_length = inputs_length + def call(self, inputs, training=False): + outputs, outputs_length = inputs + outputs = math_util.merge_two_last_dims(outputs) for block in self.convs: + outputs = block(outputs, training=training) outputs_length = math_util.conv_output_length( outputs_length, filter_size=block.layers[0].kernel_size[0], - padding=block.layers[0].padding, + padding=block.layers[0]._padding, stride=block.layers[0].strides[0], + dilation=block.layers[0].dilation_rate[0], ) - outputs = math_util.apply_mask(inputs, mask=tf.sequence_mask(outputs_length, maxlen=tf.shape(inputs)[1], dtype=tf.bool)) return outputs, outputs_length - def call(self, inputs, training=False): - inputs, inputs_length = inputs - outputs = self._create_mask(inputs, inputs_length) - outputs = math_util.merge_two_last_dims(outputs) + def compute_mask(self, inputs, mask=None): + outputs, outputs_length = inputs + maxlen = tf.shape(outputs)[1] for block in self.convs: - outputs = block(outputs, training=training) - outputs, outputs_length = super().call([outputs, inputs_length]) - return outputs, outputs_length + maxlen, outputs_length = ( + math_util.conv_output_length( + length, + filter_size=block.layers[0].kernel_size[0], + padding=block.layers[0]._padding, + stride=block.layers[0].strides[0], + dilation=block.layers[0].dilation_rate[0], + ) + for length in (maxlen, outputs_length) + ) + mask = tf.sequence_mask(outputs_length, maxlen=maxlen, dtype=tf.bool) + return mask, None def compute_output_shape(self, input_shape): - outputs_shape, inputs_length_shape = input_shape - outputs_shape = list(outputs_shape[:2]) + [outputs_shape[2] * outputs_shape[3]] + output_shape, output_length_shape = input_shape + output_shape = output_shape[:2] + (output_shape[2] * output_shape[3],) for block in self.convs: - outputs_shape = block.layers[0].compute_output_shape(outputs_shape) - return tuple(outputs_shape), inputs_length_shape + output_shape = block.layers[0].compute_output_shape(output_shape) + return output_shape, output_length_shape diff --git a/tensorflow_asr/models/transducer/__init__.py b/tensorflow_asr/models/transducer/__init__.py index dccb40d97a..9139bde684 100644 --- a/tensorflow_asr/models/transducer/__init__.py +++ b/tensorflow_asr/models/transducer/__init__.py @@ -1,4 +1,13 @@ -import tensorflow_asr.models.transducer.conformer -import tensorflow_asr.models.transducer.contextnet -import tensorflow_asr.models.transducer.rnn_transducer -import tensorflow_asr.models.transducer.transformer +import glob +from os.path import basename, dirname, isdir, isfile, join + +for fd in glob.glob(join(dirname(__file__), "*")): + if not isfile(fd) and not isdir(fd): + continue + if isfile(fd) and not fd.endswith(".py"): + continue + fd = fd if isdir(fd) else fd[:-3] + fd = basename(fd) + if fd.startswith("__"): + continue + __import__(f"{__name__}.{fd}") diff --git a/tensorflow_asr/models/transducer/base_transducer.py b/tensorflow_asr/models/transducer/base_transducer.py index 5e44229ed1..664a4c5e4e 100644 --- a/tensorflow_asr/models/transducer/base_transducer.py +++ b/tensorflow_asr/models/transducer/base_transducer.py @@ -12,19 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" https://arxiv.org/pdf/1811.06621.pdf """ +"""https://arxiv.org/pdf/1811.06621.pdf""" import collections -from typing import Dict +import typing -import tensorflow as tf +from keras.src import backend +from tensorflow_asr import keras, schemas, tf from tensorflow_asr.losses.rnnt_loss import RnntLoss from tensorflow_asr.models.base_layer import Layer from tensorflow_asr.models.base_model import BaseModel -from tensorflow_asr.models.layers.embedding import Embedding -from tensorflow_asr.models.layers.one_hot_blank import OneHotBlank -from tensorflow_asr.utils import data_util, layer_util, math_util, shape_util +from tensorflow_asr.models.layers.embedding import Embedding, OneHotBlank +from tensorflow_asr.models.layers.general import Activation +from tensorflow_asr.utils import env_util, layer_util, shape_util Hypothesis = collections.namedtuple("Hypothesis", ("index", "prediction", "states")) @@ -33,6 +34,7 @@ JOINT_MODES = ["add", "mul"] +@keras.utils.register_keras_serializable(package=__name__) class TransducerPrediction(Layer): def __init__( self, @@ -49,24 +51,24 @@ def __init__( projection_units: int = 0, kernel_regularizer=None, bias_regularizer=None, + activity_regularizer=None, + recurrent_regularizer=None, name="transducer_prediction", **kwargs, ): super().__init__(name=name, **kwargs) - if label_encoder_mode not in ["one_hot", "embedding"]: - raise ValueError("label_encode_mode must be either 'one_hot' or 'embedding'") - self.label_encoder_mode = label_encoder_mode - if self.label_encoder_mode == "embedding": - self.label_encoder = Embedding(vocab_size, embed_dim, regularizer=kernel_regularizer, name=self.label_encoder_mode) - else: - self.label_encoder = OneHotBlank(blank=blank, depth=vocab_size, name=self.label_encoder_mode) + assert label_encoder_mode in ("one_hot", "embedding"), "label_encode_mode must be either 'one_hot' or 'embedding'" + self.label_encoder = ( + Embedding(vocab_size, embed_dim, regularizer=kernel_regularizer, name=label_encoder_mode, dtype=self.dtype) + if label_encoder_mode == "embedding" + else OneHotBlank(blank=blank, depth=vocab_size, name=label_encoder_mode, dtype=self.dtype) + ) # Initialize rnn layers - RnnClass = layer_util.get_rnn(rnn_type) - self.rnns = [] + self.rnns: typing.List[typing.Union[keras.layers.GRU, keras.layers.LSTM, keras.layers.SimpleRNN]] = [] self.lns = [] self.projections = [] for i in range(num_rnns): - rnn = RnnClass( + rnn = layer_util.get_rnn(rnn_type)( units=rnn_units, return_sequences=True, name=f"{rnn_type}_{i}", @@ -76,18 +78,26 @@ def __init__( zero_output_for_mask=True, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + recurrent_regularizer=recurrent_regularizer, + use_cudnn=env_util.TF_CUDNN, + dtype=self.dtype, ) ln = ( - tf.keras.layers.LayerNormalization(name=f"ln_{i}", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer) + keras.layers.LayerNormalization( + name=f"ln_{i}", gamma_regularizer=kernel_regularizer, beta_regularizer=kernel_regularizer, dtype=self.dtype + ) if layer_norm else None ) projection = ( - tf.keras.layers.Dense( + keras.layers.Dense( projection_units, name=f"projection_{i}", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + dtype=self.dtype, ) if projection_units > 0 else None @@ -96,65 +106,74 @@ def __init__( self.lns.append(ln) self.projections.append(projection) - def get_initial_state(self): - """Get zeros states + def get_initial_state(self, batch_size: int): + """ + Get zeros states - Returns: - tf.Tensor: states having shape [num_rnns, 1 or 2, B, P] + Returns + ------- + tf.Tensor, shape [B, num_rnns, nstates, state_size] + Zero initialized states """ states = [] for rnn in self.rnns: - states.append(tf.stack(rnn.get_initial_state(tf.zeros([1, 1, 1], dtype=tf.float32)), axis=0)) - return tf.stack(states, axis=0) + states.append(tf.stack(rnn.get_initial_state(batch_size=batch_size), axis=0)) + return tf.transpose(tf.stack(states, axis=0), perm=[2, 0, 1, 3]) def call(self, inputs, training=False): - # inputs has shape [B, U] - # use tf.gather_nd instead of tf.gather for tflite conversion - outputs, prediction_length = inputs - outputs = self.label_encoder(outputs, training=training) - outputs = math_util.apply_mask(outputs, mask=tf.sequence_mask(prediction_length, maxlen=tf.shape(outputs)[1], dtype=tf.bool)) + outputs, outputs_length = inputs + outputs, outputs_length = self.label_encoder((outputs, outputs_length), training=training) for i, rnn in enumerate(self.rnns): - outputs = rnn(outputs, training=training, mask=getattr(outputs, "_keras_mask", None)) - outputs = outputs[0] + outputs, *_ = rnn(outputs, training=training) # mask auto populate if self.lns[i] is not None: outputs = self.lns[i](outputs, training=training) if self.projections[i] is not None: outputs = self.projections[i](outputs, training=training) - return outputs + return outputs, outputs_length - def recognize(self, inputs, states, tflite: bool = False): - """Recognize function for prediction network + def call_next(self, inputs, previous_decoder_states): + """ + Recognize function for prediction network from the previous predicted tokens - Args: - inputs (tf.Tensor): shape [1, 1] - states (tf.Tensor): shape [num_lstms, 2, B, P] + Parameters + ---------- + inputs : tf.Tensor, shape [B, 1] + previous_decoder_states : tf.Tensor, shape [B, num_rnns, nstates, rnn_units] - Returns: - tf.Tensor: outputs with shape [1, 1, P] - tf.Tensor: new states with shape [num_lstms, 2, 1, P] + Returns + ------- + Tuple[tf.Tensor, tf.Tensor], shapes ([B, 1, rnn_units], [B, num_rnns, nstates, rnn_units]) + Outputs, new states """ - if tflite and self.label_encoder_mode == "embedding": - outputs = self.label_encoder.recognize_tflite(inputs) - else: - outputs = self.label_encoder(inputs, training=False) - new_states = [] - for i, rnn in enumerate(self.rnns): - outputs = rnn(outputs, training=False, initial_state=tf.unstack(states[i], axis=0)) - new_states.append(tf.stack(outputs[1:])) - outputs = outputs[0] - if self.lns[i] is not None: - outputs = self.lns[i](outputs, training=False) - if self.projections[i] is not None: - outputs = self.projections[i](outputs, training=False) - return outputs, tf.stack(new_states, axis=0) + with tf.name_scope(f"{self.name}_call_next"): + previous_decoder_states = tf.transpose(previous_decoder_states, perm=[1, 2, 0, 3]) + outputs = self.label_encoder.call_next(inputs) + new_states = [] + for i, rnn in enumerate(self.rnns): + outputs, *_states = rnn(outputs, training=False, initial_state=tf.unstack(previous_decoder_states[i], axis=0)) + new_states.append(tf.stack(_states)) + if self.lns[i] is not None: + outputs = self.lns[i](outputs, training=False) + if self.projections[i] is not None: + outputs = self.projections[i](outputs, training=False) + return outputs, tf.transpose(tf.stack(new_states, axis=0), perm=[2, 0, 1, 3]) + + def compute_mask(self, inputs, mask=None): + return self.label_encoder.compute_mask(inputs, mask=mask) def compute_output_shape(self, input_shape): - predictions_shape, _ = input_shape - output_size = self.projections[-1].units if self.projections[-1] is not None else self.rnns[-1].units - outputs_shape = predictions_shape + (output_size,) - return tuple(outputs_shape) + output_shape, output_length_shape = input_shape + output_shape, output_length_shape = self.label_encoder.compute_output_shape((output_shape, output_length_shape)) + for i, rnn in enumerate(self.rnns): + output_shape = ( + self.projections[i].compute_output_shape(output_shape) + if self.projections[i] is not None + else rnn.compute_output_shape(output_shape)[0] + ) + return tuple(output_shape), tuple(output_length_shape) +@keras.utils.register_keras_serializable(package=__name__) class TransducerJointMerge(Layer): def __init__(self, joint_mode: str = "add", name="transducer_joint_merge", **kwargs): super().__init__(name=name, **kwargs) @@ -164,8 +183,8 @@ def __init__(self, joint_mode: str = "add", name="transducer_joint_merge", **kwa def compute_mask(self, inputs, mask=None): enc_out, pred_out = inputs - enc_mask = getattr(enc_out, "_keras_mask", None) # BT - pred_mask = getattr(pred_out, "_keras_mask", None) # BU + enc_mask = mask[0] if mask else backend.get_keras_mask(enc_out) # BT + pred_mask = mask[1] if mask else backend.get_keras_mask(pred_out) # BU auto_mask = None if enc_mask is not None: auto_mask = enc_mask[:, :, tf.newaxis] # BT1 @@ -174,8 +193,6 @@ def compute_mask(self, inputs, mask=None): auto_mask = auto_mask & pred_mask[:, tf.newaxis, :] # BT1 & B1U -> BTU else: auto_mask = pred_mask[:, tf.newaxis, :] - if mask is not None and auto_mask is not None: - auto_mask = auto_mask & mask mask = auto_mask return mask @@ -187,14 +204,14 @@ def call(self, inputs): outputs = tf.add(enc_out, pred_out) # broadcast operator else: outputs = tf.multiply(enc_out, pred_out) # broadcast operator - outputs = math_util.apply_mask(outputs, mask=self.compute_mask(inputs)) return outputs # [B, T, U, V] def compute_output_shape(self, input_shape): enc_shape, pred_shape = input_shape - return (enc_shape[0], enc_shape[1], pred_shape[1], enc_shape[-1]) + return enc_shape[0], enc_shape[1], pred_shape[1], enc_shape[-1] +@keras.utils.register_keras_serializable(package=__name__) class TransducerJoint(Layer): def __init__( self, @@ -207,6 +224,7 @@ def __init__( joint_mode: str = "add", kernel_regularizer=None, bias_regularizer=None, + activity_regularizer=None, name="tranducer_joint", **kwargs, ): @@ -217,19 +235,47 @@ def __init__( self.postjoint_linear = postjoint_linear if self.prejoint_encoder_linear: - self.ffn_enc = tf.keras.layers.Dense(joint_dim, name="enc", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer) + self.ffn_enc = keras.layers.Dense( + joint_dim, + name="enc", + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + dtype=self.dtype, + ) if self.prejoint_prediction_linear: - self.ffn_pred = tf.keras.layers.Dense(joint_dim, use_bias=False, name="pred", kernel_regularizer=kernel_regularizer) + self.ffn_pred = keras.layers.Dense( + joint_dim, + name="pred", + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + dtype=self.dtype, + ) - self.joint = TransducerJointMerge(joint_mode=joint_mode, name="merge") + self.joint = TransducerJointMerge(joint_mode=joint_mode, name="merge", dtype=self.dtype) activation = activation.lower() - self.activation = tf.keras.layers.Activation(activation, name=activation) + self.activation = Activation(activation, name=activation, dtype=self.dtype) if self.postjoint_linear: - self.ffn = tf.keras.layers.Dense(joint_dim, name="ffn", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer) + self.ffn = keras.layers.Dense( + joint_dim, + name="ffn", + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + dtype=self.dtype, + ) - self.ffn_out = tf.keras.layers.Dense(vocab_size, name="vocab", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer) + self.ffn_out = keras.layers.Dense( + vocab_size, + name="vocab", + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + dtype=self.dtype, + ) def call(self, inputs, training=False): # enc has shape [B, T, E] @@ -239,18 +285,21 @@ def call(self, inputs, training=False): enc_out = self.ffn_enc(enc_out, training=training) # [B, T, E] => [B, T, V] if self.prejoint_prediction_linear: pred_out = self.ffn_pred(pred_out, training=training) # [B, U, P] => [B, U, V] - outputs = self.joint([enc_out, pred_out]) # => [B, T, U, V] + outputs = self.joint((enc_out, pred_out)) # => [B, T, U, V] if self.postjoint_linear: outputs = self.ffn(outputs, training=training) outputs = self.activation(outputs, training=training) outputs = self.ffn_out(outputs, training=training) return outputs + def compute_mask(self, inputs, mask=None): + return self.joint.compute_mask(inputs, mask=mask) + def compute_output_shape(self, input_shape): encoder_shape, prediction_shape = input_shape batch_shape = encoder_shape[0] encoder_time_shape, prediction_time_shape = encoder_shape[1], prediction_shape[1] - return (batch_shape, encoder_time_shape, prediction_time_shape, self.ffn_out.units) + return batch_shape, encoder_time_shape, prediction_time_shape, self.ffn_out.units class Transducer(BaseModel): @@ -258,9 +307,10 @@ class Transducer(BaseModel): def __init__( self, - encoder: tf.keras.layers.Layer, blank: int, vocab_size: int, + speech_config: dict, + encoder: Layer, prediction_label_encoder_mode: str = "embedding", prediction_embed_dim: int = 512, prediction_num_rnns: int = 1, @@ -280,10 +330,13 @@ def __init__( postjoint_linear: bool = False, kernel_regularizer=None, bias_regularizer=None, + activity_regularizer=None, + recurrent_regularizer=None, name="transducer", **kwargs, ): - super().__init__(name=name, **kwargs) + super().__init__(speech_config=speech_config, name=name, **kwargs) + self.blank = blank self.encoder = encoder self.predict_net = TransducerPrediction( blank=blank, @@ -299,8 +352,11 @@ def __init__( projection_units=prediction_projection_units, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + recurrent_regularizer=recurrent_regularizer, trainable=prediction_trainable, name="prediction", + dtype=self.dtype, ) self.joint_net = TransducerJoint( vocab_size=vocab_size, @@ -312,688 +368,716 @@ def __init__( joint_mode=joint_mode, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, trainable=joint_trainable, name="joint", + dtype=self.dtype, ) self.time_reduction_factor = 1 - self.decoder_gwn_step = None - self.decoder_gwn_stddev = None - - def make(self, input_shape, prediction_shape=[None], batch_size=None, **kwargs): - inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, dtype=tf.float32) - inputs_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) - predictions = tf.keras.Input(shape=prediction_shape, batch_size=batch_size, dtype=tf.int32) - predictions_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) - self( - data_util.create_inputs( - inputs=inputs, - inputs_length=inputs_length, - predictions=predictions, - predictions_length=predictions_length, - ), - training=False, - ) - def compile( - self, - optimizer, - blank=0, - run_eagerly=None, - mxp="none", - ga_steps=None, - apply_gwn_config=None, - **kwargs, - ): - loss = RnntLoss(blank=blank) - super().compile( - loss=loss, - optimizer=optimizer, - run_eagerly=run_eagerly, - mxp=mxp, - ga_steps=ga_steps, - apply_gwn_config=apply_gwn_config, - **kwargs, - ) + def compile(self, optimizer, output_shapes=None, **kwargs): + loss = RnntLoss(blank=self.blank, output_shapes=output_shapes, name="rnnt_loss") + return super().compile(loss=loss, optimizer=optimizer, **kwargs) def apply_gwn(self): - if self.apply_gwn_config: + if self.gwn_config: original_weights = {} - if self.apply_gwn_config.get("encoder_step") is not None and self.apply_gwn_config.get("encoder_stddev") is not None: + if self.gwn_config.get("encoder_step") is not None and self.gwn_config.get("encoder_stddev") is not None: original_weights["encoder"] = tf.cond( - tf.greater_equal((self.optimizer.iterations), self.apply_gwn_config["encoder_step"]), - lambda: layer_util.add_gwn(self.encoder.trainable_weights, stddev=self.apply_gwn_config["encoder_stddev"]), + tf.greater_equal(self.optimizer.iterations, self.gwn_config["encoder_step"]), + lambda: layer_util.add_gwn(self.encoder.trainable_weights, stddev=self.gwn_config["encoder_stddev"]), lambda: self.encoder.trainable_weights, ) - if self.apply_gwn_config.get("predict_net_step") is not None and self.apply_gwn_config.get("predict_net_stddev") is not None: + if self.gwn_config.get("predict_net_step") is not None and self.gwn_config.get("predict_net_stddev") is not None: original_weights["predict_net"] = tf.cond( - tf.greater_equal((self.optimizer.iterations), self.apply_gwn_config["predict_net_step"]), - lambda: layer_util.add_gwn(self.predict_net.trainable_weights, stddev=self.apply_gwn_config["predict_net_stddev"]), + tf.greater_equal(self.optimizer.iterations, self.gwn_config["predict_net_step"]), + lambda: layer_util.add_gwn(self.predict_net.trainable_weights, stddev=self.gwn_config["predict_net_stddev"]), lambda: self.predict_net.trainable_weights, ) - if self.apply_gwn_config.get("joint_net_step") is not None and self.apply_gwn_config.get("joint_net_stddev") is not None: + if self.gwn_config.get("joint_net_step") is not None and self.gwn_config.get("joint_net_stddev") is not None: original_weights["joint_net"] = tf.cond( - tf.greater_equal((self.optimizer.iterations), self.apply_gwn_config["joint_net_step"]), - lambda: layer_util.add_gwn(self.joint_net.trainable_weights, stddev=self.apply_gwn_config["joint_net_stddev"]), + tf.greater_equal(self.optimizer.iterations, self.gwn_config["joint_net_step"]), + lambda: layer_util.add_gwn(self.joint_net.trainable_weights, stddev=self.gwn_config["joint_net_stddev"]), lambda: self.joint_net.trainable_weights, ) return original_weights return {} def remove_gwn(self, original_weights): - if self.apply_gwn_config: + if self.gwn_config: if original_weights.get("encoder") is not None: tf.cond( - tf.greater_equal((self.optimizer.iterations), self.apply_gwn_config["encoder_step"]), + tf.greater_equal(self.optimizer.iterations, self.gwn_config["encoder_step"]), lambda: layer_util.sub_gwn(original_weights["encoder"], self.encoder.trainable_weights), lambda: None, ) if original_weights.get("predict_net") is not None: tf.cond( - tf.greater_equal((self.optimizer.iterations), self.apply_gwn_config["predict_net_step"]), + tf.greater_equal(self.optimizer.iterations, self.gwn_config["predict_net_step"]), lambda: layer_util.sub_gwn(original_weights["predict_net"], self.predict_net.trainable_weights), lambda: None, ) if original_weights.get("joint_net") is not None: tf.cond( - tf.greater_equal((self.optimizer.iterations), self.apply_gwn_config["joint_net_step"]), + tf.greater_equal(self.optimizer.iterations, self.gwn_config["joint_net_step"]), lambda: layer_util.sub_gwn(original_weights["joint_net"], self.joint_net.trainable_weights), lambda: None, ) - def call(self, inputs, training=False): - enc, enc_length = self.encoder([inputs["inputs"], inputs["inputs_length"]], training=training) - pred = self.predict_net([inputs["predictions"], inputs["predictions_length"]], training=training) - logits = self.joint_net([enc, pred], training=training) - return data_util.create_logits(logits=logits, logits_length=enc_length) - - # -------------------------------- INFERENCES ------------------------------------- - - def preprocess(self, signals: tf.Tensor): - with tf.name_scope("preprocess"): - batch = tf.constant(0, dtype=tf.int32) - total_batch = tf.shape(signals)[0] - - inputs = tf.TensorArray( - dtype=tf.float32, - size=total_batch, - dynamic_size=False, - clear_after_read=False, - element_shape=tf.TensorShape(self.speech_featurizer.shape), - ) - - inputs_length = tf.TensorArray( - dtype=tf.int32, - size=total_batch, - dynamic_size=False, - clear_after_read=False, - element_shape=tf.TensorShape([]), - ) - - def condition(_batch, _inputs, _inputs_length): - return tf.less(_batch, total_batch) - - def body(_batch, _inputs, _inputs_length): - item_inputs = self.speech_featurizer.tf_extract(signals[_batch]) - item_inputs_length = tf.cast(tf.shape(item_inputs)[0], tf.int32) - _inputs = _inputs.write(_batch, item_inputs) - _inputs_length = _inputs_length.write(_batch, item_inputs_length) - return _batch + 1, _inputs, _inputs_length - - batch, inputs, inputs_length = tf.while_loop( - condition, - body, - loop_vars=[batch, inputs, inputs_length], - ) - inputs = math_util.pad_tfarray(inputs, blank=0.0, element_axis=0) - - return inputs.stack(), inputs_length.stack() - - def encoder_inference(self, features: tf.Tensor): - """Infer function for encoder (or encoders) - - Args: - features (tf.Tensor): features with shape [T, F, C] + def call(self, inputs: schemas.TrainInput, training=False): + features, features_length = self.feature_extraction((inputs.inputs, inputs.inputs_length), training=training) + enc, logits_length, *_ = self.encoder((features, features_length), training=training) + pred, *_ = self.predict_net((inputs.predictions, inputs.predictions_length), training=training) + logits = self.joint_net((enc, pred), training=training) + return schemas.TrainOutput( + logits=logits, + logits_length=logits_length, + ) - Returns: - tf.Tensor: output of encoders with shape [T, E] + def call_next( + self, + current_frames: tf.Tensor, + previous_tokens: tf.Tensor, + previous_decoder_states: tf.Tensor, + ): """ - with tf.name_scope("encoder"): - inputs_length = tf.expand_dims(tf.shape(features)[0], axis=0) - outputs = tf.expand_dims(features, axis=0) - outputs, inputs_length = self.encoder([outputs, inputs_length], training=False) - return tf.squeeze(outputs, axis=0) - - def decoder_inference(self, encoded: tf.Tensor, predicted: tf.Tensor, states: tf.Tensor, tflite: bool = False): - """Infer function for decoder - - Args: - encoded (tf.Tensor): output of encoder at each time step => shape [E] - predicted (tf.Tensor): last character index of predicted sequence => shape [] - states (nested lists of tf.Tensor): states returned by rnn layers - - Returns: - (ytu, new_states) + Decode current frame given previous predicted token and states + + Parameters + ---------- + current_frames : tf.Tensor, shape [B, 1, E] + Output of the encoder network of the current frame + previous_tokens : tf.Tensor, shape [B, 1] + Predicted token of the previous frame + previous_decoder_states : tf.Tensor, shape [B, num_rnns, nstates, state_size] + States got from previous frame + + Returns + ------- + Tuple[tf.Tensor, tf.Tensor], shapes ([B, 1, 1, V], [B, num_rnns, nstates, state_size]) + Output of joint network of the current frame, new states of prediction network """ - with tf.name_scope("decoder"): - encoded = tf.reshape(encoded, [1, 1, -1]) # [E] => [1, 1, E] - predicted = tf.reshape(predicted, [1, 1]) # [] => [1, 1] - y, new_states = self.predict_net.recognize(predicted, states, tflite=tflite) # [1, 1, P], states - ytu = tf.nn.log_softmax(self.joint_net([encoded, y], training=False)) # [1, 1, V] - ytu = tf.reshape(ytu, shape=[-1]) # [1, 1, V] => [V] + with tf.name_scope(f"{self.name}_call_next"): + y, new_states = self.predict_net.call_next(previous_tokens, previous_decoder_states) + ytu = self.joint_net([current_frames, y], training=False) + ytu = tf.nn.log_softmax(ytu) return ytu, new_states - # -------------------------------- GREEDY ------------------------------------- + def get_initial_encoder_states(self, batch_size=1): + return [] - def recognize(self, inputs: Dict[str, tf.Tensor]): - """ - RNN Transducer Greedy decoding - Args: - features (tf.Tensor): a batch of extracted features - input_length (tf.Tensor): a batch of extracted features length + def get_initial_decoder_states(self, batch_size=1): + return self.predict_net.get_initial_state(batch_size) - Returns: - tf.Tensor: a batch of decoded transcripts - """ - encoded, encoded_length = self.encoder([inputs["inputs"], inputs["inputs_length"]], training=False) - return self._perform_greedy_batch(encoded=encoded, encoded_length=encoded_length) + # -------------------------------- GREEDY ------------------------------------- - @tf.function(input_signature=[tf.TensorSpec(shape=[None, None], dtype=tf.float32)]) - def recognize_from_signal(self, signals: tf.Tensor): + def recognize(self, inputs: schemas.PredictInput, max_tokens_per_frame: int = 3, **kwargs): """ - RNN Transuder Greedy Decoding From Batch of Signals - - Args: - signals (tf.Tensor): batch of signals in shape [B, None] - - Returns: - tf.Tensor: batch of decoded transcripts in shape [B] + Recognize greedy from input signals + + Parameters + ---------- + inputs : schemas.PredictInput + + Returns + ------- + named tuple of + ( + tokens, will be feed to text_featurizer.detokenize or text_featurizer.detokenize_unicode_points, + next_encoder_states, if encoder does not have states, returns None, will be used to predict next chunk of audio, + next_tokens, will be used to predict next chunk of audio, + next_decoder_states, next states of predict_net, will be used to predict next chunk of audio, + ) """ - inputs, inputs_length = self.preprocess(signals) - return self.recognize(data_util.create_inputs(inputs=inputs, inputs_length=inputs_length)) + if self._batch_size == 1: + return self.recognize_single(inputs, max_tokens_per_frame=max_tokens_per_frame, **kwargs) + return self.recognize_batch(inputs, **kwargs) - def recognize_tflite(self, signal, predicted, states): + def recognize_batch(self, inputs: schemas.PredictInput, **kwargs): """ - Function to convert to tflite using greedy decoding (default streaming mode) - Args: - signal: tf.Tensor with shape [None] indicating a single audio signal - predicted: last predicted character with shape [] - states: lastest rnn states with shape [num_rnns, 1 or 2, 1, P] - - Return: - transcript: tf.Tensor of Unicode Code Points with shape [None] and dtype tf.int32 - predicted: last predicted character with shape [] - states: lastest rnn states with shape [num_rnns, 1 or 2, 1, P] + Ref: https://arxiv.org/pdf/1801.00841.pdf + This is a greedy decoding algorithm that greedily select the best token at each time step + Only apply for batch size > 1 """ - features = self.speech_featurizer.tf_extract(signal) - encoded = self.encoder_inference(features) - hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states, tflite=True) - transcript = self.text_featurizer.indices2upoints(hypothesis.prediction) - return transcript, hypothesis.index, hypothesis.states - - def recognize_tflite_with_timestamp(self, signal, predicted, states): - features = self.speech_featurizer.tf_extract(signal) - encoded = self.encoder_inference(features) - hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states, tflite=True) - indices = self.text_featurizer.normalize_indices(hypothesis.prediction) - upoints = tf.gather_nd(self.text_featurizer.upoints, tf.expand_dims(indices, axis=-1)) # [None, max_subword_length] - - num_samples = tf.cast(tf.shape(signal)[0], dtype=tf.float32) - total_time_reduction_factor = self.time_reduction_factor * self.speech_featurizer.frame_step - - stime = tf.range(0, num_samples, delta=total_time_reduction_factor, dtype=tf.float32) - stime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32) - - etime = tf.range(total_time_reduction_factor, num_samples, delta=total_time_reduction_factor, dtype=tf.float32) - etime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32) - - non_blank = tf.where(tf.not_equal(upoints, 0)) - non_blank_transcript = tf.gather_nd(upoints, non_blank) - non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank) - non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank) - - return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.index, hypothesis.states - - def _perform_greedy_batch( - self, - encoded: tf.Tensor, - encoded_length: tf.Tensor, - parallel_iterations: int = 10, - swap_memory: bool = False, - ): - with tf.name_scope("perform_greedy_batch"): - total_batch = tf.shape(encoded)[0] - batch = tf.constant(0, dtype=tf.int32) - - decoded = tf.TensorArray( - dtype=tf.int32, - size=total_batch, - dynamic_size=False, - clear_after_read=False, - element_shape=tf.TensorShape([None]), - ) - - def condition(batch, _): - return tf.less(batch, total_batch) - - def body(batch, decoded): - hypothesis = self._perform_greedy_v2( - encoded=encoded[batch], - encoded_length=encoded_length[batch], - predicted=tf.constant(self.text_featurizer.blank, dtype=tf.int32), - states=self.predict_net.get_initial_state(), - parallel_iterations=parallel_iterations, - swap_memory=swap_memory, + with tf.name_scope(f"{self.name}_recognize"): + features, features_length = self.feature_extraction((inputs.inputs, inputs.inputs_length), training=False) + encoded, encoded_length, next_encoder_states = self.encoder.call_next(features, features_length, inputs.previous_encoder_states) + + nframes = tf.expand_dims(encoded_length, axis=-1) # [B, 1] + batch_size, max_frames, _ = shape_util.shape_list(encoded) + # The current indices of the output of encoder, shape [B, 1] + frame_indices = tf.zeros([batch_size, 1], dtype=tf.int32, name="frame_indices") + # Previous predicted tokens, initially are blanks, shape [B, 1] + previous_tokens = inputs.previous_tokens + # Previous states of the prediction network, initially are zeros, shape [B, num_rnns, nstates, rnn_units] + previous_decoder_states = inputs.previous_decoder_states + # Assumption that number of tokens can not exceed (2 * the size of output of encoder + 1), this is for static runs like TPU or TFLite + max_tokens = max_frames * 2 + 1 + # All of the tokens that are getting recognized, initially are blanks, shape [B, nframes * 2 + 1] + tokens = tf.ones([batch_size, max_tokens], dtype=tf.int32, name="tokens") * self.blank + # The current indices of the token that are currently being recognized, shape [B, 1], the tokens indices are started with 1 so that any + # blank token recognized got updated to index 0 to avoid affecting results + tokens_indices = tf.ones([batch_size, 1], dtype=tf.int32, name="tokens_indices") + + def cond(_frame_indices, _previous_tokens, _previous_decoder_states, _tokens, _tokens_indices): + return tf.logical_not( # Reversed so that the loop check and continue + # One of the following condition met will terminate the loop + tf.logical_or( + # Stop when ALL of the indices of the output of the encoder reach the end + tf.math.reduce_all(tf.greater_equal(_frame_indices, nframes - 1)), + # Stop when ALL of the indices of recognized tokens reach the end + tf.math.reduce_all(tf.greater_equal(_tokens_indices, max_tokens - 1)), + ) ) - decoded = decoded.write(batch, hypothesis.prediction) - return batch + 1, decoded - - batch, decoded = tf.while_loop( - condition, - body, - loop_vars=[batch, decoded], - parallel_iterations=parallel_iterations, - swap_memory=swap_memory, - ) - - decoded = math_util.pad_tfarray(decoded, blank=self.text_featurizer.blank) - return self.text_featurizer.iextract(decoded.stack()) - def _perform_greedy( - self, - encoded: tf.Tensor, - encoded_length: tf.Tensor, - predicted: tf.Tensor, - states: tf.Tensor, - parallel_iterations: int = 10, - swap_memory: bool = False, - tflite: bool = False, - ): - with tf.name_scope("greedy"): - time = tf.constant(0, dtype=tf.int32) - total = encoded_length - - hypothesis = Hypothesis( - index=predicted, - prediction=tf.TensorArray( - dtype=tf.int32, - size=total, - dynamic_size=False, - clear_after_read=False, - element_shape=tf.TensorShape([]), - ), - states=states, - ) - - def condition(_time, _hypothesis): - return tf.less(_time, total) - - def body(_time, _hypothesis): - ytu, _states = self.decoder_inference( - # avoid using [index] in tflite - encoded=tf.gather_nd(encoded, tf.reshape(_time, shape=[1])), - predicted=_hypothesis.index, - states=_hypothesis.states, - tflite=tflite, + def body(_frame_indices, _previous_tokens, _previous_decoder_states, _tokens, _tokens_indices): + _current_frames = tf.expand_dims(tf.gather_nd(encoded, tf.minimum(_frame_indices, nframes - 1), batch_dims=1), axis=1) # [B, 1, E] + _log_softmax, _states = self.call_next(_current_frames, _previous_tokens, _previous_decoder_states) + _current_tokens = tf.reshape(tf.argmax(_log_softmax, axis=-1, output_type=tf.int32), [batch_size, 1]) # [B, 1, 1] -> [B, 1] + # conditions, blanks are ignored + _equal_blank = tf.equal(_current_tokens, self.blank) # [B, 1] + # if the token index >= max tokens, it's already finished, set to blank to ignore + _equal_blank = tf.logical_or(_equal_blank, tf.greater_equal(_tokens_indices, max_tokens)) + # if the frame index > nframes, it's already done, set to blank to ignore + _equal_blank = tf.logical_or(_equal_blank, tf.greater(_frame_indices, nframes)) + # update results + _update_tokens = tf.reshape(tf.where(_equal_blank, self.blank, _current_tokens), [batch_size]) # [B] + _update_tokens_indices = tf.where( + _equal_blank, 0, tf.minimum(tf.add(_tokens_indices, 1), max_tokens - 1) + ) # blanks are getting updated at index 0 to avoid affecting results + _tokens = tf.tensor_scatter_nd_update( + tensor=_tokens, + indices=tf.concat([tf.expand_dims(tf.range(batch_size, dtype=tf.int32), axis=-1), _update_tokens_indices], -1), # [B, 2] + updates=_update_tokens, # [B] ) - _predict = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax [] - - # something is wrong with tflite that drop support for tf.cond - # def equal_blank_fn(): return _hypothesis.index, _hypothesis.states - # def non_equal_blank_fn(): return _predict, _states # update if the new prediction is a non-blank - # _index, _states = tf.cond(tf.equal(_predict, blank), equal_blank_fn, non_equal_blank_fn) - - _equal = tf.equal(_predict, self.text_featurizer.blank) - _index = tf.where(_equal, _hypothesis.index, _predict) - _states = tf.where(_equal, _hypothesis.states, _states) - - _prediction = _hypothesis.prediction.write(_time, _predict) - _hypothesis = Hypothesis(index=_index, prediction=_prediction, states=_states) - - return _time + 1, _hypothesis - - time, hypothesis = tf.while_loop( - condition, - body, - loop_vars=[time, hypothesis], - parallel_iterations=parallel_iterations, - swap_memory=swap_memory, - ) - - return Hypothesis( - index=hypothesis.index, - prediction=hypothesis.prediction.stack(), - states=hypothesis.states, - ) - - def _perform_greedy_v2( - self, - encoded: tf.Tensor, - encoded_length: tf.Tensor, - predicted: tf.Tensor, - states: tf.Tensor, - parallel_iterations: int = 10, - swap_memory: bool = False, - tflite: bool = False, - ): - """Ref: https://arxiv.org/pdf/1801.00841.pdf""" - with tf.name_scope("greedy_v2"): - time = tf.constant(0, dtype=tf.int32) - pred_index = tf.constant(0, dtype=tf.int32) - total = encoded_length - - hypothesis = Hypothesis( - index=predicted, - prediction=tf.TensorArray( - dtype=tf.int32, - size=(2 * total), - dynamic_size=False, - clear_after_read=False, - element_shape=tf.TensorShape([]), - ), - states=states, + _tokens_indices = tf.where(_equal_blank, _tokens_indices, tf.minimum(tf.add(_tokens_indices, 1), max_tokens - 1)) + # update states + _frame_indices = tf.where(_equal_blank, tf.add(_frame_indices, 1), _frame_indices) # blank then next frames, else current frames + _previous_tokens = tf.where(_equal_blank, _previous_tokens, _current_tokens) # blank then keep prev tokens, else next tokens + _previous_decoder_states = tf.where( + tf.reshape(_equal_blank, [batch_size, 1, 1, 1]), _previous_decoder_states, _states + ) # blank then keep prev states, else next states # pylint: disable=line-too-long + return _frame_indices, _previous_tokens, _previous_decoder_states, _tokens, _tokens_indices + + ( + frame_indices, + next_tokens, + next_decoder_states, + tokens, + tokens_indices, + ) = tf.while_loop(cond, body, loop_vars=(frame_indices, previous_tokens, previous_decoder_states, tokens, tokens_indices)) + + return schemas.PredictOutput( + tokens=tokens, + next_tokens=next_tokens, + next_encoder_states=next_encoder_states, + next_decoder_states=next_decoder_states, ) - def condition(_time, _pred_index, _hypothesis): - return tf.logical_and(tf.less(_time, total), tf.less(_pred_index, 2 * total - 1)) - - def body(_time, _pred_index, _hypothesis): - ytu, _states = self.decoder_inference( - encoded=tf.gather_nd(encoded, tf.reshape(_time, shape=[1])), # avoid using [index] in tflite - predicted=_hypothesis.index, - states=_hypothesis.states, - tflite=tflite, - ) - _predict = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax [] - - _equal_blank = tf.equal(_predict, self.text_featurizer.blank) - _time = tf.where(_equal_blank, _time + 1, _time) - _index = tf.where(_equal_blank, _hypothesis.index, _predict) - _states = tf.where(_equal_blank, _hypothesis.states, _states) - _pred_index = tf.where(_equal_blank, _pred_index, _pred_index + 1) - _prediction = _hypothesis.prediction.write(_pred_index, _index) - - _hypothesis = Hypothesis(index=_index, prediction=_prediction, states=_states) - - return _time, _pred_index, _hypothesis - - time, pred_index, hypothesis = tf.while_loop( - condition, - body, - loop_vars=[time, pred_index, hypothesis], - parallel_iterations=parallel_iterations, - swap_memory=swap_memory, - ) - - return Hypothesis( - index=hypothesis.index, - prediction=hypothesis.prediction.stack(), - states=hypothesis.states, - ) - - # -------------------------------- BEAM SEARCH ------------------------------------- - - def recognize_beam(self, inputs: Dict[str, tf.Tensor], lm: bool = False): + def recognize_single(self, inputs: schemas.PredictInput, max_tokens_per_frame: int = 3, **kwargs): """ - RNN Transducer Beam Search - Args: - inputs (Dict[str, tf.Tensor]): Input dictionary containing "inputs" and "inputs_length" - lm (bool, optional): whether to use language model. Defaults to False. - - Returns: - tf.Tensor: a batch of decoded transcripts + Ref: https://arxiv.org/pdf/1801.00841.pdf + This is a greedy decoding algorithm that greedily select the best token at each time step + Only apply for batch size 1 """ - encoded, encoded_length = self.encoder([inputs["inputs"], inputs["inputs_length"]], training=False) - return self._perform_beam_search_batch(encoded=encoded, encoded_length=encoded_length, lm=lm) + with tf.name_scope(f"{self.name}_decode_greedy"): + features, features_length = self.feature_extraction((inputs.inputs, inputs.inputs_length), training=False) + encoded, encoded_length, next_encoder_states = self.encoder.call_next(features, features_length, inputs.previous_encoder_states) - def _perform_beam_search_batch( - self, - encoded: tf.Tensor, - encoded_length: tf.Tensor, - lm: bool = False, - parallel_iterations: int = 10, - swap_memory: bool = True, - ): - with tf.name_scope("perform_beam_search_batch"): - total_batch = tf.shape(encoded)[0] - batch = tf.constant(0, dtype=tf.int32) + frame = tf.zeros([1, 1], dtype=tf.int32) + nframes = encoded_length - decoded = tf.TensorArray( + previous_tokens = inputs.previous_tokens + token_index = tf.ones([], dtype=tf.int32) * -1 + tokens = tf.TensorArray( dtype=tf.int32, - size=total_batch, + size=tf.reshape(nframes, shape=[]) * max_tokens_per_frame, dynamic_size=False, clear_after_read=False, - element_shape=None, + element_shape=tf.TensorShape([]), + ) + num_tokens_per_frame = tf.TensorArray( + dtype=tf.int32, + size=tf.reshape(nframes, shape=[]), + dynamic_size=False, + clear_after_read=False, + element_shape=tf.TensorShape([]), ) - def condition(batch, _): - return tf.less(batch, total_batch) + previous_decoder_states = inputs.previous_decoder_states + + def condition( + _frame, + _nframes, + _previous_tokens, + _token_index, + _tokens, + _num_tokens_per_frame, + _max_tokens_per_frame, + _previous_decoder_states, + ): + return tf.less(_frame, _nframes) + + def body( + _frame, + _nframes, + _previous_tokens, + _token_index, + _tokens, + _num_tokens_per_frame, + _max_tokens_per_frame, + _previous_decoder_states, + ): + _current_frame = tf.expand_dims(tf.gather_nd(encoded, _frame, batch_dims=1), axis=1) # [1, 1, E] + _log_softmax, _states = self.call_next(_current_frame, _previous_tokens, _previous_decoder_states) + _current_tokens = tf.reshape(tf.argmax(_log_softmax, axis=-1, output_type=tf.int32), [1, 1]) # [1, 1, 1] -> [1, 1] + + ##################### conditions, blanks are ignored + _equal_blank = tf.equal(_current_tokens, self.blank) # [1, 1] + + ##################### step updates + __frame_index = tf.reshape(_frame, shape=[]) + __equal_blank_index = tf.reshape(_equal_blank, shape=[]) + # only non-blank tokens are counted in number of tokens per frame + _current_frame_num_tokens = tf.where( + __equal_blank_index, + _num_tokens_per_frame.read(__frame_index), + tf.add(_num_tokens_per_frame.read(__frame_index), 1), + ) + _num_tokens_per_frame = _num_tokens_per_frame.write(__frame_index, _current_frame_num_tokens) + # increase frame index if current tokens are blank or number of tokens per frame exceeds max tokens per frame + _frame = tf.where( + tf.logical_or(_equal_blank, tf.greater_equal(_current_frame_num_tokens, _max_tokens_per_frame)), + tf.add(_frame, 1), + _frame, + ) + # increase token index if current token is not blank, so that it can be appended to tokens array + _token_index = tf.where(__equal_blank_index, _token_index, tf.add(_token_index, 1)) + + ##################### content updates + # keep previous tokens if current tokens are blank + _current_tokens = tf.where(_equal_blank, _previous_tokens, _current_tokens) + # keep previous states if current tokens are blank + _states = tf.where(tf.reshape(_equal_blank, [1, 1, 1, 1]), _previous_decoder_states, _states) + # token_index initialized as -1, so that the first recognized token will be at index 0 + # therefore only update (append) tokens when token_index >= 0 + _tokens = tf.cond( + tf.greater_equal(_token_index, 0), + lambda: _tokens.write(_token_index, tf.reshape(_current_tokens, shape=[])), + lambda: _tokens, + ) - def body(batch, decoded): - hypothesis = self._perform_beam_search( - encoded[batch], - encoded_length[batch], - lm, - parallel_iterations=parallel_iterations, - swap_memory=swap_memory, + ##################### return + return ( + _frame, + _nframes, + _current_tokens, + _token_index, + _tokens, + _num_tokens_per_frame, + _max_tokens_per_frame, + _states, ) - decoded = decoded.write(batch, hypothesis.prediction) - return batch + 1, decoded - batch, decoded = tf.while_loop( + ( + frame, + nframes, + next_tokens, + token_index, + tokens, + num_tokens_per_frame, + max_tokens_per_frame, + next_decoder_states, + ) = tf.while_loop( condition, body, - loop_vars=[batch, decoded], - parallel_iterations=parallel_iterations, - swap_memory=True, - ) - - decoded = math_util.pad_tfarray(decoded, blank=self.text_featurizer.blank) - return self.text_featurizer.iextract(decoded.stack()) - - def _perform_beam_search( - self, - encoded: tf.Tensor, - encoded_length: tf.Tensor, - lm: bool = False, - parallel_iterations: int = 10, - swap_memory: bool = True, - tflite: bool = False, - ): - with tf.name_scope("beam_search"): - beam_width = tf.where( - tf.less(self.text_featurizer.decoder_config.beam_width, self.text_featurizer.num_classes), - self.text_featurizer.decoder_config.beam_width, - self.text_featurizer.num_classes - 1, + loop_vars=( + frame, + nframes, + previous_tokens, + token_index, + tokens, + num_tokens_per_frame, + max_tokens_per_frame, + previous_decoder_states, + ), + back_prop=False, ) - total = encoded_length - - def initialize_beam(dynamic=False): - return BeamHypothesis( - score=tf.TensorArray( - dtype=tf.float32, - size=beam_width if not dynamic else 0, - dynamic_size=dynamic, - element_shape=tf.TensorShape([]), - clear_after_read=False, - ), - indices=tf.TensorArray( - dtype=tf.int32, - size=beam_width if not dynamic else 0, - dynamic_size=dynamic, - element_shape=tf.TensorShape([]), - clear_after_read=False, - ), - prediction=tf.TensorArray( - dtype=tf.int32, - size=beam_width if not dynamic else 0, - dynamic_size=dynamic, - element_shape=None, - clear_after_read=False, - ), - states=tf.TensorArray( - dtype=tf.float32, - size=beam_width if not dynamic else 0, - dynamic_size=dynamic, - element_shape=tf.TensorShape(shape_util.shape_list(self.predict_net.get_initial_state())), - clear_after_read=False, - ), - ) - B = initialize_beam() - B = BeamHypothesis( - score=B.score.write(0, 0.0), - indices=B.indices.write(0, self.text_featurizer.blank), - prediction=B.prediction.write(0, tf.ones([total], dtype=tf.int32) * self.text_featurizer.blank), - states=B.states.write(0, self.predict_net.get_initial_state()), + return schemas.PredictOutput( + tokens=tf.reshape(tokens.stack(), shape=[1, -1]), + next_tokens=next_tokens, + next_encoder_states=next_encoder_states, + next_decoder_states=next_decoder_states, ) - def condition(time, total, B): - return tf.less(time, total) - - def body(time, total, B): - A = initialize_beam(dynamic=True) - A = BeamHypothesis( - score=A.score.unstack(B.score.stack()), - indices=A.indices.unstack(B.indices.stack()), - prediction=A.prediction.unstack(math_util.pad_tfarray(B.prediction, blank=self.text_featurizer.blank).stack()), - states=A.states.unstack(B.states.stack()), - ) - A_i = tf.constant(0, tf.int32) - B = initialize_beam() - - encoded_t = tf.gather_nd(encoded, tf.expand_dims(time, axis=-1)) - - def beam_condition(beam, beam_width, A, A_i, B): - return tf.less(beam, beam_width) - - def beam_body(beam, beam_width, A, A_i, B): - # get y_hat - y_hat_score, y_hat_score_index = tf.math.top_k(A.score.stack(), k=1, sorted=True) - y_hat_score = y_hat_score[0] - y_hat_index = tf.gather_nd(A.indices.stack(), y_hat_score_index) - y_hat_prediction = tf.gather_nd( - math_util.pad_tfarray(A.prediction, blank=self.text_featurizer.blank).stack(), - y_hat_score_index, - ) - y_hat_states = tf.gather_nd(A.states.stack(), y_hat_score_index) - - # remove y_hat from A - remain_indices = tf.range(0, tf.shape(A.score.stack())[0], dtype=tf.int32) - remain_indices = tf.gather_nd(remain_indices, tf.where(tf.not_equal(remain_indices, y_hat_score_index[0]))) - remain_indices = tf.expand_dims(remain_indices, axis=-1) - A = BeamHypothesis( - score=A.score.unstack(tf.gather_nd(A.score.stack(), remain_indices)), - indices=A.indices.unstack(tf.gather_nd(A.indices.stack(), remain_indices)), - prediction=A.prediction.unstack( - tf.gather_nd( - math_util.pad_tfarray(A.prediction, blank=self.text_featurizer.blank).stack(), - remain_indices, - ) - ), - states=A.states.unstack(tf.gather_nd(A.states.stack(), remain_indices)), - ) - A_i = tf.where(tf.equal(A_i, 0), A_i, A_i - 1) - - ytu, new_states = self.decoder_inference(encoded=encoded_t, predicted=y_hat_index, states=y_hat_states, tflite=tflite) - - def predict_condition(pred, A, A_i, B): - return tf.less(pred, self.text_featurizer.num_classes) - - def predict_body(pred, A, A_i, B): - new_score = y_hat_score + tf.gather_nd(ytu, tf.expand_dims(pred, axis=-1)) - - def true_fn(): - return ( - B.score.write(beam, new_score), - B.indices.write(beam, y_hat_index), - B.prediction.write(beam, y_hat_prediction), - B.states.write(beam, y_hat_states), - A.score, - A.indices, - A.prediction, - A.states, - A_i, - ) - - def false_fn(): - scatter_index = math_util.count_non_blank(y_hat_prediction, blank=self.text_featurizer.blank) - updated_prediction = tf.tensor_scatter_nd_update( - y_hat_prediction, - indices=tf.reshape(scatter_index, [1, 1]), - updates=tf.expand_dims(pred, axis=-1), - ) - return ( - B.score, - B.indices, - B.prediction, - B.states, - A.score.write(A_i, new_score), - A.indices.write(A_i, pred), - A.prediction.write(A_i, updated_prediction), - A.states.write(A_i, new_states), - A_i + 1, - ) - - b_score, b_indices, b_prediction, b_states, a_score, a_indices, a_prediction, a_states, A_i = tf.cond( - tf.equal(pred, self.text_featurizer.blank), true_fn=true_fn, false_fn=false_fn - ) - - B = BeamHypothesis(score=b_score, indices=b_indices, prediction=b_prediction, states=b_states) - A = BeamHypothesis(score=a_score, indices=a_indices, prediction=a_prediction, states=a_states) - - return pred + 1, A, A_i, B - - _, A, A_i, B = tf.while_loop( - predict_condition, - predict_body, - loop_vars=[0, A, A_i, B], - parallel_iterations=parallel_iterations, - swap_memory=swap_memory, - ) + # def recognize_tflite_with_timestamp(self, signal, predicted, states): + # features = self.speech_featurizer.tf_extract(signal) + # encoded = self.encoder_inference(features) + # hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, states, tflite=True) + # indices = self.text_featurizer.normalize_indices(hypothesis.prediction) + # upoints = tf.gather_nd(self.text_featurizer.upoints, tf.expand_dims(indices, axis=-1)) # [None, max_subword_length] + + # num_samples = tf.cast(tf.shape(signal)[0], dtype=tf.float32) + # total_time_reduction_factor = self.time_reduction_factor * self.speech_featurizer.frame_step + + # stime = tf.range(0, num_samples, delta=total_time_reduction_factor, dtype=tf.float32) + # stime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32) + + # etime = tf.range(total_time_reduction_factor, num_samples, delta=total_time_reduction_factor, dtype=tf.float32) + # etime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32) + + # non_blank = tf.where(tf.not_equal(upoints, 0)) + # non_blank_transcript = tf.gather_nd(upoints, non_blank) + # non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank) + # non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank) + + # return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.index, hypothesis.states + + # def _perform_greedy_batch( + # self, + # encoded: tf.Tensor, + # encoded_length: tf.Tensor, + # parallel_iterations: int = 10, + # swap_memory: bool = False, + # ): + # with tf.name_scope("perform_greedy_batch"): + # total_batch = tf.shape(encoded)[0] + # batch = tf.constant(0, dtype=tf.int32) + + # decoded = tf.TensorArray( + # dtype=tf.int32, + # size=total_batch, + # dynamic_size=False, + # clear_after_read=False, + # element_shape=tf.TensorShape([None]), + # ) + + # def condition(batch, _): + # return tf.less(batch, total_batch) + + # def body(batch, decoded): + # hypothesis = self._perform_greedy_v2( + # encoded=encoded[batch], + # encoded_length=encoded_length[batch], + # predicted=tf.constant(self.text_featurizer.blank, dtype=tf.int32), + # states=self.predict_net.get_initial_state(), + # parallel_iterations=parallel_iterations, + # swap_memory=swap_memory, + # ) + # decoded = decoded.write(batch, hypothesis.prediction) + # return batch + 1, decoded + + # batch, decoded = tf.while_loop( + # condition, + # body, + # loop_vars=[batch, decoded], + # parallel_iterations=parallel_iterations, + # swap_memory=swap_memory, + # ) + + # decoded = math_util.pad_tfarray(decoded, blank=self.text_featurizer.blank) + # return self.text_featurizer.detokenize(decoded.stack()) + + # def _perform_greedy( + # self, + # encoded: tf.Tensor, + # encoded_length: tf.Tensor, + # predicted: tf.Tensor, + # states: tf.Tensor, + # tflite: bool = False, + # ): + # """Ref: https://arxiv.org/pdf/1801.00841.pdf""" + # with tf.name_scope("greedy_v2"): + # time = tf.constant(0, dtype=tf.int32) + # pred_index = tf.constant(0, dtype=tf.int32) + # total = encoded_length + + # hypothesis = Hypothesis( + # index=predicted, + # prediction=tf.TensorArray( + # dtype=tf.int32, + # size=(2 * total), + # dynamic_size=False, + # clear_after_read=False, + # element_shape=tf.TensorShape([]), + # ), + # states=states, + # ) + + # def condition(_time, _pred_index, _hypothesis): + # return tf.logical_and(tf.less(_time, total), tf.less(_pred_index, 2 * total - 1)) + + # def body(_time, _pred_index, _hypothesis): + # ytu, _states = self.decoder_inference( + # encoded=tf.gather_nd(encoded, tf.reshape(_time, shape=[1])), # avoid using [index] in tflite + # predicted=_hypothesis.index, + # states=_hypothesis.states, + # tflite=tflite, + # ) + # _predict = tf.argmax(ytu, axis=-1, output_type=tf.int32) # => argmax [] + + # _equal_blank = tf.equal(_predict, self.text_featurizer.blank) + # _time = tf.where(_equal_blank, _time + 1, _time) + # _index = tf.where(_equal_blank, _hypothesis.index, _predict) + # _states = tf.where(_equal_blank, _hypothesis.states, _states) + # _pred_index = tf.where(_equal_blank, _pred_index, _pred_index + 1) + # _prediction = _hypothesis.prediction.write(_pred_index, _index) + + # _hypothesis = Hypothesis(index=_index, prediction=_prediction, states=_states) + + # return _time, _pred_index, _hypothesis + + # time, pred_index, hypothesis = tf.while_loop(condition, body, loop_vars=[time, pred_index, hypothesis]) + + # return Hypothesis( + # index=hypothesis.index, + # prediction=hypothesis.prediction.stack(), + # states=hypothesis.states, + # ) - return beam + 1, beam_width, A, A_i, B - - _, _, A, A_i, B = tf.while_loop( - beam_condition, - beam_body, - loop_vars=[0, beam_width, A, A_i, B], - parallel_iterations=parallel_iterations, - swap_memory=swap_memory, - ) - - return time + 1, total, B - - _, _, B = tf.while_loop( - condition, - body, - loop_vars=[0, total, B], - parallel_iterations=parallel_iterations, - swap_memory=swap_memory, - ) + # -------------------------------- BEAM SEARCH ------------------------------------- - scores = B.score.stack() - prediction = math_util.pad_tfarray(B.prediction, blank=self.text_featurizer.blank).stack() - if self.text_featurizer.decoder_config.norm_score: - prediction_lengths = math_util.count_non_blank(prediction, blank=self.text_featurizer.blank, axis=1) - scores /= tf.cast(prediction_lengths, dtype=scores.dtype) - - y_hat_score, y_hat_score_index = tf.math.top_k(scores, k=1) - y_hat_score = y_hat_score[0] - y_hat_index = tf.gather_nd(B.indices.stack(), y_hat_score_index) - y_hat_prediction = tf.gather_nd(prediction, y_hat_score_index) - y_hat_states = tf.gather_nd(B.states.stack(), y_hat_score_index) - - return Hypothesis(index=y_hat_index, prediction=y_hat_prediction, states=y_hat_states) - - # -------------------------------- TFLITE ------------------------------------- - - def make_tflite_function(self, timestamp: bool = False): - tflite_func = self.recognize_tflite_with_timestamp if timestamp else self.recognize_tflite - return tf.function( - tflite_func, - input_signature=[ - tf.TensorSpec([None], dtype=tf.float32), - tf.TensorSpec([], dtype=tf.int32), - tf.TensorSpec(self.predict_net.get_initial_state().get_shape(), dtype=tf.float32), - ], - ) + def recognize_beam(self, inputs: schemas.PredictInput, beam_width: int = 10, **kwargs): + return self.recognize(inputs=inputs, **kwargs) # TODO: Implement beam search + + # def _perform_beam_search_batch( + # self, + # encoded: tf.Tensor, + # encoded_length: tf.Tensor, + # lm: bool = False, + # parallel_iterations: int = 10, + # swap_memory: bool = True, + # ): + # with tf.name_scope("perform_beam_search_batch"): + # total_batch = tf.shape(encoded)[0] + # batch = tf.constant(0, dtype=tf.int32) + + # decoded = tf.TensorArray( + # dtype=tf.int32, + # size=total_batch, + # dynamic_size=False, + # clear_after_read=False, + # element_shape=None, + # ) + + # def condition(batch, _): + # return tf.less(batch, total_batch) + + # def body(batch, decoded): + # hypothesis = self._perform_beam_search( + # encoded[batch], + # encoded_length[batch], + # lm, + # parallel_iterations=parallel_iterations, + # swap_memory=swap_memory, + # ) + # decoded = decoded.write(batch, hypothesis.prediction) + # return batch + 1, decoded + + # batch, decoded = tf.while_loop( + # condition, + # body, + # loop_vars=[batch, decoded], + # parallel_iterations=parallel_iterations, + # swap_memory=True, + # ) + + # decoded = math_util.pad_tfarray(decoded, blank=self.text_featurizer.blank) + # return self.text_featurizer.detokenize(decoded.stack()) + + # def _perform_beam_search( + # self, + # encoded: tf.Tensor, + # encoded_length: tf.Tensor, + # lm: bool = False, + # parallel_iterations: int = 10, + # swap_memory: bool = True, + # tflite: bool = False, + # ): + # with tf.name_scope("beam_search"): + # beam_width = tf.where( + # tf.less(self.text_featurizer.decoder_config.beam_width, self.text_featurizer.num_classes), + # self.text_featurizer.decoder_config.beam_width, + # self.text_featurizer.num_classes - 1, + # ) + # total = encoded_length + + # def initialize_beam(dynamic=False): + # return BeamHypothesis( + # score=tf.TensorArray( + # dtype=tf.float32, + # size=beam_width if not dynamic else 0, + # dynamic_size=dynamic, + # element_shape=tf.TensorShape([]), + # clear_after_read=False, + # ), + # indices=tf.TensorArray( + # dtype=tf.int32, + # size=beam_width if not dynamic else 0, + # dynamic_size=dynamic, + # element_shape=tf.TensorShape([]), + # clear_after_read=False, + # ), + # prediction=tf.TensorArray( + # dtype=tf.int32, + # size=beam_width if not dynamic else 0, + # dynamic_size=dynamic, + # element_shape=None, + # clear_after_read=False, + # ), + # states=tf.TensorArray( + # dtype=tf.float32, + # size=beam_width if not dynamic else 0, + # dynamic_size=dynamic, + # element_shape=tf.TensorShape(shape_util.shape_list(self.predict_net.get_initial_state())), + # clear_after_read=False, + # ), + # ) + + # B = initialize_beam() + # B = BeamHypothesis( + # score=B.score.write(0, 0.0), + # indices=B.indices.write(0, self.text_featurizer.blank), + # prediction=B.prediction.write(0, tf.ones([total], dtype=tf.int32) * self.text_featurizer.blank), + # states=B.states.write(0, self.predict_net.get_initial_state(4)), + # ) + + # def condition(time, total, B): + # return tf.less(time, total) + + # def body(time, total, B): + # A = initialize_beam(dynamic=True) + # A = BeamHypothesis( + # score=A.score.unstack(B.score.stack()), + # indices=A.indices.unstack(B.indices.stack()), + # prediction=A.prediction.unstack(math_util.pad_tfarray(B.prediction, blank=self.text_featurizer.blank).stack()), + # states=A.states.unstack(B.states.stack()), + # ) + # A_i = tf.constant(0, tf.int32) + # B = initialize_beam() + + # encoded_t = tf.gather_nd(encoded, tf.expand_dims(time, axis=-1)) + + # def beam_condition(beam, beam_width, A, A_i, B): + # return tf.less(beam, beam_width) + + # def beam_body(beam, beam_width, A, A_i, B): + # # get y_hat + # y_hat_score, y_hat_score_index = tf.math.top_k(A.score.stack(), k=1, sorted=True) + # y_hat_score = y_hat_score[0] + # y_hat_index = tf.gather_nd(A.indices.stack(), y_hat_score_index) + # y_hat_prediction = tf.gather_nd( + # math_util.pad_tfarray(A.prediction, blank=self.text_featurizer.blank).stack(), + # y_hat_score_index, + # ) + # y_hat_states = tf.gather_nd(A.states.stack(), y_hat_score_index) + + # # remove y_hat from A + # remain_indices = tf.range(0, tf.shape(A.score.stack())[0], dtype=tf.int32) + # remain_indices = tf.gather_nd(remain_indices, tf.where(tf.not_equal(remain_indices, y_hat_score_index[0]))) + # remain_indices = tf.expand_dims(remain_indices, axis=-1) + # A = BeamHypothesis( + # score=A.score.unstack(tf.gather_nd(A.score.stack(), remain_indices)), + # indices=A.indices.unstack(tf.gather_nd(A.indices.stack(), remain_indices)), + # prediction=A.prediction.unstack( + # tf.gather_nd( + # math_util.pad_tfarray(A.prediction, blank=self.text_featurizer.blank).stack(), + # remain_indices, + # ) + # ), + # states=A.states.unstack(tf.gather_nd(A.states.stack(), remain_indices)), + # ) + # A_i = tf.where(tf.equal(A_i, 0), A_i, A_i - 1) + + # ytu, new_states = self.decoder_inference(encoded=encoded_t, predicted=y_hat_index, states=y_hat_states, tflite=tflite) + + # def predict_condition(pred, A, A_i, B): + # return tf.less(pred, self.text_featurizer.num_classes) + + # def predict_body(pred, A, A_i, B): + # new_score = y_hat_score + tf.gather_nd(ytu, tf.expand_dims(pred, axis=-1)) + + # def true_fn(): + # return ( + # B.score.write(beam, new_score), + # B.indices.write(beam, y_hat_index), + # B.prediction.write(beam, y_hat_prediction), + # B.states.write(beam, y_hat_states), + # A.score, + # A.indices, + # A.prediction, + # A.states, + # A_i, + # ) + + # def false_fn(): + # scatter_index = math_util.count_non_blank(y_hat_prediction, blank=self.text_featurizer.blank) + # updated_prediction = tf.tensor_scatter_nd_update( + # y_hat_prediction, + # indices=tf.reshape(scatter_index, [1, 1]), + # updates=tf.expand_dims(pred, axis=-1), + # ) + # return ( + # B.score, + # B.indices, + # B.prediction, + # B.states, + # A.score.write(A_i, new_score), + # A.indices.write(A_i, pred), + # A.prediction.write(A_i, updated_prediction), + # A.states.write(A_i, new_states), + # A_i + 1, + # ) + + # b_score, b_indices, b_prediction, b_states, a_score, a_indices, a_prediction, a_states, A_i = tf.cond( + # tf.equal(pred, self.text_featurizer.blank), true_fn=true_fn, false_fn=false_fn + # ) + + # B = BeamHypothesis(score=b_score, indices=b_indices, prediction=b_prediction, states=b_states) + # A = BeamHypothesis(score=a_score, indices=a_indices, prediction=a_prediction, states=a_states) + + # return pred + 1, A, A_i, B + + # _, A, A_i, B = tf.while_loop( + # predict_condition, + # predict_body, + # loop_vars=[0, A, A_i, B], + # parallel_iterations=parallel_iterations, + # swap_memory=swap_memory, + # ) + + # return beam + 1, beam_width, A, A_i, B + + # _, _, A, A_i, B = tf.while_loop( + # beam_condition, + # beam_body, + # loop_vars=[0, beam_width, A, A_i, B], + # parallel_iterations=parallel_iterations, + # swap_memory=swap_memory, + # ) + + # return time + 1, total, B + + # _, _, B = tf.while_loop( + # condition, + # body, + # loop_vars=[0, total, B], + # parallel_iterations=parallel_iterations, + # swap_memory=swap_memory, + # ) + + # scores = B.score.stack() + # prediction = math_util.pad_tfarray(B.prediction, blank=self.text_featurizer.blank).stack() + # if self.text_featurizer.decoder_config.norm_score: + # prediction_lengths = math_util.count_non_blank(prediction, blank=self.text_featurizer.blank, axis=1) + # scores /= tf.cast(prediction_lengths, dtype=scores.dtype) + + # y_hat_score, y_hat_score_index = tf.math.top_k(scores, k=1) + # y_hat_score = y_hat_score[0] + # y_hat_index = tf.gather_nd(B.indices.stack(), y_hat_score_index) + # y_hat_prediction = tf.gather_nd(prediction, y_hat_score_index) + # y_hat_states = tf.gather_nd(B.states.stack(), y_hat_score_index) + + # return Hypothesis(index=y_hat_index, prediction=y_hat_prediction, states=y_hat_states) + # return Hypothesis(index=y_hat_index, prediction=y_hat_prediction, states=y_hat_states) diff --git a/tensorflow_asr/models/transducer/conformer.py b/tensorflow_asr/models/transducer/conformer.py index 4a433e4236..5bc0ac5449 100644 --- a/tensorflow_asr/models/transducer/conformer.py +++ b/tensorflow_asr/models/transducer/conformer.py @@ -13,18 +13,18 @@ # limitations under the License. -import tensorflow as tf - +from tensorflow_asr import keras from tensorflow_asr.models.encoders.conformer import L2, ConformerEncoder from tensorflow_asr.models.transducer.base_transducer import Transducer -@tf.keras.utils.register_keras_serializable("tensorflow_asr.models.transducer") +@keras.utils.register_keras_serializable(package=__name__) class Conformer(Transducer): def __init__( self, blank: int, vocab_size: int, + speech_config: dict, encoder_subsampling: dict, encoder_dmodel: int = 144, encoder_num_blocks: int = 16, @@ -34,18 +34,24 @@ def __init__( encoder_interleave_relpe: bool = True, encoder_use_attention_causal_mask: bool = False, encoder_use_attention_auto_mask: bool = True, - encoder_kernel_size: int = 32, + encoder_kernel_size: int = 31, encoder_padding: str = "causal", encoder_ffm_scale_factor: int = 4, encoder_ffm_residual_factor: float = 0.5, encoder_mhsam_residual_factor: float = 1.0, + encoder_mhsam_use_attention_bias: bool = False, + encoder_mhsam_causal: bool = False, + encoder_mhsam_flash_attention: bool = False, encoder_convm_scale_factor: int = 2, encoder_convm_residual_factor: float = 1.0, + encoder_convm_use_group_conv: bool = False, + encoder_convm_dw_norm_type: str = "batch", encoder_dropout: float = 0.1, encoder_module_norm_position: str = "pre", encoder_block_norm_position: str = "post", encoder_memory_length: int = None, - encoder_mhsam_before_convm: bool = True, + encoder_history_size: int = None, + encoder_chunk_size: int = None, encoder_trainable: bool = True, prediction_label_encode_mode: str = "embedding", prediction_embed_dim: int = 512, @@ -65,11 +71,14 @@ def __init__( joint_mode: str = "add", joint_trainable: bool = True, kernel_regularizer=L2, - bias_regularizer=L2, + bias_regularizer=None, + activity_regularizer=None, + recurrent_regularizer=None, name: str = "conformer", **kwargs, ): super().__init__( + speech_config=speech_config, encoder=ConformerEncoder( subsampling=encoder_subsampling, dmodel=encoder_dmodel, @@ -85,15 +94,22 @@ def __init__( ffm_scale_factor=encoder_ffm_scale_factor, ffm_residual_factor=encoder_ffm_residual_factor, mhsam_residual_factor=encoder_mhsam_residual_factor, + mhsam_use_attention_bias=encoder_mhsam_use_attention_bias, + mhsam_causal=encoder_mhsam_causal, + mhsam_flash_attention=encoder_mhsam_flash_attention, convm_scale_factor=encoder_convm_scale_factor, convm_residual_factor=encoder_convm_residual_factor, + convm_use_group_conv=encoder_convm_use_group_conv, + convm_dw_norm_type=encoder_convm_dw_norm_type, dropout=encoder_dropout, module_norm_position=encoder_module_norm_position, block_norm_position=encoder_block_norm_position, memory_length=encoder_memory_length, - mhsam_before_convm=encoder_mhsam_before_convm, + history_size=encoder_history_size, + chunk_size=encoder_chunk_size, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, trainable=encoder_trainable, name="encoder", ), @@ -118,8 +134,13 @@ def __init__( joint_trainable=joint_trainable, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + recurrent_regularizer=recurrent_regularizer, name=name, **kwargs, ) self.dmodel = encoder_dmodel self.time_reduction_factor = self.encoder.conv_subsampling.time_reduction_factor + + def get_initial_encoder_states(self, batch_size=1): + return self.encoder.get_initial_state(batch_size) diff --git a/tensorflow_asr/models/transducer/contextnet.py b/tensorflow_asr/models/transducer/contextnet.py index 4f4ad083ef..d5f5f78ff6 100644 --- a/tensorflow_asr/models/transducer/contextnet.py +++ b/tensorflow_asr/models/transducer/contextnet.py @@ -14,18 +14,18 @@ from typing import List -import tensorflow as tf - +from tensorflow_asr import keras from tensorflow_asr.models.encoders.contextnet import L2, ContextNetEncoder from tensorflow_asr.models.transducer.base_transducer import Transducer -@tf.keras.utils.register_keras_serializable("tensorflow_asr.models.transducer") +@keras.utils.register_keras_serializable(package=__name__) class ContextNet(Transducer): def __init__( self, blank: int, vocab_size: int, + speech_config: dict, encoder_blocks: List[dict], encoder_alpha: float = 0.5, encoder_trainable: bool = True, @@ -47,11 +47,12 @@ def __init__( joint_mode: str = "add", joint_trainable: bool = True, kernel_regularizer=L2, - bias_regularizer=L2, + bias_regularizer=None, name: str = "contextnet", **kwargs, ): super().__init__( + speech_config=speech_config, encoder=ContextNetEncoder( blocks=encoder_blocks, alpha=encoder_alpha, diff --git a/tensorflow_asr/models/transducer/rnn_transducer.py b/tensorflow_asr/models/transducer/rnn_transducer.py deleted file mode 100644 index 8b32365b0e..0000000000 --- a/tensorflow_asr/models/transducer/rnn_transducer.py +++ /dev/null @@ -1,345 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" http://arxiv.org/abs/1811.06621 """ - -import tensorflow as tf - -from tensorflow_asr.models.base_layer import Layer -from tensorflow_asr.models.layers.subsampling import TimeReduction -from tensorflow_asr.models.transducer.base_transducer import Transducer -from tensorflow_asr.utils import layer_util, math_util - - -class Reshape(Layer): - def call(self, inputs): - outputs, outputs_length = inputs - outputs = math_util.merge_two_last_dims(outputs) - outputs = math_util.apply_mask(outputs, mask=tf.sequence_mask(outputs_length, maxlen=tf.shape(outputs)[1], dtype=tf.bool)) - return outputs, outputs_length - - def compute_output_shape(self, input_shape): - output_shape, output_length_shape = input_shape - output_shape = list(output_shape) - return (output_shape[0], output_shape[1], output_shape[2] * output_shape[3]), tuple(output_length_shape) - - -class RnnTransducerBlock(Layer): - def __init__( - self, - reduction_factor: int = 0, - dmodel: int = 640, - rnn_type: str = "lstm", - rnn_units: int = 2048, - rnn_unroll: bool = False, - layer_norm: bool = True, - kernel_regularizer=None, - bias_regularizer=None, - **kwargs, - ): - super().__init__(**kwargs) - - RnnClass = layer_util.get_rnn(rnn_type) - self.rnn = RnnClass( - units=rnn_units, - return_sequences=True, - name=rnn_type, - unroll=rnn_unroll, - return_state=True, - zero_output_for_mask=True, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - ) - - if layer_norm: - self.ln = tf.keras.layers.LayerNormalization(name="ln", gamma_regularizer=kernel_regularizer, beta_regularizer=bias_regularizer) - else: - self.ln = None - - if reduction_factor > 0: - self.reduction = TimeReduction(reduction_factor, name="reduction") - else: - self.reduction = None - - self.projection = tf.keras.layers.Dense(dmodel, name="projection", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer) - - def call(self, inputs, training=False): - outputs, outputs_length = inputs - outputs = self.rnn(outputs, training=training, mask=getattr(outputs, "_keras_mask", None)) - outputs = outputs[0] - if self.ln is not None: - outputs = self.ln(outputs, training=training) - if self.reduction is not None: - outputs, outputs_length = self.reduction([outputs, outputs_length]) - outputs = self.projection(outputs, training=training) - return outputs, outputs_length - - def compute_mask(self, inputs, mask=None): - if self.reduction is not None: - mask = self.reduction.compute_mask(inputs) - return mask - - def recognize(self, inputs, states): - outputs = inputs - outputs = self.rnn(outputs, training=False, initial_state=states, mask=getattr(outputs, "_keras_mask", None)) - new_states = tf.stack(outputs[1:], axis=0) - outputs = outputs[0] - if self.ln is not None: - outputs = self.ln(outputs, training=False) - if self.reduction is not None: - outputs, _ = self.reduction([outputs, tf.reshape(tf.shape(outputs)[1], [1])]) - outputs = self.projection(outputs, training=False) - return outputs, new_states - - def compute_output_shape(self, input_shape): - if self.reduction is None: - return tuple(input_shape) - return self.reduction.compute_output_shape(input_shape) - - -class RnnTransducerEncoder(Layer): - def __init__( - self, - reductions: dict = {0: 3, 1: 2}, - dmodel: int = 640, - nlayers: int = 8, - rnn_type: str = "lstm", - rnn_units: int = 2048, - rnn_unroll: bool = False, - layer_norm: bool = True, - kernel_regularizer=None, - bias_regularizer=None, - **kwargs, - ): - super().__init__(**kwargs) - self._dmodel = dmodel - self.reshape = Reshape(name="reshape") - - self.blocks = [ - RnnTransducerBlock( - reduction_factor=reductions.get(i, 0) if reductions else 0, # key is index, value is the factor - dmodel=dmodel, - rnn_type=rnn_type, - rnn_units=rnn_units, - rnn_unroll=rnn_unroll, - layer_norm=layer_norm, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - name=f"block_{i}", - ) - for i in range(nlayers) - ] - - self.time_reduction_factor = 1 - for block in self.blocks: - if block.reduction is not None: - self.time_reduction_factor *= block.reduction.time_reduction_factor - - def get_initial_state(self, batch_size=1): - """Get zeros states - - Returns: - tf.Tensor: states having shape [num_rnns, 1 or 2, 1, P] - """ - states = [] - for block in self.blocks: - states.append(tf.stack(block.rnn.get_initial_state(tf.zeros([batch_size, 1, 1], dtype=tf.float32)), axis=0)) - return tf.stack(states, axis=0) - - def call(self, inputs, training=False): - outputs, outputs_length = self.reshape(inputs) - for block in self.blocks: - outputs, outputs_length = block([outputs, outputs_length], training=training) - return outputs, outputs_length - - def recognize(self, inputs, states): - """Recognize function for encoder network - - Args: - inputs (tf.Tensor): shape [1, T, F, C] - states (tf.Tensor): shape [num_lstms, 1 or 2, 1, P] - - Returns: - tf.Tensor: outputs with shape [1, T, E] - tf.Tensor: new states with shape [num_lstms, 1 or 2, 1, P] - """ - outputs, _ = self.reshape([inputs, tf.reshape(tf.shape(inputs)[1], [1])]) - new_states = [] - for i, block in enumerate(self.blocks): - outputs, block_states = block.recognize(outputs, states=tf.unstack(states[i], axis=0)) - new_states.append(block_states) - return outputs, tf.stack(new_states, axis=0) - - def compute_output_shape(self, input_shape): - output_shape, output_length_shape = self.reshape.compute_output_shape(input_shape) - output_shape = list(output_shape) - output_shape[1] = None if output_shape[1] is None else math_util.legacy_get_reduced_length(output_shape[1], self.time_reduction_factor) - output_shape[2] = self._dmodel - return tuple(output_shape), output_length_shape - - -@tf.keras.utils.register_keras_serializable("tensorflow_asr.models.transducer") -class RnnTransducer(Transducer): - def __init__( - self, - blank: int, - vocab_size: int, - encoder_reductions: dict = {0: 3, 1: 2}, - encoder_dmodel: int = 640, - encoder_nlayers: int = 8, - encoder_rnn_type: str = "lstm", - encoder_rnn_units: int = 2048, - encoder_rnn_unroll: bool = False, - encoder_layer_norm: bool = False, - encoder_trainable: bool = True, - prediction_label_encode_mode: str = "embedding", - prediction_embed_dim: int = 320, - prediction_num_rnns: int = 2, - prediction_rnn_units: int = 2048, - prediction_rnn_type: str = "lstm", - prediction_rnn_implementation: int = 2, - prediction_rnn_unroll: bool = False, - prediction_layer_norm: bool = False, - prediction_projection_units: int = 640, - prediction_trainable: bool = True, - joint_dim: int = 640, - joint_activation: str = "tanh", - prejoint_encoder_linear: bool = True, - prejoint_prediction_linear: bool = True, - postjoint_linear: bool = False, - joint_mode: str = "add", - joint_trainable: bool = True, - kernel_regularizer=None, - bias_regularizer=None, - name="rnn_transducer", - **kwargs, - ): - super().__init__( - encoder=RnnTransducerEncoder( - reductions=encoder_reductions, - dmodel=encoder_dmodel, - nlayers=encoder_nlayers, - rnn_type=encoder_rnn_type, - rnn_units=encoder_rnn_units, - rnn_unroll=encoder_rnn_unroll, - layer_norm=encoder_layer_norm, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - trainable=encoder_trainable, - name="encoder", - ), - blank=blank, - vocab_size=vocab_size, - prediction_label_encoder_mode=prediction_label_encode_mode, - prediction_embed_dim=prediction_embed_dim, - prediction_num_rnns=prediction_num_rnns, - prediction_rnn_units=prediction_rnn_units, - prediction_rnn_type=prediction_rnn_type, - prediction_layer_norm=prediction_layer_norm, - prediction_rnn_implementation=prediction_rnn_implementation, - prediction_rnn_unroll=prediction_rnn_unroll, - prediction_projection_units=prediction_projection_units, - prediction_trainable=prediction_trainable, - joint_dim=joint_dim, - joint_activation=joint_activation, - prejoint_encoder_linear=prejoint_encoder_linear, - prejoint_prediction_linear=prejoint_prediction_linear, - postjoint_linear=postjoint_linear, - joint_mode=joint_mode, - joint_trainable=joint_trainable, - kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer, - name=name, - **kwargs, - ) - self.time_reduction_factor = self.encoder.time_reduction_factor - self.dmodel = encoder_dmodel - - def encoder_inference(self, features: tf.Tensor, states: tf.Tensor): - """Infer function for encoder (or encoders) - - Args: - features (tf.Tensor): features with shape [T, F, C] - states (tf.Tensor): previous states of encoders with shape [num_rnns, 1 or 2, 1, P] - - Returns: - tf.Tensor: output of encoders with shape [T, E] - tf.Tensor: states of encoders with shape [num_rnns, 1 or 2, 1, P] - """ - with tf.name_scope("encoder"): - outputs = tf.expand_dims(features, axis=0) - outputs, new_states = self.encoder.recognize(outputs, states) - return tf.squeeze(outputs, axis=0), new_states - - # -------------------------------- GREEDY ------------------------------------- - - def recognize_tflite(self, signal, predicted, encoder_states, prediction_states): - """ - Function to convert to tflite using greedy decoding (default streaming mode) - Args: - signal: tf.Tensor with shape [None] indicating a single audio signal - predicted: last predicted character with shape [] - encoder_states: lastest encoder states with shape [num_rnns, 1 or 2, 1, P] - prediction_states: lastest prediction states with shape [num_rnns, 1 or 2, 1, P] - - Return: - transcript: tf.Tensor of Unicode Code Points with shape [None] and dtype tf.int32 - predicted: last predicted character with shape [] - encoder_states: lastest encoder states with shape [num_rnns, 1 or 2, 1, P] - prediction_states: lastest prediction states with shape [num_rnns, 1 or 2, 1, P] - """ - features = self.speech_featurizer.tf_extract(signal) - encoded, new_encoder_states = self.encoder_inference(features, encoder_states) - hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states) - transcript = self.text_featurizer.indices2upoints(hypothesis.prediction) - return transcript, hypothesis.index, new_encoder_states, hypothesis.states - - def recognize_tflite_with_timestamp(self, signal, predicted, encoder_states, prediction_states): - features = self.speech_featurizer.tf_extract(signal) - encoded, new_encoder_states = self.encoder_inference(features, encoder_states) - hypothesis = self._perform_greedy(encoded, tf.shape(encoded)[0], predicted, prediction_states) - indices = self.text_featurizer.normalize_indices(hypothesis.prediction) - upoints = tf.gather_nd(self.text_featurizer.upoints, tf.expand_dims(indices, axis=-1)) # [None, max_subword_length] - - num_samples = tf.cast(tf.shape(signal)[0], dtype=tf.float32) - total_time_reduction_factor = self.time_reduction_factor * self.speech_featurizer.frame_step - - stime = tf.range(0, num_samples, delta=total_time_reduction_factor, dtype=tf.float32) - stime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32) - - etime = tf.range(total_time_reduction_factor, num_samples, delta=total_time_reduction_factor, dtype=tf.float32) - etime /= tf.cast(self.speech_featurizer.sample_rate, dtype=tf.float32) - - non_blank = tf.where(tf.not_equal(upoints, 0)) - non_blank_transcript = tf.gather_nd(upoints, non_blank) - non_blank_stime = tf.gather_nd(tf.repeat(tf.expand_dims(stime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank) - non_blank_etime = tf.gather_nd(tf.repeat(tf.expand_dims(etime, axis=-1), tf.shape(upoints)[-1], axis=-1), non_blank) - - return non_blank_transcript, non_blank_stime, non_blank_etime, hypothesis.index, new_encoder_states, hypothesis.states - - # -------------------------------- TFLITE ------------------------------------- - - def make_tflite_function( - self, - timestamp: bool = True, - ): - tflite_func = self.recognize_tflite_with_timestamp if timestamp else self.recognize_tflite - return tf.function( - tflite_func, - input_signature=[ - tf.TensorSpec([None], dtype=tf.float32), - tf.TensorSpec([], dtype=tf.int32), - tf.TensorSpec(self.encoder.get_initial_state().get_shape(), dtype=tf.float32), - tf.TensorSpec(self.predict_net.get_initial_state().get_shape(), dtype=tf.float32), - ], - ) diff --git a/tensorflow_asr/models/transducer/rnnt.py b/tensorflow_asr/models/transducer/rnnt.py new file mode 100644 index 0000000000..71280fd537 --- /dev/null +++ b/tensorflow_asr/models/transducer/rnnt.py @@ -0,0 +1,103 @@ +# Copyright 2020 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" http://arxiv.org/abs/1811.06621 """ + +from tensorflow_asr import keras +from tensorflow_asr.models.encoders.rnnt import RnnTransducerEncoder +from tensorflow_asr.models.transducer.base_transducer import Transducer + + +@keras.utils.register_keras_serializable(package=__name__) +class RnnTransducer(Transducer): + def __init__( + self, + blank: int, + vocab_size: int, + speech_config: dict, + encoder_reduction_positions: list = ["pre", "pre", "pre", "pre", "pre", "pre", "pre", "pre"], + encoder_reduction_factors: list = [6, 0, 0, 0, 0, 0, 0, 0], + encoder_dmodel: int = 640, + encoder_nlayers: int = 8, + encoder_rnn_type: str = "lstm", + encoder_rnn_units: int = 2048, + encoder_rnn_unroll: bool = False, + encoder_layer_norm: bool = False, + encoder_trainable: bool = True, + prediction_label_encode_mode: str = "embedding", + prediction_embed_dim: int = 320, + prediction_num_rnns: int = 2, + prediction_rnn_units: int = 2048, + prediction_rnn_type: str = "lstm", + prediction_rnn_implementation: int = 2, + prediction_rnn_unroll: bool = False, + prediction_layer_norm: bool = False, + prediction_projection_units: int = 640, + prediction_trainable: bool = True, + joint_dim: int = 640, + joint_activation: str = "tanh", + prejoint_encoder_linear: bool = True, + prejoint_prediction_linear: bool = True, + postjoint_linear: bool = False, + joint_mode: str = "add", + joint_trainable: bool = True, + kernel_regularizer=None, + bias_regularizer=None, + name="rnn_transducer", + **kwargs, + ): + super().__init__( + speech_config=speech_config, + encoder=RnnTransducerEncoder( + reduction_positions=encoder_reduction_positions, + reduction_factors=encoder_reduction_factors, + dmodel=encoder_dmodel, + nlayers=encoder_nlayers, + rnn_type=encoder_rnn_type, + rnn_units=encoder_rnn_units, + rnn_unroll=encoder_rnn_unroll, + layer_norm=encoder_layer_norm, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + trainable=encoder_trainable, + name="encoder", + ), + blank=blank, + vocab_size=vocab_size, + prediction_label_encoder_mode=prediction_label_encode_mode, + prediction_embed_dim=prediction_embed_dim, + prediction_num_rnns=prediction_num_rnns, + prediction_rnn_units=prediction_rnn_units, + prediction_rnn_type=prediction_rnn_type, + prediction_layer_norm=prediction_layer_norm, + prediction_rnn_implementation=prediction_rnn_implementation, + prediction_rnn_unroll=prediction_rnn_unroll, + prediction_projection_units=prediction_projection_units, + prediction_trainable=prediction_trainable, + joint_dim=joint_dim, + joint_activation=joint_activation, + prejoint_encoder_linear=prejoint_encoder_linear, + prejoint_prediction_linear=prejoint_prediction_linear, + postjoint_linear=postjoint_linear, + joint_mode=joint_mode, + joint_trainable=joint_trainable, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + name=name, + **kwargs, + ) + self.time_reduction_factor = self.encoder.time_reduction_factor + self.dmodel = encoder_dmodel + + def get_initial_encoder_states(self, batch_size=1): + return self.encoder.get_initial_state(batch_size) diff --git a/tensorflow_asr/models/transducer/transformer.py b/tensorflow_asr/models/transducer/transformer.py index f1da72112b..7547568f4b 100644 --- a/tensorflow_asr/models/transducer/transformer.py +++ b/tensorflow_asr/models/transducer/transformer.py @@ -12,18 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf - +from tensorflow_asr import keras from tensorflow_asr.models.encoders.transformer import TransformerEncoder from tensorflow_asr.models.transducer.base_transducer import Transducer -@tf.keras.utils.register_keras_serializable("tensorflow_asr.models.transducer") +@keras.utils.register_keras_serializable(package=__name__) class Transformer(Transducer): def __init__( self, blank: int, vocab_size: int, + speech_config: dict, encoder_subsampling: dict, encoder_dmodel: int = 512, encoder_dff: int = 1024, @@ -39,6 +39,10 @@ def __init__( encoder_pwffn_activation: str = "relu", encoder_dropout: float = 0.1, encoder_memory_length: int = None, + encoder_history_size: int = None, + encoder_chunk_size: int = None, + encoder_mha_causal: bool = False, + encoder_flash_attention: bool = False, encoder_trainable: bool = True, prediction_label_encode_mode: str = "embedding", prediction_embed_dim: int = 512, @@ -63,6 +67,7 @@ def __init__( **kwargs, ): super().__init__( + speech_config=speech_config, encoder=TransformerEncoder( subsampling=encoder_subsampling, num_blocks=encoder_num_blocks, @@ -79,6 +84,10 @@ def __init__( pwffn_activation=encoder_pwffn_activation, dropout=encoder_dropout, memory_length=encoder_memory_length, + history_size=encoder_history_size, + chunk_size=encoder_chunk_size, + relmha_causal=encoder_mha_causal, + flash_attention=encoder_flash_attention, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, trainable=encoder_trainable, diff --git a/tensorflow_asr/optimizers/__init__.py b/tensorflow_asr/optimizers/__init__.py index 3aa4501e17..9139bde684 100755 --- a/tensorflow_asr/optimizers/__init__.py +++ b/tensorflow_asr/optimizers/__init__.py @@ -1 +1,13 @@ -import tensorflow_asr.optimizers.schedules +import glob +from os.path import basename, dirname, isdir, isfile, join + +for fd in glob.glob(join(dirname(__file__), "*")): + if not isfile(fd) and not isdir(fd): + continue + if isfile(fd) and not fd.endswith(".py"): + continue + fd = fd if isdir(fd) else fd[:-3] + fd = basename(fd) + if fd.startswith("__"): + continue + __import__(f"{__name__}.{fd}") diff --git a/tensorflow_asr/optimizers/accumulation.py b/tensorflow_asr/optimizers/accumulation.py index e684956609..c0e01f6662 100644 --- a/tensorflow_asr/optimizers/accumulation.py +++ b/tensorflow_asr/optimizers/accumulation.py @@ -1,9 +1,10 @@ """ Gradient Accummulation for training TF2 custom training loop. -Copy and modified from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py. """ -import tensorflow as tf +from keras.src.optimizers.base_optimizer import BaseOptimizer + +from tensorflow_asr import tf class GradientAccumulator: @@ -11,66 +12,59 @@ class GradientAccumulator: # performed on assignment. To get the value, we call .value() which returns the # value on the current replica without synchronization. - def __init__( - self, - ga_steps, - trainable_variables, - name="ga", - ): + def __init__(self, ga_steps, optimizer: BaseOptimizer, name="ga"): self.name = name if ga_steps is None: raise ValueError("ga_steps must be defined") - if trainable_variables is None: - raise ValueError("trainable_variables must be defined") - self._ga_steps = tf.constant(ga_steps, dtype=tf.int32) - self._accum_step = tf.Variable( - tf.constant(0, dtype=tf.int32), - trainable=False, - synchronization=tf.VariableSynchronization.ON_READ, - aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA, - name="accum_step", - ) - self._gradients = [ - tf.Variable( - tf.zeros_like(v), - trainable=False, - synchronization=tf.VariableSynchronization.ON_READ, - aggregation=tf.VariableAggregation.NONE, - name=f"{name}_{i}", - ) - for i, v in enumerate(trainable_variables) - ] + self._ga_steps = ga_steps + self._optimizer = optimizer + self._accumulated_gradients = [] + self.built = False - @property - def step(self): - """Number of accumulated steps.""" - return self._accum_step.value() + def build(self, variables): + if not self._optimizer.built: + self._optimizer.build(variables) + for i, variable in enumerate(variables): + self._accumulated_gradients.append( + self._optimizer.add_variable_from_reference( + variable, + name="gradient_accumulator", + ) + ) + self.built = True @property def total_steps(self): return self._ga_steps - @property - def is_apply_step(self): - return tf.equal(self.step, self.total_steps) + # def is_apply_step(self, step): + # return tf.math.equal(step % self._ga_steps, 0) - @property - def gradients(self): - """The accumulated gradients on the current replica.""" - return tf.cond( # zeros gradients so that apply_gradient has no effect - self.is_apply_step, - lambda: list(gradient.value() for gradient in self._gradients), - lambda: list(tf.zeros_like(gradient) for gradient in self._gradients), - ) + def reset(self): + for g_acc in self._accumulated_gradients: + g_acc.assign(tf.zeros(g_acc.shape, dtype=g_acc.dtype)) - def accumulate(self, gradients): + def _get_acc_grads(self, trainable_variables): + # `trainable_variables` might have been filtered in previous + # processing steps, so we need to ensure the correct mapping between + # `self._accumulated_gradients` and `trainable_variables` + acc_grads = [self._accumulated_gradients[self._optimizer._get_variable_index(v)] for v in trainable_variables] + return acc_grads + + def accumulate(self, grads, trainable_variables): """Accumulates :obj:`gradients` on the current replica.""" - for accum_gradient, gradient in zip(self._gradients, gradients): - accum_gradient.assign_add(gradient, read_value=False) - self._accum_step.assign_add(1) + if not self.built: + self.build(trainable_variables) + # return [None if x is None else x if y is None else x + y for x, y in zip(gradients, per_ga_gradients)] + acc_grads = self._get_acc_grads(trainable_variables) + new_g_accs = [(g + acc_g) for g, acc_g in zip(grads, acc_grads)] + for n_g_acc, g_acc in zip(new_g_accs, acc_grads): + g_acc.assign(n_g_acc) - def reset(self): - """Resets the accumulated gradients on the current replica.""" - self._accum_step.assign(0) - for gradient in self._gradients: - gradient.assign(tf.zeros_like(gradient), read_value=False) + def gradients(self, grads, trainable_variables): + """Gets the gradients for the apply step.""" + if not self.built: + self.build(trainable_variables) + acc_grads = self._get_acc_grads(trainable_variables) + grads = [(g + acc_g) / self._ga_steps for g, acc_g in zip(grads, acc_grads)] + return grads diff --git a/tensorflow_asr/optimizers/regularizers.py b/tensorflow_asr/optimizers/regularizers.py new file mode 100644 index 0000000000..3b6a74d94d --- /dev/null +++ b/tensorflow_asr/optimizers/regularizers.py @@ -0,0 +1,50 @@ +from typing import List + +from tensorflow_asr import keras, tf + + +@keras.utils.register_keras_serializable(package=__name__) +class TimeDependentGaussianGradientNoise(keras.regularizers.Regularizer): + """ + Reference: https://openreview.net/pdf/ZY9xxQDMMu5Pk8ELfEz4.pdf + """ + + def __init__( + self, + mean: float = 0.0, + eta: float = 1.0, # {0.01, 0.3, 1.0} + gamma: float = 0.55, + ): + self.mean = mean + self.eta = eta + self.gamma = gamma + super().__init__() + + def noise(self, step: tf.Tensor, gradient: tf.Tensor): + sigma_squared = self.eta / ((1 + tf.cast(step, dtype=gradient.dtype)) ** self.gamma) + return tf.random.normal(mean=self.mean, stddev=tf.math.sqrt(sigma_squared), shape=tf.shape(gradient), dtype=gradient.dtype) + + def __call__(self, step: tf.Tensor, gradients: List[tf.Tensor]): + """ + Apply gaussian noise with time dependent to gradients + + Parameters + ---------- + step : tf.Tensor + Training step + gradients : List[tf.Tensor] + Gradients calculated from optimizer + + Returns + ------- + List[tf.Tensor] + Noise added gradients + """ + return list(tf.add(gradient, self.noise(step, gradient=gradient)) for gradient in gradients) + + def get_config(self): + return { + "mean": self.mean, + "eta": self.eta, + "gamma": self.gamma, + } diff --git a/tensorflow_asr/optimizers/schedules.py b/tensorflow_asr/optimizers/schedules.py index 7be74536c4..3322186967 100755 --- a/tensorflow_asr/optimizers/schedules.py +++ b/tensorflow_asr/optimizers/schedules.py @@ -12,24 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf +from tensorflow_asr import keras, tf -@tf.keras.utils.register_keras_serializable("tensorflow_asr.optimizers.schedules") -class TransformerSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): - def __init__(self, dmodel, initial_lr=1.0, warmup_steps=4000, max_lr=None, min_lr=None): +@keras.utils.register_keras_serializable(package=__name__) +class TransformerSchedule(keras.optimizers.schedules.LearningRateSchedule): + def __init__(self, dmodel, scale=1.0, warmup_steps=4000, max_lr=None, min_lr=None): super().__init__() self.dmodel = tf.convert_to_tensor(dmodel, dtype=tf.float32) - self.initial_lr = tf.convert_to_tensor(initial_lr, dtype=tf.float32) + self.scale = tf.convert_to_tensor(scale, dtype=tf.float32) self.warmup_steps = tf.convert_to_tensor(warmup_steps, dtype=tf.float32) - self.max_lr = max_lr - self.min_lr = min_lr + self.max_lr = eval(max_lr) if isinstance(max_lr, str) else max_lr + self.min_lr = eval(min_lr) if isinstance(min_lr, str) else min_lr - def __call__(self, step): + def __call__(self, current_step): # lr = (d_model^-0.5) * min(step^-0.5, step*(warm_up^-1.5)) - step = tf.cast(step, dtype=tf.float32) + step = tf.cast(current_step, dtype=tf.float32) lr = (self.dmodel**-0.5) * tf.math.minimum(step**-0.5, step * (self.warmup_steps**-1.5)) - lr = self.initial_lr * lr + lr = self.scale * lr if self.max_lr is not None: lr = tf.math.minimum(self.max_lr, lr) if self.min_lr is not None: @@ -38,42 +38,16 @@ def __call__(self, step): def get_config(self): return { - "dmodel": self.dmodel, - "initial_lr": self.initial_lr, - "warmup_steps": self.warmup_steps, + "dmodel": int(self.dmodel.numpy()), + "scale": float(self.scale.numpy()), + "warmup_steps": int(self.warmup_steps.numpy()), "max_lr": self.max_lr, "min_lr": self.min_lr, } -@tf.keras.utils.register_keras_serializable("tensorflow_asr.optimizers.schedules") -class BoundExponentialDecay(tf.keras.optimizers.schedules.ExponentialDecay): - def __init__(self, min_lr=0.0, **kwargs): - super().__init__(**kwargs) - self.min_lr = min_lr - - def __call__(self, step): - with tf.name_scope(self.name or "ExponentialDecay") as name: - initial_learning_rate = tf.convert_to_tensor(self.initial_learning_rate, name="initial_learning_rate") - dtype = initial_learning_rate.dtype - decay_steps = tf.cast(self.decay_steps, dtype) - decay_rate = tf.cast(self.decay_rate, dtype) - - global_step_recomp = tf.cast(step, dtype) - p = global_step_recomp / decay_steps - if self.staircase: - p = tf.math.floor(p) - new_lr = tf.multiply(initial_learning_rate, tf.pow(decay_rate, p), name=name) - return tf.maximum(self.min_lr, new_lr) - - def get_config(self): - return { - "min_lr": self.min_lr, - } - - -@tf.keras.utils.register_keras_serializable("tensorflow_asr.optimizers.schedules") -class CyclicTransformerSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): +@keras.utils.register_keras_serializable(package=__name__) +class CyclicTransformerSchedule(keras.optimizers.schedules.LearningRateSchedule): """This callback implements a cyclical learning rate policy (CLR) to the square root decay generally used to train transformers. The method cycles the learning rate around the square root decay LR with an amplitude @@ -89,11 +63,11 @@ class CyclicTransformerSchedule(tf.keras.optimizers.schedules.LearningRateSchedu It is inspired from the paper: # References - - [Cyclical Learning Rates for Training Neural Networks]( - https://arxiv.org/abs/1506.01186) + - [Cyclical Learning Rates for Training Neural Networks]( + https://arxiv.org/abs/1506.01186) """ - def __init__(self, dmodel, warmup_steps=4000, max_lr=None, step_size=None): + def __init__(self, dmodel, step_size, max_lr, warmup_steps=4000): """Applies triangular cyclic to the square root decay learning rate. Args: d_model: Model dimension @@ -102,13 +76,14 @@ def __init__(self, dmodel, warmup_steps=4000, max_lr=None, step_size=None): step_size: The size of the cyclic triangular half cycle. """ super().__init__() - self.dmodel = tf.cast(dmodel, tf.float32) - self.warmup_steps = tf.cast(warmup_steps, tf.float32) - self.max_lr = tf.cast(max_lr, tf.float32) - self.step_size = tf.cast(step_size, tf.float32) - - def __call__(self, step): - step = tf.cast(step, tf.float32) + self.dmodel = tf.convert_to_tensor(dmodel, tf.float32) + self.warmup_steps = tf.convert_to_tensor(warmup_steps, tf.float32) + self.max_lr = eval(max_lr) if isinstance(max_lr, str) else max_lr + self.max_lr = tf.convert_to_tensor(self.max_lr, tf.float32) + self.step_size = tf.convert_to_tensor(step_size, tf.float32) + + def __call__(self, current_step): + step = tf.cast(current_step, tf.float32) warmup = step * (self.warmup_steps**-1.5) lr = 2 * tf.math.rsqrt(step) lr = tf.math.rsqrt(self.dmodel) * tf.math.minimum(lr, warmup) @@ -121,8 +96,8 @@ def __call__(self, step): def get_config(self): return { - "dmodel": self.dmodel, - "warmup_steps": self.warmup_steps, - "max_lr": self.max_lr, - "step_size": self.step_size, + "dmodel": float(self.dmodel.numpy()), + "warmup_steps": int(self.warmup_steps.numpy()), + "max_lr": float(self.max_lr.numpy()), + "step_size": int(self.step_size.numpy()), } diff --git a/tensorflow_asr/schemas.py b/tensorflow_asr/schemas.py new file mode 100644 index 0000000000..312362015a --- /dev/null +++ b/tensorflow_asr/schemas.py @@ -0,0 +1,62 @@ +# Copyright 2023 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +import tensorflow as tf + + +class TrainInput(typing.NamedTuple): + inputs: tf.Tensor + inputs_length: tf.Tensor + predictions: tf.Tensor + predictions_length: tf.Tensor + + +class TrainOutput(typing.NamedTuple): + logits: tf.Tensor + logits_length: tf.Tensor + + +class TrainLabel(typing.NamedTuple): + labels: tf.Tensor + labels_length: tf.Tensor + + +class TrainData(typing.NamedTuple): + inputs: TrainInput + labels: TrainLabel + + +class PredictInput(typing.NamedTuple): + inputs: tf.Tensor + inputs_length: tf.Tensor + previous_tokens: typing.Optional[tf.Tensor] = None + previous_encoder_states: typing.Optional[tf.Tensor] = None + previous_decoder_states: typing.Optional[tf.Tensor] = None + + +class PredictOutput(typing.NamedTuple): + tokens: tf.Tensor + next_tokens: tf.Tensor + next_encoder_states: typing.Optional[tf.Tensor] = None + next_decoder_states: typing.Optional[tf.Tensor] = None + + +class PredictOutputWithTranscript(typing.NamedTuple): + transcript: tf.Tensor + tokens: tf.Tensor + next_tokens: tf.Tensor + next_encoder_states: typing.Optional[tf.Tensor] = None + next_decoder_states: typing.Optional[tf.Tensor] = None diff --git a/tensorflow_asr/scripts/__init__.py b/tensorflow_asr/scripts/__init__.py new file mode 100644 index 0000000000..95ef39dd17 --- /dev/null +++ b/tensorflow_asr/scripts/__init__.py @@ -0,0 +1,19 @@ +from tensorflow_asr.scripts import save, test, tflite, train +from tensorflow_asr.scripts.utils import create_datasets_metadata, create_mls_trans, create_tfrecords +from tensorflow_asr.utils import cli_util + + +def main(): + cli_util.run( + { + "train": train.main, + "test": test.main, + "tflite": tflite.main, + "save": save.main, + "utils": { + "create_mls_trans": create_mls_trans.main, + "create_tfrecords": create_tfrecords.main, + "create_datasets_metadata": create_datasets_metadata.main, + }, + } + ) diff --git a/tensorflow_asr/scripts/save.py b/tensorflow_asr/scripts/save.py new file mode 100644 index 0000000000..39eb8044fc --- /dev/null +++ b/tensorflow_asr/scripts/save.py @@ -0,0 +1,58 @@ +# Copyright 2024 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +from tensorflow_asr import keras, tf, tokenizers +from tensorflow_asr.configs import Config +from tensorflow_asr.models.base_model import BaseModel +from tensorflow_asr.utils import cli_util, env_util, keras_util + +logger = logging.getLogger(__name__) + + +def main( + config_path: str, + output: str, + h5: str = None, + bs: int = 2, + save_format: str = "h5", + repodir: str = os.getcwd(), +): + assert output + keras.backend.clear_session() + env_util.setup_seed() + + config = Config(config_path, training=False, repodir=repodir) + tokenizer = tokenizers.get(config) + tokenizer.make() + + logger.info(f"Configs: {str(config)}") + + model: BaseModel = keras_util.model_from_config(config.model_config) + model.tokenizer = tokenizer + model.make(batch_size=bs) + if h5 and tf.io.gfile.exists(h5): + model.load_weights(h5, skip_mismatch=False) + model.summary() + + model.save(output, save_format=save_format) + loaded_model: BaseModel = keras.models.load_model(output) + logger.info(loaded_model.to_json()) + loaded_model.summary() + + +if __name__ == "__main__": + cli_util.run(main) diff --git a/tensorflow_asr/scripts/test.py b/tensorflow_asr/scripts/test.py new file mode 100644 index 0000000000..67648c4a33 --- /dev/null +++ b/tensorflow_asr/scripts/test.py @@ -0,0 +1,92 @@ +# Copyright 2023 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import os + +from tensorflow_asr import datasets, tf, tokenizers # import to aid logging messages +from tensorflow_asr.callbacks import PredictLogger +from tensorflow_asr.configs import Config +from tensorflow_asr.models.base_model import BaseModel +from tensorflow_asr.utils import app_util, cli_util, env_util, file_util, keras_util + +logger = logging.getLogger(__name__) + + +def main( + config_path: str, + dataset_type: str, + datadir: str, + outputdir: str, + h5: str = None, + mxp: str = "none", + bs: int = 1, + jit_compile: bool = False, + repodir: str = os.getcwd(), +): + + outputdir = file_util.preprocess_paths(outputdir, isdir=True) + checkpoint_name = os.path.splitext(os.path.basename(h5))[0] + + env_util.setup_seed() + env_util.setup_mxp(mxp=mxp) + + config = Config(config_path, training=False, repodir=repodir, datadir=datadir) + batch_size = bs + + tokenizer = tokenizers.get(config) + tokenizer.make() + + logger.info(f"Configs: {str(config)}") + + model: BaseModel = keras_util.model_from_config(config.model_config) + model.tokenizer = tokenizer + model.make(batch_size=batch_size) + model.load_weights(h5, skip_mismatch=False) + model.jit_compile = jit_compile + model.summary() + + for test_data_config in config.data_config.test_dataset_configs: + if not test_data_config.name: + raise ValueError("Test dataset name must be provided") + logger.info(f"Testing dataset: {test_data_config.name}") + + output = os.path.join(outputdir, f"{test_data_config.name}-{checkpoint_name}.tsv") + + test_dataset = datasets.get(tokenizer=tokenizer, dataset_config=test_data_config, dataset_type=dataset_type) + test_data_loader = test_dataset.create(batch_size) + + overwrite = True + if tf.io.gfile.exists(output): + while overwrite not in ["yes", "no"]: + overwrite = input(f"File {output} exists, overwrite? (yes/no): ").lower() + overwrite = overwrite == "yes" + + if overwrite: + with file_util.save_file(output) as output_file_path: + model.predict( + test_data_loader, + verbose=1, + callbacks=[ + PredictLogger(test_dataset=test_dataset, output_file_path=output_file_path), + ], + ) + + evaluation_outputs = app_util.evaluate_hypotheses(output) + logger.info(f"Results:\n{evaluation_outputs.to_markdown()}") + + +if __name__ == "__main__": + cli_util.run(main) diff --git a/tensorflow_asr/scripts/tflite.py b/tensorflow_asr/scripts/tflite.py new file mode 100644 index 0000000000..e12d1b0932 --- /dev/null +++ b/tensorflow_asr/scripts/tflite.py @@ -0,0 +1,55 @@ +# Copyright 2023 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +from tensorflow_asr import keras, tf, tokenizers # import to aid logging messages +from tensorflow_asr.configs import Config +from tensorflow_asr.models.base_model import BaseModel +from tensorflow_asr.utils import app_util, cli_util, env_util, keras_util + +logger = logging.getLogger(__name__) + + +def main( + config_path: str, + output: str, + h5: str = None, + bs: int = 1, + beam_width: int = 0, + repodir: str = os.getcwd(), +): + assert output + keras.backend.clear_session() + env_util.setup_seed() + + config = Config(config_path, training=False, repodir=repodir) + tokenizer = tokenizers.get(config) + tokenizer.make() + + logger.info(f"Configs: {str(config)}") + + model: BaseModel = keras_util.model_from_config(config.model_config) + model.tokenizer = tokenizer + model.make(batch_size=bs) + if h5 and tf.io.gfile.exists(h5): + model.load_weights(h5, skip_mismatch=False) + model.summary() + + app_util.convert_tflite(model=model, output=output, batch_size=bs, beam_width=beam_width) + + +if __name__ == "__main__": + cli_util.run(main) diff --git a/tensorflow_asr/scripts/train.py b/tensorflow_asr/scripts/train.py new file mode 100644 index 0000000000..09fd4cb24a --- /dev/null +++ b/tensorflow_asr/scripts/train.py @@ -0,0 +1,125 @@ +# Copyright 2023 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os + +os.environ["TQDM_DISABLE"] = "1" + +from tensorflow_asr import callbacks, datasets, keras, tokenizers # import to aid logging messages +from tensorflow_asr.configs import Config +from tensorflow_asr.models.base_model import BaseModel +from tensorflow_asr.utils import cli_util, env_util, file_util, keras_util + +logger = logging.getLogger(__name__) + + +def main( + config_path: str, + modeldir: str, + datadir: str, + dataset_type: str, + dataset_cache: bool = False, + bs: int = None, + spx: int = 1, + devices: list = None, + tpu_address: str = None, + tpu_vm: bool = False, + device_type: str = "gpu", + mxp: str = "none", + jit_compile: bool = False, + ga_steps: int = None, + verbose: int = 1, + repodir: str = os.getcwd(), + clean: bool = False, + **kwargs, +): + if clean: + file_util.clean_dir(modeldir) + + keras.backend.clear_session() + env_util.setup_seed() + strategy = env_util.setup_strategy(device_type=device_type, devices=devices, tpu_address=tpu_address, tpu_vm=tpu_vm) + env_util.setup_mxp(mxp=mxp) + + config = Config(config_path, training=True, repodir=repodir, datadir=datadir, modeldir=modeldir, **kwargs) + + tokenizer = tokenizers.get(config) + tokenizer.make() + + train_dataset = datasets.get( + tokenizer=tokenizer, + dataset_config=config.data_config.train_dataset_config, + dataset_type=dataset_type, + dataset_cache=dataset_cache, + ) + eval_dataset = datasets.get( + tokenizer=tokenizer, + dataset_config=config.data_config.eval_dataset_config, + dataset_type=dataset_type, + dataset_cache=dataset_cache, + ) + + logger.info(f"Configs: {str(config)}") + + model_shapes, batch_size, padded_shapes = datasets.get_global_shape( + config, + strategy, + train_dataset, + eval_dataset, + batch_size=bs or config.learning_config.batch_size, + ) + ga_steps = ga_steps or config.learning_config.ga_steps or 1 + + train_data_loader = train_dataset.create(batch_size, ga_steps=ga_steps, padded_shapes=padded_shapes) + logger.info(f"train_data_loader.element_spec = {json.dumps(train_data_loader.element_spec, indent=2, default=str)}") + + eval_data_loader = eval_dataset.create(batch_size, padded_shapes=padded_shapes) + if eval_data_loader: + logger.info(f"eval_data_loader.element_spec = {json.dumps(eval_data_loader.element_spec, indent=2, default=str)}") + + with strategy.scope(): + model: BaseModel = keras_util.model_from_config(config.model_config) + model.tokenizer = tokenizer + output_shapes = model.make(**model_shapes) + if config.learning_config.pretrained: + model.load_weights( + file_util.preprocess_paths(config.learning_config.pretrained), + by_name=file_util.is_hdf5_filepath(config.learning_config.pretrained), + skip_mismatch=True, + ) + model.compile( + optimizer=keras.optimizers.get(config.learning_config.optimizer_config), + output_shapes=output_shapes, + steps_per_execution=spx, + jit_compile=jit_compile, + ga_steps=ga_steps or config.learning_config.ga_steps, + gwn_config=config.learning_config.gwn_config, + gradn_config=config.learning_config.gradn_config, + ) + model.summary() + model.fit( + train_data_loader, + epochs=config.learning_config.num_epochs, + verbose=verbose, + validation_data=eval_data_loader, + callbacks=callbacks.deserialize(config.learning_config.callbacks), + steps_per_epoch=train_dataset.total_steps, + validation_steps=eval_dataset.total_steps if eval_data_loader else None, + ) + + +if __name__ == "__main__": + cli_util.run(main) diff --git a/tensorflow_asr/helpers/__init__.py b/tensorflow_asr/scripts/utils/__init__.py similarity index 100% rename from tensorflow_asr/helpers/__init__.py rename to tensorflow_asr/scripts/utils/__init__.py diff --git a/tensorflow_asr/scripts/utils/create_datasets_metadata.py b/tensorflow_asr/scripts/utils/create_datasets_metadata.py new file mode 100644 index 0000000000..78bf099a98 --- /dev/null +++ b/tensorflow_asr/scripts/utils/create_datasets_metadata.py @@ -0,0 +1,62 @@ +# Copyright 2022 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import os + +from tensorflow_asr import datasets, tokenizers +from tensorflow_asr.configs import Config +from tensorflow_asr.utils import cli_util + +logger = logging.getLogger(__name__) + + +def main( + config_path: str, + datadir: str, + dataset_type: str, + repodir: str = os.getcwd(), +): + config = Config(config_path, repodir=repodir, datadir=datadir) + if not config.decoder_config.vocabulary: + raise ValueError("decoder_config.vocabulary must be defined") + + tokenizer = tokenizers.get(config) + + logger.info("Preparing train metadata ...") + config.data_config.train_dataset_config.drop_remainder = False + config.data_config.train_dataset_config.shuffle = False + train_dataset = datasets.get( + tokenizer=tokenizer, + dataset_config=config.data_config.train_dataset_config, + dataset_type=dataset_type, + ) + tokenizer.build(train_dataset) + tokenizer.make() + train_dataset.update_metadata() + + logger.info("Preparing eval metadata ...") + config.data_config.eval_dataset_config.drop_remainder = False + config.data_config.eval_dataset_config.shuffle = False + eval_dataset = datasets.get( + tokenizer=tokenizer, + dataset_config=config.data_config.eval_dataset_config, + dataset_type=dataset_type, + ) + eval_dataset.update_metadata() + + +if __name__ == "__main__": + cli_util.run(main) diff --git a/scripts/create_mls_trans.py b/tensorflow_asr/scripts/utils/create_mls_trans.py similarity index 69% rename from scripts/create_mls_trans.py rename to tensorflow_asr/scripts/utils/create_mls_trans.py index 3e1a82f3ac..3de555a681 100644 --- a/scripts/create_mls_trans.py +++ b/tensorflow_asr/scripts/utils/create_mls_trans.py @@ -16,29 +16,16 @@ import os import librosa -import tensorflow as tf -import tqdm + +from tensorflow_asr import keras # example usage: python create_mls_trans.py -dataset-home /mnt/datasets/mls --language polish --opus base_url = "https://dl.fbaipublicfiles.com/mls/" -langs = [ - "dutch", - "english", - "german", - "french", - "italian", - "portuguese", - "polish", - "spanish" -] - -splits = [ - "dev", - "test", - "train" -] +langs = ["dutch", "english", "german", "french", "italian", "portuguese", "polish", "spanish"] + +splits = ["dev", "test", "train"] chars = set() @@ -46,17 +33,19 @@ def prepare_split(dataset_dir, split, opus=False): # Setup necessary paths split_home = os.path.join(dataset_dir, split) - transcripts_infile = os.path.join(split_home, 'transcripts.txt') - transcripts_outfile = os.path.join(split_home, 'transcripts_tfasr.tsv') + transcripts_infile = os.path.join(split_home, "transcripts.txt") + transcripts_outfile = os.path.join(split_home, "transcripts_tfasr.tsv") audio_home = os.path.join(split_home, "audio") extension = ".opus" if opus else ".flac" transcripts = [] + from tqdm.auto import tqdm + # Make paths absolute, get durations and read chars to form alphabet later on - with open(transcripts_infile, 'r', encoding='utf8') as infile: - for line in tqdm.tqdm(infile.readlines(), desc=f"Reading from {transcripts_infile}..."): - file_id, transcript = line.strip().split('\t') - speaker_id, book_id, _ = file_id.split('_') + with open(transcripts_infile, "r", encoding="utf8") as infile: + for line in tqdm(infile.readlines(), desc=f"Reading from {transcripts_infile}...", disable=False): + file_id, transcript = line.strip().split("\t") + speaker_id, book_id, _ = file_id.split("_") audio_path = os.path.join(audio_home, speaker_id, book_id, f"{file_id}{extension}") y, sr = librosa.load(audio_path, sr=None) duration = librosa.get_duration(y, sr) @@ -65,15 +54,15 @@ def prepare_split(dataset_dir, split, opus=False): chars.add(char) # Write transcripts to file - with open(transcripts_outfile, 'w', encoding='utf8') as outfile: + with open(transcripts_outfile, "w", encoding="utf8") as outfile: outfile.write("PATH\tDURATION\tTRANSCRIPT\n") - for t in tqdm.tqdm(transcripts, desc=f"Writing to {transcripts_outfile}"): + for t in tqdm(transcripts, desc=f"Writing to {transcripts_outfile}", disable=False): outfile.write(t) def make_alphabet_file(filepath, chars_list, lang): print(f"Writing alphabet to {filepath}...") - with open(filepath, 'w', encoding='utf8') as outfile: + with open(filepath, "w", encoding="utf8") as outfile: outfile.write(f"# Alphabet file for language {lang}\n") outfile.write("Automatically generated. Do not edit\n#\n") for char in sorted(list(chars_list)): @@ -82,12 +71,12 @@ def make_alphabet_file(filepath, chars_list, lang): outfile.write("# end of file") -if __name__ == "__main__": +def main(): ap = argparse.ArgumentParser(description="Download and prepare MLS dataset in a given language") - ap.add_argument("--dataset-home", "-d", default=None, required=False, - help="Path to home directory to download and prepare dataset. Default to ~/.keras") - ap.add_argument("--language", "-l", type=str, choices=langs, default=None, required=True, - help="Any name of language included in MLS") + ap.add_argument( + "--dataset-home", "-d", default=None, required=False, help="Path to home directory to download and prepare dataset. Default to ~/.keras" + ) + ap.add_argument("--language", "-l", type=str, choices=langs, default=None, required=True, help="Any name of language included in MLS") ap.add_argument("--opus", default=False, action="store_true", help="Whether to use dataset in opus format or not") args = ap.parse_args() @@ -97,12 +86,7 @@ def make_alphabet_file(filepath, chars_list, lang): dataset_dir = os.path.join(dataset_home, subdir) full_url = base_url + fname - downloaded_file = tf.keras.utils.get_file( - fname, - full_url, - cache_subdir=dataset_home, - extract=True - ) + downloaded_file = keras.utils.get_file(fname, full_url, cache_subdir=dataset_home, extract=True) print(f"Dataset extracted to {dataset_dir}. Preparing...") @@ -110,3 +94,7 @@ def make_alphabet_file(filepath, chars_list, lang): prepare_split(dataset_dir=dataset_dir, split=split, opus=args.opus) make_alphabet_file(os.path.join(dataset_dir, "alphabet.txt"), chars, args.language) + + +if __name__ == "__main__": + main() diff --git a/tensorflow_asr/scripts/utils/create_tfrecords.py b/tensorflow_asr/scripts/utils/create_tfrecords.py new file mode 100644 index 0000000000..53bd414c86 --- /dev/null +++ b/tensorflow_asr/scripts/utils/create_tfrecords.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List + +from tensorflow_asr import datasets, tokenizers +from tensorflow_asr.configs import Config +from tensorflow_asr.utils import cli_util + + +def main( + config_path: str, + datadir: str, + modes: List[str], + repodir: str = os.getcwd(), + dataset_type: str = "tfrecord", +): + config = Config(config_path, repodir=repodir, datadir=datadir) + tokenizer = tokenizers.get(config=config) + tokenizer.make() + for mode in modes: + dat = datasets.get( + tokenizer=tokenizer, + dataset_config=getattr(config.data_config, f"{mode}_dataset_config"), + dataset_type=dataset_type, + ) + dat.create_tfrecords() + + +if __name__ == "__main__": + cli_util.run(main) diff --git a/tensorflow_asr/tokenizers.py b/tensorflow_asr/tokenizers.py new file mode 100755 index 0000000000..df2a75b23d --- /dev/null +++ b/tensorflow_asr/tokenizers.py @@ -0,0 +1,431 @@ +# Copyright 2020 Huy Le Nguyen (@nglehuy) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import codecs +import logging +import multiprocessing +import os +import unicodedata +from dataclasses import asdict, dataclass + +import sentencepiece as sp +import tensorflow_text as tft +from tensorflow_text.tools.wordpiece_vocab import bert_vocab_from_dataset as bert_vocab + +from tensorflow_asr import tf +from tensorflow_asr.abstracts import AbstractDataset, AbstractTokenizer +from tensorflow_asr.configs import Config, DecoderConfig +from tensorflow_asr.utils import file_util + +logger = logging.getLogger(__name__) + + +@dataclass +class TOKENIZER_TYPES: + CHARACTERS: str = "characters" + WORDPIECE: str = "wordpiece" + SENTENCEPIECE: str = "sentencepiece" + + +def get(config: Config): + if config.decoder_config.type == TOKENIZER_TYPES.SENTENCEPIECE: + logger.info("Loading SentencePiece model ...") + return SentencePieceTokenizer(config.decoder_config) + if config.decoder_config.type == TOKENIZER_TYPES.WORDPIECE: + logger.info("Loading wordpiece ...") + return WordPieceTokenizer(config.decoder_config) + if config.decoder_config.type == TOKENIZER_TYPES.CHARACTERS: + logger.info("Loading characters ...") + return CharTokenizer(config.decoder_config) + raise ValueError(f"type must be in {asdict(TOKENIZER_TYPES()).values()}, received {config.decoder_config.type}") + + +ENGLISH_CHARACTERS = [ + "", + " ", + "a", + "b", + "c", + "d", + "e", + "f", + "g", + "h", + "i", + "j", + "k", + "l", + "m", + "n", + "o", + "p", + "q", + "r", + "s", + "t", + "u", + "v", + "w", + "x", + "y", + "z", + "'", +] + + +class Tokenizer(AbstractTokenizer): + def __init__(self, decoder_config: DecoderConfig): + self.scorer = None + self.decoder_config = decoder_config + if self.decoder_config.vocabulary: + self.decoder_config.vocabulary = file_util.preprocess_paths(self.decoder_config.vocabulary) + self.blank = None + self.tokens2indices = {} + self.tokens = [] + self.num_classes = None + self.max_length = 0 + self.blank = self.decoder_config.blank_index + self.initialized = False + + def generator(self, *datasets: AbstractDataset): + from tqdm import tqdm + + for dataset in datasets: + dataset.read_entries() + for text in tqdm( + dataset.vocab_generator(), + total=dataset.num_entries, + desc=f"Building vocabulary in dataset {dataset.name}", + disable=False, + ): + data = self.normalize_text(text, self.decoder_config).numpy() + yield data + + def build(self, *datasets: AbstractDataset): + raise NotImplementedError("Tokenizer.build() must be implemented in subclasses") + + @property + def shape(self) -> list: + return [self.max_length if self.max_length > 0 else None] + + @property + def prepand_shape(self) -> list: + return [self.max_length + 1 if self.max_length > 0 else None] + + def update_length( + self, + length: int, + ): + self.max_length = max(self.max_length, length) + + def reset_length(self): + self.max_length = 0 + + @classmethod + def normalize_text(cls, text: tf.Tensor, decoder_config: DecoderConfig): + text = tf.strings.regex_replace(text, b"\xe2\x81\x87".decode("utf-8"), "") + text = tft.normalize_utf8(text, decoder_config.normalization_form) + text = tf.strings.regex_replace(text, r"\p{Cc}|\p{Cf}", " ") + text = tf.strings.regex_replace(text, decoder_config.unknown_token, "") + text = tf.strings.regex_replace(text, decoder_config.pad_token, "") + text = tf.strings.regex_replace(text, r" +", " ") + text = tf.strings.lower(text, encoding="utf-8") + text = tf.strings.strip(text) # remove trailing whitespace + return text + + def add_scorer(self, scorer: any = None): + """Add scorer to this instance""" + self.scorer = scorer + + def normalize_indices(self, indices: tf.Tensor) -> tf.Tensor: + """ + Remove -1 in indices by replacing them with blanks + Args: + indices (tf.Tensor): shape any + + Returns: + tf.Tensor: normalized indices with shape same as indices + """ + with tf.name_scope("normalize_indices"): + minus_one = -1 * tf.ones_like(indices, dtype=tf.int32) + blank_like = self.blank * tf.ones_like(indices, dtype=tf.int32) + return tf.where(tf.equal(indices, minus_one), blank_like, indices) + + def prepand_blank(self, text: tf.Tensor) -> tf.Tensor: + """Prepand blank index for transducer models""" + return tf.concat([[self.blank], text], 0) + + def tokenize(self, text: str) -> tf.Tensor: + raise NotImplementedError() + + def detokenize(self, indices: tf.Tensor) -> tf.Tensor: + raise NotImplementedError() + + def detokenize_unicode_points(self, indices: tf.Tensor) -> tf.Tensor: + raise NotImplementedError() + + +class CharTokenizer(Tokenizer): + """ + Extract text feature based on char-level granularity. + By looking up the vocabulary table, each line of transcript will be + converted to a sequence of integer indexes. + """ + + def make(self): + lines = [] + if self.decoder_config.vocabulary is not None: + with codecs.open(self.decoder_config.vocabulary, "r", "utf-8") as fin: + lines.extend(fin.readlines()) + else: + lines = ENGLISH_CHARACTERS + self.tokens = [] + for line in lines: + line = unicodedata.normalize(self.decoder_config.normalization_form, line.lower()).strip("\n") + if line.startswith("#") or not line: + continue + if line == "": + line = "" + self.tokens.append(line) + if self.blank is None: + self.blank = len(self.tokens) # blank not at zero + self.num_classes = len(self.tokens) + self.indices = tf.range(self.num_classes, dtype=tf.int32) + self.tokenizer = tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer(keys=self.tokens, values=self.indices, key_dtype=tf.string, value_dtype=tf.int32), + default_value=self.blank, + ) + self.detokenizer = tf.lookup.StaticHashTable( + tf.lookup.KeyValueTensorInitializer(keys=self.indices, values=self.tokens, key_dtype=tf.int32, value_dtype=tf.string), + default_value=self.tokens[self.blank], + ) + self.upoints = tf.strings.unicode_decode(self.tokens, "UTF-8").to_tensor(shape=[None, 1]) + self.initialized = True + + def build(self, *datasets: AbstractDataset): + vocab_file_path = file_util.preprocess_paths(self.decoder_config.vocabulary) + + def write_vocab_file(filepath, vocab): + with tf.io.gfile.GFile(filepath, "w") as f: + for token in vocab: + print(token, file=f) + + vocab = set() + for data in self.generator(*datasets): + vocab.update(data) + + write_vocab_file(vocab_file_path, vocab) + + def tokenize(self, text): + text = self.normalize_text(text, self.decoder_config) + text = tf.strings.unicode_split(text, "UTF-8") + return self.tokenizer.lookup(text) + + def detokenize(self, indices: tf.Tensor) -> tf.Tensor: + """ + Convert list of indices to string + Args: + indices: tf.Tensor with dim [B, None] + + Returns: + transcripts: tf.Tensor of dtype tf.string with dim [B] + """ + indices = self.normalize_indices(indices) + # indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.blank)) + tokens = self.detokenizer.lookup(indices) + tokens = tf.strings.reduce_join(tokens, axis=-1) + tokens = self.normalize_text(tokens, self.decoder_config) + return tokens + + @tf.function(input_signature=[tf.TensorSpec([None], dtype=tf.int32)]) + def detokenize_unicode_points(self, indices: tf.Tensor) -> tf.Tensor: + """ + Transform Predicted Indices to Unicode Code Points (for using tflite) + Args: + indices: tf.Tensor of Classes in shape [None] + + Returns: + unicode code points transcript with dtype tf.int32 and shape [None] + """ + with tf.name_scope("indices2upoints"): + indices = self.normalize_indices(indices) + upoints = tf.gather_nd(self.upoints, tf.expand_dims(indices, axis=-1)) + return tf.gather_nd(upoints, tf.where(tf.not_equal(upoints, 0))) + + +class SentencePieceTokenizer(Tokenizer): + def make(self): + self.blank = self.decoder_config.blank_index + self.tokenizer = tft.FastSentencepieceTokenizer(self.__load_model(), reverse=False, add_bos=False, add_eos=False) + self.num_classes = int(self.tokenizer.vocab_size()) + self.initialized = True + + def __load_model(self): + with file_util.read_file(self.decoder_config.vocabulary) as path: + with open(path, "rb") as f: + return f.read() + + def build(self, *datasets: AbstractDataset): + vocab_file_path = file_util.preprocess_paths(self.decoder_config.vocabulary) + + sp.SentencePieceTrainer.Train( + sentence_iterator=self.generator(*datasets), + model_prefix=os.path.splitext(vocab_file_path)[0], + model_type=self.decoder_config.model_type, + vocab_size=self.decoder_config.vocab_size, + hard_vocab_limit=True, + unk_id=self.decoder_config.unknown_index, + bos_id=self.decoder_config.bos_index, + eos_id=self.decoder_config.eos_index, + pad_id=self.decoder_config.pad_index, + character_coverage=self.decoder_config.character_coverage, + unk_surface="", # change default unk surface U+2047("⁇") by "" + allow_whitespace_only_pieces=False, + split_by_whitespace=(not self.decoder_config.keep_whitespace), + treat_whitespace_as_suffix=False, + user_defined_symbols="", + max_sentencepiece_length=self.decoder_config.max_sentencepiece_length, + max_sentence_length=self.decoder_config.max_sentence_length, # bytes + remove_extra_whitespaces=True, + num_threads=multiprocessing.cpu_count(), + ) + + def tokenize(self, text: tf.Tensor) -> tf.Tensor: + text = self.normalize_text(text, self.decoder_config) + indices = self.tokenizer.tokenize(text) + indices = tf.cast(indices, tf.int32) + return indices + + def detokenize(self, indices: tf.Tensor) -> tf.Tensor: + """ + Convert list of indices to string + Args: + indices: tf.Tensor with dim [B, None] + + Returns: + transcripts: tf.Tensor of dtype tf.string with dim [B] + """ + # indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.blank)) + # indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.decoder_config.unknown_index)) + # indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.decoder_config.bos_index)) + # indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.decoder_config.eos_index)) + transcripts = self.tokenizer.detokenize(indices) + transcripts = self.normalize_text(transcripts, self.decoder_config) + # transcripts = tf.strings.regex_replace(transcripts, r" +", " ") + return transcripts + + @tf.function(input_signature=[tf.TensorSpec([None], dtype=tf.int32)]) + def detokenize_unicode_points(self, indices: tf.Tensor) -> tf.Tensor: + """ + Transform Predicted Indices to Unicode Code Points (for using tflite) + Args: + indices: tf.Tensor of Classes in shape [None] + + Returns: + unicode code points transcript with dtype tf.int32 and shape [None] + """ + with tf.name_scope("indices2upoints"): + transcripts = self.detokenize(tf.reshape(indices, [1, -1])) + upoints = tf.strings.unicode_decode(transcripts, "UTF-8").to_tensor() + return tf.reshape(upoints, [-1]) + + +class WordPieceTokenizer(Tokenizer): + def make(self): + self.vocab = None + with tf.io.gfile.GFile(self.decoder_config.vocabulary, "r") as voc: + self.vocab = voc.read().splitlines() + if not self.vocab: + raise ValueError("Unable to read vocabulary") + self.tokenizer = tft.FastWordpieceTokenizer( + vocab=self.vocab, + token_out_type=tf.int32, + unknown_token=self.decoder_config.unknown_token, + no_pretokenization=True, # False is limited, so we manually do pretokenization + support_detokenization=True, + ) + self.num_classes = len(self.vocab) + self.initialized = True + + def build(self, *datasets: AbstractDataset): + vocab_file_path = file_util.preprocess_paths(self.decoder_config.vocabulary) + + def write_vocab_file(filepath, vocab): + with tf.io.gfile.GFile(filepath, "w") as f: + for token in vocab: + print(token, file=f) + + dataset = ( + tf.data.Dataset.from_generator(self.generator(*datasets), output_signature=tf.TensorSpec(shape=(), dtype=tf.string)) + .batch(1000) + .prefetch(2) + ) + vocab = bert_vocab.bert_vocab_from_dataset( + dataset, + vocab_size=self.decoder_config.vocab_size, + reserved_tokens=self.decoder_config.reserved_tokens, + bert_tokenizer_params={ + "lower_case": False, # keep original from dataset + "keep_whitespace": self.decoder_config.keep_whitespace, + "normalization_form": self.decoder_config.normalization_form, + "preserve_unused_token": False, + }, + learn_params={ + "max_token_length": self.decoder_config.max_token_length, + "max_unique_chars": self.decoder_config.max_unique_chars, + "num_iterations": self.decoder_config.num_iterations, + }, + ) + write_vocab_file(vocab_file_path, vocab) + + def tokenize(self, text: tf.Tensor) -> tf.Tensor: + text = self.normalize_text(text, self.decoder_config) + if self.decoder_config.keep_whitespace: + text = tf.strings.regex_replace(text, " ", "| |") + text = tf.strings.split(text, sep="|") + else: + text = tf.strings.split(text) + indices = self.tokenizer.tokenize(text).merge_dims(0, 1) + return indices + + def detokenize(self, indices: tf.Tensor) -> tf.Tensor: + """ + Convert list of indices to string + Args: + indices: tf.Tensor with dim [B, None] + + Returns: + transcripts: tf.Tensor of dtype tf.string with dim [B] + """ + # indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.blank)) + # indices = tf.ragged.boolean_mask(indices, tf.not_equal(indices, self.decoder_config.unknown_index)) + transcripts = self.tokenizer.detokenize(indices) + transcripts = self.normalize_text(transcripts, self.decoder_config) + # transcripts = tf.strings.regex_replace(transcripts, r" +", " ") + return transcripts + + @tf.function(input_signature=[tf.TensorSpec([None], dtype=tf.int32)]) + def detokenize_unicode_points(self, indices: tf.Tensor) -> tf.Tensor: + """ + Transform Predicted Indices to Unicode Code Points (for using tflite) + Args: + indices: tf.Tensor of Classes in shape [None] + + Returns: + unicode code points transcript with dtype tf.int32 and shape [None] + """ + with tf.name_scope("indices2upoints"): + transcripts = self.detokenize(tf.reshape(indices, [1, -1])) + upoints = tf.strings.unicode_decode(transcripts, "UTF-8").to_tensor() + return tf.reshape(upoints, [-1]) diff --git a/tensorflow_asr/utils/app_util.py b/tensorflow_asr/utils/app_util.py index f14e7de5f1..4e74b0cd7a 100644 --- a/tensorflow_asr/utils/app_util.py +++ b/tensorflow_asr/utils/app_util.py @@ -13,37 +13,99 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf -from tqdm import tqdm +import logging -from tensorflow_asr.metrics.error_rates import ErrorRate -from tensorflow_asr.utils import file_util, metric_util +import jiwer -logger = tf.get_logger() +from tensorflow_asr import tf +from tensorflow_asr.models.base_model import BaseModel +from tensorflow_asr.utils import file_util, math_util +logger = logging.getLogger(__name__) -def evaluate_results( - filepath: str, -): - logger.info(f"Evaluating result from {filepath} ...") - metrics = { - "greedy_wer": ErrorRate(metric_util.tf_wer, name="greedy_wer", dtype=tf.float32), - "greedy_cer": ErrorRate(metric_util.tf_cer, name="greedy_cer", dtype=tf.float32), - "beamsearch_wer": ErrorRate(metric_util.tf_wer, name="beamsearch_wer", dtype=tf.float32), - "beamsearch_cer": ErrorRate(metric_util.tf_cer, name="beamsearch_cer", dtype=tf.float32), - } + +def evaluate_hypotheses(filepath: str): + """ + Compute wer, cer, mer, wil, wip for given lists of greedy and beamsearch hypotheses + + Parameters + ---------- + filepath : str + Output tsv file path for the predictions + + Returns + ------- + dict + {"greedy": {wer, cer, mer, wil, wip}, "beam": {wer, cer, mer, wil, wip}} + The results are original, NOT multiplied with 100. + """ + import pandas as pd # pylint: disable=import-outside-toplevel + from tqdm import tqdm # pylint: disable=import-outside-toplevel + + logger.info(f"Reading file {filepath} ...") + reference, greedy_hypothesis, beam_hypothesis = [], [], [] with file_util.read_file(filepath) as path: - with open(path, "r", encoding="utf-8") as openfile: + with tf.io.gfile.GFile(path, "r") as openfile: lines = openfile.read().splitlines() lines = lines[1:] # skip header - for eachline in tqdm(lines): - _, _, groundtruth, greedy, beamsearch = eachline.split("\t") - groundtruth = tf.convert_to_tensor([groundtruth], dtype=tf.string) - greedy = tf.convert_to_tensor([greedy], dtype=tf.string) - beamsearch = tf.convert_to_tensor([beamsearch], dtype=tf.string) - metrics["greedy_wer"].update_state(decode=greedy, target=groundtruth) - metrics["greedy_cer"].update_state(decode=greedy, target=groundtruth) - metrics["beamsearch_wer"].update_state(decode=beamsearch, target=groundtruth) - metrics["beamsearch_cer"].update_state(decode=beamsearch, target=groundtruth) - for key, value in metrics.items(): - logger.info(f"{key}: {value.result().numpy()}") + for eachline in tqdm(lines, disable=False): + _, groundtruth, greedy, beamsearch = eachline.split("\t") + reference.append(groundtruth) + greedy_hypothesis.append(greedy) + beam_hypothesis.append(beamsearch) + + logger.info("Evaluating greedy results ...") + greedy_wordoutput = jiwer.process_words(reference=reference, hypothesis=greedy_hypothesis) + greedy_charoutput = jiwer.process_characters(reference=reference, hypothesis=greedy_hypothesis) + + logger.info("Evaluating beamsearch results ...") + beam_wordoutput = jiwer.process_words(reference=reference, hypothesis=beam_hypothesis) + beam_charoutput = jiwer.process_characters(reference=reference, hypothesis=beam_hypothesis) + + outputs = { + "greedy": { + "wer": greedy_wordoutput.wer, + "cer": greedy_charoutput.cer, + "mer": greedy_wordoutput.mer, + "wil": greedy_wordoutput.wil, + "wip": greedy_wordoutput.wip, + }, + "beam": { + "wer": beam_wordoutput.wer, + "cer": beam_charoutput.cer, + "mer": beam_wordoutput.mer, + "wil": beam_wordoutput.wil, + "wip": beam_wordoutput.wip, + }, + } + df = pd.DataFrame.from_dict(outputs, orient="index") + return df + + +def convert_tflite( + model: BaseModel, + output: str, + batch_size: int = 1, + beam_width: int = 0, +): + if not math_util.is_power_of_two(model.feature_extraction.nfft): + logger.error("NFFT must be power of 2 for TFLite conversion") + overwrite_nfft = input("Do you want to overwrite nfft to the nearest power of 2? (y/n): ") + if overwrite_nfft.lower() == "y": + model.feature_extraction.nfft = math_util.next_power_of_two(model.feature_extraction.nfft) + logger.info(f"Overwritten nfft to {model.feature_extraction.nfft}") + else: + raise ValueError("NFFT must be power of 2 for TFLite conversion") + + concrete_func = model.make_tflite_function(batch_size=batch_size, beam_width=beam_width).get_concrete_function() + converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func], trackable_obj=model) + converter.target_spec.supported_ops = [ + tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. + tf.lite.OpsSet.SELECT_TF_OPS, # enable TensorFlow ops. + ] + converter.allow_custom_ops = True + tflite_model = converter.convert() + + output = file_util.preprocess_paths(output) + with open(output, "wb") as tflite_out: + tflite_out.write(tflite_model) diff --git a/tensorflow_asr/utils/data_util.py b/tensorflow_asr/utils/data_util.py index 47c8b21757..8fd66adb60 100644 --- a/tensorflow_asr/utils/data_util.py +++ b/tensorflow_asr/utils/data_util.py @@ -14,39 +14,42 @@ # tf.data.Dataset does not work well for namedtuple so we are using dict +import os +from functools import reduce +from typing import Any -def create_inputs( - inputs, - inputs_length, - predictions=None, - predictions_length=None, -) -> dict: - data = { - "inputs": inputs, - "inputs_length": inputs_length, - } - if predictions is not None: - data["predictions"] = predictions - if predictions_length is not None: - data["predictions_length"] = predictions_length - return data - - -def create_logits( - logits, - logits_length, -) -> dict: - return { - "logits": logits, - "logits_length": logits_length, - } - - -def create_labels( - labels, - labels_length, -) -> dict: - return { - "labels": labels, - "labels_length": labels_length, - } +import librosa +import tensorflow as tf + + +def load_and_convert_to_wav( + path: str, + sample_rate: int = None, +): + wave, rate = librosa.load(os.path.realpath(os.path.expanduser(path)), sr=sample_rate, mono=True) + return tf.audio.encode_wav(tf.expand_dims(wave, axis=-1), sample_rate=rate) + + +def read_raw_audio(audio: tf.Tensor): + wave, _ = tf.audio.decode_wav(audio, desired_channels=1, desired_samples=-1) + return tf.reshape(wave, shape=[-1]) # reshape for using tf.signal + + +def get( + obj: dict, + path: str, + default: Any = None, +): + path = str(path) + + def _reduce_fn(d, key): + if isinstance(d, dict): + return d.get(key, default) + if isinstance(d, list): + try: + return d[int(key)] + except (IndexError, ValueError): + return default + return default + + return reduce(_reduce_fn, path.split("."), obj) diff --git a/tensorflow_asr/utils/env_util.py b/tensorflow_asr/utils/env_util.py index 8d28604127..3f6d838d91 100644 --- a/tensorflow_asr/utils/env_util.py +++ b/tensorflow_asr/utils/env_util.py @@ -12,72 +12,90 @@ # See the License for the specific language governing permissions and # limitations under the License. - +import logging +import os import random +import sys +import warnings +from datetime import datetime, timezone from typing import List, Union +TF_LOG_LEVEL = os.getenv("TF_LOG_LEVEL", "warning").upper() +TF_SOFT_PLACEMENT = os.getenv("TF_SOFT_PLACEMENT", "false").lower() == "true" +TF_ENABLE_CHECK_NUMERIC = os.getenv("TF_ENABLE_CHECK_NUMERIC", "false").lower() == "true" +TF_CUDNN = os.getenv("TF_CUDNN", "auto").lower() +TF_CUDNN = "auto" if TF_CUDNN == "auto" else TF_CUDNN == "true" +DEBUG = TF_LOG_LEVEL == "DEBUG" + + +def _logging_format_time(self, record, datefmt=None): + return datetime.fromtimestamp(record.created, timezone.utc).astimezone().isoformat(sep="T", timespec="milliseconds") + + +logging.basicConfig(level=logging.INFO, format=logging.BASIC_FORMAT, stream=sys.stdout, force=True) +logging.Formatter.formatTime = _logging_format_time +logging.captureWarnings(True) +warnings.filterwarnings("ignore") + +import keras import numpy as np import tensorflow as tf +from packaging import version +from tensorflow.python.util import deprecation # pylint: disable = no-name-in-module -logger = tf.get_logger() +tf.get_logger().setLevel(TF_LOG_LEVEL) +deprecation._PRINT_DEPRECATION_WARNINGS = False # comment this line to print deprecation warnings +if TF_ENABLE_CHECK_NUMERIC: + tf.debugging.enable_check_numerics() -def setup_devices( +KERAS_SRC = "keras.src" if version.parse(tf.version.VERSION) >= version.parse("2.13.0") else "keras" + +logger = logging.getLogger(__name__) + + +def setup_gpu( devices: List[int] = None, - cpu: bool = False, ): - """Setting visible devices - - Args: - devices (list): list of visible devices' indices - cpu (bool): use cpu or not - """ - if cpu: - cpus = tf.config.list_physical_devices("CPU") - tf.config.set_visible_devices(cpus, "CPU") - tf.config.set_visible_devices([], "GPU") - logger.info(f"Run on {len(cpus)} Physical CPUs") - else: - gpus = tf.config.list_physical_devices("GPU") - if gpus: - if devices is not None: - gpus = [gpus[i] for i in devices] - tf.config.set_visible_devices(gpus, "GPU") - logger.info(f"Run on {len(gpus)} Physical GPUs") + logger.info(f"Using TF_CUDNN={TF_CUDNN}, TF_SOFT_PLACEMENT={TF_SOFT_PLACEMENT}") + tf.config.set_soft_device_placement(DEBUG or TF_SOFT_PLACEMENT) + gpus = tf.config.list_physical_devices("GPU") + if not gpus: + raise RuntimeError("No GPUs found!") + if devices is not None: + gpus = [gpus[i] for i in devices] + tf.config.set_visible_devices(gpus, "GPU") + logger.info("Run on GPU") + logger.info(f"All devices: {gpus}") + return tf.distribute.MirroredStrategy() def setup_tpu( tpu_address=None, + tpu_vm: bool = False, ): - if tpu_address is None: - resolver = tf.distribute.cluster_resolver.TPUClusterResolver() - else: - resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="grpc://" + tpu_address) - tf.config.experimental_connect_to_cluster(resolver) + # might cause performance penalty if ops fallback to cpu, see https://cloud.google.com/tpu/docs/tensorflow-ops + tf.config.set_soft_device_placement(DEBUG) + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=tpu_address) + if not tpu_vm: + tf.config.experimental_connect_to_cluster(resolver) tf.tpu.experimental.initialize_tpu_system(resolver) - logger.info(f"All TPUs: {tf.config.list_logical_devices('TPU')}") + logger.info(f"Run on TPU {tpu_address}") + logger.info(f"All devices: {tf.config.list_logical_devices('TPU')}") return tf.distribute.TPUStrategy(resolver) def setup_strategy( - devices: List[int], + device_type: str, + devices: List[int] = None, tpu_address: str = None, + tpu_vm: bool = False, ): - """Setting mirrored strategy for training - - Args: - devices (list): list of visible devices' indices - tpu_address (str): an optional custom tpu address - - Returns: - tf.distribute.Strategy: TPUStrategy for training on tpus or MirroredStrategy for training on gpus - """ - try: - return setup_tpu(tpu_address) - except (ValueError, tf.errors.NotFoundError) as e: - logger.warning(e) - setup_devices(devices) - return tf.distribute.MirroredStrategy() + if device_type.lower() == "tpu": + return setup_tpu(tpu_address, tpu_vm) + if device_type.lower() == "gpu": + return setup_gpu(devices) + return tf.distribute.get_strategy() def has_devices( @@ -94,20 +112,36 @@ def setup_mxp( """ Setup mixed precision - Args: - mxp (str, optional): either "strict" or "auto". Defaults to "strict". + Parameters + ---------- + mxp : str, optional + Either "strict", "auto" or "none", by default "strict" + + Raises + ------ + ValueError + Wrong value for mxp """ - options = ["strict", "auto", "none"] + options = ["strict", "strict_auto", "auto", "none"] if mxp not in options: raise ValueError(f"mxp must be in {options}") if mxp == "strict": policy = "mixed_bfloat16" if has_devices("TPU") else "mixed_float16" - tf.keras.mixed_precision.set_global_policy(policy) - tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True}) + keras.mixed_precision.set_global_policy(policy) + tf.config.optimizer.set_experimental_options({"auto_mixed_precision": False}) logger.info(f"USING mixed precision policy {policy}") + elif mxp == "strict_auto": + policy = "mixed_bfloat16" if has_devices("TPU") else "mixed_float16" + keras.mixed_precision.set_global_policy(policy) + tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True}) + logger.info(f"USING auto mixed precision policy {policy}") elif mxp == "auto": tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True}) logger.info("USING auto mixed precision policy") + else: + keras.mixed_precision.set_global_policy("float32") + tf.config.optimizer.set_experimental_options({"auto_mixed_precision": False}) + logger.info("USING float32 precision policy") def setup_seed( @@ -120,11 +154,12 @@ def setup_seed( I sat at my desk, stared into the garden and thought 42 will do!" - Douglas Adams's popular 1979 science-fiction novel The Hitchhiker's Guide to the Galaxy - Args: - seed (int, optional): integer. Defaults to 42. + Parameters + ---------- + seed : int, optional + Random seed, by default 42 """ random.seed(seed) np.random.seed(seed) tf.random.set_seed(seed) - tf.keras.backend.experimental.enable_tf_random_generator() - tf.keras.utils.set_random_seed(seed) + keras.utils.set_random_seed(seed) diff --git a/tensorflow_asr/utils/feature_util.py b/tensorflow_asr/utils/feature_util.py index 4a0e28c5bc..a2398953a2 100644 --- a/tensorflow_asr/utils/feature_util.py +++ b/tensorflow_asr/utils/feature_util.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf +from tensorflow_asr import tf def float_feature( diff --git a/tensorflow_asr/utils/file_util.py b/tensorflow_asr/utils/file_util.py index f6e37f231b..73c9834164 100644 --- a/tensorflow_asr/utils/file_util.py +++ b/tensorflow_asr/utils/file_util.py @@ -13,19 +13,27 @@ # limitations under the License. import contextlib +import logging import os import re +import tarfile import tempfile +import zipfile from typing import List, Union -import tensorflow as tf +import jinja2 import yaml -from jinja2 import BaseLoader, Environment + +from tensorflow_asr import tf + +ENABLE_PATH_PREPROCESS = True +logger = logging.getLogger(__name__) def load_yaml( path: str, -) -> dict: + **kwargs, +): # Fix yaml numbers https://stackoverflow.com/a/30462009/11037553 loader = yaml.SafeLoader loader.add_implicit_resolver( @@ -43,7 +51,10 @@ def load_yaml( list("-+0123456789."), ) with tf.io.gfile.GFile(path, "r") as file: - return yaml.load(Environment(loader=BaseLoader()).from_string(file.read()).render(), Loader=loader) + return yaml.load( + jinja2.Environment(loader=jinja2.FileSystemLoader([kwargs["repodir"]])).from_string(file.read()).render(**kwargs), + Loader=loader, + ) def is_hdf5_filepath( @@ -70,6 +81,7 @@ def preprocess_paths( paths: Union[List[str], str], isdir: bool = False, enabled: bool = True, + check_exists: bool = False, ) -> Union[List[str], str]: """Expand the path to the root "/" and makedirs @@ -79,20 +91,27 @@ def preprocess_paths( Returns: Union[List, str]: A processed path or list of paths, return None if it's not path """ - if not enabled: + if not (enabled and ENABLE_PATH_PREPROCESS): return paths if isinstance(paths, (list, tuple)): paths = [path if is_cloud_path(path) else os.path.abspath(os.path.expanduser(path)) for path in paths] - for path in paths: + for i, path in enumerate(paths): dirpath = path if isdir else os.path.dirname(path) - if not tf.io.gfile.exists(dirpath): - tf.io.gfile.makedirs(dirpath) - return paths + if not tf.io.gfile.exists(path): + if check_exists: + paths[i] = None + else: + if not tf.io.gfile.exists(dirpath): + tf.io.gfile.makedirs(dirpath) + return list(filter(None, paths)) if isinstance(paths, str): paths = paths if is_cloud_path(paths) else os.path.abspath(os.path.expanduser(paths)) dirpath = paths if isdir else os.path.dirname(paths) - if not tf.io.gfile.exists(dirpath): - tf.io.gfile.makedirs(dirpath) + if not tf.io.gfile.exists(paths): + if check_exists: + return None + if not tf.io.gfile.exists(dirpath): + tf.io.gfile.makedirs(dirpath) return paths return None @@ -102,8 +121,9 @@ def save_file( filepath: str, ): if is_cloud_path(filepath): - _, ext = os.path.splitext(filepath) - with tempfile.NamedTemporaryFile(suffix=ext) as tmp: + _, *ext = os.path.basename(filepath).split(".") + suffix = "." + ".".join(ext) + with tempfile.NamedTemporaryFile(suffix=suffix) as tmp: yield tmp.name tf.io.gfile.copy(tmp.name, filepath, overwrite=True) else: @@ -115,9 +135,32 @@ def read_file( filepath: str, ): if is_cloud_path(filepath): - _, ext = os.path.splitext(filepath) - with tempfile.NamedTemporaryFile(suffix=ext) as tmp: + _, *ext = os.path.basename(filepath).split(".") + suffix = "." + ".".join(ext) + with tempfile.NamedTemporaryFile(suffix=suffix) as tmp: tf.io.gfile.copy(filepath, tmp.name, overwrite=True) yield tmp.name else: yield filepath + + +def clean_dir(dirpath: str): + path = preprocess_paths(dirpath, isdir=True) + logger.info(f"Cleaning up {path}") + if tf.io.gfile.exists(path): + tf.io.gfile.rmtree(path) + + +def extract_file( + filepath: str, + extractpath: str, +): + if filepath.endswith(".tar.gz") or filepath.endswith(".tgz") or filepath.endswith(".tar"): + with tarfile.open(filepath, "r:gz") as tar: + tar.extractall(path=os.path.realpath(extractpath)) + return + if filepath.endswith(".zip"): + with zipfile.ZipFile(filepath, "r") as zip_ref: + zip_ref.extractall(os.path.realpath(extractpath)) + return + raise ValueError(f"Unsupported file format: {filepath}") diff --git a/tensorflow_asr/utils/keras_util.py b/tensorflow_asr/utils/keras_util.py new file mode 100644 index 0000000000..67848d9eb6 --- /dev/null +++ b/tensorflow_asr/utils/keras_util.py @@ -0,0 +1,26 @@ +import tensorflow as tf +from keras.src.saving import serialization_lib + + +def model_from_config(model_config: dict, custom_objects=None): + return serialization_lib.deserialize_keras_object(model_config, custom_objects=custom_objects) + + +def reduce_per_replica(values, strategy, reduction): + if reduction == "auto": + if isinstance(strategy, tf.distribute.TPUStrategy): + reduction = "first" + else: + reduction = "mean" + + def _reduce(v): + """Reduce a single `PerReplica` object.""" + if reduction == "first": + return strategy.experimental_local_results(v)[0] + if reduction == "sum": + return strategy.reduce("SUM", v, axis=None) + if reduction == "mean": + return strategy.reduce("MEAN", v, axis=None) + raise ValueError("`reduction` must be one of " '"first", "mean", "sum", or "auto". ' f"Received: reduction={reduction}.") + + return tf.nest.map_structure(_reduce, values) diff --git a/tensorflow_asr/utils/layer_util.py b/tensorflow_asr/utils/layer_util.py index ac6bdcacfb..5c84fdc9df 100644 --- a/tensorflow_asr/utils/layer_util.py +++ b/tensorflow_asr/utils/layer_util.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf +from typing import List -from tensorflow_asr.models.layers import convolution, recurrent +from tensorflow_asr import keras, tf +from tensorflow_asr.models.layers.convolution import Conv1D, Conv2D def get_rnn( @@ -22,10 +23,10 @@ def get_rnn( ): assert rnn_type in ["lstm", "gru", "rnn"] if rnn_type == "lstm": - return recurrent.LSTM + return keras.layers.LSTM if rnn_type == "gru": - return recurrent.GRU - return tf.keras.layers.SimpleRNN + return keras.layers.GRU + return keras.layers.SimpleRNN def get_conv( @@ -33,18 +34,18 @@ def get_conv( ): assert conv_type in ["conv1d", "conv2d"] if conv_type == "conv1d": - return convolution.Conv1D - return convolution.Conv2D + return Conv1D + return Conv2D def add_gwn( - trainable_weights: list, + trainable_weights: List[tf.Variable], stddev: float = 1.0, ): original_weights = [] for weight in trainable_weights: noise = tf.stop_gradient(tf.random.normal(mean=0.0, stddev=stddev, shape=weight.shape, dtype=weight.dtype)) - original_weights.append(weight.value()) + original_weights.append(weight) weight.assign_add(noise) return original_weights diff --git a/tensorflow_asr/utils/math_util.py b/tensorflow_asr/utils/math_util.py index bcc2ed73a4..6a17da1038 100644 --- a/tensorflow_asr/utils/math_util.py +++ b/tensorflow_asr/utils/math_util.py @@ -16,8 +16,9 @@ from typing import Union import numpy as np -import tensorflow as tf +from keras.src import backend +from tensorflow_asr import tf from tensorflow_asr.utils import shape_util @@ -45,6 +46,30 @@ def nan_to_zero( return tf.where(tf.math.is_nan(input_tensor), tf.zeros_like(input_tensor), input_tensor) +def nan_to_num(x, nan=0.0, posinf=None, neginf=None): + x = tf.convert_to_tensor(x) + + dtype = x.dtype + dtype_as_dtype = tf.as_dtype(dtype) + if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric: + return x + + # Replace NaN with `nan` + x = tf.where(tf.math.is_nan(x), nan, x) + + # Replace positive infinity with `posinf` or `dtype.max` + if posinf is None: + posinf = dtype.max + x = tf.where(tf.math.is_inf(x) & (x > 0), posinf, x) + + # Replace negative infinity with `neginf` or `dtype.min` + if neginf is None: + neginf = dtype.min + x = tf.where(tf.math.is_inf(x) & (x < 0), neginf, x) + + return x + + def bytes_to_string( array: np.ndarray, encoding: str = "utf-8", @@ -75,18 +100,25 @@ def legacy_get_reduced_length( def count_non_blank( tensor: tf.Tensor, - blank: int or tf.Tensor = 0, + blank: Union[int, tf.Tensor] = 0, axis=None, + dtype=tf.int32, + keepdims=False, ): return tf.reduce_sum( - tf.where(tf.not_equal(tensor, blank), x=tf.ones_like(tensor), y=tf.zeros_like(tensor)), + tf.where( + tf.not_equal(tf.cast(tensor, dtype), blank), + x=tf.ones_like(tensor, dtype=dtype), + y=tf.zeros_like(tensor, dtype=dtype), + ), axis=axis, + keepdims=keepdims, ) def count( tensor: tf.Tensor, - value: float or int or tf.Tensor = 0, + value: Union[float, int, tf.Tensor] = 0, axis=None, ): return tf.reduce_sum( @@ -200,19 +232,38 @@ def masked_fill( value=0, ): shape = shape_util.shape_list(tensor) - mask = tf.broadcast_to(mask, shape) + mask = tf.cast(tf.broadcast_to(mask, shape), dtype=tf.bool) values = tf.cast(tf.fill(shape, value), tensor.dtype) return tf.where(mask, tensor, values) -def large_compatible_negative( +def large_compatible_negative_number( tensor_type, ): - if tensor_type == tf.float16: + dtype = backend.standardize_dtype(tensor_type) + if dtype == "float16": return tf.float16.min return -1e9 +def large_compatible_positive_number( + tensor_type, +): + dtype = backend.standardize_dtype(tensor_type) + if dtype == "float16": + return tf.float16.max + return 1e9 + + +def compatible_epsilon( + tensor_type, +): + dtype = backend.standardize_dtype(tensor_type) + if dtype == "float16": + return 1e-6 + return 1e-9 + + def apply_mask( outputs, mask=None, @@ -249,4 +300,72 @@ def conv_output_length(input_length, filter_size, padding, stride, dilation=1): output_length = input_length - dilated_filter_size + 1 elif padding == "full": output_length = input_length + dilated_filter_size - 1 + else: + raise ValueError(f"Invalid padding: {padding}") return (output_length + stride - 1) // stride + + +def get_nsamples( + duration: float, + sample_rate: int = 16000, +): + return math.ceil(float(duration) * sample_rate) + + +def slice_batch_tensor( + tensor: tf.Tensor, + index: int, + batch_size: int, +): + with tf.name_scope("slice_batch_tensor"): + begin = [index * batch_size] + [0] * (tensor.shape.rank - 1) + size = [batch_size] + [-1] * (tensor.shape.rank - 1) + sliced_tensor = tf.slice(tensor, begin, size) + return sliced_tensor + + +def split_tensor_by_ga( + tensor: tf.Tensor, # [B, ...] + batch_size: int, + ga_steps: int, +): + """ + Parameters + ---------- + tensor : tf.Tensor of shape [B, ...] + + Returns + ------- + tf.Tensor of shape [num_batches, mini_batch_size, ...] + """ + with tf.name_scope("split_tensor_by_ga"): + splits = [batch_size] * ga_steps + return tf.stack(tf.split(tensor, splits, num=ga_steps, axis=0), axis=0) + + +def compute_time_length( + tensor: tf.Tensor, + dtype=tf.int32, +): + with tf.name_scope("compute_time_length"): + batch_size, time_length, *_ = shape_util.shape_list(tensor) + return tf.cast(tf.repeat(time_length, batch_size, axis=0), dtype=dtype) + + +def is_power_of_two( + x: int, +): + return x != 0 and (x & (x - 1)) == 0 + + +def next_power_of_two( + x: int, +): + return 1 if x == 0 else 2 ** math.ceil(math.log2(x)) + + +def add_gauss_noise( + data, + stddev: float = 0.075, +): + return tf.nest.map_structure(lambda x: tf.add(x, tf.random.normal(shape=tf.shape(x), mean=0.0, stddev=stddev, dtype=x.dtype)), data) diff --git a/tensorflow_asr/utils/metric_util.py b/tensorflow_asr/utils/metric_util.py index 5e1388a122..f1c63e541d 100644 --- a/tensorflow_asr/utils/metric_util.py +++ b/tensorflow_asr/utils/metric_util.py @@ -1,125 +1,125 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple - -import tensorflow as tf -from nltk.metrics import distance - -from tensorflow_asr.utils import math_util - - -def execute_wer( - decode, - target, -): - decode = math_util.bytes_to_string(decode) - target = math_util.bytes_to_string(target) - dis = 0.0 - length = 0.0 - for dec, tar in zip(decode, target): - words = set(dec.split() + tar.split()) - word2char = dict(zip(words, range(len(words)))) - - new_decode = [chr(word2char[w]) for w in dec.split()] - new_target = [chr(word2char[w]) for w in tar.split()] - - dis += distance.edit_distance("".join(new_decode), "".join(new_target)) - length += len(tar.split()) - return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32) - - -def wer( - decode: tf.Tensor, - target: tf.Tensor, -) -> Tuple[tf.Tensor, tf.Tensor]: - """Word Error Rate - - Args: - decode (np.ndarray): array of prediction texts - target (np.ndarray): array of groundtruth texts - - Returns: - tuple: a tuple of tf.Tensor of (edit distances, number of words) of each text - """ - return tf.numpy_function(execute_wer, inp=[decode, target], Tout=[tf.float32, tf.float32]) - - -def execute_cer(decode, target): - decode = math_util.bytes_to_string(decode) - target = math_util.bytes_to_string(target) - dis = 0 - length = 0 - for dec, tar in zip(decode, target): - dis += distance.edit_distance(dec, tar) - length += len(tar) - return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32) - - -def cer( - decode: tf.Tensor, - target: tf.Tensor, -) -> Tuple[tf.Tensor, tf.Tensor]: - """Character Error Rate - - Args: - decode (np.ndarray): array of prediction texts - target (np.ndarray): array of groundtruth texts - - Returns: - tuple: a tuple of tf.Tensor of (edit distances, number of characters) of each text - """ - return tf.numpy_function(execute_cer, inp=[decode, target], Tout=[tf.float32, tf.float32]) - - -def tf_wer( - decode: tf.Tensor, - target: tf.Tensor, -) -> Tuple[tf.Tensor, tf.Tensor]: - """ - Tensorflow Word Error Rate - - Args: - decode (tf.Tensor): tensor shape [B] - target (tf.Tensor): tensor shape [B] - - Returns: - tuple: a tuple of tf.Tensor of (edit distances, number of words) of each text - """ - decode = tf.strings.split(decode) - target = tf.strings.split(target) - distances = tf.edit_distance(decode.to_sparse(), target.to_sparse(), normalize=False) # [B] - lengths = tf.cast(target.row_lengths(axis=1), dtype=tf.float32) # [B] - return tf.reduce_sum(distances), tf.reduce_sum(lengths) - - -def tf_cer( - decode: tf.Tensor, - target: tf.Tensor, -) -> Tuple[tf.Tensor, tf.Tensor]: - """ - Tensorflow Charactor Error rate - - Args: - decoder (tf.Tensor): tensor shape [B] - target (tf.Tensor): tensor shape [B] - - Returns: - tuple: a tuple of tf.Tensor of (edit distances, number of characters) of each text - """ - decode = tf.strings.bytes_split(decode) # [B, N] - target = tf.strings.bytes_split(target) # [B, M] - distances = tf.edit_distance(decode.to_sparse(), target.to_sparse(), normalize=False) # [B] - lengths = tf.cast(target.row_lengths(axis=1), dtype=tf.float32) # [B] - return tf.reduce_sum(distances), tf.reduce_sum(lengths) +# # Copyright 2020 Huy Le Nguyen (@nglehuy) +# # +# # Licensed under the Apache License, Version 2.0 (the "License"); +# # you may not use this file except in compliance with the License. +# # You may obtain a copy of the License at +# # +# # http://www.apache.org/licenses/LICENSE-2.0 +# # +# # Unless required by applicable law or agreed to in writing, software +# # distributed under the License is distributed on an "AS IS" BASIS, +# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# # See the License for the specific language governing permissions and +# # limitations under the License. + +# from typing import Tuple + +# from nltk.metrics import distance + +# from tensorflow_asr import tf +# from tensorflow_asr.utils import math_util + + +# def execute_wer( +# decode, +# target, +# ): +# decode = math_util.bytes_to_string(decode) +# target = math_util.bytes_to_string(target) +# dis = 0.0 +# length = 0.0 +# for dec, tar in zip(decode, target): +# words = set(dec.split() + tar.split()) +# word2char = dict(zip(words, range(len(words)))) + +# new_decode = [chr(word2char[w]) for w in dec.split()] +# new_target = [chr(word2char[w]) for w in tar.split()] + +# dis += distance.edit_distance("".join(new_decode), "".join(new_target)) +# length += len(tar.split()) +# return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32) + + +# def wer( +# decode: tf.Tensor, +# target: tf.Tensor, +# ) -> Tuple[tf.Tensor, tf.Tensor]: +# """Word Error Rate + +# Args: +# decode (np.ndarray): array of prediction texts +# target (np.ndarray): array of groundtruth texts + +# Returns: +# tuple: a tuple of tf.Tensor of (edit distances, number of words) of each text +# """ +# return tf.numpy_function(execute_wer, inp=[decode, target], Tout=[tf.float32, tf.float32]) + + +# def execute_cer(decode, target): +# decode = math_util.bytes_to_string(decode) +# target = math_util.bytes_to_string(target) +# dis = 0 +# length = 0 +# for dec, tar in zip(decode, target): +# dis += distance.edit_distance(dec, tar) +# length += len(tar) +# return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32) + + +# def cer( +# decode: tf.Tensor, +# target: tf.Tensor, +# ) -> Tuple[tf.Tensor, tf.Tensor]: +# """Character Error Rate + +# Args: +# decode (np.ndarray): array of prediction texts +# target (np.ndarray): array of groundtruth texts + +# Returns: +# tuple: a tuple of tf.Tensor of (edit distances, number of characters) of each text +# """ +# return tf.numpy_function(execute_cer, inp=[decode, target], Tout=[tf.float32, tf.float32]) + + +# def tf_wer( +# decode: tf.Tensor, +# target: tf.Tensor, +# ) -> Tuple[tf.Tensor, tf.Tensor]: +# """ +# Tensorflow Word Error Rate + +# Args: +# decode (tf.Tensor): tensor shape [B] +# target (tf.Tensor): tensor shape [B] + +# Returns: +# tuple: a tuple of tf.Tensor of (edit distances, number of words) of each text +# """ +# decode = tf.strings.split(decode) +# target = tf.strings.split(target) +# distances = tf.edit_distance(decode.to_sparse(), target.to_sparse(), normalize=False) # [B] +# lengths = tf.cast(target.row_lengths(axis=1), dtype=tf.float32) # [B] +# return tf.reduce_sum(distances), tf.reduce_sum(lengths) + + +# def tf_cer( +# decode: tf.Tensor, +# target: tf.Tensor, +# ) -> Tuple[tf.Tensor, tf.Tensor]: +# """ +# Tensorflow Charactor Error rate + +# Args: +# decoder (tf.Tensor): tensor shape [B] +# target (tf.Tensor): tensor shape [B] + +# Returns: +# tuple: a tuple of tf.Tensor of (edit distances, number of characters) of each text +# """ +# decode = tf.strings.bytes_split(decode) # [B, N] +# target = tf.strings.bytes_split(target) # [B, M] +# distances = tf.edit_distance(decode.to_sparse(), target.to_sparse(), normalize=False) # [B] +# lengths = tf.cast(target.row_lengths(axis=1), dtype=tf.float32) # [B] +# return tf.reduce_sum(distances), tf.reduce_sum(lengths) diff --git a/tensorflow_asr/utils/plot_util.py b/tensorflow_asr/utils/plot_util.py index d14e471a72..7c3a6e6ec2 100644 --- a/tensorflow_asr/utils/plot_util.py +++ b/tensorflow_asr/utils/plot_util.py @@ -3,7 +3,7 @@ import matplotlib.pyplot as plt -def plotmesh(data, title="data", scale_ysize=4): +def plotmesh(data, title="data", scale_ysize=4, invert_yaxis=True): xsize = data.shape[1] ysize = data.shape[0] gcd = math.gcd(xsize, ysize) @@ -15,6 +15,8 @@ def plotmesh(data, title="data", scale_ysize=4): fig, ax = plt.subplots(figsize=figsize) ax.set_title(title, fontweight="bold") ax.minorticks_on() + if invert_yaxis: + ax.invert_yaxis() img = ax.pcolormesh(data, cmap="viridis") cbar = fig.colorbar(img, ax=ax, format="%.2f", pad=0.01) cbar.minorticks_on() diff --git a/tensorflow_asr/utils/shape_util.py b/tensorflow_asr/utils/shape_util.py index 8a016c1449..5013000231 100644 --- a/tensorflow_asr/utils/shape_util.py +++ b/tensorflow_asr/utils/shape_util.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf +from tensorflow_asr import tf def shape_list(x, out_type=tf.int32): @@ -22,6 +22,12 @@ def shape_list(x, out_type=tf.int32): return [dynamic[i] if s is None else s for i, s in enumerate(static)] +def shape_list_per_replica(x, per_replica_batch_size): + _, *rest_shape = x.shape + shapes = (int(per_replica_batch_size),) + tuple(rest_shape) + return shapes + + def get_shape_invariants(tensor): shapes = shape_list(tensor) return tf.TensorShape([i if isinstance(i, int) else None for i in shapes]) @@ -30,3 +36,8 @@ def get_shape_invariants(tensor): def get_float_spec(tensor): shape = get_shape_invariants(tensor) return tf.TensorSpec(shape, dtype=tf.float32) + + +def get_dim(tensor, i): + """Get value of tensor shape[i] preferring static value if available.""" + return tf.compat.dimension_value(tensor.shape[i]) or tf.shape(tensor)[i] diff --git a/tensorflow_asr/utils/tf_util.py b/tensorflow_asr/utils/tf_util.py new file mode 100644 index 0000000000..a729106ec4 --- /dev/null +++ b/tensorflow_asr/utils/tf_util.py @@ -0,0 +1,36 @@ +# # import importlib + +# import tensorflow as tf +# from keras.src.utils import tf_utils + +# from tensorflow_asr.utils.env_util import KERAS_SRC + +# # tf_utils = importlib.import_module(f"{KERAS_SRC}.utils.tf_utils") + + +# def convert_shapes(input_shape, to_tuples=True): +# if input_shape is None: +# return None + +# def _is_shape_component(value): +# return value is None or isinstance(value, (int, tf.compat.v1.Dimension)) + +# def _is_atomic_shape(input_shape): +# # Ex: TensorShape or (None, 10, 32) or 5 or `None` +# if _is_shape_component(input_shape): +# return True +# if isinstance(input_shape, tf.TensorShape): +# return True +# if isinstance(input_shape, (tuple, list)) and all(_is_shape_component(ele) for ele in input_shape): +# return True +# return False + +# def _convert_shape(input_shape): +# if input_shape is None: +# return None +# input_shape = tf.TensorShape(input_shape) +# if to_tuples: +# input_shape = tuple(input_shape.as_list()) +# return input_shape + +# return tf_utils.map_structure_with_atomic(_is_atomic_shape, _convert_shape, input_shape) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/conformer/config.yml b/tests/conformer/config.yml deleted file mode 100644 index 573b5d9e6f..0000000000 --- a/tests/conformer/config.yml +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - preemphasis: 0.97 - normalize_signal: True - normalize_feature: True - normalize_per_frame: False - -decoder_config: - vocabulary: null - vocab_size: 1024 - max_subword_length: 4 - blank_at_zero: True - beam_width: 5 - norm_score: True - -model_config: - name: conformer - encoder_subsampling: - type: conv2d - filters: 144 - kernel_size: 3 - strides: 2 - encoder_positional_encoding: sinusoid_concat - encoder_dmodel: 144 - encoder_num_blocks: 16 - encoder_head_size: 36 - encoder_num_heads: 4 - encoder_mha_type: relmha - encoder_kernel_size: 32 - encoder_fc_factor: 0.5 - encoder_dropout: 0.1 - prediction_embed_dim: 320 - prediction_embed_dropout: 0 - prediction_num_rnns: 1 - prediction_rnn_units: 320 - prediction_rnn_type: lstm - prediction_rnn_implementation: 1 - prediction_layer_norm: True - prediction_projection_units: 0 - joint_dim: 320 - joint_activation: tanh - prejoint_linear: False - joint_mode: concat - -learning_config: - augmentations: - feature_augment: - time_masking: - num_masks: 10 - mask_factor: 100 - p_upperbound: 0.05 - freq_masking: - num_masks: 1 - mask_factor: 27 - - dataset_config: - train_paths: - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv - eval_paths: - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv - test_paths: - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv - tfrecords_dir: null - - optimizer_config: - warmup_steps: 40000 - beta1: 0.9 - beta2: 0.98 - epsilon: 1e-9 - - running_config: - batch_size: 2 - accumulation_steps: 4 - num_epochs: 20 - outdir: /mnt/Miscellanea/Models/local/conformer - log_interval_steps: 300 - eval_interval_steps: 500 - save_interval_steps: 1000 diff --git a/tests/conformer/test_conformer.py b/tests/conformer/test_conformer.py deleted file mode 100644 index e30bda7132..0000000000 --- a/tests/conformer/test_conformer.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" -import tensorflow as tf - -logger = tf.get_logger() - -DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config_wp.j2") - -from tensorflow_asr.configs.config import Config -from tensorflow_asr.featurizers.speech_featurizers import SpeechFeaturizer -from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer -from tensorflow_asr.models.transducer.conformer import Conformer - - -def test_conformer(): - config = Config(DEFAULT_YAML) - - text_featurizer = CharFeaturizer(config.decoder_config) - - speech_featurizer = SpeechFeaturizer(config.speech_config) - - model = Conformer(vocab_size=text_featurizer.num_classes, **config.model_config) - - model.make(speech_featurizer.shape) - model.summary() - - model.add_featurizers(speech_featurizer=speech_featurizer, text_featurizer=text_featurizer) - - concrete_func = model.make_tflite_function(timestamp=False).get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.experimental_new_converter = True - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] - tflite = converter.convert() - - logger.info("Converted successfully with no timestamp") - - concrete_func = model.make_tflite_function(timestamp=True).get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.experimental_new_converter = True - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] - converter.convert() - - logger.info("Converted successfully with timestamp") - - tflitemodel = tf.lite.Interpreter(model_content=tflite) - signal = tf.random.normal([4000]) - - input_details = tflitemodel.get_input_details() - output_details = tflitemodel.get_output_details() - tflitemodel.resize_tensor_input(input_details[0]["index"], [4000]) - tflitemodel.allocate_tensors() - tflitemodel.set_tensor(input_details[0]["index"], signal) - tflitemodel.set_tensor(input_details[1]["index"], tf.constant(text_featurizer.blank, dtype=tf.int32)) - tflitemodel.set_tensor( - input_details[2]["index"], - tf.zeros([config.model_config["prediction_num_rnns"], 2, 1, config.model_config["prediction_rnn_units"]], dtype=tf.float32), - ) - tflitemodel.invoke() - hyp = tflitemodel.get_tensor(output_details[0]["index"]) - - logger.info(hyp) - - -if __name__ == "__main__": - test_conformer() diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/contextnet/config.yml b/tests/contextnet/config.yml deleted file mode 100644 index 3e5ad4d7f0..0000000000 --- a/tests/contextnet/config.yml +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - preemphasis: 0.97 - normalize_signal: True - normalize_feature: True - normalize_per_frame: False - -decoder_config: - vocabulary: null - vocab_size: 1024 - max_subword_length: 4 - blank_at_zero: True - beam_width: 5 - norm_score: True - -model_config: - name: contextnet - encoder_alpha: 0.5 - encoder_blocks: - # C0 - - nlayers: 1 - kernel_size: 5 - filters: 256 - strides: 1 - residual: False - activation: silu - # C1-C2 - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - # C3 - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 2 - residual: True - activation: silu - # C4-C6 - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - # C7 - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 2 - residual: True - activation: silu - # C8 - C10 - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - - nlayers: 5 - kernel_size: 5 - filters: 256 - strides: 1 - residual: True - activation: silu - # C11 - C13 - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - # C14 - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 2 - residual: True - activation: silu - # C15 - C21 - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - - nlayers: 5 - kernel_size: 5 - filters: 512 - strides: 1 - residual: True - activation: silu - # C22 - - nlayers: 1 - kernel_size: 5 - filters: 640 - strides: 1 - residual: False - activation: silu - prediction_embed_dim: 640 - prediction_embed_dropout: 0 - prediction_num_rnns: 1 - prediction_rnn_units: 640 - prediction_rnn_type: lstm - prediction_rnn_implementation: 1 - prediction_layer_norm: True - prediction_projection_units: 0 - joint_dim: 640 - joint_activation: tanh - -learning_config: - augmentations: - feature_augment: - time_masking: - num_masks: 10 - mask_factor: 100 - p_upperbound: 0.05 - freq_masking: - num_masks: 1 - mask_factor: 27 - - dataset_config: - train_paths: - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv - eval_paths: - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv - test_paths: - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv - tfrecords_dir: null - - optimizer_config: - warmup_steps: 40000 - beta1: 0.9 - beta2: 0.98 - epsilon: 1e-9 - - running_config: - batch_size: 2 - accumulation_steps: 4 - num_epochs: 20 - outdir: /mnt/Miscellanea/Models/local/contextnet - log_interval_steps: 300 - eval_interval_steps: 500 - save_interval_steps: 1000 diff --git a/tests/contextnet/test_contextnet.py b/tests/contextnet/test_contextnet.py deleted file mode 100644 index f802570222..0000000000 --- a/tests/contextnet/test_contextnet.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" -import tensorflow as tf - -logger = tf.get_logger() - -DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config_wp.j2") - -from tensorflow_asr.configs.config import Config -from tensorflow_asr.featurizers.speech_featurizers import SpeechFeaturizer -from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer -from tensorflow_asr.models.transducer.contextnet import ContextNet - - -def test_contextnet(): - config = Config(DEFAULT_YAML) - - text_featurizer = CharFeaturizer(config.decoder_config) - - speech_featurizer = SpeechFeaturizer(config.speech_config) - - model = ContextNet(vocab_size=text_featurizer.num_classes, **config.model_config) - - model.make(speech_featurizer.shape) - model.summary() - - model.add_featurizers(speech_featurizer=speech_featurizer, text_featurizer=text_featurizer) - - concrete_func = model.make_tflite_function(timestamp=False).get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.experimental_new_converter = True - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] - tflite = converter.convert() - - logger.info("Converted successfully with no timestamp") - - concrete_func = model.make_tflite_function(timestamp=True).get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.experimental_new_converter = True - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] - converter.convert() - - logger.info("Converted successfully with timestamp") - - tflitemodel = tf.lite.Interpreter(model_content=tflite) - signal = tf.random.normal([4000]) - - input_details = tflitemodel.get_input_details() - output_details = tflitemodel.get_output_details() - tflitemodel.resize_tensor_input(input_details[0]["index"], [4000]) - tflitemodel.allocate_tensors() - tflitemodel.set_tensor(input_details[0]["index"], signal) - tflitemodel.set_tensor(input_details[1]["index"], tf.constant(text_featurizer.blank, dtype=tf.int32)) - tflitemodel.set_tensor( - input_details[2]["index"], - tf.zeros([config.model_config["prediction_num_rnns"], 2, 1, config.model_config["prediction_rnn_units"]], dtype=tf.float32), - ) - tflitemodel.invoke() - hyp = tflitemodel.get_tensor(output_details[0]["index"]) - - logger.info(hyp) - - -if __name__ == "__main__": - test_contextnet() diff --git a/tests/deepspeech2/config.yml b/tests/deepspeech2/config.yml deleted file mode 100644 index 22b44b9dae..0000000000 --- a/tests/deepspeech2/config.yml +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: spectrogram - preemphasis: 0.97 - normalize_signal: True - normalize_feature: True - normalize_per_frame: False - -decoder_config: - vocabulary: null - blank_at_zero: False - beam_width: 500 - lm_config: - model_path: null - alpha: 2.0 - beta: 1.0 - -model_config: - name: deepspeech2 - conv_type: conv2d - conv_kernels: [[11, 41], [11, 21], [11, 11]] - conv_strides: [[2, 2], [1, 2], [1, 2]] - conv_filters: [32, 32, 96] - conv_dropout: 0.1 - rnn_nlayers: 5 - rnn_type: lstm - rnn_units: 512 - rnn_bidirectional: True - rnn_rowconv: 0 - rnn_dropout: 0.1 - fc_nlayers: 0 - fc_units: 1024 - -learning_config: - augmentations: null - - dataset_config: - train_paths: - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv - eval_paths: - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv - test_paths: - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv - tfrecords_dir: null - - optimizer_config: - class_name: adam - config: - learning_rate: 0.0001 - - running_config: - batch_size: 4 - num_epochs: 20 - accumulation_steps: 8 - outdir: /mnt/Miscellanea/Models/local/deepspeech2 - log_interval_steps: 400 - save_interval_steps: 400 - eval_interval_steps: 800 diff --git a/tests/deepspeech2/test_ds2.py b/tests/deepspeech2/test_ds2.py deleted file mode 100644 index 8554d62bdd..0000000000 --- a/tests/deepspeech2/test_ds2.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" -import tensorflow as tf - -logger = tf.get_logger() - -DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config_wp.j2") - -from tensorflow_asr.configs.config import Config -from tensorflow_asr.featurizers.speech_featurizers import SpeechFeaturizer -from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer -from tensorflow_asr.models.ctc.deepspeech2 import DeepSpeech2 - - -def test_ds2(): - config = Config(DEFAULT_YAML) - - text_featurizer = CharFeaturizer(config.decoder_config) - - speech_featurizer = SpeechFeaturizer(config.speech_config) - - model = DeepSpeech2(vocab_size=text_featurizer.num_classes, **config.model_config) - - model.make(speech_featurizer.shape) - model.summary() - - model.add_featurizers(speech_featurizer=speech_featurizer, text_featurizer=text_featurizer) - - concrete_func = model.make_tflite_function(greedy=False).get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.experimental_new_converter = True - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] - converter.convert() - - logger.info("Converted successfully with beam search") - - concrete_func = model.make_tflite_function(greedy=True).get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.experimental_new_converter = True - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] - converter.convert() - - logger.info("Converted successfully with greedy") diff --git a/tests/featurizer/test_sentencepiece.py b/tests/featurizer/test_sentencepiece.py index 08f310bc4e..538c03546d 100644 --- a/tests/featurizer/test_sentencepiece.py +++ b/tests/featurizer/test_sentencepiece.py @@ -3,9 +3,9 @@ import sentencepiece as spm -from tensorflow_asr.datasets.asr_dataset import ASRSliceDataset, ASRSliceTestDataset -from tensorflow_asr.featurizers.speech_featurizers import SpeechFeaturizer -from tensorflow_asr.featurizers.text_featurizers import SentencePieceFeaturizer, SubwordFeaturizer, TextFeaturizer +from tensorflow_asr.datasets import ASRSliceDataset, ASRSliceTestDataset +from tensorflow_asr.features.speech_featurizers import SpeechFeaturizer +from tensorflow_asr.tokenizers import SentencePieceTokenizer, SubwordFeaturizer, Tokenizer def test_encoder(): @@ -34,7 +34,7 @@ def test_featurizer(): "blank_at_zero": True, "beam_width": 5, "norm_score": True, - "corpus_files": [ + "train_files": [ "/data/datasets/LibriSpeech/train-clean-100/transcripts.tsv" "/data/datasets/LibriSpeech/train-clean-360/transcripts.tsv" "/data/datasets/LibriSpeech/train-other-500/transcripts.tsv" @@ -53,7 +53,7 @@ def test_featurizer(): "normalize_per_frame": False, } - text_featurizer_sentencepiece = SentencePieceFeaturizer.load_from_file(config, None) + text_featurizer_sentencepiece = SentencePieceTokenizer.load_from_file(config, None) subwords_path = os.path.join( os.path.abspath(os.path.dirname(__file__)), os.pardir, os.pardir, "vocabularies", "librispeech_train_4_1030.subwords" ) @@ -61,7 +61,7 @@ def test_featurizer(): speech_featurizer = SpeechFeaturizer(config_speech) data_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "transcripts_librispeech_train_clean_100.tsv") - def get_data(featurizer: TextFeaturizer): + def get_data(featurizer: Tokenizer): train_dataset = ASRSliceDataset( data_paths=[data_path], speech_featurizer=speech_featurizer, @@ -88,7 +88,7 @@ def test_iextract(): "blank_at_zero": True, "beam_width": 5, "norm_score": True, - "corpus_files": [ + "train_files": [ "/data/datasets/LibriSpeech/train-clean-100/transcripts.tsv" "/data/datasets/LibriSpeech/train-clean-360/transcripts.tsv" "/data/datasets/LibriSpeech/train-other-500/transcripts.tsv" @@ -107,7 +107,7 @@ def test_iextract(): "normalize_per_frame": False, } - text_featurizer_sentencepiece = SentencePieceFeaturizer.load_from_file(config, None) + text_featurizer_sentencepiece = SentencePieceTokenizer.load_from_file(config, None) speech_featurizer = SpeechFeaturizer(config_speech) data_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "transcripts_librispeech_train_clean_100.tsv") diff --git a/tests/featurizer/test_speech_featurizer.py b/tests/featurizer/test_speech_featurizer.py index d42c855f35..ef3f5e7e54 100644 --- a/tests/featurizer/test_speech_featurizer.py +++ b/tests/featurizer/test_speech_featurizer.py @@ -1,116 +1,113 @@ -# %% -import librosa -import librosa.display -import matplotlib.pyplot as plt -import numpy as np -import tensorflow as tf +# # %% +# import librosa +# import librosa.display +# import matplotlib.pyplot as plt +# import numpy as np -from tensorflow_asr.augmentations.methods import specaugment -from tensorflow_asr.configs.config import SpeechConfig -from tensorflow_asr.featurizers import speech_featurizers -from tensorflow_asr.utils import env_util +# from tensorflow_asr import tf +# from tensorflow_asr.augmentations.methods import specaugment +# from tensorflow_asr.configs import SpeechConfig +# from tensorflow_asr.features import speech_featurizers -env_util.setup_environment() +# speech_conf = SpeechConfig( +# { +# "sample_rate": 16000, +# "frame_ms": 25, +# "stride_ms": 10, +# "feature_type": "log_mel_spectrogram", +# "num_feature_bins": 80, +# # "compute_energy": True, +# # "use_natural_log": False, +# # "use_librosa_like_stft": True, +# # "fft_overdrive": False, +# # "normalize_feature": False, +# } +# ) +# signal = speech_featurizers.read_raw_audio("./test.flac", speech_conf.sample_rate) -speech_conf = SpeechConfig( - { - "sample_rate": 16000, - "frame_ms": 25, - "stride_ms": 10, - "feature_type": "log_mel_spectrogram", - "num_feature_bins": 80, - # "compute_energy": True, - # "use_natural_log": False, - # "use_librosa_like_stft": True, - # "fft_overdrive": False, - # "normalize_feature": False, - } -) -signal = speech_featurizers.read_raw_audio("./test.flac", speech_conf.sample_rate) +# print(f"signal length: {len(signal)}") +# sf = speech_featurizers.SpeechFeaturizer(speech_conf) +# ft = sf.extract(signal) +# freq_mask = specaugment.FreqMasking(prob=1, mask_value="min") +# ft = freq_mask.augment(ft) +# time_mask = specaugment.TimeMasking(prob=1, p_upperbound=0.05) +# ft = time_mask.augment(ft) +# ft = tf.squeeze(ft, axis=-1) +# ft = ft.numpy().T +# print(ft.shape) -print(f"signal length: {len(signal)}") -sf = speech_featurizers.SpeechFeaturizer(speech_conf) -ft = sf.extract(signal) -freq_mask = specaugment.FreqMasking(prob=1, mask_value="min") -ft = freq_mask.augment(ft) -time_mask = specaugment.TimeMasking(prob=1, p_upperbound=0.05) -ft = time_mask.augment(ft) -ft = tf.squeeze(ft, axis=-1) -ft = ft.numpy().T -print(ft.shape) +# plt.figure(figsize=(24, 5)) +# ax = plt.gca() +# ax.set_title("log_mel_spectrogram", fontweight="bold") +# librosa.display.specshow(ft, cmap="viridis") +# v1 = np.linspace(ft.min(), ft.max(), 8, endpoint=True) +# plt.colorbar(pad=0.01, fraction=0.02, ax=ax, format="%.2f", ticks=v1) +# plt.tight_layout() +# plt.show() -plt.figure(figsize=(24, 5)) -ax = plt.gca() -ax.set_title("log_mel_spectrogram", fontweight="bold") -librosa.display.specshow(ft, cmap="viridis") -v1 = np.linspace(ft.min(), ft.max(), 8, endpoint=True) -plt.colorbar(pad=0.01, fraction=0.02, ax=ax, format="%.2f", ticks=v1) -plt.tight_layout() -plt.show() +# sf.speech_config.normalize_per_frame = True +# ft = sf.extract(signal) +# ft = tf.squeeze(ft, axis=-1) +# ft = ft.numpy().T +# print(ft.shape) -sf.speech_config.normalize_per_frame = True -ft = sf.extract(signal) -ft = tf.squeeze(ft, axis=-1) -ft = ft.numpy().T -print(ft.shape) +# plt.figure(figsize=(24, 5)) +# ax = plt.gca() +# ax.set_title("log_mel_spectrogram", fontweight="bold") +# librosa.display.specshow(ft, cmap="viridis") +# v1 = np.linspace(ft.min(), ft.max(), 8, endpoint=True) +# plt.colorbar(pad=0.01, fraction=0.02, ax=ax, format="%.2f", ticks=v1) +# plt.tight_layout() +# plt.show() -plt.figure(figsize=(24, 5)) -ax = plt.gca() -ax.set_title("log_mel_spectrogram", fontweight="bold") -librosa.display.specshow(ft, cmap="viridis") -v1 = np.linspace(ft.min(), ft.max(), 8, endpoint=True) -plt.colorbar(pad=0.01, fraction=0.02, ax=ax, format="%.2f", ticks=v1) -plt.tight_layout() -plt.show() +# print(np.std(ft)) +# print(np.mean(ft)) -print(np.std(ft)) -print(np.mean(ft)) +# nframes = 5 +# chunk_size = (nframes - 1) * sf.speech_config.frame_step + sf.speech_config.frame_length +# stride = nframes * sf.speech_config.frame_step +# print(f"With chunk size: {chunk_size} and nfft: {sf.nfft}") +# signal_length = len(signal) +# all_ft = None +# for i in range(int(np.ceil((signal_length - chunk_size) / stride))): # this ensure the fft shape of chunked signal is the same with whole signal +# chunk = signal[i * stride : i * stride + chunk_size] +# # cft = sf.power_to_db(sf.stft(chunk)) +# cft = sf.extract(chunk) +# cft = tf.squeeze(cft, axis=-1) +# cft = cft.numpy() +# if all_ft is None: +# all_ft = cft +# else: +# all_ft = np.concatenate([all_ft, cft], axis=0) +# all_ft = all_ft.T +# all_ft = np.pad(all_ft, [[0, 0], [0, ft.shape[-1] - all_ft.shape[-1]]]) +# print(all_ft.shape) -nframes = 5 -chunk_size = (nframes - 1) * sf.speech_config.frame_step + sf.speech_config.frame_length -stride = nframes * sf.speech_config.frame_step -print(f"With chunk size: {chunk_size} and nfft: {sf.nfft}") -signal_length = len(signal) -all_ft = None -for i in range(int(np.ceil((signal_length - chunk_size) / stride))): # this ensure the fft shape of chunked signal is the same with whole signal - chunk = signal[i * stride : i * stride + chunk_size] - # cft = sf.power_to_db(sf.stft(chunk)) - cft = sf.extract(chunk) - cft = tf.squeeze(cft, axis=-1) - cft = cft.numpy() - if all_ft is None: - all_ft = cft - else: - all_ft = np.concatenate([all_ft, cft], axis=0) -all_ft = all_ft.T -all_ft = np.pad(all_ft, [[0, 0], [0, ft.shape[-1] - all_ft.shape[-1]]]) -print(all_ft.shape) +# plt.figure(figsize=(24, 5)) +# ax = plt.gca() +# ax.set_title(f"chunked log_mel_spectrogram", fontweight="bold") +# librosa.display.specshow(all_ft, cmap="viridis") +# v1 = np.linspace(all_ft.min(), all_ft.max(), 8, endpoint=True) +# plt.colorbar(pad=0.01, fraction=0.02, ax=ax, format="%.2f", ticks=v1) +# plt.tight_layout() +# plt.show() -plt.figure(figsize=(24, 5)) -ax = plt.gca() -ax.set_title(f"chunked log_mel_spectrogram", fontweight="bold") -librosa.display.specshow(all_ft, cmap="viridis") -v1 = np.linspace(all_ft.min(), all_ft.max(), 8, endpoint=True) -plt.colorbar(pad=0.01, fraction=0.02, ax=ax, format="%.2f", ticks=v1) -plt.tight_layout() -plt.show() +# dft = all_ft - ft -dft = all_ft - ft +# plt.figure(figsize=(24, 5)) +# ax = plt.gca() +# ax.set_title(f"diff of chunked log_mel_spectrogram with whole log_mel_spectrogram", fontweight="bold") +# librosa.display.specshow(dft, cmap="viridis") +# v1 = np.linspace(dft.min(), dft.max(), 8, endpoint=True) +# plt.colorbar(pad=0.01, fraction=0.02, ax=ax, format="%.2f", ticks=v1) +# plt.tight_layout() +# plt.show() -plt.figure(figsize=(24, 5)) -ax = plt.gca() -ax.set_title(f"diff of chunked log_mel_spectrogram with whole log_mel_spectrogram", fontweight="bold") -librosa.display.specshow(dft, cmap="viridis") -v1 = np.linspace(dft.min(), dft.max(), 8, endpoint=True) -plt.colorbar(pad=0.01, fraction=0.02, ax=ax, format="%.2f", ticks=v1) -plt.tight_layout() -plt.show() +# plt.figure(figsize=(24, 5)) +# ax = plt.gca() +# ax.set_title(f"RMSE of chunked log_mel_spectrogram with whole log_mel_spectrogram", fontweight="bold") +# plt.plot(np.sqrt(np.mean(dft**2, axis=0))) +# plt.tight_layout() +# plt.show() -plt.figure(figsize=(24, 5)) -ax = plt.gca() -ax.set_title(f"RMSE of chunked log_mel_spectrogram with whole log_mel_spectrogram", fontweight="bold") -plt.plot(np.sqrt(np.mean(dft**2, axis=0))) -plt.tight_layout() -plt.show() - -# %% +# # %% diff --git a/tests/jasper/config.yml b/tests/jasper/config.yml deleted file mode 100644 index 6ac4bf11bc..0000000000 --- a/tests/jasper/config.yml +++ /dev/null @@ -1,83 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - preemphasis: 0.97 - normalize_signal: True - normalize_feature: True - normalize_per_frame: False - -decoder_config: - vocabulary: null - blank_at_zero: False - beam_width: 500 - lm_config: - model_path: null - alpha: 2.0 - beta: 1.0 - -model_config: - name: jasper - dense: True - first_additional_block_channels: 256 - first_additional_block_kernels: 11 - first_additional_block_strides: 2 - first_additional_block_dilation: 1 - first_additional_block_dropout: 0.2 - nsubblocks: 1 - block_channels: [256, 384, 512, 640, 768] - block_kernels: [11, 13, 17, 21, 25] - block_dropout: [0.2, 0.2, 0.2, 0.3, 0.3] - second_additional_block_channels: 896 - second_additional_block_kernels: 1 - second_additional_block_strides: 1 - second_additional_block_dilation: 2 - second_additional_block_dropout: 0.4 - third_additional_block_channels: 1024 - third_additional_block_kernels: 1 - third_additional_block_strides: 1 - third_additional_block_dilation: 1 - third_additional_block_dropout: 0.4 - -learning_config: - augmentations: null - - dataset_config: - train_paths: - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv - eval_paths: - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv - test_paths: - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv - tfrecords_dir: null - - optimizer_config: - class_name: adam - config: - learning_rate: 0.0001 - - running_config: - batch_size: 4 - num_epochs: 20 - accumulation_steps: 8 - outdir: /mnt/Miscellanea/Models/local/jasper - log_interval_steps: 400 - save_interval_steps: 400 - eval_interval_steps: 800 diff --git a/tests/jasper/test_jasper.py b/tests/jasper/test_jasper.py deleted file mode 100644 index 974cd3aa8a..0000000000 --- a/tests/jasper/test_jasper.py +++ /dev/null @@ -1,60 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" -import tensorflow as tf - -logger = tf.get_logger() - -DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config_wp.j2") - -from tensorflow_asr.configs.config import Config -from tensorflow_asr.featurizers.speech_featurizers import SpeechFeaturizer -from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer -from tensorflow_asr.models.ctc.jasper import Jasper - - -def test_jasper(): - config = Config(DEFAULT_YAML) - - text_featurizer = CharFeaturizer(config.decoder_config) - - speech_featurizer = SpeechFeaturizer(config.speech_config) - - model = Jasper(vocab_size=text_featurizer.num_classes, **config.model_config) - - model.make(speech_featurizer.shape) - model.summary() - - model.add_featurizers(speech_featurizer=speech_featurizer, text_featurizer=text_featurizer) - - concrete_func = model.make_tflite_function(greedy=False).get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.experimental_new_converter = True - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] - converter.convert() - - logger.info("Converted successfully with beam search") - - concrete_func = model.make_tflite_function(greedy=True).get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.experimental_new_converter = True - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] - converter.convert() - - logger.info("Converted successfully with greedy") diff --git a/tests/losses/test_rnnt_loss.py b/tests/losses/test_rnnt_loss.py deleted file mode 100644 index e442e92f3e..0000000000 --- a/tests/losses/test_rnnt_loss.py +++ /dev/null @@ -1,209 +0,0 @@ -import numpy as np -import tensorflow as tf - -import tensorflow_asr.losses.rnnt_loss as rnnt_losses - - -class WarpRNNTTest(tf.test.TestCase): - def _run_rnnt(self, acts, labels, input_lengths, label_lengths, expected_costs, expected_grads, use_gpu=False): - self.assertEquals(acts.shape, expected_grads.shape) - acts_t = tf.constant(acts) - labels_t = tf.constant(labels) - input_lengths_t = tf.constant(input_lengths) - label_lengths_t = tf.constant(label_lengths) - - with tf.GradientTape() as tape: - # by default, GradientTape doesn’t track constants - tape.watch(acts_t) - tape.watch(labels_t) - tape.watch(input_lengths_t) - tape.watch(label_lengths_t) - logits = acts_t if use_gpu else tf.nn.log_softmax(acts_t) - costs = rnnt_losses.rnnt_loss( - logits=logits, labels=labels_t, label_length=label_lengths_t, logit_length=input_lengths_t, name=None - ) - - grads = tape.gradient(costs, [acts_t])[0] - self.assertAllClose(costs, expected_costs, atol=1e-6) - self.assertAllClose(grads, expected_grads, atol=1e-6) - - def _test_multiple_batches(self, use_gpu): - B = 2 - T = 4 - U = 3 - V = 3 - - acts = np.array( - [ - 0.065357, - 0.787530, - 0.081592, - 0.529716, - 0.750675, - 0.754135, - 0.609764, - 0.868140, - 0.622532, - 0.668522, - 0.858039, - 0.164539, - 0.989780, - 0.944298, - 0.603168, - 0.946783, - 0.666203, - 0.286882, - 0.094184, - 0.366674, - 0.736168, - 0.166680, - 0.714154, - 0.399400, - 0.535982, - 0.291821, - 0.612642, - 0.324241, - 0.800764, - 0.524106, - 0.779195, - 0.183314, - 0.113745, - 0.240222, - 0.339470, - 0.134160, - 0.505562, - 0.051597, - 0.640290, - 0.430733, - 0.829473, - 0.177467, - 0.320700, - 0.042883, - 0.302803, - 0.675178, - 0.569537, - 0.558474, - 0.083132, - 0.060165, - 0.107958, - 0.748615, - 0.943918, - 0.486356, - 0.418199, - 0.652408, - 0.024243, - 0.134582, - 0.366342, - 0.295830, - 0.923670, - 0.689929, - 0.741898, - 0.250005, - 0.603430, - 0.987289, - 0.592606, - 0.884672, - 0.543450, - 0.660770, - 0.377128, - 0.358021, - ], - dtype=np.float32, - ).reshape(B, T, U, V) - - expected_costs = np.array([4.28065, 3.93844], dtype=np.float32) - expected_grads = np.array( - [ - -0.186844, - -0.062555, - 0.249399, - -0.203377, - 0.202399, - 0.000977, - -0.141016, - 0.079123, - 0.061893, - -0.011552, - -0.081280, - 0.092832, - -0.154257, - 0.229433, - -0.075176, - -0.246593, - 0.146405, - 0.100188, - -0.012918, - -0.061593, - 0.074512, - -0.055986, - 0.219831, - -0.163845, - -0.497627, - 0.209240, - 0.288387, - 0.013605, - -0.030220, - 0.016615, - 0.113925, - 0.062781, - -0.176706, - -0.667078, - 0.367659, - 0.299419, - -0.356344, - -0.055347, - 0.411691, - -0.096922, - 0.029459, - 0.067463, - -0.063518, - 0.027654, - 0.035863, - -0.154499, - -0.073942, - 0.228441, - -0.166790, - -0.000088, - 0.166878, - -0.172370, - 0.105565, - 0.066804, - 0.023875, - -0.118256, - 0.094381, - -0.104707, - -0.108934, - 0.213642, - -0.369844, - 0.180118, - 0.189726, - 0.025714, - -0.079462, - 0.053748, - 0.122328, - -0.238789, - 0.116460, - -0.598687, - 0.302203, - 0.296484, - ], - dtype=np.float32, - ).reshape(B, T, U, V) - - labels = np.array([[1, 2], [1, 1]], dtype=np.int32) - input_lengths = np.array([4, 4], dtype=np.int32) - label_lengths = np.array([2, 2], dtype=np.int32) - - self._run_rnnt(acts, labels, input_lengths, label_lengths, expected_costs, expected_grads) - - def test_multiple_batches_gpu(self): - rnnt_losses.use_warprnnt = False - self._test_multiple_batches(use_gpu=True) - - def test_multiple_batches_cpu(self): - rnnt_losses.use_warprnnt = False - self._test_multiple_batches(use_gpu=False) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tests/rnn_transducer/config.yml b/tests/rnn_transducer/config.yml deleted file mode 100644 index 7608224c2e..0000000000 --- a/tests/rnn_transducer/config.yml +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -speech_config: - sample_rate: 16000 - frame_ms: 25 - stride_ms: 10 - num_feature_bins: 80 - feature_type: log_mel_spectrogram - preemphasis: 0.97 - normalize_signal: True - normalize_feature: True - normalize_per_frame: False - -decoder_config: - vocabulary: null - vocab_size: 1024 - max_subword_length: 4 - blank_at_zero: True - beam_width: 5 - norm_score: True - -model_config: - name: streaming_transducer - encoder_reductions: - 0: 3 - 1: 2 - encoder_dmodel: 320 - encoder_rnn_type: lstm - encoder_rnn_units: 1024 - encoder_nlayers: 2 - encoder_layer_norm: True - prediction_embed_dim: 320 - prediction_embed_dropout: 0.0 - prediction_num_rnns: 2 - prediction_rnn_units: 1024 - prediction_rnn_type: lstm - prediction_projection_units: 320 - prediction_layer_norm: True - joint_dim: 320 - joint_activation: tanh - -learning_config: - augmentations: - feature_augment: - time_masking: - num_masks: 10 - mask_factor: 100 - p_upperbound: 0.05 - freq_masking: - num_masks: 1 - mask_factor: 27 - - dataset_config: - train_paths: - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv - eval_paths: - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv - test_paths: - - /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv - tfrecords_dir: null - - optimizer_config: - class_name: adam - config: - learning_rate: 0.0001 - - running_config: - batch_size: 2 - accumulation_steps: 1 - num_epochs: 20 - outdir: /mnt/Miscellanea/Models/local/streaming_transducer - log_interval_steps: 300 - eval_interval_steps: 500 - save_interval_steps: 1000 diff --git a/tests/rnn_transducer/test_rnn_transducer.py b/tests/rnn_transducer/test_rnn_transducer.py deleted file mode 100644 index f7748196d6..0000000000 --- a/tests/rnn_transducer/test_rnn_transducer.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2020 Huy Le Nguyen (@nglehuy) -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" -import tensorflow as tf - -logger = tf.get_logger() - -DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config_wp.j2") - -from tensorflow_asr.configs.config import Config -from tensorflow_asr.featurizers.speech_featurizers import SpeechFeaturizer -from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer -from tensorflow_asr.models.transducer.rnn_transducer import RnnTransducer - - -def test_streaming_transducer(): - config = Config(DEFAULT_YAML) - - text_featurizer = CharFeaturizer(config.decoder_config) - - speech_featurizer = SpeechFeaturizer(config.speech_config) - - model = RnnTransducer(vocab_size=text_featurizer.num_classes, **config.model_config) - - model.make(speech_featurizer.shape) - model.summary() - - model.add_featurizers(speech_featurizer=speech_featurizer, text_featurizer=text_featurizer) - - concrete_func = model.make_tflite_function(timestamp=False).get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.experimental_new_converter = True - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] - tflite_model = converter.convert() - - logger.info("Converted successfully with no timestamp") - - concrete_func = model.make_tflite_function(timestamp=True).get_concrete_function() - converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) - converter.optimizations = [tf.lite.Optimize.DEFAULT] - converter.experimental_new_converter = True - converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] - converter.convert() - - logger.info("Converted successfully with timestamp") - - tflitemodel = tf.lite.Interpreter(model_content=tflite_model) - signal = tf.random.normal([4000]) - - input_details = tflitemodel.get_input_details() - output_details = tflitemodel.get_output_details() - tflitemodel.resize_tensor_input(input_details[0]["index"], signal.shape) - tflitemodel.allocate_tensors() - tflitemodel.set_tensor(input_details[0]["index"], signal) - tflitemodel.set_tensor(input_details[1]["index"], tf.constant(text_featurizer.blank, dtype=tf.int32)) - tflitemodel.set_tensor( - input_details[2]["index"], - tf.zeros([config.model_config["encoder_nlayers"], 2, 1, config.model_config["encoder_rnn_units"]], dtype=tf.float32), - ) - tflitemodel.set_tensor( - input_details[3]["index"], - tf.zeros([config.model_config["prediction_num_rnns"], 2, 1, config.model_config["prediction_rnn_units"]], dtype=tf.float32), - ) - tflitemodel.invoke() - hyp = tflitemodel.get_tensor(output_details[0]["index"]) - - logger.info(hyp) - - -if __name__ == "__main__": - test_streaming_transducer() diff --git a/tests/featurizer/test.flac b/tests/test.flac similarity index 100% rename from tests/featurizer/test.flac rename to tests/test.flac diff --git a/tests/test_bug.py b/tests/test_bug.py new file mode 100644 index 0000000000..5258685631 --- /dev/null +++ b/tests/test_bug.py @@ -0,0 +1,17 @@ +import keras + + +class Model(keras.Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dense = keras.layers.Dense(10) + self.mha = keras.layers.MultiHeadAttention(10, 10, output_shape=(100,)) + + def call(self, inputs): + x = self.dense(inputs) + return self.mha(x, x, x) + + +model = Model() +model(keras.Input(shape=(10, 10))) +model.summary() diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py new file mode 100644 index 0000000000..52cadf7216 --- /dev/null +++ b/tests/test_callbacks.py @@ -0,0 +1,23 @@ +import os +import tempfile + +from tensorflow_asr.callbacks import KaggleModelBackupAndRestore + + +def test_kaggle_model_backup_and_restore(): + model_handle = os.getenv("TEST_MODEL_HANDLE") + if not model_handle: + return + with tempfile.TemporaryDirectory() as temp_dir: + os.environ["KAGGLEHUB_CACHE"] = os.path.join(temp_dir, "cache") + os.makedirs(os.environ["KAGGLEHUB_CACHE"], exist_ok=True) + model_dir = os.path.join(temp_dir, "model") + os.makedirs(model_dir, exist_ok=True) + with open(os.path.join(model_dir, "model.h5"), "w", encoding="utf-8") as f: + f.write("dummy model data") + callback = KaggleModelBackupAndRestore( + model_handle=model_handle, + model_dir=model_dir, + save_freq=1, + ) + callback._backup_kaggle(logs={}, notes="Backed up model at batch") diff --git a/tests/test_char_featurizer.py b/tests/test_char_featurizer.py deleted file mode 100644 index c49535e620..0000000000 --- a/tests/test_char_featurizer.py +++ /dev/null @@ -1,31 +0,0 @@ -# pylint: disable=line-too-long -import os - -import tensorflow as tf - -from tensorflow_asr.configs.config import DecoderConfig -from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer - -decoder_config = DecoderConfig( - { - "vocabulary": f"{os.path.dirname(__file__)}/../vocabularies/english.characters", - } -) - -text = "i'm good but it would have broken down after ten miles of that hard trail dawn came while they wound over the crest of the range and with the sun in their faces they took the downgrade it was well into the morning before nash reached logan" - - -def test(): - featurizer = CharFeaturizer(decoder_config=decoder_config) - print(featurizer.tokens) - print(featurizer.num_classes) - print(text) - indices = featurizer.extract(text) - print(indices.numpy()) - indices = featurizer.tf_extract(text) - print(indices.numpy()) - batch_indices = tf.stack([indices, indices], axis=0) - reversed_text = featurizer.iextract(batch_indices) - print(reversed_text.numpy()) - upoints = featurizer.indices2upoints(indices) - print(upoints.numpy()) diff --git a/tests/test_dataset.py b/tests/test_dataset.py deleted file mode 100644 index 57a81147fd..0000000000 --- a/tests/test_dataset.py +++ /dev/null @@ -1,12 +0,0 @@ -import tensorflow as tf - - -def test_dataset(): - a = [1, 2, 3, 4, 5, 6, 7] - batch = 2 - ds = tf.data.Dataset.from_tensor_slices(a) - ds = ds.cache() - ds = ds.shuffle(3) - ds = ds.repeat(3) - ds = ds.batch(batch, drop_remainder=True) - print(list(ds.as_numpy_iterator())) diff --git a/tests/test_error_rates.py b/tests/test_error_rates.py deleted file mode 100644 index 1194463ccb..0000000000 --- a/tests/test_error_rates.py +++ /dev/null @@ -1,13 +0,0 @@ -from tensorflow_asr.utils import metric_util - - -def test_wer(): - decode = [ - "hello i am huy", - ] - target = [ - "hello i am huy", - ] - a, b = metric_util.tf_wer(decode, target) - print(a.numpy()) - print(b.numpy()) diff --git a/tests/test_layers.py b/tests/test_layers.py new file mode 100644 index 0000000000..99d7fb1e16 --- /dev/null +++ b/tests/test_layers.py @@ -0,0 +1,78 @@ +# pylint: disable=line-too-long +import os + +import librosa +import matplotlib.pyplot as plt +import numpy as np + +from tensorflow_asr import tf +from tensorflow_asr.augmentations.augmentation import Augmentation +from tensorflow_asr.models.layers.feature_extraction import FeatureExtraction +from tensorflow_asr.utils import data_util, file_util + +# config_path = os.path.join(os.path.dirname(__file__), "..", "examples", "configs", "log_mel_spectrogram.yml.j2") +# config = file_util.load_yaml(config_path) + +audio_file_path = os.path.join(os.path.dirname(__file__), "test.flac") + + +def plot_specs(ft, title): + ft = ft.numpy() if isinstance(ft, tf.Tensor) else ft + ft = np.squeeze(ft) + ft = ft.T + plt.figure(figsize=(24, 5)) + ax = plt.gca() + ax.set_title(title, fontweight="bold") + librosa.display.specshow(ft, cmap="viridis") + v1 = np.linspace(ft.min(), ft.max(), 8, endpoint=True) + plt.colorbar(pad=0.01, fraction=0.02, ax=ax, format="%.2f", ticks=v1) + plt.tight_layout() + plt.show() + + +def test_feature_extraction(): + signal = data_util.load_and_convert_to_wav(audio_file_path) + signal = tf.expand_dims(data_util.read_raw_audio(signal), axis=0) + signal_length = tf.expand_dims(tf.shape(signal)[1], axis=0) + signal = tf.pad(signal, paddings=[[0, 0], [0, 16000]], mode="CONSTANT", constant_values=0.0) + + feature_extraction_layer = FeatureExtraction() + + for ftype in ("spectrogram", "log_mel_spectrogram", "log_gammatone_spectrogram", "mfcc"): + feature_extraction_layer.feature_type = ftype + ft, _ = feature_extraction_layer((signal, signal_length)) + plot_specs(ft, feature_extraction_layer.feature_type) + + mask, _ = feature_extraction_layer.compute_mask((signal, signal_length)) + print(mask) + + feature_extraction_layer.feature_type = "log_mel_spectrogram" + feature_extraction_layer.preemphasis = 0.0 + ft1, _ = feature_extraction_layer((signal, signal_length)) + feature_extraction_layer.preemphasis = 0.97 + ft2, _ = feature_extraction_layer((signal, signal_length)) + ft = ft1 - ft2 + plot_specs(ft, feature_extraction_layer.feature_type) + + feature_extraction_layer.augmentations = Augmentation( + { + "feature_augment": { + "freq_masking": { + "num_masks": 2, + "mask_factor": 27, + "prob": 0.0, + "mask_value": 0, + }, + "time_masking": { + "num_masks": 2, + "mask_factor": -1, + "prob": 0.0, + "mask_value": 0, + "p_upperbound": 0.05, + }, + } + } + ) + feature_extraction_layer.preemphasis = 0.0 + ft1, _ = feature_extraction_layer((signal, signal_length), training=True) + plot_specs(ft1, feature_extraction_layer.feature_type) diff --git a/tests/test_load_yaml.py b/tests/test_load_yaml.py deleted file mode 100644 index 7569c01d46..0000000000 --- a/tests/test_load_yaml.py +++ /dev/null @@ -1,8 +0,0 @@ -import os - -from tensorflow_asr.utils import file_util - - -def test(): - a = file_util.load_yaml(f"{os.path.dirname(__file__)}/../examples/conformer/config_wp.yml") - print(a) diff --git a/tests/test_mask.py b/tests/test_mask.py new file mode 100644 index 0000000000..b411efb6d3 --- /dev/null +++ b/tests/test_mask.py @@ -0,0 +1,55 @@ +import tensorflow as tf + +from tensorflow_asr.models.layers.multihead_attention import compute_streaming_mask + + +def test_mha_streaming_mask(): + mask = compute_streaming_mask(2, 2, tf.zeros([5, 8, 8])) + print(mask) + assert tf.reduce_all( + tf.equal( + mask, + tf.constant( + [ + [ + [True, True, False, False, False, False, False, False], + [True, True, False, False, False, False, False, False], + [True, True, True, True, False, False, False, False], + [True, True, True, True, False, False, False, False], + [False, False, True, True, True, True, False, False], + [False, False, True, True, True, True, False, False], + [False, False, False, False, True, True, True, True], + [False, False, False, False, True, True, True, True], + ] + ] + ), + ) + ).numpy() + + mask = compute_streaming_mask(3, 3, tf.zeros([5, 14, 14])) + print(mask) + assert tf.reduce_all( + tf.equal( + mask, + tf.constant( + [ + [ + [True, True, True, False, False, False, False, False, False, False, False, False, False, False], + [True, True, True, False, False, False, False, False, False, False, False, False, False, False], + [True, True, True, False, False, False, False, False, False, False, False, False, False, False], + [True, True, True, True, True, True, False, False, False, False, False, False, False, False], + [True, True, True, True, True, True, False, False, False, False, False, False, False, False], + [True, True, True, True, True, True, False, False, False, False, False, False, False, False], + [False, False, False, True, True, True, True, True, True, False, False, False, False, False], + [False, False, False, True, True, True, True, True, True, False, False, False, False, False], + [False, False, False, True, True, True, True, True, True, False, False, False, False, False], + [False, False, False, False, False, False, True, True, True, True, True, True, False, False], + [False, False, False, False, False, False, True, True, True, True, True, True, False, False], + [False, False, False, False, False, False, True, True, True, True, True, True, False, False], + [False, False, False, False, False, False, False, False, False, True, True, True, True, True], + [False, False, False, False, False, False, False, False, False, True, True, True, True, True], + ] + ] + ), + ) + ).numpy() diff --git a/tests/test_masked_fill.py b/tests/test_masked_fill.py deleted file mode 100644 index 3b89d55b2a..0000000000 --- a/tests/test_masked_fill.py +++ /dev/null @@ -1,12 +0,0 @@ -import tensorflow as tf - -from tensorflow_asr.utils import math_util - - -def test(): - a = math_util.masked_fill( - tf.convert_to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], tf.float32), - [[True, True, True], [True, False, True], [False, True, True]], - value=-1e9, - ) - print(a.numpy()) diff --git a/tests/test_relpe.py b/tests/test_relpe.py index 79a353fa9c..7a4285f52c 100644 --- a/tests/test_relpe.py +++ b/tests/test_relpe.py @@ -1,17 +1,24 @@ -import tensorflow as tf - -from tensorflow_asr.models.layers import positional_encoding +from tensorflow_asr import tf from tensorflow_asr.models.layers.multihead_attention import rel_left_shift +from tensorflow_asr.models.layers.positional_encoding import RelativeSinusoidalPositionalEncoding from tensorflow_asr.utils import plot_util def test(): - batch_size, input_length, max_length, dmodel = 1, 300, 500, 144 - position = positional_encoding.compute_position(-input_length, max_length, 1) - pe = positional_encoding.compute_sinusoid_position_encoding(position, batch_size, dmodel, interleave=True) - pe = pe[0].numpy().T + batch_size, input_length, max_length, dmodel = 2, 300, 500, 144 + causal = False + layer = RelativeSinusoidalPositionalEncoding(interleave=True, memory_length=input_length, causal=causal) + _, pe = layer((tf.random.normal([batch_size, max_length, dmodel]), tf.convert_to_tensor([input_length, input_length + 10])), training=False) + shift = tf.einsum("brd,btd->btr", pe, tf.ones([batch_size, max_length, dmodel])) + shift = rel_left_shift(shift[0][None, None, ...], causal=causal) + pe = tf.transpose(pe[0], perm=[1, 0]) + pe = pe.numpy() print(pe.shape) - plot_util.plotmesh(pe, title="sinusoid position encoding") + shift = shift[0][0] + shift = shift.numpy() + print(shift.shape) + plot_util.plotmesh(pe, title="sinusoid position encoding", invert_yaxis=False) + plot_util.plotmesh(shift, title="relshift") def test_relshift(): @@ -19,7 +26,7 @@ def test_relshift(): print(a) a = a[None, ...] a = a[None, ...] - b = rel_left_shift(a) + b = rel_left_shift(a, causal=True) b = tf.squeeze(b, 0) b = tf.squeeze(b, 0) print(b) diff --git a/tests/test_rnnt_loss.py b/tests/test_rnnt_loss.py index 7aa7beb90c..534dd09ddc 100644 --- a/tests/test_rnnt_loss.py +++ b/tests/test_rnnt_loss.py @@ -1,10 +1,9 @@ import time -import tensorflow as tf +from tensorflow_asr import tf +from tensorflow_asr.losses.rnnt_loss import compute_rnnt_loss_and_grad_helper -from tensorflow_asr.losses.rnnt_loss_naive import compute_rnnt_loss_and_grad_helper - -B = 4 +B = 1 T = 743 U = 200 V = 1000 @@ -24,7 +23,6 @@ def run(): labels=labels, label_length=label_length, logit_length=logit_length, - blank=blank, ) t1 = time.time() tf.print(loss) diff --git a/tests/test_schedules.py b/tests/test_schedules.py new file mode 100644 index 0000000000..0691851dab --- /dev/null +++ b/tests/test_schedules.py @@ -0,0 +1,16 @@ +import matplotlib.pyplot as plt + +from tensorflow_asr.optimizers.schedules import CyclicTransformerSchedule, TransformerSchedule + + +def test_transformer_schedule(): + sched = TransformerSchedule(dmodel=176, scale=10.0, warmup_steps=10000, max_lr="0.05/(176**0.5)", min_lr=None) + sched2 = CyclicTransformerSchedule(dmodel=320, step_size=10000, warmup_steps=15000, max_lr=0.0025) + lrs = [sched(i).numpy() for i in range(100000)] + print(lrs[:100]) + plt.plot(lrs) + plt.show() + lrs = [sched2(i).numpy() for i in range(100000)] + print(lrs[:100]) + plt.plot(lrs) + plt.show() diff --git a/tests/test_sp_featurizer.py b/tests/test_sp_featurizer.py deleted file mode 100644 index 4558523ead..0000000000 --- a/tests/test_sp_featurizer.py +++ /dev/null @@ -1,42 +0,0 @@ -# pylint: disable=line-too-long -import os - -import tensorflow as tf - -from tensorflow_asr.configs.config import DecoderConfig -from tensorflow_asr.featurizers.text_featurizers import SentencePieceFeaturizer - -decoder_config = DecoderConfig( - { - "model_type": "unigram", - "vocabulary": f"{os.path.dirname(__file__)}/../vocabularies/librispeech/sentencepiece/train_bpe_1000.model", - "blank_index": 0, - "pad_token": "", - "pad_index": 0, - "unknown_token": "", - "unknown_index": 1, - "bos_token": "", - "bos_index": 2, - "eos_token": "", - "eos_index": 3, - } -) - -text = "i'm good but it would have broken down after ten miles of that hard trail dawn came while they wound over the crest of the range and with the sun in their faces they took the downgrade it was well into the morning before nash reached logan" - - -def test(): - featurizer = SentencePieceFeaturizer(decoder_config=decoder_config) - print(featurizer.num_classes) - print(text) - indices = featurizer.extract(text) - print(indices.numpy()) - indices = featurizer.tf_extract(text) - print(indices.numpy()) - indices = list(indices.numpy()) - indices += [0, 0] - batch_indices = tf.stack([indices, indices], axis=0) - reversed_text = featurizer.iextract(batch_indices) - print(reversed_text.numpy()) - upoints = featurizer.indices2upoints(indices) - print(upoints.numpy()) diff --git a/tests/test_tokenizers.py b/tests/test_tokenizers.py new file mode 100644 index 0000000000..a66f643ff6 --- /dev/null +++ b/tests/test_tokenizers.py @@ -0,0 +1,65 @@ +# pylint: disable=line-too-long +import os + +from tensorflow_asr import tf +from tensorflow_asr.configs import DecoderConfig +from tensorflow_asr.tokenizers import CharTokenizer, SentencePieceTokenizer, WordPieceTokenizer +from tensorflow_asr.utils import file_util + +file_util.ENABLE_PATH_PREPROCESS = False + +repodir = os.path.realpath(os.path.join(os.path.dirname(__file__), "..")) + + +text = "i'm good but it would have broken down after ten miles of that hard trail dawn came while they wound over the crest of the range and with the sun in their faces they took the downgrade it was well into the morning before nash reached logan" +# text = "a b" + + +def test_char(): + config_path = os.path.join(os.path.dirname(__file__), "..", "examples", "configs", "librispeech", "characters", "char.yml.j2") + config = file_util.load_yaml(config_path, repodir=repodir) + decoder_config = DecoderConfig(config["decoder_config"]) + featurizer = CharTokenizer(decoder_config=decoder_config) + print(featurizer.num_classes) + print(text) + indices = featurizer.tokenize(text) + print(indices.numpy()) + batch_indices = tf.stack([indices, indices], axis=0) + reversed_text = featurizer.detokenize(batch_indices) + print(reversed_text.numpy()) + upoints = featurizer.detokenize_unicode_points(indices) + print(upoints.numpy()) + + +def test_wp(): + config_path = os.path.join(os.path.dirname(__file__), "..", "examples", "configs", "librispeech", "wordpiece", "wp.yml.j2") + config = file_util.load_yaml(config_path, repodir=repodir) + decoder_config = DecoderConfig(config["decoder_config"]) + featurizer = WordPieceTokenizer(decoder_config=decoder_config) + print(featurizer.num_classes) + print(text) + indices = featurizer.tokenize(text) + print(indices.numpy()) + batch_indices = tf.stack([indices, indices], axis=0) + reversed_text = featurizer.detokenize(batch_indices) + print(reversed_text.numpy()) + upoints = featurizer.detokenize_unicode_points(indices) + print(upoints.numpy()) + + +def test_sp(): + config_path = os.path.join(os.path.dirname(__file__), "..", "examples", "configs", "librispeech", "sentencepiece", "sp.yml.j2") + config = file_util.load_yaml(config_path, repodir=repodir) + decoder_config = DecoderConfig(config["decoder_config"]) + featurizer = SentencePieceTokenizer(decoder_config=decoder_config) + print(featurizer.num_classes) + print(text) + indices = featurizer.tokenize(text) + print(indices) + indices = list(indices.numpy()) + indices += [0, 0] + batch_indices = tf.stack([indices, indices], axis=0) + reversed_text = featurizer.detokenize(batch_indices) + print(reversed_text.numpy()) + upoints = featurizer.detokenize_unicode_points(indices) + print(upoints.numpy()) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000000..969df3e4bf --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,35 @@ +import os + +from tensorflow_asr import tf +from tensorflow_asr.utils import file_util, math_util + + +def test_load_yaml(): + a = file_util.load_yaml(f"{os.path.dirname(__file__)}/../examples/conformer/config_wp.yml") + print(a) + + +def test_mask_fill(): + a = math_util.masked_fill( + tf.convert_to_tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], tf.float32), + [[True, True, True], [True, False, True], [False, True, True]], + value=-1e9, + ) + print(a.numpy()) + + +def test_dataset(): + a = [1, 2, 3, 4, 5, 6, 7] + batch = 2 + ds = tf.data.Dataset.from_tensor_slices(a) + ds = ds.cache() + ds = ds.shuffle(3) + ds = ds.repeat(3) + ds = ds.batch(batch, drop_remainder=True) + print(list(ds.as_numpy_iterator())) + + +def test_split_batch(): + a = tf.ones((12, 2, 4), tf.float32) + b = math_util.split_tensor_by_ga(a, 4, 3) + print(b) diff --git a/tests/test_wp_featurizer.py b/tests/test_wp_featurizer.py deleted file mode 100644 index 9f73a53c58..0000000000 --- a/tests/test_wp_featurizer.py +++ /dev/null @@ -1,32 +0,0 @@ -# pylint: disable=line-too-long -import os - -import tensorflow as tf - -from tensorflow_asr.configs.config import DecoderConfig -from tensorflow_asr.featurizers.text_featurizers import WordPieceFeaturizer - -decoder_config = DecoderConfig( - { - "vocabulary": f"{os.path.dirname(__file__)}/../vocabularies/librispeech/wordpiece/train_1000_50.tokens", - "max_subword_length": 50, - "unknown_token": "", - } -) - -text = " i'm good but it would have broken down after ten miles of that hard trail dawn came while they wound over the crest of the range and with the sun in their faces they took the downgrade it was well into the morning before nash reached logan" - - -def test_wordpiece_featurizer(): - featurizer = WordPieceFeaturizer(decoder_config=decoder_config) - print(featurizer.num_classes) - print(text) - indices = featurizer.extract(text) - print(indices.numpy()) - indices = featurizer.tf_extract(text) - print(indices.numpy()) - batch_indices = tf.stack([indices, indices], axis=0) - reversed_text = featurizer.iextract(batch_indices) - print(reversed_text.numpy()) - upoints = featurizer.indices2upoints(indices) - print(upoints.numpy()) diff --git a/vocabularies/README.md b/vocabularies/README.md deleted file mode 100644 index 0881fbc4d6..0000000000 --- a/vocabularies/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Predefined Vocabularies - -- `language.characters` files contain all of that language's characters -- `corpus_maxlength_nwords.subwords` files contain subwords generated from corpus transcripts, with maximum length of a subword is `maxlength` and number of subwords is `nwords`. -- `corpus_maxlength_nwords.metadata.json` files contain metadata calculated from corpus duration and transcripts, for using static training \ No newline at end of file diff --git a/vocabularies/librispeech/characters/english.characters b/vocabularies/librispeech/characters/english.characters deleted file mode 100644 index 487d25d048..0000000000 --- a/vocabularies/librispeech/characters/english.characters +++ /dev/null @@ -1,33 +0,0 @@ -# List of alphabets (utf-8 encoded). Note that '#' starts a comment line, which -# will be ignored by the parser. -# begin of vocabulary - - -a -b -c -d -e -f -g -h -i -j -k -l -m -n -o -p -q -r -s -t -u -v -w -x -y -z -' -# end of vocabulary diff --git a/vocabularies/librispeech/sentencepiece/train_bpe_1000.vocab b/vocabularies/librispeech/sentencepiece/train_bpe_1000.vocab deleted file mode 100644 index a5e1b3075d..0000000000 --- a/vocabularies/librispeech/sentencepiece/train_bpe_1000.vocab +++ /dev/null @@ -1,1000 +0,0 @@ - 0 - 0 - 0 - 0 -▁t -0 -he -1 -▁a -2 -▁the -3 -in -4 -▁s -5 -▁w -6 -▁o -7 -re -8 -nd -9 -▁b -10 -▁h -11 -er -12 -▁m -13 -▁i -14 -ou -15 -▁c -16 -▁f -17 -at -18 -ed -19 -▁and -20 -en -21 -▁to -22 -▁of -23 -on -24 -is -25 -▁d -26 -ing -27 -▁th -28 -▁p -29 -▁he -30 -or -31 -▁l -32 -es -33 -▁in -34 -ll -35 -it -36 -ar -37 -as -38 -an -39 -▁n -40 -▁g -41 -om -42 -▁be -43 -▁ha -44 -▁e -45 -le -46 -ot -47 -▁y -48 -ut -49 -ow -50 -ic -51 -▁wh -52 -▁it -53 -ld -54 -ve -55 -▁that -56 -ly -57 -▁was -58 -id -59 -se -60 -st -61 -▁on -62 -gh -63 -ent -64 -▁re -65 -▁you -66 -im -67 -ce -68 -▁u -69 -ver -70 -ion -71 -▁as -72 -et -73 -▁for -74 -ay -75 -▁we -76 -▁his -77 -ith -78 -al -79 -ir -80 -▁r -81 -▁with -82 -▁st -83 -ad -84 -ur -85 -ght -86 -▁an -87 -▁her -88 -▁not -89 -▁had -90 -▁is -91 -ter -92 -her -93 -ac -94 -am -95 -▁at -96 -oo -97 -▁but -98 -ould -99 -▁she -100 -▁k -101 -▁se -102 -▁sa -103 -▁sh -104 -▁fr -105 -▁him -106 -▁so -107 -ill -108 -▁me -109 -ain -110 -▁su -111 -ight -112 -ch -113 -red -114 -ct -115 -all -116 -ro -117 -ke -118 -ess -119 -il -120 -ore -121 -▁de -122 -▁they -123 -▁my -124 -▁whe -125 -▁all -126 -ich -127 -▁ne -128 -ri -129 -▁by -130 -▁have -131 -ome -132 -pp -133 -▁this -134 -▁li -135 -▁do -136 -▁con -137 -us -138 -▁which -139 -▁ch -140 -ul -141 -qu -142 -▁j -143 -▁up -144 -▁said -145 -▁from -146 -ard -147 -ge -148 -▁or -149 -▁v -150 -▁one -151 -th -152 -▁no -153 -▁ex -154 -▁were -155 -▁there -156 -pe -157 -and -158 -est -159 -▁man -160 -▁who -161 -ble -162 -ant -163 -ie -164 -▁al -165 -res -166 -ous -167 -ust -168 -very -169 -ation -170 -▁fe -171 -▁them -172 -lf -173 -▁when -174 -ind -175 -nt -176 -ame -177 -ra -178 -▁go -179 -ers -180 -ast -181 -fe -182 -ood -183 -▁kn -184 -▁int -185 -ist -186 -art -187 -▁are -188 -out -189 -▁would -190 -▁le -191 -os -192 -▁their -193 -ong -194 -▁what -195 -our -196 -▁if -197 -ound -198 -▁com -199 -▁ab -200 -▁out -201 -▁wor -202 -em -203 -▁will -204 -ak -205 -▁mis -206 -ate -207 -ol -208 -um -209 -un -210 -itt -211 -ough -212 -ked -213 -ap -214 -ig -215 -one -216 -▁been -217 -own -218 -ive -219 -▁then -220 -▁br -221 -ven -222 -if -223 -▁ar -224 -▁tr -225 -self -226 -▁pl -227 -▁ro -228 -ther -229 -▁pr -230 -reat -231 -▁un -232 -▁af -233 -▁sp -234 -▁qu -235 -▁pro -236 -ity -237 -hed -238 -▁tw -239 -▁ag -240 -▁could -241 -ost -242 -ace -243 -ort -244 -ure -245 -ake -246 -ack -247 -▁am -248 -▁any -249 -▁some -250 -▁your -251 -▁more -252 -▁can -253 -au -254 -▁tim -255 -ep -256 -▁en -257 -ag -258 -ck -259 -▁cl -260 -▁into -261 -ry -262 -hing -263 -▁now -264 -nder -265 -are -266 -▁very -267 -▁gr -268 -el -269 -ose -270 -▁loo -271 -▁bo -272 -ved -273 -op -274 -▁other -275 -▁did -276 -ance -277 -▁than -278 -ittle -279 -▁little -280 -ine -281 -ies -282 -way -283 -ite -284 -▁like -285 -ide -286 -ass -287 -▁bl -288 -able -289 -▁lo -290 -urn -291 -ought -292 -▁know -293 -other -294 -▁time -295 -▁im -296 -▁dis -297 -▁us -298 -▁co -299 -fore -300 -▁te -301 -▁how -302 -ence -303 -▁day -304 -▁ad -305 -ade -306 -▁about -307 -ice -308 -▁see -309 -▁over -310 -pt -311 -cc -312 -▁too -313 -ink -314 -▁fl -315 -wn -316 -▁great -317 -▁after -318 -pl -319 -de -320 -▁per -321 -▁again -322 -ment -323 -▁upon -324 -▁hand -325 -ab -326 -ree -327 -▁has -328 -ish -329 -ci -330 -▁only -331 -ally -332 -▁well -333 -▁should -334 -▁po -335 -▁mar -336 -ress -337 -▁say -338 -▁good -339 -ather -340 -▁two -341 -ings -342 -▁pe -343 -ount -344 -▁our -345 -ire -346 -ving -347 -▁down -348 -ars -349 -ert -350 -we -351 -▁before -352 -ile -353 -▁app -354 -ves -355 -▁every -356 -▁its -357 -▁old -358 -▁thr -359 -▁mu -360 -▁made -361 -ick -362 -ied -363 -▁long -364 -te -365 -age -366 -ft -367 -▁where -368 -▁never -369 -ang -370 -▁pre -371 -▁must -372 -▁sm -373 -▁such -374 -ull -375 -ful -376 -▁str -377 -ions -378 -▁sc -379 -▁off -380 -▁came -381 -ious -382 -ue -383 -▁miss -384 -ward -385 -▁fir -386 -ild -387 -▁even -388 -▁under -389 -▁these -390 -act -391 -▁come -392 -▁part -393 -▁fo -394 -ated -395 -ness -396 -▁rem -397 -▁bec -398 -ord -399 -▁may -400 -ty -401 -▁think -402 -▁much -403 -per -404 -▁mister -405 -▁way -406 -led -407 -orn -408 -▁ey -409 -▁let -410 -▁cont -411 -▁gl -412 -▁thought -413 -▁look -414 -ect -415 -▁spe -416 -▁back -417 -ise -418 -▁bet -419 -▁ye -420 -ady -421 -ach -422 -ans -423 -▁just -424 -▁first -425 -▁here -426 -ren -427 -▁ho -428 -▁des -429 -▁ob -430 -▁own -431 -ried -432 -ud -433 -ary -434 -▁went -435 -▁himself -436 -▁mo -437 -cl -438 -▁men -439 -air -440 -ave -441 -ath -442 -▁sl -443 -ff -444 -co -445 -▁cr -446 -llow -447 -▁res -448 -▁might -449 -ily -450 -▁seem -451 -int -452 -ip -453 -▁beg -454 -ouse -455 -anc -456 -▁wat -457 -▁through -458 -▁comp -459 -ber -460 -▁car -461 -▁away -462 -▁em -463 -▁get -464 -▁imp -465 -▁head -466 -oss -467 -▁don -468 -▁bel -469 -▁life -470 -▁without -471 -▁pass -472 -▁most -473 -▁make -474 -ened -475 -▁cons -476 -▁som -477 -▁turn -478 -av -479 -ng -480 -▁shall -481 -▁those -482 -▁eyes -483 -▁pres -484 -▁acc -485 -▁house -486 -▁somet -487 -▁jo -488 -▁still -489 -▁call -490 -hes -491 -▁op -492 -▁night -493 -ause -494 -▁wom -495 -less -496 -▁last -497 -ks -498 -ared -499 -▁comm -500 -▁nothing -501 -▁ent -502 -▁tell -503 -▁new -504 -▁take -505 -ign -506 -▁being -507 -▁many -508 -▁word -509 -▁found -510 -ons -511 -▁ret -512 -ase -513 -▁while -514 -▁ear -515 -▁att -516 -ory -517 -▁saw -518 -ix -519 -▁put -520 -oth -521 -ne -522 -▁ser -523 -▁peop -524 -iend -525 -▁wr -526 -ark -527 -▁young -528 -dy -529 -aking -530 -les -531 -▁la -532 -▁once -533 -ens -534 -▁count -535 -pect -536 -▁friend -537 -▁people -538 -ible -539 -ors -540 -▁mat -541 -fect -542 -ince -543 -▁room -544 -ered -545 -▁three -546 -▁yet -547 -ail -548 -▁same -549 -▁father -550 -▁right -551 -▁child -552 -igh -553 -▁cour -554 -▁another -555 -▁place -556 -ult -557 -iv -558 -▁though -559 -ition -560 -▁ind -561 -▁want -562 -▁nor -563 -▁far -564 -▁king -565 -▁end -566 -▁happ -567 -▁heart -568 -▁face -569 -▁ever -570 -▁nat -571 -get -572 -thing -573 -▁took -574 -▁hu -575 -▁love -576 -▁dist -577 -ew -578 -ever -579 -▁arm -580 -ian -581 -▁inst -582 -man -583 -▁work -584 -▁light -585 -▁set -586 -▁ple -587 -ict -588 -▁looked -589 -▁char -590 -▁missus -591 -▁ac -592 -▁mind -593 -▁inte -594 -▁rep -595 -▁asked -596 -▁supp -597 -cess -598 -▁yes -599 -ently -600 -▁left -601 -ertain -602 -gg -603 -▁ke -604 -ished -605 -▁pers -606 -▁things -607 -ub -608 -ways -609 -▁mom -610 -irl -611 -alk -612 -▁sir -613 -▁moment -614 -▁wa -615 -ations -616 -▁sat -617 -sel -618 -▁find -619 -ia -620 -ower -621 -rew -622 -▁world -623 -ject -624 -vent -625 -▁give -626 -▁gen -627 -▁cap -628 -so -629 -▁gu -630 -▁sw -631 -▁why -632 -lt -633 -ling -634 -▁always -635 -▁mother -636 -dd -637 -pped -638 -▁soon -639 -▁ans -640 -▁act -641 -▁form -642 -▁el -643 -▁heard -644 -der -645 -ret -646 -▁thing -647 -▁seemed -648 -▁something -649 -ange -650 -▁door -651 -▁sub -652 -▁girl -653 -ced -654 -ither -655 -▁appe -656 -▁wind -657 -▁mon -658 -▁dif -659 -▁because -660 -ss -661 -▁told -662 -▁going -663 -orm -664 -▁home -665 -▁war -666 -ained -667 -▁got -668 -aught -669 -▁gi -670 -▁god -671 -▁eng -672 -▁sur -673 -land -674 -ning -675 -▁hands -676 -▁woman -677 -aut -678 -▁vo -679 -▁poss -680 -▁follow -681 -▁feel -682 -ched -683 -▁rel -684 -ph -685 -ple -686 -ical -687 -▁return -688 -ook -689 -▁boy -690 -▁knew -691 -▁reg -692 -▁each -693 -ner -694 -▁rest -695 -▁kind -696 -▁ma -697 -▁exp -698 -▁cle -699 -iver -700 -▁oh -701 -▁hel -702 -▁sil -703 -ual -704 -▁water -705 -ting -706 -▁del -707 -▁ass -708 -▁inf -709 -▁wo -710 -▁bre -711 -▁certain -712 -▁against -713 -▁conf -714 -cept -715 -▁belie -716 -▁hard -717 -row -718 -▁unt -719 -▁years -720 -▁quite -721 -iness -722 -▁near -723 -▁ph -724 -ined -725 -▁side -726 -▁hor -727 -▁four -728 -ired -729 -ters -730 -ool -731 -▁few -732 -ier -733 -rest -734 -▁done -735 -most -736 -▁half -737 -▁che -738 -▁better -739 -ited -740 -▁tre -741 -▁min -742 -ock -743 -ps -744 -▁also -745 -uck -746 -▁care -747 -oub -748 -▁began -749 -ully -750 -ised -751 -▁having -752 -ru -753 -▁enough -754 -▁gener -755 -▁dra -756 -▁seen -757 -▁lady -758 -▁pur -759 -aps -760 -ott -761 -▁hum -762 -ross -763 -aken -764 -ying -765 -▁ter -766 -ank -767 -▁inde -768 -▁called -769 -▁hour -770 -ial -771 -ason -772 -▁beh -773 -▁does -774 -▁whole -775 -▁morn -776 -▁ste -777 -▁pleas -778 -▁turned -779 -ib -780 -▁ref -781 -ense -782 -▁ins -783 -ream -784 -▁occ -785 -▁course -786 -gether -787 -▁both -788 -▁gave -789 -uth -790 -▁cur -791 -▁sou -792 -een -793 -▁read -794 -▁add -795 -ween -796 -▁col -797 -selves -798 -▁between -799 -▁among -800 -ular -801 -▁beaut -802 -▁keep -803 -▁inc -804 -▁poor -805 -▁sure -806 -▁morning -807 -▁white -808 -ged -809 -▁dear -810 -▁name -811 -▁toward -812 -▁whom -813 -▁small -814 -▁sk -815 -▁repl -816 -▁lar -817 -ute -818 -▁felt -819 -osed -820 -bo -821 -ating -822 -▁open -823 -▁six -824 -▁myself -825 -ond -826 -▁however -827 -xt -828 -▁bu -829 -▁herself -830 -▁inter -831 -▁high -832 -aint -833 -▁fore -834 -▁wi -835 -ction -836 -▁stood -837 -▁hund -838 -▁tra -839 -▁hundred -840 -▁ev -841 -▁sent -842 -aster -843 -▁sim -844 -▁show -845 -▁round -846 -▁point -847 -▁almost -848 -▁days -849 -▁words -850 -vel -851 -▁gra -852 -ale -853 -▁dr -854 -▁gre -855 -▁eight -856 -ents -857 -dden -858 -ates -859 -▁bus -860 -▁fam -861 -ces -862 -▁land -863 -▁stand -864 -ung -865 -▁ed -866 -▁sun -867 -haps -868 -ird -869 -▁mean -870 -▁perhaps -871 -ned -872 -ures -873 -iet -874 -▁since -875 -▁sudden -876 -▁sle -877 -▁best -878 -▁dark -879 -iss -880 -▁replied -881 -▁voice -882 -▁bar -883 -▁met -884 -▁till -885 -▁anything -886 -▁until -887 -▁underst -888 -its -889 -▁black -890 -oud -891 -aring -892 -▁bro -893 -▁looking -894 -ins -895 -▁cried -896 -amp -897 -▁prin -898 -▁fact -899 -▁next -900 -▁less -901 -▁law -902 -▁lay -903 -up -904 -▁power -905 -▁prop -906 -▁brought -907 -not -908 -enty -909 -ately -910 -rent -911 -▁country -912 -▁help -913 -med -914 -▁vis -915 -▁sn -916 -als -917 -▁air -918 -▁quest -919 -▁together -920 -fully -921 -▁spo -922 -▁adv -923 -▁person -924 -▁need -925 -▁use -926 -▁indeed -927 -▁contin -928 -oney -929 -ows -930 -▁present -931 -▁gent -932 -▁par -933 -▁unc -934 -ured -935 -▁run -936 -▁full -937 -▁aw -938 -▁rather -939 -▁ide -940 -nded -941 -▁feet -942 -tain -943 -▁cond -944 -▁sy -945 -▁lat -946 -be -947 -▁fall -948 -du -949 -▁five -950 -eter -951 -▁har -952 -▁fin -953 -cei -954 -▁bed -955 -▁mil -956 -▁doct -957 -▁interest -958 -oc -959 -▁matter -960 -▁gone -961 -ressed -962 -▁lord -963 -▁wife -964 -▁pat -965 -▁es -966 -fort -967 -ering -968 -▁ -969 -e -970 -t -971 -a -972 -o -973 -n -974 -i -975 -h -976 -s -977 -r -978 -d -979 -l -980 -u -981 -m -982 -c -983 -w -984 -f -985 -g -986 -y -987 -p -988 -b -989 -v -990 -k -991 -' -992 -x -993 -j -994 -q -995 diff --git a/vocabularies/librispeech/sentencepiece/train_uni_1000.vocab b/vocabularies/librispeech/sentencepiece/train_uni_1000.vocab deleted file mode 100644 index 84388bd0d3..0000000000 --- a/vocabularies/librispeech/sentencepiece/train_uni_1000.vocab +++ /dev/null @@ -1,1000 +0,0 @@ - 0 - 0 - 0 - 0 -s -3.23452 -▁the -3.26227 -▁and -3.87796 -▁a -3.9333 -▁of -3.99028 -▁to -4.02804 -ed -4.10377 -t -4.20956 -▁in -4.35719 -d -4.39652 -ing -4.49816 -▁i -4.5609 -n -4.56591 -▁he -4.68651 -e -4.75941 -y -4.77317 -▁that -4.84679 -▁was -4.87831 -▁it -4.91969 -ly -4.92571 -' -5.10953 -er -5.12267 -▁his -5.12582 -r -5.13731 -▁for -5.13769 -▁ -5.13911 -l -5.14069 -▁be -5.14874 -m -5.20857 -p -5.23353 -▁with -5.26542 -c -5.27253 -▁you -5.27265 -▁as -5.3009 -in -5.31751 -▁had -5.37408 -▁her -5.41689 -re -5.43176 -▁is -5.4677 -u -5.49233 -▁but -5.49646 -a -5.50101 -▁she -5.50697 -▁not -5.50821 -g -5.53661 -al -5.5606 -▁so -5.57503 -o -5.57825 -▁me -5.5797 -▁re -5.58798 -▁at -5.59131 -b -5.59372 -▁on -5.6336 -▁s -5.64878 -or -5.65444 -i -5.65526 -le -5.65714 -st -5.66957 -ll -5.73439 -k -5.74368 -▁him -5.80162 -▁all -5.84885 -▁we -5.87595 -▁have -5.87901 -ar -5.88157 -ri -5.89493 -▁de -5.89857 -▁by -5.89973 -ce -5.90393 -▁my -5.91128 -▁this -5.91582 -▁they -5.92788 -▁which -5.94054 -▁no -5.94858 -w -5.98018 -f -5.99006 -▁said -5.99861 -▁from -6.0003 -ve -6.0439 -▁one -6.05333 -▁an -6.0666 -▁were -6.08323 -it -6.09159 -se -6.09375 -▁b -6.1028 -ne -6.10377 -v -6.12752 -ter -6.13139 -▁or -6.13307 -li -6.14984 -ch -6.15309 -▁c -6.15608 -en -6.15777 -ra -6.1617 -th -6.16863 -ck -6.17516 -il -6.18247 -▁do -6.18707 -on -6.20253 -▁e -6.20357 -ent -6.21641 -▁when -6.24556 -▁there -6.25355 -es -6.25862 -te -6.25914 -▁f -6.25919 -▁con -6.2732 -h -6.28147 -ro -6.29465 -▁would -6.31178 -nd -6.31218 -▁what -6.33079 -ation -6.33766 -▁are -6.35024 -▁their -6.36263 -▁if -6.37062 -▁who -6.37271 -▁out -6.38386 -▁will -6.40459 -▁pa -6.40694 -ion -6.4095 -ry -6.42329 -an -6.42491 -ur -6.42879 -▁up -6.44363 -ir -6.45702 -▁g -6.45721 -▁them -6.463 -▁pro -6.46673 -ver -6.48036 -ic -6.48229 -▁man -6.49095 -▁ex -6.50568 -at -6.50753 -▁been -6.50898 -▁st -6.5342 -▁w -6.55098 -ge -6.55641 -tion -6.55945 -ment -6.5757 -▁ro -6.57662 -▁un -6.57802 -▁ma -6.60257 -▁could -6.61424 -▁mi -6.6418 -▁t -6.64951 -ted -6.65531 -▁bo -6.66553 -▁la -6.67214 -▁more -6.67288 -▁p -6.68925 -▁sp -6.70037 -▁some -6.70101 -▁into -6.70491 -ity -6.71011 -▁time -6.71455 -▁sa -6.72089 -ad -6.72559 -la -6.72781 -▁co -6.73328 -▁now -6.73891 -▁ha -6.74084 -us -6.74978 -ke -6.75609 -▁very -6.76175 -▁like -6.76686 -mp -6.76728 -▁go -6.77596 -▁your -6.77809 -▁know -6.78389 -ci -6.78739 -is -6.78788 -el -6.79088 -as -6.7949 -id -6.79797 -▁other -6.80144 -▁can -6.80771 -▁than -6.80994 -▁fa -6.8155 -▁little -6.81564 -▁did -6.81593 -ow -6.82561 -▁then -6.83685 -lo -6.8373 -▁po -6.84146 -de -6.8431 -ight -6.84521 -ru -6.85041 -▁see -6.85164 -ol -6.85513 -▁mo -6.85537 -ful -6.86181 -age -6.86188 -pp -6.86757 -x -6.86801 -▁d -6.8692 -▁th -6.87556 -ul -6.87593 -ate -6.87864 -vi -6.87896 -est -6.8818 -▁ra -6.888 -▁lo -6.88812 -▁day -6.88965 -▁any -6.89083 -▁has -6.89109 -▁ho -6.89251 -▁about -6.89477 -ng -6.89816 -▁over -6.90184 -▁ba -6.90189 -able -6.90505 -un -6.92151 -ive -6.92736 -me -6.92759 -pe -6.93011 -▁o -6.93848 -▁ca -6.94145 -▁great -6.9444 -▁dis -6.9537 -ta -6.95396 -ence -6.95478 -sh -6.96035 -▁upon -6.96338 -▁pre -6.96364 -▁se -6.97086 -ous -6.97976 -▁di -6.99005 -ti -6.99109 -ness -6.9921 -mo -6.9947 -▁should -6.99758 -▁well -7.00726 -ty -7.01203 -▁only -7.01397 -▁le -7.01688 -ish -7.02149 -▁good -7.02389 -▁two -7.02653 -▁say -7.02878 -▁after -7.03222 -▁us -7.03627 -om -7.03944 -ure -7.04141 -he -7.04657 -▁li -7.05363 -▁down -7.06304 -ant -7.06487 -▁before -7.08756 -am -7.08988 -▁k -7.09238 -▁mar -7.09745 -▁come -7.10241 -▁gra -7.10626 -▁made -7.10685 -▁our -7.11316 -ers -7.11391 -▁old -7.11695 -▁wi -7.11833 -▁even -7.11863 -▁ri -7.11868 -▁su -7.12612 -▁long -7.12868 -mi -7.15031 -et -7.15224 -▁where -7.15399 -ance -7.15919 -nt -7.16349 -▁never -7.16758 -▁must -7.17264 -ist -7.17275 -▁bu -7.17417 -▁how -7.17696 -▁such -7.17951 -per -7.18132 -▁came -7.18311 -co -7.18433 -▁sh -7.1857 -ting -7.19186 -▁ru -7.19261 -ut -7.20053 -ten -7.21002 -ious -7.22093 -▁en -7.22315 -ho -7.23013 -▁hu -7.2368 -▁think -7.24856 -▁much -7.25053 -▁ga -7.25811 -▁mister -7.25905 -▁may -7.26171 -▁cl -7.26379 -▁way -7.26412 -▁vi -7.27271 -▁ne -7.27706 -▁men -7.28094 -ff -7.29047 -▁thought -7.29265 -▁am -7.29721 -ie -7.30094 -▁ta -7.30166 -▁back -7.30415 -▁cha -7.31559 -▁va -7.33378 -▁first -7.33515 -▁just -7.33625 -▁da -7.34428 -▁every -7.35004 -▁own -7.35062 -▁again -7.35661 -um -7.35768 -▁make -7.35993 -▁these -7.36065 -▁sta -7.36129 -ies -7.36444 -red -7.36725 -▁himself -7.37261 -ward -7.37387 -▁hand -7.37396 -▁went -7.37679 -▁ar -7.38149 -▁sc -7.38215 -▁sto -7.3825 -ard -7.39236 -ine -7.40209 -▁off -7.40249 -man -7.4042 -▁might -7.40556 -row -7.40978 -▁per -7.41675 -less -7.41787 -▁give -7.41842 -▁part -7.43462 -j -7.43917 -▁place -7.4398 -ot -7.4422 -▁fi -7.44394 -▁house -7.44839 -ac -7.45306 -▁here -7.45371 -ian -7.45427 -▁fl -7.45635 -und -7.45979 -▁mu -7.46065 -▁pi -7.4627 -▁through -7.46442 -▁pe -7.47056 -▁love -7.47201 -ble -7.47216 -der -7.47697 -im -7.47983 -▁away -7.48442 -▁tra -7.48707 -▁get -7.49505 -▁life -7.49995 -▁without -7.51324 -▁com -7.5173 -cu -7.51775 -▁head -7.51811 -▁pass -7.51997 -▁take -7.52354 -▁car -7.52738 -▁most -7.52867 -ber -7.53307 -ud -7.53706 -▁don -7.54146 -▁col -7.54155 -ma -7.54347 -gg -7.54483 -▁its -7.55138 -▁shall -7.55578 -▁those -7.55587 -▁si -7.56504 -▁pri -7.56594 -nce -7.56989 -▁eyes -7.57063 -ha -7.57095 -▁too -7.57216 -▁du -7.57259 -ated -7.57567 -▁fo -7.58001 -▁still -7.58105 -ling -7.58232 -▁night -7.58451 -les -7.59716 -▁last -7.60295 -▁nothing -7.60714 -led -7.609 -▁face -7.61038 -▁tell -7.6115 -▁under -7.61262 -▁work -7.61287 -▁new -7.61462 -qui -7.61481 -▁hi -7.62062 -our -7.63745 -▁far -7.63847 -▁word -7.63997 -▁found -7.64249 -op -7.64356 -side -7.64646 -▁while -7.64831 -▁many -7.65325 -▁let -7.6583 -▁gu -7.66391 -▁te -7.66687 -▁mean -7.67024 -▁saw -7.67053 -ress -7.67111 -no -7.67565 -▁people -7.67688 -▁imp -7.68303 -▁put -7.68635 -▁look -7.6926 -▁young -7.69521 -▁friend -7.69572 -tic -7.69696 -ton -7.70067 -▁jo -7.70191 -bo -7.70227 -po -7.70304 -▁three -7.70679 -▁na -7.70914 -▁being -7.71256 -▁room -7.7134 -▁once -7.71392 -▁king -7.72044 -▁yet -7.72316 -▁same -7.72588 -▁sw -7.72917 -▁right -7.7353 -▁father -7.73566 -▁though -7.73573 -fe -7.74355 -▁another -7.74503 -▁heart -7.74621 -na -7.75055 -▁want -7.75668 -▁ever -7.76595 -ugh -7.76709 -▁read -7.77042 -▁took -7.78104 -lar -7.7845 -ig -7.78847 -▁sea -7.79747 -▁sha -7.79761 -▁light -7.80058 -if -7.80197 -▁arm -7.80439 -▁nor -7.80592 -▁tri -7.81257 -▁end -7.81404 -ical -7.81763 -▁missus -7.81895 -▁gre -7.82172 -▁open -7.82431 -▁em -7.82452 -▁ju -7.82895 -▁mind -7.82922 -▁asked -7.83008 -▁door -7.83258 -▁looked -7.83337 -▁mon -7.83989 -▁left -7.8416 -ary -7.84493 -▁wa -7.84512 -ther -7.84544 -son -7.846 -ring -7.84966 -ia -7.85396 -▁sir -7.85481 -▁yes -7.8559 -que -7.85754 -▁near -7.86771 -▁moment -7.86779 -▁ja -7.86865 -▁pu -7.87053 -▁war -7.87071 -▁find -7.87783 -▁world -7.87913 -▁comp -7.88141 -▁ob -7.88425 -▁ve -7.88561 -▁home -7.88583 -▁things -7.88836 -wa -7.89353 -▁ti -7.89355 -▁why -7.9009 -▁cu -7.90454 -▁always -7.90486 -▁mother -7.90533 -▁fe -7.91143 -▁answer -7.91311 -▁thing -7.9137 -▁water -7.91522 -ni -7.91889 -▁side -7.9223 -▁live -7.92724 -▁something -7.92865 -▁soon -7.92887 -▁girl -7.93238 -▁cor -7.9333 -▁heard -7.93448 -ions -7.93523 -▁seemed -7.93756 -▁name -7.94347 -▁because -7.94537 -ath -7.94777 -port -7.95106 -▁told -7.95419 -vo -7.95541 -▁flo -7.95616 -ak -7.95744 -▁ab -7.95793 -▁ear -7.95934 -ign -7.9596 -▁high -7.96099 -ft -7.97106 -▁miss -7.97566 -▁god -7.9781 -▁ni -7.97952 -to -7.98198 -▁woman -7.98268 -▁going -7.98508 -ite -7.98513 -▁act -7.98696 -▁che -7.98768 -▁follow -7.99045 -▁feel -7.9928 -va -7.99363 -▁got -8.00711 -▁return -8.01046 -go -8.01261 -▁set -8.01365 -▁knew -8.01495 -ach -8.01854 -▁each -8.02106 -▁form -8.02271 -ca -8.02402 -▁kind -8.02671 -rn -8.02846 -▁je -8.02903 -▁par -8.02962 -che -8.0304 -be -8.03209 -▁oh -8.0376 -▁lu -8.04198 -▁boy -8.04714 -▁certain -8.04971 -cy -8.05586 -▁hands -8.05649 -▁against -8.05844 -wn -8.06334 -ship -8.06389 -▁care -8.06434 -▁quite -8.06768 -▁conf -8.07082 -▁rest -8.07244 -▁sat -8.07343 -▁appear -8.07365 -▁hard -8.07489 -▁bri -8.0785 -▁four -8.0841 -▁ten -8.08445 -▁few -8.08584 -▁years -8.08605 -▁better -8.08721 -▁half -8.08998 -▁present -8.09019 -▁show -8.09166 -▁bar -8.09484 -▁bra -8.09509 -uc -8.10091 -▁wish -8.10207 -▁prince -8.10309 -▁also -8.10912 -▁began -8.11413 -▁done -8.1166 -ial -8.12015 -▁having -8.12362 -▁enough -8.12528 -▁person -8.12825 -men -8.13014 -▁lady -8.13338 -tro -8.13597 -▁dear -8.14089 -▁whole -8.15168 -▁white -8.15423 -▁course -8.15563 -▁both -8.15642 -▁voice -8.15856 -▁hour -8.15966 -▁called -8.16579 -▁speak -8.16714 -▁close -8.16881 -land -8.17065 -▁seen -8.17717 -▁does -8.1794 -▁gave -8.18493 -▁state -8.18729 -▁turned -8.18742 -▁power -8.19096 -ily -8.19361 -▁fra -8.19394 -▁morning -8.19515 -▁between -8.19572 -▁hope -8.19787 -▁poor -8.19829 -▁among -8.1991 -▁keep -8.20136 -ap -8.20635 -▁walk -8.20759 -▁matter -8.20955 -▁order -8.21293 -▁believe -8.21844 -▁sun -8.22012 -▁small -8.22062 -▁mor -8.22533 -▁talk -8.22841 -▁ste -8.22953 -gu -8.23063 -▁felt -8.23266 -▁rep -8.2338 -▁cur -8.23714 -▁horse -8.23735 -▁sure -8.23875 -▁myself -8.23972 -▁six -8.24179 -▁however -8.24221 -▁full -8.24247 -▁pla -8.24317 -▁char -8.24413 -▁herself -8.24635 -▁use -8.2465 -▁el -8.24862 -▁point -8.25104 -▁stood -8.25238 -▁hundred -8.25326 -▁whom -8.25352 -▁help -8.25899 -▁har -8.26154 -▁turn -8.26273 -▁almost -8.26723 -▁round -8.26746 -▁qua -8.27235 -▁since -8.27605 -▁large -8.28123 -▁leave -8.28604 -ttle -8.29039 -▁sent -8.29301 -▁stand -8.29542 -▁enter -8.29596 -▁perhaps -8.29681 -▁land -8.30542 -▁law -8.30709 -▁count -8.30719 -▁dark -8.30785 -▁sudden -8.30827 -▁replied -8.31461 -form -8.31738 -ible -8.31824 -▁anything -8.32155 -▁wait -8.32288 -▁till -8.32674 -▁wonder -8.32704 -▁until -8.32831 -▁black -8.32946 -▁ran -8.33239 -▁cried -8.33814 -▁fire -8.34465 -▁next -8.34566 -▁child -8.35079 -▁cre -8.35125 -▁looking -8.35594 -▁brought -8.35599 -▁fear -8.35759 -▁fact -8.35846 -▁seem -8.36131 -▁strange -8.3651 -▁country -8.36911 -▁less -8.37229 -▁together -8.37405 -▁reason -8.37687 -▁general -8.37907 -▁laugh -8.38112 -▁indeed -8.38328 -▁tree -8.39143 -▁lay -8.3926 -▁rather -8.39792 -▁feet -8.40569 -▁idea -8.40705 -▁question -8.41277 -▁call -8.41297 -ries -8.41432 -▁five -8.42471 -▁interest -8.42521 -▁ye -8.42686 -▁fall -8.42767 -▁consider -8.42809 -▁lord -8.43189 -▁wife -8.4338 -▁gone -8.43399 -▁gentle -8.43878 -▁death -8.44773 -▁nature -8.45072 -▁seven -8.4524 -▁sur -8.45246 -▁cap -8.45289 -▁sub -8.45973 -▁along -8.46692 -▁cannot -8.46719 -▁themselves -8.46958 -▁low -8.46959 -▁remain -8.47395 -▁tea -8.47425 -▁case -8.47714 -▁master -8.48118 -▁taken -8.48432 -▁direct -8.48764 -▁step -8.49384 -▁true -8.4959 -▁letter -8.49697 -▁thus -8.49777 -▁play -8.50395 -▁thousand -8.50624 -▁watch -8.50825 -▁brother -8.51072 -▁money -8.51095 -▁doubt -8.51329 -▁behind -8.51457 -▁children -8.51657 -▁doctor -8.518 -▁wood -8.5207 -▁twenty -8.52103 -▁book -8.52467 -▁thou -8.52704 -▁sound -8.53115 -▁hold -8.53349 -▁grow -8.53435 -▁clear -8.5382 -▁free -8.54074 -gue -8.54149 -▁whose -8.54238 -▁whi -8.54276 -▁pur -8.54404 -▁fair -8.54712 -▁alone -8.54987 -▁plan -8.55655 -▁strong -8.56103 -ities -8.56238 -▁gold -8.56263 -▁ground -8.56676 -▁window -8.56846 -▁short -8.56892 -▁dead -8.57228 -▁happen -8.57289 -▁change -8.57639 -▁spoke -8.57668 -▁remember -8.57732 -▁vo -8.57897 -▁chi -8.58201 -▁cause -8.58933 -▁draw -8.59407 -▁cra -8.59658 -▁fell -8.59715 -▁earth -8.59839 -ative -8.6006 -▁sign -8.60112 -▁therefore -8.60124 -▁deep -8.60264 -▁table -8.60295 -▁body -8.60443 -▁second -8.60971 -▁suppose -8.61148 -▁manner -8.61844 -▁dur -8.63164 -▁dress -8.6342 -▁bring -8.63439 -▁line -8.64292 -▁express -8.64429 -▁minute -8.65089 -▁become -8.65128 -▁smile -8.65765 -▁understand -8.66089 -▁wall -8.66506 -▁everything -8.66696 -▁try -8.66884 -▁eye -8.67092 -▁above -8.67265 -▁beautiful -8.67591 -▁least -8.67797 -▁spirit -8.67896 -▁gen -8.68005 -▁already -8.68007 -▁itself -8.68199 -▁around -8.68937 -▁quick -8.69503 -▁street -8.70199 -▁fine -8.70861 -▁whether -8.70897 -▁learn -8.71507 -▁expect -8.71642 -▁captain -8.71643 -▁office -8.71881 -▁women -8.72079 -▁john -8.72279 -▁please -8.72752 -▁held -8.73406 -▁perfect -8.7355 -▁big -8.73895 -▁else -8.74073 -▁foot -8.74141 -▁year -8.74748 -▁either -8.74885 -▁human -8.75409 -▁kept -8.7566 -pose -8.75777 -▁sleep -8.76277 -▁dream -8.76941 -▁fellow -8.77107 -▁became -8.77273 -▁making -8.77427 -▁ski -8.7744 -▁object -8.7773 -▁broke -8.78407 -▁towards -8.78481 -▁trouble -8.79845 -▁daughter -8.80059 -▁sense -8.80267 -▁sometimes -8.80574 -▁visit -8.80964 -ddle -8.82151 -▁nu -8.82186 -▁listen -8.82379 -▁truth -8.83371 -ctor -8.83397 -▁business -8.83857 -▁subject -8.84124 -▁court -8.8433 -▁family -8.85027 -▁desire -8.8534 -▁glad -8.85454 -▁several -8.85521 -▁kill -8.85921 -▁christ -8.86061 -▁different -8.86243 -▁possible -8.86244 -▁sister -8.87376 -▁natural -8.87805 -▁blood -8.88037 -▁quiet -8.88224 -▁character -8.89063 -▁common -8.89352 -▁across -8.89485 -▁except -8.89535 -▁difficult -8.90146 -▁number -8.90526 -▁church -8.90669 -▁touch -8.90729 -▁discover -8.90813 -▁piece -8.91159 -▁remark -8.91194 -▁effect -8.91482 -▁cross -8.91486 -▁wild -8.91619 -itude -8.92337 -▁husband -8.92544 -▁respect -8.92544 -▁regard -8.93478 -▁drop -8.93693 -▁happy -8.93822 -▁suffer -8.94014 -▁bird -8.94684 -▁possess -8.95399 -▁reached -8.95967 -▁pretty -8.96012 -▁sweet -8.9681 -▁week -8.97326 -▁occasion -8.97366 -▁garden -8.97419 -▁mountain -8.97723 -▁danger -8.98189 -ably -8.98366 -▁bright -8.98761 -▁secret -8.98854 -▁promise -8.98935 -▁school -8.99737 -▁carried -8.9981 -▁surprise -8.99841 -▁command -9.01057 -▁usual -9.01484 -▁south -9.01646 -▁particular -9.01685 -▁account -9.01738 -▁ought -9.02224 -▁breath -9.02865 -▁worth -9.03844 -▁heaven -9.04054 -▁purpose -9.0408 -▁phil -9.04331 -▁instant -9.05298 -▁english -9.05478 -▁view -9.05879 -▁success -9.06533 -▁neither -9.06646 -▁uncle -9.06669 -▁self -9.0717 -▁chief -9.07265 -▁immediate -9.07318 -▁pleasure -9.08109 -▁front -9.08226 -▁further -9.08736 -▁agree -9.09805 -▁queen -9.09829 -▁comfort -9.09882 -▁added -9.10162 -▁eighteen -9.10694 -▁public -9.10694 -▁notice -9.1102 -▁pray -9.11356 -▁affect -9.11512 -▁although -9.1157 -▁animal -9.11987 -▁mouth -9.12287 -▁living -9.13129 -▁yourself -9.13168 -▁write -9.13834 -▁field -9.13894 -▁beyond -9.13948 -▁condition -9.14187 -hood -9.14669 -▁america -9.14974 -▁travel -9.15139 -ified -9.15507 -▁yo -9.16968 -▁observe -9.17373 -▁according -9.17433 -▁silence -9.17881 -▁soldier -9.18308 -▁servant -9.18371 -▁bank -9.1848 -▁figure -9.18627 -▁england -9.1894 -▁afraid -9.19382 -qua -9.19818 -▁equal -9.20469 -▁service -9.20723 -▁drive -9.21083 -▁glass -9.21114 -▁exclaimed -9.21563 -▁faith -9.21836 -▁drew -9.21836 -▁complete -9.21851 -▁escape -9.2241 -▁position -9.2241 -▁necessary -9.22607 -▁length -9.22609 -▁picture -9.22738 -▁heavy -9.23397 -▁attention -9.23468 -▁caught -9.23933 -▁utter -9.2427 -▁experience -9.24326 -▁stopped -9.24939 -▁strength -9.25466 -▁repeat -9.25548 -▁opinion -9.25669 -▁knight -9.25678 -▁instead -9.25805 -▁creature -9.25939 -▁author -9.26009 -▁squ -9.26314 -▁third -9.26982 -▁straight -9.27102 -▁determin -9.27236 -▁dinner -9.27443 -▁sharp -9.27484 -▁beauty -9.27579 -▁peace -9.27658 -▁approach -9.27993 -▁silent -9.28156 -▁sitting -9.28564 -▁thirty -9.2862 -▁break -9.28691 -▁explain -9.28757 -▁glance -9.28905 -▁companion -9.2939 -clock -9.29539 -▁french -9.30232 -▁struck -9.30377 -▁distance -9.30674 -▁knowledge -9.30729 -▁wrong -9.30731 -▁fifty -9.31018 -▁attempt -9.31085 -q -9.66285 diff --git a/vocabularies/librispeech/wordpiece/train_1000_4.metadata.json b/vocabularies/librispeech/wordpiece/train_1000_4.metadata.json deleted file mode 100644 index 9083cc6070..0000000000 --- a/vocabularies/librispeech/wordpiece/train_1000_4.metadata.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "train": { - "max_input_length": 2972, - "max_label_length": 250, - "num_entries": 281241 - } -} \ No newline at end of file diff --git a/vocabularies/librispeech/wordpiece/train_1000_4.tokens b/vocabularies/librispeech/wordpiece/train_1000_4.tokens deleted file mode 100644 index e397cfa13e..0000000000 --- a/vocabularies/librispeech/wordpiece/train_1000_4.tokens +++ /dev/null @@ -1,985 +0,0 @@ - -' -a -b -c -d -e -f -g -h -i -j -k -l -m -n -o -p -q -r -s -t -u -v -w -x -y -z -the -and -of -to -in -he -that -was -it -his -you -with -as -for -had -is -her -but -not -she -at -on -be -him -they -have -by -this -all -my -said -from -so -one -were -me -we -no -or -when -are -an -them -if -what -who -will -been -out -then -up -do -man -more -into -now -very -your -some -time -like -than -did -upon -can -only -has -any -well -two -see -its -good -down -over -know -made -our -old -such -came -must -how -day -come -much -go -us -##s -may -way -here -went -long -back -say -men -own -am -too -away -just -even -life -make -most -eyes -don -take -many -hand -last -once -off -saw -let -yet -##e -same -head -get -face -tell -took -room -ever -put -left -ll -sir -look -give -king -why -new -yes -told -mind -love -got -door -home -far -knew -soon -each -work -oh -find -done -few -side -also -half -seen -god -part -lady -gave -both -poor -whom -set -felt -miss -girl -dear -days -name -nor -##t -end -want -does -four -next -till -##a -boy -feet -ve -five -word -best -##n -gone -wife -full -sure -##d -sat -kind -##y -air -thus -lay -near -high -open -true -thou -keep -less -##o -lord -fire -hope -body -help -rest -sea -hear -wish -fell -fact -dead -son -hour -used -re -sent -##k -use -case -hard -bed -held -care -says -lost -kept -red -six -dark -land -mean -read -ask -call -feel -city -##l -else -live -didn -won -sun -arms -ten -turn -year -fear -soul -rose -idea -able -big -sort -town -form -ran -fine -thy -hair -thee -cold -law -met -john -road -talk -war -need -run -eye -show -tree -hold -past -low -deep -##in -bad -##an -book -cut -gold -blue -free -none -wind -##g -glad -seem -arm -laid -age -foot -meet -late -ah -fall -line -real -ship -bear -try -led -drew -fair -wild -mine -mary -grew -tom -ye -self -ago -ill -aunt -play -de -wood -##h -##i -walk -pass -stay -art -wait -save -boys -deal -wall -eat -died -view -dog -tone -lips -die -hall -act -rich -boat -##p -food -top -easy -pay -##m -##it -##b -anne -send -army -##ap -fast -nine -joy -bit -week -soft -stop -step -sit -cry -sky -wide -duty -hot -snow -box -pale -bird -isn -evil -##ar -note -##id -born -shot -neck -news -hat -spot -##r -hill -moon -lie -tea -warm -legs -mere -iron -pain -shut -wise -cat -safe -ears -tall -bank -fish -seat -hung -camp -game -paid -sad -gate -west -sign -jane -ways -cast -sake -bell -rock -laws -nice -path -##f -lad -##x -goes -ones -lose -move -wine -##at -##ot -dry -wasn -lead -coat -dick -kill -edge -race -##ad -ring -rain -##u -ear -dare -lot -lake -##am -sick -##ug -##al -duke -blow -##re -beat -thin -rode -papa -fate -hath -jack -fit -rise -##il -gray -due -add -east -##w -hurt -em -grow -vain -bent -##on -pity -seek -ice -pray -hole -##le -skin -heat -##ip -loss -plan -baby -loud -york -calm -nose -flew -rate -wore -dogs -busy -holy -mark -vast -ball -gods -pure -##oe -bill -fly -grey -song -##um -draw -weak -fool -post -fond -##ur -maid -lies -tail -pair -main -dust -drop -shop -mad -##od -rule -huge -unto -buy -milk -slow -bow -cap -mass -roof -##ed -ride -sand -yard -##as -gun -sing -suit -task -wear -alas -bore -##im -ain -dull -flat -mile -##ice -size -##ig -fat -salt -##aw -bare -cup -worn -luck -meat -##ail -harm -nay -type -cook -##ue -car -key -hale -hide -##c -bag -##op -lamp -cost -hadn -aid -firm -lest -deck -wet -beg -fox -farm -join -tis -sail -meal -sold -hate -rope -tied -leg -rome -ease -tiny -ate -cool -eggs -tale -sons -##ow -fail -fill -##ush -joe -hers -##ub -club -gaze -sell -##ir -star -sank -rage -jim -gain -silk -kiss -rank -ugly -dawn -##z -bade -fort -sum -bay -host -sang -##ll -fled -cave -##ut -lucy -##ash -##uck -##ab -lion -paul -mill -rare -##ay -bob -ends -pipe -sigh -##ob -corn -bye -hans -adam -##un -##ul -##en -sin -##ax -oak -odd -win -##oss -pull -band -crew -hut -bold -er -gift -poet -la -guns -sees -hid -hit -gets -hast -rush -pick -##ind -brow -risk -wolf -folk -mc -pink -soil -torn -june -gay -earl -katy -ages -hero -ha -park -##et -##ve -##er -fun -van -wept -##st -stir -mud -ruin -job -ours -page -rays -tad -##ond -pen -kate -row -dim -##ock -cow -anna -bar -wash -##ah -oil -##es -ere -keen -rid -desk -base -obey -port -et -trip -##ag -hunt -mode -acts -ma -nest -pace -rang -wave -##om -per -shed -##avy -##ye -eyed -##is -amid -##or -ned -sex -gown -##ick -##vy -hang -noon -pony -lack -toby -jean -pack -inn -mood -hell -term -##ine -##ud -knee -roll -##ce -##oil -heap -deny -trap -wake -##af -##el -##eat -cart -lamb -pine -pot -bush -##ury -fury -slip -##ile -##ok -rude -##us -lift -##ly -date -rear -aim -bath -cats -##ale -beds -flag -##ee -##em -##me -dish -tear -chin -leaf -##bs -bid -deed -fed -tent -##age -##ell -ruth -##ak -##ch -fix -lily -##ra -##te -inch -sam -veil -wit -tide -##ef -feed -##ke -##ove -burn -dale -idle -list -pole -sore -##ss -##ake -lit -##usk -aged -prey -##aws -lock -fro -##rim -##up -cake -##ew -##ure -ben -##ars -blew -lane -owe -##ike -pan -pile -##one -egg -warn -pool -##ff -##ump -lean -##rab -ann -bark -yo -##osy -mist -##se -seas -##alm -tie -coal -bull -doth -wilt -joke -pour -tore -##ame -##rag -clay -fog -sunk -card -whip -##lot -bone -sale -sole -weep -##lt -hint -pig -##ack -##ang -##ie -##ust -flow -oath -grim -jew -tin -##to -sofa -apt -arts -july -test -##awn -##olt -emma -hay -jest -##ipe -fame -##ft -##ire -cent -ray -##ry -goat -log -jump -wing -##oke -bars -damp -deer -eve -roar -##ic -##v -rent -root -bred -pope -tip -##ent -jews -##hy -bees -horn -wire -##umb -belt -glow -kit -##' -##j -##q diff --git a/vocabularies/librispeech/wordpiece/train_1000_6.metadata.json b/vocabularies/librispeech/wordpiece/train_1000_6.metadata.json deleted file mode 100644 index 67508a1f06..0000000000 --- a/vocabularies/librispeech/wordpiece/train_1000_6.metadata.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "train": { - "max_input_length": 2972, - "max_label_length": 217, - "num_entries": 281241 - } -} \ No newline at end of file diff --git a/vocabularies/librispeech/wordpiece/train_1000_6.tokens b/vocabularies/librispeech/wordpiece/train_1000_6.tokens deleted file mode 100644 index 6119cc2837..0000000000 --- a/vocabularies/librispeech/wordpiece/train_1000_6.tokens +++ /dev/null @@ -1,973 +0,0 @@ - -' -a -b -c -d -e -f -g -h -i -j -k -l -m -n -o -p -q -r -s -t -u -v -w -x -y -z -the -and -of -to -in -he -##s -that -was -it -his -you -with -as -for -had -is -her -but -not -she -at -on -be -him -they -have -by -this -all -which -my -said -from -so -one -were -me -##y -we -there -##ed -no -or -when -##d -are -their -an -them -would -if -what -who -will -##e -been -##er -out -then -up -##n -##t -do -could -man -##ly -more -into -now -very -your -some -little -time -##a -about -like -than -did -upon -can -only -has -any -well -two -##ing -other -see -before -##r -its -good -down -over -know -made -our -after -should -great -old -such -##k -came -must -how -day -never -these -come -much -mister -go -us -where -##l -may -##in -first -way -again -here -went -long -back -##on -say -##al -men -##o -own -##es -am -think -too -away -might -just -even -##le -life -make -most -every -##en -those -eyes -don -shall -##an -##h -##or -still -take -being -##ch -many -hand -while -last -##ot -house -##ar -##ge -once -off -saw -night -let -people -##at -three -yet -found -##st -same -head -get -##id -##it -##m -##g -place -though -face -tell -took -father -young -##ry -##ain -##ice -room -ever -looked -missus -put -asked -under -left -right -things -ll -sir -look -give -king -why -always -heard -world -thing -seemed -new -##ue -##ie -yes -mother -##ate -told -going -mind -##ies -love -##ay -got -##ick -door -woman -home -far -knew -soon -each -moment -##ce -work -heart -##are -oh -years -quite -##ow -##ast -find -##ty -done -better -few -side -also -began -water -half -having -enough -##et -seen -called -##ure -##ill -##ion -whole -god -hands -part -turned -lady -course -##nt -gave -both -##ise -##th -##b -light -##ard -##se -poor -whom -set -##us -##ake -##is -felt -myself -##ame -miss -##ide -girl -stood -white -dear -almost -days -name -words -##ad -nor -##om -##ale -##ter -among -##row -voice -end -want -round -until -cried -does -four -next -till -##ic -##ose -##age -small -##est -boy -indeed -##el -##re -since -rather -##ail -feet -ve -five -word -friend -##ath -best -##rown -gone -##ive -##p -matter -wife -##ist -full -sure -taken -others -##ble -cannot -sat -kind -death -##our -air -thus -lay -along -child -near -high -##il -##ss -open -behind -true -##ack -whose -##ks -money -twenty -thou -large -keep -passed -##ock -##ey -##am -black -nature -alone -##ve -doctor -less -##ish -##im -power -leave -lord -##ite -fire -given -##ap -hope -##orn -##ank -body -help -rest -sea -hear -##ply -##gs -speak -##f -##ia -##ut -often -##i -##ars -wish -really -##ine -fell -##ak -##ight -fact -dead -son -coming -##cy -hour -above -##ts -master -itself -used -least -re -sent -around -##ass -##ore -order -ground -use -case -known -prince -##per -##ll -##op -##od -##as -within -##pt -times -##use -##ost -women -hard -##ip -##ace -##ze -bed -reason -held -##oud -state -point -during -##art -care -says -##ught -earth -either -lost -##read -kept -means -red -six -dark -horse -land -##x -became -##ance -second -table -mean -making -read -##ied -manner -ask -short -call -feel -city -##ray -ready -##ung -else -##ep -live -close -##ire -didn -sight -won -person -sun -answer -family -arms -ten -turn -year -##ny -fear -become -soul -letter -##ark -##ief -##and -##ted -rose -idea -able -big -##amp -sort -town -across -##ort -##all -story -##ook -strong -##ape -form -ran -window -##ged -truth -longer -fine -##ugh -bring -thy -##rain -doubt -taking -hair -happy -##ver -##ff -thee -spoke -human -cold -##ord -law -return -##um -pretty -opened -##ory -##vil -##rew -##dge -sound -saying -##un -##ute -met -##ur -##my -river -john -##by -road -talk -war -##low -need -run -eye -wanted -##ult -show -tree -hold -past -##der -street -low -fellow -##ped -sense -##ach -deep -##mer -bad -##ush -ought -book -cut -gold -toward -##arry -##ear -lived -##ying -seven -##ox -hours -blood -##ave -clear -blue -free -tried -##ound -##ell -##ared -green -##ang -##aid -##ent -none -added -church -##ane -wind -##zed -##fe -walked -##ror -cause -##ual -##ote -##one -uncle -##ans -sleep -glad -##aw -##ove -doing -later -seems -beyond -##air -seem -##dy -party -##ring -##rove -arm -laid -queen -living -nearly -age -##unt -number -##lly -##ool -##ign -##ls -trees -foot -##ile -meet -except -sister -early -##ial -##ots -##awn -##ms -##rry -front -##rave -afraid -late -ah -fall -smile -line -real -##ger -seeing -##ade -ship -bear -##ant -miles -##uck -##ield -##oke -try -##attle -led -##ful -##bs -stand -spirit -##ope -drew -##ull -##rt -##ilt -##oth -common -##ect -##ward -fair -wild -mine -##sed -heavy -eight -##den -caught -##oat -##ith -##mit -hardly -##ned -##atch -mary -change -grew -tom -ye -##race -self -##ah -public -##ne -##ode -##its -ago -##ct -ill -aunt -##up -garden -##ase -##ton -##ity -beauty -##oe -play -##oon -forth -##ole -##gin -de -##ag -##uth -##te -thirty -chance -loved -##eck -##less -comes -wood -please -##ob -school -##c -girls -struck -##red -court -walk -fifty -length -##nd -mouth -force -pass -##ust -dinner -stay -##ift -art -wonder -wait -##ple -##w -save -##ins -clock -slowly -tears -boys -deal -##led -wall -eat -died -view -dog -##ron -##land -paper -##alth -beside -third -tone -##rs -bright -chair -sweet -##llow -stone -##ike -filled -lips -die -placed -desire -##ea -##ead -##ber -effect -##uit -hall -london -##arm -makes -##tle -##used -act -##rust -heaven -rich -boat -##lain -floor -food -chief -count -##old -top -easy -knows -broken -##u -pay -##ills -##roke -horses -##light -##dle -object -palace -glass -##ily -piece -##em -anne -##ash -##les -##ind -##icked -send -follow -army -wrong -secret -##oof -corner -##ride -below -fast -##aves -saint -##ames -nine -##dd -dress -quiet -french -##ised -joy -##lice -##ross -##ixed -##rn -##man -##resh -couldn -bit -dream -showed -single -week -##gree -morrow -soft -##ird -usual -wished -sudden -giving -##ailed -stop -danger -months -step -moved -peace -sit -cry -##iddle -watch -##able -##atter -##hore -fight -sky -silent -##eter -##road -wide -##hin -##ian -##ream -youth -##ond -##uch -##rely -##oving -easily -##now -##hy -duty -##imple -##ody -##rink -##lew -##ocks -carry -##' -##j -##q -##v -##z diff --git a/vocabularies/librispeech/wordpiece/train_100h_1000_50.metadata.json b/vocabularies/librispeech/wordpiece/train_100h_1000_50.metadata.json deleted file mode 100644 index bfcc239772..0000000000 --- a/vocabularies/librispeech/wordpiece/train_100h_1000_50.metadata.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "train": { - "max_input_length": 2451, - "max_label_length": 220, - "num_entries": 28539 - } -} \ No newline at end of file