Skip to content

Commit 41e8291

Browse files
authored
Add ALBERT to the Tensorflow to Pytorch model conversion cli (#3933)
* Add ALBERT to convert command of transformers-cli * Document ALBERT tf to pytorch model conversion
1 parent 3f42eb9 commit 41e8291

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

docs/source/converting_tensorflow_models.rst

+21-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ A command-line interface is provided to convert original Bert/GPT/GPT-2/Transfor
1212
BERT
1313
^^^^
1414

15-
You can convert any TensorFlow checkpoint for BERT (in particular `the pre-trained models released by Google <https://github.com/google-research/bert#pre-trained-models>`_\ ) in a PyTorch save file by using the `convert_tf_checkpoint_to_pytorch.py <https://github.com/huggingface/transformers/blob/master/transformers/convert_tf_checkpoint_to_pytorch.py>`_ script.
15+
You can convert any TensorFlow checkpoint for BERT (in particular `the pre-trained models released by Google <https://github.com/google-research/bert#pre-trained-models>`_\ ) in a PyTorch save file by using the `convert_bert_original_tf_checkpoint_to_pytorch.py <https://github.com/huggingface/transformers/blob/master/src/transformers/convert_bert_original_tf_checkpoint_to_pytorch.py>`_ script.
1616

1717
This CLI takes as input a TensorFlow checkpoint (three files starting with ``bert_model.ckpt``\ ) and the associated configuration file (\ ``bert_config.json``\ ), and creates a PyTorch model for this configuration, loads the weights from the TensorFlow checkpoint in the PyTorch model and saves the resulting model in a standard PyTorch save file that can be imported using ``torch.load()`` (see examples in `run_bert_extract_features.py <https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/examples/run_bert_extract_features.py>`_\ , `run_bert_classifier.py <https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/examples/run_bert_classifier.py>`_ and `run_bert_squad.py <https://github.com/huggingface/pytorch-pretrained-BERT/tree/master/examples/run_bert_squad.py>`_\ ).
1818

@@ -33,6 +33,26 @@ Here is an example of the conversion process for a pre-trained ``BERT-Base Uncas
3333
3434
You can download Google's pre-trained models for the conversion `here <https://github.com/google-research/bert#pre-trained-models>`__.
3535

36+
ALBERT
37+
^^^^^^
38+
39+
Convert TensorFlow model checkpoints of ALBERT to PyTorch using the `convert_albert_original_tf_checkpoint_to_pytorch.py <https://github.com/huggingface/transformers/blob/master/src/transformers/convert_bert_original_tf_checkpoint_to_pytorch.py>`_ script.
40+
41+
The CLI takes as input a TensorFlow checkpoint (three files starting with ``model.ckpt-best``\ ) and the accompanying configuration file (\ ``albert_config.json``\ ), then creates and saves a PyTorch model. To run this conversion you will need to have TensorFlow and PyTorch installed.
42+
43+
Here is an example of the conversion process for the pre-trained ``ALBERT Base`` model:
44+
45+
.. code-block:: shell
46+
47+
export ALBERT_BASE_DIR=/path/to/albert/albert_base
48+
49+
transformers-cli convert --model_type albert \
50+
--tf_checkpoint $ALBERT_BASE_DIR/model.ckpt-best \
51+
--config $ALBERT_BASE_DIR/albert_config.json \
52+
--pytorch_dump_output $ALBERT_BASE_DIR/pytorch_model.bin
53+
54+
You can download Google's pre-trained models for the conversion `here <https://github.com/google-research/albert#pre-trained-models>`__.
55+
3656
OpenAI GPT
3757
^^^^^^^^^^
3858

src/transformers/commands/convert.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,21 @@ def __init__(
6262
self._finetuning_task_name = finetuning_task_name
6363

6464
def run(self):
65-
if self._model_type == "bert":
65+
if self._model_type == "albert":
66+
try:
67+
from transformers.convert_albert_original_tf_checkpoint_to_pytorch import (
68+
convert_tf_checkpoint_to_pytorch,
69+
)
70+
except ImportError:
71+
msg = (
72+
"transformers can only be used from the commandline to convert TensorFlow models in PyTorch, "
73+
"In that case, it requires TensorFlow to be installed. Please see "
74+
"https://www.tensorflow.org/install/ for installation instructions."
75+
)
76+
raise ImportError(msg)
77+
78+
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
79+
elif self._model_type == "bert":
6680
try:
6781
from transformers.convert_bert_original_tf_checkpoint_to_pytorch import (
6882
convert_tf_checkpoint_to_pytorch,

0 commit comments

Comments
 (0)