Skip to content

Commit 4bf6c6e

Browse files
authoredJan 21, 2022
Adding prototype flag on reference scripts (#5248)
* Adding prototype flag on reference scripts. * Import prototype instead of models/transforms. * Correcting exception type. * fixing none referencing
1 parent 7d4bdd4 commit 4bf6c6e

File tree

5 files changed

+102
-43
lines changed

5 files changed

+102
-43
lines changed
 

‎references/classification/train.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616

1717

1818
try:
19-
from torchvision.prototype import models as PM
19+
from torchvision import prototype
2020
except ImportError:
21-
PM = None
21+
prototype = None
2222

2323

2424
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
@@ -154,13 +154,18 @@ def load_data(traindir, valdir, args):
154154
print(f"Loading dataset_test from {cache_path}")
155155
dataset_test, _ = torch.load(cache_path)
156156
else:
157-
if not args.weights:
157+
if not args.prototype:
158158
preprocessing = presets.ClassificationPresetEval(
159159
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
160160
)
161161
else:
162-
weights = PM.get_weight(args.weights)
163-
preprocessing = weights.transforms()
162+
if args.weights:
163+
weights = prototype.models.get_weight(args.weights)
164+
preprocessing = weights.transforms()
165+
else:
166+
preprocessing = prototype.transforms.ImageNetEval(
167+
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
168+
)
164169

165170
dataset_test = torchvision.datasets.ImageFolder(
166171
valdir,
@@ -186,8 +191,10 @@ def load_data(traindir, valdir, args):
186191

187192

188193
def main(args):
189-
if args.weights and PM is None:
194+
if args.prototype and prototype is None:
190195
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
196+
if not args.prototype and args.weights:
197+
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
191198
if args.output_dir:
192199
utils.mkdir(args.output_dir)
193200

@@ -229,10 +236,10 @@ def main(args):
229236
)
230237

231238
print("Creating model")
232-
if not args.weights:
239+
if not args.prototype:
233240
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
234241
else:
235-
model = PM.__dict__[args.model](weights=args.weights, num_classes=num_classes)
242+
model = prototype.models.__dict__[args.model](weights=args.weights, num_classes=num_classes)
236243
model.to(device)
237244

238245
if args.distributed and args.sync_bn:
@@ -491,6 +498,12 @@ def get_args_parser(add_help=True):
491498
)
492499

493500
# Prototype models only
501+
parser.add_argument(
502+
"--prototype",
503+
dest="prototype",
504+
help="Use prototype model builders instead those from main area",
505+
action="store_true",
506+
)
494507
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
495508

496509
return parser

‎references/detection/train.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434

3535

3636
try:
37-
from torchvision.prototype import models as PM
37+
from torchvision import prototype
3838
except ImportError:
39-
PM = None
39+
prototype = None
4040

4141

4242
def get_dataset(name, image_set, transform, data_path):
@@ -50,11 +50,14 @@ def get_dataset(name, image_set, transform, data_path):
5050
def get_transform(train, args):
5151
if train:
5252
return presets.DetectionPresetTrain(args.data_augmentation)
53-
elif not args.weights:
53+
elif not args.prototype:
5454
return presets.DetectionPresetEval()
5555
else:
56-
weights = PM.get_weight(args.weights)
57-
return weights.transforms()
56+
if args.weights:
57+
weights = prototype.models.get_weight(args.weights)
58+
return weights.transforms()
59+
else:
60+
return prototype.transforms.CocoEval()
5861

5962

6063
def get_args_parser(add_help=True):
@@ -141,6 +144,12 @@ def get_args_parser(add_help=True):
141144
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
142145

143146
# Prototype models only
147+
parser.add_argument(
148+
"--prototype",
149+
dest="prototype",
150+
help="Use prototype model builders instead those from main area",
151+
action="store_true",
152+
)
144153
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
145154

146155
# Mixed precision training parameters
@@ -150,8 +159,10 @@ def get_args_parser(add_help=True):
150159

151160

152161
def main(args):
153-
if args.weights and PM is None:
162+
if args.prototype and prototype is None:
154163
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
164+
if not args.prototype and args.weights:
165+
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
155166
if args.output_dir:
156167
utils.mkdir(args.output_dir)
157168

@@ -193,12 +204,12 @@ def main(args):
193204
if "rcnn" in args.model:
194205
if args.rpn_score_thresh is not None:
195206
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
196-
if not args.weights:
207+
if not args.prototype:
197208
model = torchvision.models.detection.__dict__[args.model](
198209
pretrained=args.pretrained, num_classes=num_classes, **kwargs
199210
)
200211
else:
201-
model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
212+
model = prototype.models.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
202213
model.to(device)
203214
if args.distributed and args.sync_bn:
204215
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

‎references/optical_flow/train.py

+23-9
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010
from torchvision.datasets import KittiFlow, FlyingChairs, FlyingThings3D, Sintel, HD1K
1111

1212
try:
13-
from torchvision.prototype import models as PM
14-
from torchvision.prototype.models import optical_flow as PMOF
13+
from torchvision import prototype
1514
except ImportError:
16-
PM = PMOF = None
15+
prototype = None
1716

1817

1918
def get_train_dataset(stage, dataset_root):
@@ -133,9 +132,12 @@ def inner_loop(blob):
133132
def validate(model, args):
134133
val_datasets = args.val_dataset or []
135134

136-
if args.weights:
137-
weights = PM.get_weight(args.weights)
138-
preprocessing = weights.transforms()
135+
if args.prototype:
136+
if args.weights:
137+
weights = prototype.models.get_weight(args.weights)
138+
preprocessing = weights.transforms()
139+
else:
140+
preprocessing = prototype.transforms.RaftEval()
139141
else:
140142
preprocessing = OpticalFlowPresetEval()
141143

@@ -192,10 +194,14 @@ def train_one_epoch(model, optimizer, scheduler, train_loader, logger, args):
192194

193195

194196
def main(args):
197+
if args.prototype and prototype is None:
198+
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
199+
if not args.prototype and args.weights:
200+
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
195201
utils.setup_ddp(args)
196202

197-
if args.weights:
198-
model = PMOF.__dict__[args.model](weights=args.weights)
203+
if args.prototype:
204+
model = prototype.models.optical_flow.__dict__[args.model](weights=args.weights)
199205
else:
200206
model = torchvision.models.optical_flow.__dict__[args.model](pretrained=args.pretrained)
201207

@@ -317,7 +323,6 @@ def get_args_parser(add_help=True):
317323
)
318324
# TODO: resume, pretrained, and weights should be in an exclusive arg group
319325
parser.add_argument("--pretrained", action="store_true", help="Whether to use pretrained weights")
320-
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
321326

322327
parser.add_argument(
323328
"--num_flow_updates",
@@ -336,6 +341,15 @@ def get_args_parser(add_help=True):
336341
required=True,
337342
)
338343

344+
# Prototype models only
345+
parser.add_argument(
346+
"--prototype",
347+
dest="prototype",
348+
help="Use prototype model builders instead those from main area",
349+
action="store_true",
350+
)
351+
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load.")
352+
339353
return parser
340354

341355

‎references/segmentation/train.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313

1414
try:
15-
from torchvision.prototype import models as PM
15+
from torchvision import prototype
1616
except ImportError:
17-
PM = None
17+
prototype = None
1818

1919

2020
def get_dataset(dir_path, name, image_set, transform):
@@ -35,11 +35,14 @@ def sbd(*args, **kwargs):
3535
def get_transform(train, args):
3636
if train:
3737
return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
38-
elif not args.weights:
38+
elif not args.prototype:
3939
return presets.SegmentationPresetEval(base_size=520)
4040
else:
41-
weights = PM.get_weight(args.weights)
42-
return weights.transforms()
41+
if args.weights:
42+
weights = prototype.models.get_weight(args.weights)
43+
return weights.transforms()
44+
else:
45+
return prototype.transforms.VocEval(resize_size=520)
4346

4447

4548
def criterion(inputs, target):
@@ -97,8 +100,10 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi
97100

98101

99102
def main(args):
100-
if args.weights and PM is None:
103+
if args.prototype and prototype is None:
101104
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
105+
if not args.prototype and args.weights:
106+
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
102107
if args.output_dir:
103108
utils.mkdir(args.output_dir)
104109

@@ -130,14 +135,14 @@ def main(args):
130135
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
131136
)
132137

133-
if not args.weights:
138+
if not args.prototype:
134139
model = torchvision.models.segmentation.__dict__[args.model](
135140
pretrained=args.pretrained,
136141
num_classes=num_classes,
137142
aux_loss=args.aux_loss,
138143
)
139144
else:
140-
model = PM.segmentation.__dict__[args.model](
145+
model = prototype.models.segmentation.__dict__[args.model](
141146
weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss
142147
)
143148
model.to(device)
@@ -278,6 +283,12 @@ def get_args_parser(add_help=True):
278283
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
279284

280285
# Prototype models only
286+
parser.add_argument(
287+
"--prototype",
288+
dest="prototype",
289+
help="Use prototype model builders instead those from main area",
290+
action="store_true",
291+
)
281292
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
282293

283294
# Mixed precision training parameters

‎references/video_classification/train.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
1414

1515
try:
16-
from torchvision.prototype import models as PM
16+
from torchvision import prototype
1717
except ImportError:
18-
PM = None
18+
prototype = None
1919

2020

2121
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None):
@@ -96,9 +96,10 @@ def collate_fn(batch):
9696

9797

9898
def main(args):
99-
if args.weights and PM is None:
99+
if args.prototype and prototype is None:
100100
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
101-
101+
if not args.prototype and args.weights:
102+
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
102103
if args.output_dir:
103104
utils.mkdir(args.output_dir)
104105

@@ -149,11 +150,14 @@ def main(args):
149150
print("Loading validation data")
150151
cache_path = _get_cache_path(valdir)
151152

152-
if not args.weights:
153-
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
153+
if not args.prototype:
154+
transform_test = presets.VideoClassificationPresetEval(resize_size=(128, 171), crop_size=(112, 112))
154155
else:
155-
weights = PM.get_weight(args.weights)
156-
transform_test = weights.transforms()
156+
if args.weights:
157+
weights = prototype.models.get_weight(args.weights)
158+
transform_test = weights.transforms()
159+
else:
160+
transform_test = prototype.transforms.Kinect400Eval(crop_size=(112, 112), resize_size=(128, 171))
157161

158162
if args.cache_dataset and os.path.exists(cache_path):
159163
print(f"Loading dataset_test from {cache_path}")
@@ -204,10 +208,10 @@ def main(args):
204208
)
205209

206210
print("Creating model")
207-
if not args.weights:
211+
if not args.prototype:
208212
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
209213
else:
210-
model = PM.video.__dict__[args.model](weights=args.weights)
214+
model = prototype.models.video.__dict__[args.model](weights=args.weights)
211215
model.to(device)
212216
if args.distributed and args.sync_bn:
213217
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
@@ -360,6 +364,12 @@ def parse_args():
360364
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
361365

362366
# Prototype models only
367+
parser.add_argument(
368+
"--prototype",
369+
dest="prototype",
370+
help="Use prototype model builders instead those from main area",
371+
action="store_true",
372+
)
363373
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
364374

365375
# Mixed precision training parameters

0 commit comments

Comments
 (0)