|
13 | 13 | from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
|
14 | 14 |
|
15 | 15 | try:
|
16 |
| - from torchvision.prototype import models as PM |
| 16 | + from torchvision import prototype |
17 | 17 | except ImportError:
|
18 |
| - PM = None |
| 18 | + prototype = None |
19 | 19 |
|
20 | 20 |
|
21 | 21 | def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None):
|
@@ -96,9 +96,10 @@ def collate_fn(batch):
|
96 | 96 |
|
97 | 97 |
|
98 | 98 | def main(args):
|
99 |
| - if args.weights and PM is None: |
| 99 | + if args.prototype and prototype is None: |
100 | 100 | raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
|
101 |
| - |
| 101 | + if not args.prototype and args.weights: |
| 102 | + raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.") |
102 | 103 | if args.output_dir:
|
103 | 104 | utils.mkdir(args.output_dir)
|
104 | 105 |
|
@@ -149,11 +150,14 @@ def main(args):
|
149 | 150 | print("Loading validation data")
|
150 | 151 | cache_path = _get_cache_path(valdir)
|
151 | 152 |
|
152 |
| - if not args.weights: |
153 |
| - transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) |
| 153 | + if not args.prototype: |
| 154 | + transform_test = presets.VideoClassificationPresetEval(resize_size=(128, 171), crop_size=(112, 112)) |
154 | 155 | else:
|
155 |
| - weights = PM.get_weight(args.weights) |
156 |
| - transform_test = weights.transforms() |
| 156 | + if args.weights: |
| 157 | + weights = prototype.models.get_weight(args.weights) |
| 158 | + transform_test = weights.transforms() |
| 159 | + else: |
| 160 | + transform_test = prototype.transforms.Kinect400Eval(crop_size=(112, 112), resize_size=(128, 171)) |
157 | 161 |
|
158 | 162 | if args.cache_dataset and os.path.exists(cache_path):
|
159 | 163 | print(f"Loading dataset_test from {cache_path}")
|
@@ -204,10 +208,10 @@ def main(args):
|
204 | 208 | )
|
205 | 209 |
|
206 | 210 | print("Creating model")
|
207 |
| - if not args.weights: |
| 211 | + if not args.prototype: |
208 | 212 | model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
|
209 | 213 | else:
|
210 |
| - model = PM.video.__dict__[args.model](weights=args.weights) |
| 214 | + model = prototype.models.video.__dict__[args.model](weights=args.weights) |
211 | 215 | model.to(device)
|
212 | 216 | if args.distributed and args.sync_bn:
|
213 | 217 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
@@ -360,6 +364,12 @@ def parse_args():
|
360 | 364 | parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
|
361 | 365 |
|
362 | 366 | # Prototype models only
|
| 367 | + parser.add_argument( |
| 368 | + "--prototype", |
| 369 | + dest="prototype", |
| 370 | + help="Use prototype model builders instead those from main area", |
| 371 | + action="store_true", |
| 372 | + ) |
363 | 373 | parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
|
364 | 374 |
|
365 | 375 | # Mixed precision training parameters
|
|
0 commit comments