Skip to content

Commit 0cdcd7a

Browse files
authored
Remove task arg in load_dataset in image-classification example (#28408)
* Remove `task` arg in `load_dataset` in image-classification example * Manage case where "train" is not in dataset * Add new args to manage image and label column names * Similar to audio-classification example * Fix README * Update tests
1 parent edb1702 commit 0cdcd7a

File tree

5 files changed

+71
-18
lines changed

5 files changed

+71
-18
lines changed

examples/pytorch/image-classification/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ python run_image_classification.py \
4141
--dataset_name beans \
4242
--output_dir ./beans_outputs/ \
4343
--remove_unused_columns False \
44+
--label_column_name labels \
4445
--do_train \
4546
--do_eval \
4647
--push_to_hub \
@@ -197,7 +198,7 @@ accelerate test
197198
that will check everything is ready for training. Finally, you can launch training with
198199

199200
```bash
200-
accelerate launch run_image_classification_trainer.py
201+
accelerate launch run_image_classification_no_trainer.py --image_column_name img
201202
```
202203

203204
This command is the same and will work for:

examples/pytorch/image-classification/run_image_classification.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,14 @@ class DataTrainingArguments:
111111
)
112112
},
113113
)
114+
image_column_name: str = field(
115+
default="image",
116+
metadata={"help": "The name of the dataset column containing the image data. Defaults to 'image'."},
117+
)
118+
label_column_name: str = field(
119+
default="label",
120+
metadata={"help": "The name of the dataset column containing the labels. Defaults to 'label'."},
121+
)
114122

115123
def __post_init__(self):
116124
if self.dataset_name is None and (self.train_dir is None and self.validation_dir is None):
@@ -175,12 +183,6 @@ class ModelArguments:
175183
)
176184

177185

178-
def collate_fn(examples):
179-
pixel_values = torch.stack([example["pixel_values"] for example in examples])
180-
labels = torch.tensor([example["labels"] for example in examples])
181-
return {"pixel_values": pixel_values, "labels": labels}
182-
183-
184186
def main():
185187
# See all possible arguments in src/transformers/training_args.py
186188
# or by passing the --help flag to this script.
@@ -255,7 +257,6 @@ def main():
255257
data_args.dataset_name,
256258
data_args.dataset_config_name,
257259
cache_dir=model_args.cache_dir,
258-
task="image-classification",
259260
token=model_args.token,
260261
)
261262
else:
@@ -268,9 +269,27 @@ def main():
268269
"imagefolder",
269270
data_files=data_files,
270271
cache_dir=model_args.cache_dir,
271-
task="image-classification",
272272
)
273273

274+
dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names
275+
if data_args.image_column_name not in dataset_column_names:
276+
raise ValueError(
277+
f"--image_column_name {data_args.image_column_name} not found in dataset '{data_args.dataset_name}'. "
278+
"Make sure to set `--image_column_name` to the correct audio column - one of "
279+
f"{', '.join(dataset_column_names)}."
280+
)
281+
if data_args.label_column_name not in dataset_column_names:
282+
raise ValueError(
283+
f"--label_column_name {data_args.label_column_name} not found in dataset '{data_args.dataset_name}'. "
284+
"Make sure to set `--label_column_name` to the correct text column - one of "
285+
f"{', '.join(dataset_column_names)}."
286+
)
287+
288+
def collate_fn(examples):
289+
pixel_values = torch.stack([example["pixel_values"] for example in examples])
290+
labels = torch.tensor([example[data_args.label_column_name] for example in examples])
291+
return {"pixel_values": pixel_values, "labels": labels}
292+
274293
# If we don't have a validation split, split off a percentage of train as validation.
275294
data_args.train_val_split = None if "validation" in dataset.keys() else data_args.train_val_split
276295
if isinstance(data_args.train_val_split, float) and data_args.train_val_split > 0.0:
@@ -280,7 +299,7 @@ def main():
280299

281300
# Prepare label mappings.
282301
# We'll include these in the model's config to get human readable labels in the Inference API.
283-
labels = dataset["train"].features["labels"].names
302+
labels = dataset["train"].features[data_args.label_column_name].names
284303
label2id, id2label = {}, {}
285304
for i, label in enumerate(labels):
286305
label2id[label] = str(i)
@@ -354,13 +373,15 @@ def compute_metrics(p):
354373
def train_transforms(example_batch):
355374
"""Apply _train_transforms across a batch."""
356375
example_batch["pixel_values"] = [
357-
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]
376+
_train_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
358377
]
359378
return example_batch
360379

361380
def val_transforms(example_batch):
362381
"""Apply _val_transforms across a batch."""
363-
example_batch["pixel_values"] = [_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch["image"]]
382+
example_batch["pixel_values"] = [
383+
_val_transforms(pil_img.convert("RGB")) for pil_img in example_batch[data_args.image_column_name]
384+
]
364385
return example_batch
365386

366387
if training_args.do_train:

examples/pytorch/image-classification/run_image_classification_no_trainer.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,18 @@ def parse_args():
189189
action="store_true",
190190
help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
191191
)
192+
parser.add_argument(
193+
"--image_column_name",
194+
type=str,
195+
default="image",
196+
help="The name of the dataset column containing the image data. Defaults to 'image'.",
197+
)
198+
parser.add_argument(
199+
"--label_column_name",
200+
type=str,
201+
default="label",
202+
help="The name of the dataset column containing the labels. Defaults to 'label'.",
203+
)
192204
args = parser.parse_args()
193205

194206
# Sanity checks
@@ -272,7 +284,7 @@ def main():
272284
# download the dataset.
273285
if args.dataset_name is not None:
274286
# Downloading and loading a dataset from the hub.
275-
dataset = load_dataset(args.dataset_name, task="image-classification")
287+
dataset = load_dataset(args.dataset_name)
276288
else:
277289
data_files = {}
278290
if args.train_dir is not None:
@@ -282,11 +294,24 @@ def main():
282294
dataset = load_dataset(
283295
"imagefolder",
284296
data_files=data_files,
285-
task="image-classification",
286297
)
287298
# See more about loading custom images at
288299
# https://huggingface.co/docs/datasets/v2.0.0/en/image_process#imagefolder.
289300

301+
dataset_column_names = dataset["train"].column_names if "train" in dataset else dataset["validation"].column_names
302+
if args.image_column_name not in dataset_column_names:
303+
raise ValueError(
304+
f"--image_column_name {args.image_column_name} not found in dataset '{args.dataset_name}'. "
305+
"Make sure to set `--image_column_name` to the correct audio column - one of "
306+
f"{', '.join(dataset_column_names)}."
307+
)
308+
if args.label_column_name not in dataset_column_names:
309+
raise ValueError(
310+
f"--label_column_name {args.label_column_name} not found in dataset '{args.dataset_name}'. "
311+
"Make sure to set `--label_column_name` to the correct text column - one of "
312+
f"{', '.join(dataset_column_names)}."
313+
)
314+
290315
# If we don't have a validation split, split off a percentage of train as validation.
291316
args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split
292317
if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:
@@ -296,7 +321,7 @@ def main():
296321

297322
# Prepare label mappings.
298323
# We'll include these in the model's config to get human readable labels in the Inference API.
299-
labels = dataset["train"].features["labels"].names
324+
labels = dataset["train"].features[args.label_column_name].names
300325
label2id = {label: str(i) for i, label in enumerate(labels)}
301326
id2label = {str(i): label for i, label in enumerate(labels)}
302327

@@ -355,12 +380,16 @@ def main():
355380

356381
def preprocess_train(example_batch):
357382
"""Apply _train_transforms across a batch."""
358-
example_batch["pixel_values"] = [train_transforms(image.convert("RGB")) for image in example_batch["image"]]
383+
example_batch["pixel_values"] = [
384+
train_transforms(image.convert("RGB")) for image in example_batch[args.image_column_name]
385+
]
359386
return example_batch
360387

361388
def preprocess_val(example_batch):
362389
"""Apply _val_transforms across a batch."""
363-
example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
390+
example_batch["pixel_values"] = [
391+
val_transforms(image.convert("RGB")) for image in example_batch[args.image_column_name]
392+
]
364393
return example_batch
365394

366395
with accelerator.main_process_first():
@@ -376,7 +405,7 @@ def preprocess_val(example_batch):
376405
# DataLoaders creation:
377406
def collate_fn(examples):
378407
pixel_values = torch.stack([example["pixel_values"] for example in examples])
379-
labels = torch.tensor([example["labels"] for example in examples])
408+
labels = torch.tensor([example[args.label_column_name] for example in examples])
380409
return {"pixel_values": pixel_values, "labels": labels}
381410

382411
train_dataloader = DataLoader(

examples/pytorch/test_accelerate_examples.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ def test_run_image_classification_no_trainer(self):
322322
--output_dir {tmp_dir}
323323
--with_tracking
324324
--checkpointing_steps 1
325+
--label_column_name labels
325326
""".split()
326327

327328
run_command(self._launch_args + testargs)

examples/pytorch/test_pytorch_examples.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,7 @@ def test_run_image_classification(self):
398398
--max_steps 10
399399
--train_val_split 0.1
400400
--seed 42
401+
--label_column_name labels
401402
""".split()
402403

403404
if is_torch_fp16_available_on_device(torch_device):

0 commit comments

Comments
 (0)