Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added PILToTensor and ConvertImageDtype classes in reference scripts #4495

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions references/detection/presets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch

import transforms as T


Expand All @@ -6,21 +8,24 @@ def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123., 117., 104.)):
if data_augmentation == 'hflip':
self.transforms = T.Compose([
T.RandomHorizontalFlip(p=hflip_prob),
T.ToTensor(),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
])
elif data_augmentation == 'ssd':
self.transforms = T.Compose([
T.RandomPhotometricDistort(),
T.RandomZoomOut(fill=list(mean)),
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.ToTensor(),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
])
elif data_augmentation == 'ssdlite':
self.transforms = T.Compose([
T.RandomIoUCrop(),
T.RandomHorizontalFlip(p=hflip_prob),
T.ToTensor(),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
])
else:
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
Expand Down
22 changes: 20 additions & 2 deletions references/detection/transforms.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import List, Tuple, Dict, Optional

import torch
import torchvision

from torch import nn, Tensor
from torchvision.transforms import functional as F
from torchvision.transforms import transforms as T
from typing import List, Tuple, Dict, Optional


def _flip_coco_person_keypoints(kps, width):
Expand Down Expand Up @@ -52,6 +52,24 @@ def forward(self, image: Tensor,
return image, target


class PILToTensor(nn.Module):
def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F.pil_to_tensor(image)
return image, target


class ConvertImageDtype(nn.Module):
def __init__(self, dtype: torch.dtype) -> None:
super().__init__()
self.dtype = dtype

def forward(self, image: Tensor,
target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
image = F.convert_image_dtype(image, self.dtype)
return image, target


class RandomIoUCrop(nn.Module):
def __init__(self, min_scale: float = 0.3, max_scale: float = 1.0, min_aspect_ratio: float = 0.5,
max_aspect_ratio: float = 2.0, sampler_options: Optional[List[float]] = None, trials: int = 40):
Expand Down
8 changes: 6 additions & 2 deletions references/segmentation/presets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch

import transforms as T


Expand All @@ -11,7 +13,8 @@ def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.4
trans.append(T.RandomHorizontalFlip(hflip_prob))
trans.extend([
T.RandomCrop(crop_size),
T.ToTensor(),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
])
self.transforms = T.Compose(trans)
Expand All @@ -24,7 +27,8 @@ class SegmentationPresetEval:
def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
self.transforms = T.Compose([
T.RandomResize(base_size, base_size),
T.ToTensor(),
T.PILToTensor(),
T.ConvertImageDtype(torch.float),
T.Normalize(mean=mean, std=std),
])

Expand Down
15 changes: 11 additions & 4 deletions references/segmentation/transforms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
from PIL import Image
import random

import numpy as np
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
Expand Down Expand Up @@ -75,14 +74,22 @@ def __call__(self, image, target):
return image, target


class ToTensor(object):
class PILToTensor:
def __call__(self, image, target):
image = F.pil_to_tensor(image)
image = F.convert_image_dtype(image)
target = torch.as_tensor(np.array(target), dtype=torch.int64)
return image, target


class ConvertImageDtype:
def __init__(self, dtype):
self.dtype = dtype

def __call__(self, image, target):
image = F.convert_image_dtype(image, self.dtype)
return image, target


class Normalize(object):
def __init__(self, mean, std):
self.mean = mean
Expand Down