|
| 1 | +from functools import partial |
| 2 | +from typing import Any, Callable, Optional |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch import nn |
| 6 | +from torchvision.ops.misc import Conv3dNormActivation |
| 7 | + |
| 8 | +from ...transforms._presets import VideoClassification |
| 9 | +from ...utils import _log_api_usage_once |
| 10 | +from .._api import register_model, Weights, WeightsEnum |
| 11 | +from .._meta import _KINETICS400_CATEGORIES |
| 12 | +from .._utils import _ovewrite_named_param |
| 13 | + |
| 14 | + |
| 15 | +__all__ = [ |
| 16 | + "S3D", |
| 17 | + "S3D_Weights", |
| 18 | + "s3d", |
| 19 | +] |
| 20 | + |
| 21 | + |
| 22 | +class TemporalSeparableConv(nn.Sequential): |
| 23 | + def __init__( |
| 24 | + self, |
| 25 | + in_planes: int, |
| 26 | + out_planes: int, |
| 27 | + kernel_size: int, |
| 28 | + stride: int, |
| 29 | + padding: int, |
| 30 | + norm_layer: Callable[..., nn.Module], |
| 31 | + ): |
| 32 | + super().__init__( |
| 33 | + Conv3dNormActivation( |
| 34 | + in_planes, |
| 35 | + out_planes, |
| 36 | + kernel_size=(1, kernel_size, kernel_size), |
| 37 | + stride=(1, stride, stride), |
| 38 | + padding=(0, padding, padding), |
| 39 | + bias=False, |
| 40 | + norm_layer=norm_layer, |
| 41 | + ), |
| 42 | + Conv3dNormActivation( |
| 43 | + out_planes, |
| 44 | + out_planes, |
| 45 | + kernel_size=(kernel_size, 1, 1), |
| 46 | + stride=(stride, 1, 1), |
| 47 | + padding=(padding, 0, 0), |
| 48 | + bias=False, |
| 49 | + norm_layer=norm_layer, |
| 50 | + ), |
| 51 | + ) |
| 52 | + |
| 53 | + |
| 54 | +class SepInceptionBlock3D(nn.Module): |
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + in_planes: int, |
| 58 | + b0_out: int, |
| 59 | + b1_mid: int, |
| 60 | + b1_out: int, |
| 61 | + b2_mid: int, |
| 62 | + b2_out: int, |
| 63 | + b3_out: int, |
| 64 | + norm_layer: Callable[..., nn.Module], |
| 65 | + ): |
| 66 | + super().__init__() |
| 67 | + |
| 68 | + self.branch0 = Conv3dNormActivation(in_planes, b0_out, kernel_size=1, stride=1, norm_layer=norm_layer) |
| 69 | + self.branch1 = nn.Sequential( |
| 70 | + Conv3dNormActivation(in_planes, b1_mid, kernel_size=1, stride=1, norm_layer=norm_layer), |
| 71 | + TemporalSeparableConv(b1_mid, b1_out, kernel_size=3, stride=1, padding=1, norm_layer=norm_layer), |
| 72 | + ) |
| 73 | + self.branch2 = nn.Sequential( |
| 74 | + Conv3dNormActivation(in_planes, b2_mid, kernel_size=1, stride=1, norm_layer=norm_layer), |
| 75 | + TemporalSeparableConv(b2_mid, b2_out, kernel_size=3, stride=1, padding=1, norm_layer=norm_layer), |
| 76 | + ) |
| 77 | + self.branch3 = nn.Sequential( |
| 78 | + nn.MaxPool3d(kernel_size=(3, 3, 3), stride=1, padding=1), |
| 79 | + Conv3dNormActivation(in_planes, b3_out, kernel_size=1, stride=1, norm_layer=norm_layer), |
| 80 | + ) |
| 81 | + |
| 82 | + def forward(self, x): |
| 83 | + x0 = self.branch0(x) |
| 84 | + x1 = self.branch1(x) |
| 85 | + x2 = self.branch2(x) |
| 86 | + x3 = self.branch3(x) |
| 87 | + out = torch.cat((x0, x1, x2, x3), 1) |
| 88 | + |
| 89 | + return out |
| 90 | + |
| 91 | + |
| 92 | +class S3D(nn.Module): |
| 93 | + """S3D main class. |
| 94 | +
|
| 95 | + Args: |
| 96 | + num_class (int): number of classes for the classification task. |
| 97 | + dropout (float): dropout probability. |
| 98 | + norm_layer (Optional[Callable]): Module specifying the normalization layer to use. |
| 99 | +
|
| 100 | + Inputs: |
| 101 | + x (Tensor): batch of videos with dimensions (batch, channel, time, height, width) |
| 102 | + """ |
| 103 | + |
| 104 | + def __init__( |
| 105 | + self, |
| 106 | + num_classes: int = 400, |
| 107 | + dropout: float = 0.0, |
| 108 | + norm_layer: Optional[Callable[..., torch.nn.Module]] = None, |
| 109 | + ) -> None: |
| 110 | + super().__init__() |
| 111 | + _log_api_usage_once(self) |
| 112 | + |
| 113 | + if norm_layer is None: |
| 114 | + norm_layer = partial(nn.BatchNorm3d, eps=0.001, momentum=0.001) |
| 115 | + |
| 116 | + self.features = nn.Sequential( |
| 117 | + TemporalSeparableConv(3, 64, 7, 2, 3, norm_layer), |
| 118 | + nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)), |
| 119 | + Conv3dNormActivation( |
| 120 | + 64, |
| 121 | + 64, |
| 122 | + kernel_size=1, |
| 123 | + stride=1, |
| 124 | + norm_layer=norm_layer, |
| 125 | + ), |
| 126 | + TemporalSeparableConv(64, 192, 3, 1, 1, norm_layer), |
| 127 | + nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1)), |
| 128 | + SepInceptionBlock3D(192, 64, 96, 128, 16, 32, 32, norm_layer), |
| 129 | + SepInceptionBlock3D(256, 128, 128, 192, 32, 96, 64, norm_layer), |
| 130 | + nn.MaxPool3d(kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)), |
| 131 | + SepInceptionBlock3D(480, 192, 96, 208, 16, 48, 64, norm_layer), |
| 132 | + SepInceptionBlock3D(512, 160, 112, 224, 24, 64, 64, norm_layer), |
| 133 | + SepInceptionBlock3D(512, 128, 128, 256, 24, 64, 64, norm_layer), |
| 134 | + SepInceptionBlock3D(512, 112, 144, 288, 32, 64, 64, norm_layer), |
| 135 | + SepInceptionBlock3D(528, 256, 160, 320, 32, 128, 128, norm_layer), |
| 136 | + nn.MaxPool3d(kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0)), |
| 137 | + SepInceptionBlock3D(832, 256, 160, 320, 32, 128, 128, norm_layer), |
| 138 | + SepInceptionBlock3D(832, 384, 192, 384, 48, 128, 128, norm_layer), |
| 139 | + ) |
| 140 | + self.avgpool = nn.AvgPool3d(kernel_size=(2, 7, 7), stride=1) |
| 141 | + self.classifier = nn.Sequential( |
| 142 | + nn.Dropout(p=dropout), |
| 143 | + nn.Conv3d(1024, num_classes, kernel_size=1, stride=1, bias=True), |
| 144 | + ) |
| 145 | + |
| 146 | + def forward(self, x): |
| 147 | + x = self.features(x) |
| 148 | + x = self.avgpool(x) |
| 149 | + x = self.classifier(x) |
| 150 | + x = torch.mean(x, dim=(2, 3, 4)) |
| 151 | + return x |
| 152 | + |
| 153 | + |
| 154 | +class S3D_Weights(WeightsEnum): |
| 155 | + KINETICS400_V1 = Weights( |
| 156 | + url="https://download.pytorch.org/models/s3d-1bd8ae63.pth", |
| 157 | + transforms=partial( |
| 158 | + VideoClassification, |
| 159 | + crop_size=(224, 224), |
| 160 | + resize_size=(256, 256), |
| 161 | + mean=(0.5, 0.5, 0.5), |
| 162 | + std=(0.5, 0.5, 0.5), |
| 163 | + ), |
| 164 | + meta={ |
| 165 | + "min_size": (224, 224), |
| 166 | + "min_temporal_size": 14, |
| 167 | + "categories": _KINETICS400_CATEGORIES, |
| 168 | + "recipe": "https://github.com/pytorch/vision/pull/6412#issuecomment-1219687434", |
| 169 | + "_docs": ( |
| 170 | + "The weights are ported from a community repository. The accuracies are estimated on clip-level " |
| 171 | + "with parameters `frame_rate=15`, `clips_per_video=1`, and `clip_len=128`." |
| 172 | + ), |
| 173 | + "num_params": 8320048, |
| 174 | + "_metrics": { |
| 175 | + "Kinetics-400": { |
| 176 | + "acc@1": 67.315, |
| 177 | + "acc@5": 87.593, |
| 178 | + } |
| 179 | + }, |
| 180 | + }, |
| 181 | + ) |
| 182 | + DEFAULT = KINETICS400_V1 |
| 183 | + |
| 184 | + |
| 185 | +@register_model() |
| 186 | +def s3d(*, weights: Optional[S3D_Weights] = None, progress: bool = True, **kwargs: Any) -> S3D: |
| 187 | + """Construct Separable 3D CNN model. |
| 188 | +
|
| 189 | + Reference: `Rethinking Spatiotemporal Feature Learning <https://arxiv.org/abs/1712.04851>`__. |
| 190 | +
|
| 191 | + Args: |
| 192 | + weights (:class:`~torchvision.models.video.S3D_Weights`, optional): The |
| 193 | + pretrained weights to use. See |
| 194 | + :class:`~torchvision.models.video.S3D_Weights` |
| 195 | + below for more details, and possible values. By default, no |
| 196 | + pre-trained weights are used. |
| 197 | + progress (bool): If True, displays a progress bar of the download to stderr. Default is True. |
| 198 | + **kwargs: parameters passed to the ``torchvision.models.video.S3D`` base class. |
| 199 | + Please refer to the `source code |
| 200 | + <https://github.com/pytorch/vision/blob/main/torchvision/models/video/s3d.py>`_ |
| 201 | + for more details about this class. |
| 202 | +
|
| 203 | + .. autoclass:: torchvision.models.video.S3D_Weights |
| 204 | + :members: |
| 205 | + """ |
| 206 | + weights = S3D_Weights.verify(weights) |
| 207 | + |
| 208 | + if weights is not None: |
| 209 | + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) |
| 210 | + |
| 211 | + model = S3D(**kwargs) |
| 212 | + |
| 213 | + if weights is not None: |
| 214 | + model.load_state_dict(weights.get_state_dict(progress=progress)) |
| 215 | + |
| 216 | + return model |
0 commit comments