-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtransform.py
111 lines (98 loc) · 3.21 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
Based on https://github.com/mlfoundations/open_clip
"""
from typing import Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from torchvision.transforms import (
Normalize,
Compose,
RandomResizedCrop,
InterpolationMode,
ToTensor,
Resize,
CenterCrop,
)
class ResizeMaxSize(nn.Module):
def __init__(
self, max_size, interpolation=InterpolationMode.BICUBIC, fn="max", fill=0
):
super().__init__()
if not isinstance(max_size, int):
raise TypeError(f"Size should be int. Got {type(max_size)}")
self.max_size = max_size
self.interpolation = interpolation
self.fn = min if fn == "min" else min
self.fill = fill
def forward(self, img):
if isinstance(img, torch.Tensor):
height, width = img.shape[:2]
else:
width, height = img.size
scale = self.max_size / float(max(height, width))
if scale != 1.0:
new_size = tuple(round(dim * scale) for dim in (height, width))
img = F.resize(img, new_size, self.interpolation)
pad_h = self.max_size - new_size[0]
pad_w = self.max_size - new_size[1]
img = F.pad(
img,
padding=[
pad_w // 2,
pad_h // 2,
pad_w - pad_w // 2,
pad_h - pad_h // 2,
],
fill=self.fill,
)
return img
def _convert_to_rgb(image):
return image.convert("RGB")
def image_transform(
image_size: int,
is_train: bool,
mean: Optional[Tuple[float, ...]] = None,
std: Optional[Tuple[float, ...]] = None,
resize_longest_max: bool = False,
fill_color: int = 0,
):
mean = mean or (0.48145466, 0.4578275, 0.40821073) # OpenAI dataset mean
std = std or (0.26862954, 0.26130258, 0.27577711) # OpenAI dataset std
if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
image_size = image_size[0]
normalize = Normalize(mean=mean, std=std)
if is_train:
return Compose(
[
RandomResizedCrop(
image_size,
scale=(0.9, 1.0),
interpolation=InterpolationMode.BICUBIC,
),
_convert_to_rgb,
ToTensor(),
normalize,
]
)
else:
if resize_longest_max:
transforms = [ResizeMaxSize(image_size, fill=fill_color)]
else:
transforms = [
Resize(image_size, interpolation=InterpolationMode.BICUBIC),
CenterCrop(image_size),
]
transforms.extend(
[
_convert_to_rgb,
ToTensor(),
normalize,
]
)
return Compose(transforms)