Skip to content

Commit 9f89fa0

Browse files
ydshiehpatil-suraj
andauthored
Add Flax image captioning example (#14864)
* add image captioning example * update README * fix style & quality * simplify * apply review suggestions * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * Apply review suggestions * add comments about using np instead jax array * remove unused lines * add model creation script * only support from_pretrained * fix style * fix * not use cache_dir when creating model * fix tokenizer creation * update README * fix quality * apply suggestion * simplify some blocks * Update examples/flax/image-captioning/README.md * Update examples/flax/image-captioning/run_image_captioning_flax.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * apply suggestion Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
1 parent 2e9af29 commit 9f89fa0

File tree

3 files changed

+1388
-0
lines changed

3 files changed

+1388
-0
lines changed
+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Image Captioning (vision-encoder-text-decoder model) training example
2+
3+
The following example showcases how to finetune a vision-encoder-text-decoder model for image captioning
4+
using the JAX/Flax backend, leveraging 🤗 Transformers library's [FlaxVisionEncoderDecoderModel](https://huggingface.co/docs/transformers/model_doc/visionencoderdecoder#transformers.FlaxVisionEncoderDecoderModel).
5+
6+
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
7+
Models written in JAX/Flax are **immutable** and updated in a purely functional
8+
way which enables simple and efficient model parallelism.
9+
10+
`run_image_captioning_flax.py` is a lightweight example of how to download and preprocess a dataset from the 🤗 Datasets
11+
library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it.
12+
13+
For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files and you also will find examples of these below.
14+
15+
### Download COCO dataset (2017)
16+
This example uses COCO dataset (2017) through a custom dataset script, which requires users to manually download the
17+
COCO dataset before training.
18+
19+
```bash
20+
mkdir data
21+
cd data
22+
wget http://images.cocodataset.org/zips/train2017.zip
23+
wget http://images.cocodataset.org/zips/val2017.zip
24+
wget http://images.cocodataset.org/zips/test2017.zip
25+
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
26+
wget http://images.cocodataset.org/annotations/image_info_test2017.zip
27+
cd ..
28+
```
29+
30+
### Create a model from a vision encoder model and a text decoder model
31+
Next, we create a [FlaxVisionEncoderDecoderModel](https://huggingface.co/docs/transformers/model_doc/visionencoderdecoder#transformers.FlaxVisionEncoderDecoderModel) instance from a pre-trained vision encoder ([ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.FlaxViTModel)) and a pre-trained text decoder ([GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.FlaxGPT2Model)):
32+
33+
```bash
34+
python3 create_model_from_encoder_decoder_models.py \
35+
--output_dir model \
36+
--encoder_model_name_or_path google/vit-base-patch16-224-in21k \
37+
--decoder_model_name_or_path gpt2
38+
```
39+
40+
### Train the model
41+
Finally, we can run the example script to train the model:
42+
43+
```bash
44+
python3 run_image_captioning_flax.py \
45+
--output_dir ./image-captioning-training-results \
46+
--model_name_or_path model \
47+
--dataset_name ydshieh/coco_dataset_script \
48+
--dataset_config_name=2017 \
49+
--data_dir $PWD/data \
50+
--image_column image_path \
51+
--caption_column caption \
52+
--do_train --do_eval --predict_with_generate \
53+
--num_train_epochs 1 \
54+
--eval_steps 500 \
55+
--learning_rate 3e-5 --warmup_steps 0 \
56+
--per_device_train_batch_size 32 \
57+
--per_device_eval_batch_size 32 \
58+
--overwrite_output_dir \
59+
--max_target_length 32 \
60+
--num_beams 8 \
61+
--preprocessing_num_workers 16 \
62+
--logging_steps 10 \
63+
--block_size 16384 \
64+
--push_to_hub
65+
```
66+
67+
This should finish in about 1h30 on Cloud TPU, with validation loss and ROUGE2 score of 2.0153 and 14.64 respectively
68+
after 1 epoch. Training statistics can be accessed on [Models](https://huggingface.co/ydshieh/image-captioning-training-results/tensorboard).
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# Copyright 2022 The HuggingFace Team All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
"""
17+
Create a VisionEncoderDecoderModel instance from pretrained encoder/decoder models.
18+
19+
The cross-attention will be randomly initialized.
20+
"""
21+
22+
from dataclasses import dataclass, field
23+
from typing import Optional
24+
25+
from transformers import (
26+
AutoConfig,
27+
AutoFeatureExtractor,
28+
AutoTokenizer,
29+
FlaxVisionEncoderDecoderModel,
30+
HfArgumentParser,
31+
)
32+
33+
34+
@dataclass
35+
class ModelArguments:
36+
"""
37+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
38+
"""
39+
40+
output_dir: str = field(
41+
metadata={"help": "The output directory where the model will be written."},
42+
)
43+
encoder_model_name_or_path: str = field(
44+
metadata={
45+
"help": "The encoder model checkpoint for weights initialization."
46+
"Don't set if you want to train an encoder model from scratch."
47+
},
48+
)
49+
decoder_model_name_or_path: str = field(
50+
metadata={
51+
"help": "The decoder model checkpoint for weights initialization."
52+
"Don't set if you want to train a decoder model from scratch."
53+
},
54+
)
55+
encoder_config_name: Optional[str] = field(
56+
default=None, metadata={"help": "Pretrained encoder config name or path if not the same as encoder_model_name"}
57+
)
58+
decoder_config_name: Optional[str] = field(
59+
default=None, metadata={"help": "Pretrained decoder config name or path if not the same as decoder_model_name"}
60+
)
61+
62+
63+
def main():
64+
parser = HfArgumentParser((ModelArguments,))
65+
(model_args,) = parser.parse_args_into_dataclasses()
66+
67+
# Load pretrained model and tokenizer
68+
69+
# Use explicit specified encoder config
70+
if model_args.encoder_config_name:
71+
encoder_config = AutoConfig.from_pretrained(model_args.encoder_config_name)
72+
# Use pretrained encoder model's config
73+
else:
74+
encoder_config = AutoConfig.from_pretrained(model_args.encoder_model_name_or_path)
75+
76+
# Use explicit specified decoder config
77+
if model_args.decoder_config_name:
78+
decoder_config = AutoConfig.from_pretrained(model_args.decoder_config_name)
79+
# Use pretrained decoder model's config
80+
else:
81+
decoder_config = AutoConfig.from_pretrained(model_args.decoder_model_name_or_path)
82+
83+
# necessary for `from_encoder_decoder_pretrained` when `decoder_config` is passed
84+
decoder_config.is_decoder = True
85+
decoder_config.add_cross_attention = True
86+
87+
model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
88+
encoder_pretrained_model_name_or_path=model_args.encoder_model_name_or_path,
89+
decoder_pretrained_model_name_or_path=model_args.decoder_model_name_or_path,
90+
encoder_config=encoder_config,
91+
decoder_config=decoder_config,
92+
)
93+
94+
# GPT2 only has bos/eos tokens but not decoder_start/pad tokens
95+
decoder_start_token_id = decoder_config.decoder_start_token_id
96+
pad_token_id = decoder_config.pad_token_id
97+
if decoder_start_token_id is None:
98+
decoder_start_token_id = decoder_config.bos_token_id
99+
if pad_token_id is None:
100+
pad_token_id = decoder_config.eos_token_id
101+
102+
# This is necessary to make Flax's generate() work
103+
model.config.eos_token_id = decoder_config.eos_token_id
104+
model.config.decoder_start_token_id = decoder_start_token_id
105+
model.config.pad_token_id = pad_token_id
106+
107+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_args.encoder_model_name_or_path)
108+
109+
tokenizer = AutoTokenizer.from_pretrained(model_args.decoder_model_name_or_path)
110+
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(model.config.pad_token_id)
111+
112+
model.save_pretrained(model_args.output_dir)
113+
feature_extractor.save_pretrained(model_args.output_dir)
114+
tokenizer.save_pretrained(model_args.output_dir)
115+
116+
117+
if __name__ == "__main__":
118+
main()

0 commit comments

Comments
 (0)