-
Notifications
You must be signed in to change notification settings - Fork 28.4k
/
Copy pathrun_semantic_segmentation_no_trainer.py
658 lines (567 loc) · 25.3 KB
/
run_semantic_segmentation_no_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning any 🤗 Transformers model supported by AutoModelForSemanticSegmentation for semantic segmentation."""
import argparse
import json
import math
import os
import random
from pathlib import Path
import datasets
import numpy as np
import torch
from datasets import load_dataset, load_metric
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.transforms import functional
from tqdm.auto import tqdm
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from huggingface_hub import Repository, hf_hub_download
from transformers import (
AutoConfig,
AutoFeatureExtractor,
AutoModelForSemanticSegmentation,
SchedulerType,
default_data_collator,
get_scheduler,
)
from transformers.utils import get_full_repo_name
from transformers.utils.versions import require_version
logger = get_logger(__name__)
require_version("datasets>=2.0.0", "To fix: pip install -r examples/pytorch/semantic-segmentation/requirements.txt")
def pad_if_smaller(img, size, fill=0):
min_size = min(img.size)
if min_size < size:
original_width, original_height = img.size
pad_height = size - original_height if original_height < size else 0
pad_width = size - original_width if original_width < size else 0
img = functional.pad(img, (0, 0, pad_width, pad_height), fill=fill)
return img
class Compose:
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target
class Identity:
def __init__(self):
pass
def __call__(self, image, target):
return image, target
class Resize:
def __init__(self, size):
self.size = size
def __call__(self, image, target):
image = functional.resize(image, self.size)
target = functional.resize(target, self.size, interpolation=transforms.InterpolationMode.NEAREST)
return image, target
class RandomResize:
def __init__(self, min_size, max_size=None):
self.min_size = min_size
if max_size is None:
max_size = min_size
self.max_size = max_size
def __call__(self, image, target):
size = random.randint(self.min_size, self.max_size)
image = functional.resize(image, size)
target = functional.resize(target, size, interpolation=transforms.InterpolationMode.NEAREST)
return image, target
class RandomCrop:
def __init__(self, size):
self.size = size
def __call__(self, image, target):
image = pad_if_smaller(image, self.size)
target = pad_if_smaller(target, self.size, fill=255)
crop_params = transforms.RandomCrop.get_params(image, (self.size, self.size))
image = functional.crop(image, *crop_params)
target = functional.crop(target, *crop_params)
return image, target
class RandomHorizontalFlip:
def __init__(self, flip_prob):
self.flip_prob = flip_prob
def __call__(self, image, target):
if random.random() < self.flip_prob:
image = functional.hflip(image)
target = functional.hflip(target)
return image, target
class PILToTensor:
def __call__(self, image, target):
image = functional.pil_to_tensor(image)
target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target
class ConvertImageDtype:
def __init__(self, dtype):
self.dtype = dtype
def __call__(self, image, target):
image = functional.convert_image_dtype(image, self.dtype)
return image, target
class Normalize:
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, image, target):
image = functional.normalize(image, mean=self.mean, std=self.std)
return image, target
class ReduceLabels:
def __call__(self, image, target):
if not isinstance(target, np.ndarray):
target = np.array(target).astype(np.uint8)
# avoid using underflow conversion
target[target == 0] = 255
target = target - 1
target[target == 254] = 255
target = Image.fromarray(target)
return image, target
def parse_args():
parser = argparse.ArgumentParser(description="Finetune a transformers model on a text classification task")
parser.add_argument(
"--model_name_or_path",
type=str,
help="Path to a pretrained model or model identifier from huggingface.co/models.",
default="nvidia/mit-b0",
)
parser.add_argument(
"--dataset_name",
type=str,
help="Name of the dataset on the hub.",
default="segments/sidewalk-semantic",
)
parser.add_argument(
"--reduce_labels",
action="store_true",
help="Whether or not to reduce all labels by 1 and replace background by 255.",
)
parser.add_argument(
"--train_val_split",
type=float,
default=0.15,
help="Fraction of the dataset to be used for validation.",
)
parser.add_argument(
"--cache_dir",
type=str,
help="Path to a folder in which the model and dataset will be cached.",
)
parser.add_argument(
"--use_auth_token",
action="store_true",
help="Whether to use an authentication token to access the model repository.",
)
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=8,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument(
"--per_device_eval_batch_size",
type=int,
default=8,
help="Batch size (per device) for the evaluation dataloader.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--adam_beta1",
type=float,
default=0.9,
help="Beta1 for AdamW optimizer",
)
parser.add_argument(
"--adam_beta2",
type=float,
default=0.999,
help="Beta2 for AdamW optimizer",
)
parser.add_argument(
"--adam_epsilon",
type=float,
default=1e-8,
help="Epsilon for AdamW optimizer",
)
parser.add_argument("--num_train_epochs", type=int, default=3, help="Total number of training epochs to perform.")
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--lr_scheduler_type",
type=SchedulerType,
default="polynomial",
help="The scheduler type to use.",
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
)
parser.add_argument(
"--num_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument("--output_dir", type=str, default=None, help="Where to store the final model.")
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument(
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
)
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--checkpointing_steps",
type=str,
default=None,
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help="If the training should continue from a checkpoint folder.",
)
parser.add_argument(
"--with_tracking",
required=False,
action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.",
)
args = parser.parse_args()
# Sanity checks
if args.push_to_hub or args.with_tracking:
if args.output_dir is None:
raise ValueError(
"Need an `output_dir` to create a repo when `--push_to_hub` or `with_tracking` is specified."
)
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
return args
def main():
args = parse_args()
# Initialize the accelerator. We will let the accelerator handle device placement for us in this example.
# If we're using tracking, we also need to initialize it here and it will pick up all supported trackers in the environment
accelerator = Accelerator(log_with="all", logging_dir=args.output_dir) if args.with_tracking else Accelerator()
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
# We set device_specific to True as we want different data augmentation per device.
if args.seed is not None:
set_seed(args.seed, device_specific=True)
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
# Load dataset
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
# TODO support datasets from local folders
dataset = load_dataset(args.dataset_name, cache_dir=args.cache_dir)
# Rename column names to standardized names (only "image" and "label" need to be present)
if "pixel_values" in dataset["train"].column_names:
dataset = dataset.rename_columns({"pixel_values": "image"})
if "annotation" in dataset["train"].column_names:
dataset = dataset.rename_columns({"annotation": "label"})
# If we don't have a validation split, split off a percentage of train as validation.
args.train_val_split = None if "validation" in dataset.keys() else args.train_val_split
if isinstance(args.train_val_split, float) and args.train_val_split > 0.0:
split = dataset["train"].train_test_split(args.train_val_split)
dataset["train"] = split["train"]
dataset["validation"] = split["test"]
# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
if args.dataset_name == "scene_parse_150":
repo_id = "datasets/huggingface/label-files"
filename = "ade20k-id2label.json"
else:
repo_id = f"datasets/{args.dataset_name}"
filename = "id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}
# Load pretrained model and feature extractor
config = AutoConfig.from_pretrained(args.model_name_or_path, id2label=id2label, label2id=label2id)
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_name_or_path)
model = AutoModelForSemanticSegmentation.from_pretrained(args.model_name_or_path, config=config)
# Preprocessing the datasets
# Define torchvision transforms to be applied to each image + target.
# Not that straightforward in torchvision: https://github.com/pytorch/vision/issues/9
# Currently based on official torchvision references: https://github.com/pytorch/vision/blob/main/references/segmentation/transforms.py
train_transforms = Compose(
[
ReduceLabels() if args.reduce_labels else Identity(),
RandomCrop(size=feature_extractor.size),
RandomHorizontalFlip(flip_prob=0.5),
PILToTensor(),
ConvertImageDtype(torch.float),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
]
)
# Define torchvision transform to be applied to each image.
# jitter = ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.1)
val_transforms = Compose(
[
ReduceLabels() if args.reduce_labels else Identity(),
Resize(size=(feature_extractor.size, feature_extractor.size)),
PILToTensor(),
ConvertImageDtype(torch.float),
Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
]
)
def preprocess_train(example_batch):
pixel_values = []
labels = []
for image, target in zip(example_batch["image"], example_batch["label"]):
image, target = train_transforms(image.convert("RGB"), target)
pixel_values.append(image)
labels.append(target)
encoding = dict()
encoding["pixel_values"] = torch.stack(pixel_values)
encoding["labels"] = torch.stack(labels)
return encoding
def preprocess_val(example_batch):
pixel_values = []
labels = []
for image, target in zip(example_batch["image"], example_batch["label"]):
image, target = val_transforms(image.convert("RGB"), target)
pixel_values.append(image)
labels.append(target)
encoding = dict()
encoding["pixel_values"] = torch.stack(pixel_values)
encoding["labels"] = torch.stack(labels)
return encoding
with accelerator.main_process_first():
train_dataset = dataset["train"].with_transform(preprocess_train)
eval_dataset = dataset["validation"].with_transform(preprocess_val)
train_dataloader = DataLoader(
train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=args.per_device_train_batch_size
)
eval_dataloader = DataLoader(
eval_dataset, collate_fn=default_data_collator, batch_size=args.per_device_eval_batch_size
)
# Optimizer
optimizer = torch.optim.AdamW(
list(model.parameters()),
lr=args.learning_rate,
betas=[args.adam_beta1, args.adam_beta2],
eps=args.adam_epsilon,
)
# Figure out how many steps we should save the Accelerator states
if hasattr(args.checkpointing_steps, "isdigit"):
checkpointing_steps = args.checkpointing_steps
if args.checkpointing_steps.isdigit():
checkpointing_steps = int(args.checkpointing_steps)
else:
checkpointing_steps = None
# Scheduler and math around the number of training steps.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
else:
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.max_train_steps,
)
# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Instantiate metric
metric = load_metric("mean_iou")
if args.with_tracking:
experiment_config = vars(args)
# TensorBoard cannot log Enums, need the raw value
experiment_config["lr_scheduler_type"] = experiment_config["lr_scheduler_type"].value
accelerator.init_trackers("semantic_segmentation_no_trainer", experiment_config)
# Train!
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
starting_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
accelerator.print(f"Resumed from checkpoint: {args.resume_from_checkpoint}")
accelerator.load_state(args.resume_from_checkpoint)
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
dirs.sort(key=os.path.getctime)
path = dirs[-1] # Sorts folders by date modified, most recent checkpoint is the last
# Extract `epoch_{i}` or `step_{i}`
training_difference = os.path.splitext(path)[0]
if "epoch" in training_difference:
starting_epoch = int(training_difference.replace("epoch_", "")) + 1
resume_step = None
else:
resume_step = int(training_difference.replace("step_", ""))
starting_epoch = resume_step // len(train_dataloader)
resume_step -= starting_epoch * len(train_dataloader)
for epoch in range(starting_epoch, args.num_train_epochs):
if args.with_tracking:
total_loss = 0
model.train()
for step, batch in enumerate(train_dataloader):
# We need to skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == starting_epoch:
if resume_step is not None and step < resume_step:
completed_steps += 1
continue
outputs = model(**batch)
loss = outputs.loss
# We keep track of the loss at each epoch
if args.with_tracking:
total_loss += loss.detach().float()
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
completed_steps += 1
if isinstance(checkpointing_steps, int):
if completed_steps % checkpointing_steps == 0:
output_dir = f"step_{completed_steps }"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
args.output_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
)
if accelerator.is_main_process:
feature_extractor.save_pretrained(args.output_dir)
repo.push_to_hub(
commit_message=f"Training in progress {completed_steps} steps",
blocking=False,
auto_lfs_prune=True,
)
if completed_steps >= args.max_train_steps:
break
logger.info("***** Running evaluation *****")
model.eval()
samples_seen = 0
for step, batch in enumerate(tqdm(eval_dataloader, disable=not accelerator.is_local_main_process)):
with torch.no_grad():
outputs = model(**batch)
upsampled_logits = torch.nn.functional.interpolate(
outputs.logits, size=batch["labels"].shape[-2:], mode="bilinear", align_corners=False
)
predictions = upsampled_logits.argmax(dim=1)
predictions, references = accelerator.gather((predictions, batch["labels"]))
# If we are in a multiprocess environment, the last batch has duplicates
if accelerator.num_processes > 1:
if step == len(eval_dataloader):
predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
references = references[: len(eval_dataloader.dataset) - samples_seen]
else:
samples_seen += references.shape[0]
metric.add_batch(
predictions=predictions,
references=references,
)
eval_metrics = metric.compute(
num_labels=len(id2label),
ignore_index=255,
reduce_labels=False, # we've already reduced the labels before
)
logger.info(f"epoch {epoch}: {eval_metrics}")
if args.with_tracking:
accelerator.log(
{
"mean_iou": eval_metrics["mean_iou"],
"mean_accuracy": eval_metrics["mean_accuracy"],
"overall_accuracy": eval_metrics["overall_accuracy"],
"train_loss": total_loss,
"epoch": epoch,
"step": completed_steps,
},
)
if args.push_to_hub and epoch < args.num_train_epochs - 1:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process:
feature_extractor.save_pretrained(args.output_dir)
repo.push_to_hub(
commit_message=f"Training in progress epoch {epoch}", blocking=False, auto_lfs_prune=True
)
if args.checkpointing_steps == "epoch":
output_dir = f"epoch_{epoch}"
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
if args.output_dir is not None:
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(
args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
)
if accelerator.is_main_process:
feature_extractor.save_pretrained(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", auto_lfs_prune=True)
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"eval_overall_accuracy": eval_metrics["overall_accuracy"]}, f)
if __name__ == "__main__":
main()