Skip to content

Commit 4b61ba1

Browse files
author
Mike
committed
upgrade ConvMLP
1 parent b266cc9 commit 4b61ba1

9 files changed

+2088
-1029
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
# Copyright (c) 2021 PPViT Authors. All Rights Reserved.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
14+
"""Auto Augmentation"""
15+
16+
import random
17+
import numpy as np
18+
from PIL import Image, ImageEnhance, ImageOps
19+
20+
21+
def auto_augment_policy_original():
22+
"""ImageNet auto augment policy"""
23+
policy = [
24+
[('Posterize', 0.4, 8), ('Rotate', 0.6, 9)],
25+
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
26+
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
27+
[('Posterize', 0.6, 7), ('Posterize', 0.6, 6)],
28+
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
29+
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
30+
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
31+
[('Posterize', 0.8, 5), ('Equalize', 1.0, 2)],
32+
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
33+
[('Equalize', 0.6, 8), ('Posterize', 0.4, 6)],
34+
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
35+
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
36+
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
37+
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
38+
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
39+
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
40+
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
41+
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
42+
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
43+
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
44+
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
45+
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
46+
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
47+
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
48+
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
49+
]
50+
policy = [[SubPolicy(*args) for args in subpolicy] for subpolicy in policy]
51+
return policy
52+
53+
54+
class AutoAugment():
55+
"""Auto Augment
56+
Randomly choose a tuple of augment ops from a list of policy
57+
Then apply the tuple of augment ops to input image
58+
"""
59+
def __init__(self, policy):
60+
self.policy = policy
61+
62+
def __call__(self, image, policy_idx=None):
63+
if policy_idx is None:
64+
policy_idx = random.randint(0, len(self.policy)-1)
65+
66+
sub_policy = self.policy[policy_idx]
67+
for op in sub_policy:
68+
image = op(image)
69+
return image
70+
71+
72+
class SubPolicy:
73+
"""Subpolicy
74+
Read augment name and magnitude, apply augment with probability
75+
Args:
76+
op_name: str, augment operation name
77+
prob: float, if prob > random prob, apply augment
78+
magnitude_idx: int, index of magnitude in preset magnitude ranges
79+
"""
80+
def __init__(self, op_name, prob, magnitude_idx):
81+
# ranges of operations' magnitude
82+
ranges = {
83+
'ShearX': np.linspace(0, 0.3, 10), # [-0.3, 0.3] (by random negative)
84+
'ShearY': np.linspace(0, 0.3, 10), # [-0.3, 0.3] (by random negative)
85+
'TranslateX': np.linspace(0, 150 / 331, 10), #[-0.45, 0.45] (by random negative)
86+
'TranslateY': np.linspace(0, 150 / 331, 10), #[-0.45, 0.45] (by random negative)
87+
'Rotate': np.linspace(0, 30, 10), #[-30, 30] (by random negative)
88+
'Color': np.linspace(0, 0.9, 10), #[-0.9, 0.9] (by random negative)
89+
'Posterize': np.round(np.linspace(8, 4, 10), 0).astype(np.int), #[0, 4]
90+
'Solarize': np.linspace(256, 0, 10), #[0, 256]
91+
'Contrast': np.linspace(0, 0.9, 10), #[-0.9, 0.9] (by random negative)
92+
'Sharpness': np.linspace(0, 0.9, 10), #[-0.9, 0.9] (by random negative)
93+
'Brightness': np.linspace(0, 0.9, 10), #[-0.9, 0.9] (by random negative)
94+
'AutoContrast': [0] * 10, # no range
95+
'Equalize': [0] * 10, # no range
96+
'Invert': [0] * 10, # no range
97+
}
98+
99+
# augmentation operations
100+
# Lambda is not pickleable for DDP
101+
#image_ops = {
102+
# 'ShearX': lambda image, magnitude: shear_x(image, magnitude),
103+
# 'ShearY': lambda image, magnitude: shear_y(image, magnitude),
104+
# 'TranslateX': lambda image, magnitude: translate_x(image, magnitude),
105+
# 'TranslateY': lambda image, magnitude: translate_y(image, magnitude),
106+
# 'Rotate': lambda image, magnitude: rotate(image, magnitude),
107+
# 'AutoContrast': lambda image, magnitude: auto_contrast(image, magnitude),
108+
# 'Invert': lambda image, magnitude: invert(image, magnitude),
109+
# 'Equalize': lambda image, magnitude: equalize(image, magnitude),
110+
# 'Solarize': lambda image, magnitude: solarize(image, magnitude),
111+
# 'Posterize': lambda image, magnitude: posterize(image, magnitude),
112+
# 'Contrast': lambda image, magnitude: contrast(image, magnitude),
113+
# 'Color': lambda image, magnitude: color(image, magnitude),
114+
# 'Brightness': lambda image, magnitude: brightness(image, magnitude),
115+
# 'Sharpness': lambda image, magnitude: sharpness(image, magnitude),
116+
#}
117+
image_ops = {
118+
'ShearX': shear_x,
119+
'ShearY': shear_y,
120+
'TranslateX': translate_x_relative,
121+
'TranslateY': translate_y_relative,
122+
'Rotate': rotate,
123+
'AutoContrast': auto_contrast,
124+
'Invert': invert,
125+
'Equalize': equalize,
126+
'Solarize': solarize,
127+
'Posterize': posterize,
128+
'Contrast': contrast,
129+
'Color': color,
130+
'Brightness': brightness,
131+
'Sharpness': sharpness,
132+
}
133+
134+
self.prob = prob
135+
self.magnitude = ranges[op_name][magnitude_idx]
136+
self.op = image_ops[op_name]
137+
138+
def __call__(self, image):
139+
if self.prob > random.random():
140+
image = self.op(image, self.magnitude)
141+
return image
142+
143+
144+
# PIL Image transforms
145+
# https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.transform
146+
def shear_x(image, magnitude, fillcolor=(128, 128, 128)):
147+
factor = magnitude * random.choice([-1, 1]) # random negative
148+
return image.transform(image.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), fillcolor=fillcolor)
149+
150+
151+
def shear_y(image, magnitude, fillcolor=(128, 128, 128)):
152+
factor = magnitude * random.choice([-1, 1]) # random negative
153+
return image.transform(image.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), fillcolor=fillcolor)
154+
155+
156+
def translate_x_relative(image, magnitude, fillcolor=(128, 128, 128)):
157+
pixels = magnitude * image.size[0]
158+
pixels = pixels * random.choice([-1, 1]) # random negative
159+
return image.transform(image.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), fillcolor=fillcolor)
160+
161+
162+
def translate_y_relative(image, magnitude, fillcolor=(128, 128, 128)):
163+
pixels = magnitude * image.size[0]
164+
pixels = pixels * random.choice([-1, 1]) # random negative
165+
return image.transform(image.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), fillcolor=fillcolor)
166+
167+
168+
def translate_x_absolute(image, magnitude, fillcolor=(128, 128, 128)):
169+
magnitude = magnitude * random.choice([-1, 1]) # random negative
170+
return image.transform(image.size, Image.AFFINE, (1, 0, magnitude, 0, 1, 0), fillcolor=fillcolor)
171+
172+
173+
def translate_y_absolute(image, magnitude, fillcolor=(128, 128, 128)):
174+
magnitude = magnitude * random.choice([-1, 1]) # random negative
175+
return image.transform(image.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude), fillcolor=fillcolor)
176+
177+
178+
def rotate(image, magnitude):
179+
rot = image.convert("RGBA").rotate(magnitude)
180+
return Image.composite(rot,
181+
Image.new('RGBA', rot.size, (128, ) * 4),
182+
rot).convert(image.mode)
183+
184+
185+
def auto_contrast(image, magnitude=None):
186+
return ImageOps.autocontrast(image)
187+
188+
189+
def invert(image, magnitude=None):
190+
return ImageOps.invert(image)
191+
192+
193+
def equalize(image, magnitude=None):
194+
return ImageOps.equalize(image)
195+
196+
197+
def solarize(image, magnitude):
198+
return ImageOps.solarize(image, magnitude)
199+
200+
201+
def posterize(image, magnitude):
202+
return ImageOps.posterize(image, magnitude)
203+
204+
205+
def contrast(image, magnitude):
206+
magnitude = magnitude * random.choice([-1, 1]) # random negative
207+
return ImageEnhance.Contrast(image).enhance(1 + magnitude)
208+
209+
210+
def color(image, magnitude):
211+
magnitude = magnitude * random.choice([-1, 1]) # random negative
212+
return ImageEnhance.Color(image).enhance(1 + magnitude)
213+
214+
215+
def brightness(image, magnitude):
216+
magnitude = magnitude * random.choice([-1, 1]) # random negative
217+
return ImageEnhance.Brightness(image).enhance(1 + magnitude)
218+
219+
220+
def sharpness(image, magnitude):
221+
magnitude = magnitude * random.choice([-1, 1]) # random negative
222+
return ImageEnhance.Sharpness(image).enhance(1 + magnitude)
223+

0 commit comments

Comments
 (0)