|
| 1 | +# Copyright (c) 2021 PPViT Authors. All Rights Reserved. |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# |
| 6 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +# |
| 8 | +# Unless required by applicable law or agreed to in writing, software |
| 9 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 10 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 11 | +# See the License for the specific language governing permissions and |
| 12 | +# limitations under the License. |
| 13 | + |
| 14 | +"""Auto Augmentation""" |
| 15 | + |
| 16 | +import random |
| 17 | +import numpy as np |
| 18 | +from PIL import Image, ImageEnhance, ImageOps |
| 19 | + |
| 20 | + |
| 21 | +def auto_augment_policy_original(): |
| 22 | + """ImageNet auto augment policy""" |
| 23 | + policy = [ |
| 24 | + [('Posterize', 0.4, 8), ('Rotate', 0.6, 9)], |
| 25 | + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], |
| 26 | + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], |
| 27 | + [('Posterize', 0.6, 7), ('Posterize', 0.6, 6)], |
| 28 | + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], |
| 29 | + [('Equalize', 0.4, 4), ('Rotate', 0.8, 8)], |
| 30 | + [('Solarize', 0.6, 3), ('Equalize', 0.6, 7)], |
| 31 | + [('Posterize', 0.8, 5), ('Equalize', 1.0, 2)], |
| 32 | + [('Rotate', 0.2, 3), ('Solarize', 0.6, 8)], |
| 33 | + [('Equalize', 0.6, 8), ('Posterize', 0.4, 6)], |
| 34 | + [('Rotate', 0.8, 8), ('Color', 0.4, 0)], |
| 35 | + [('Rotate', 0.4, 9), ('Equalize', 0.6, 2)], |
| 36 | + [('Equalize', 0.0, 7), ('Equalize', 0.8, 8)], |
| 37 | + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], |
| 38 | + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], |
| 39 | + [('Rotate', 0.8, 8), ('Color', 1.0, 2)], |
| 40 | + [('Color', 0.8, 8), ('Solarize', 0.8, 7)], |
| 41 | + [('Sharpness', 0.4, 7), ('Invert', 0.6, 8)], |
| 42 | + [('ShearX', 0.6, 5), ('Equalize', 1.0, 9)], |
| 43 | + [('Color', 0.4, 0), ('Equalize', 0.6, 3)], |
| 44 | + [('Equalize', 0.4, 7), ('Solarize', 0.2, 4)], |
| 45 | + [('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)], |
| 46 | + [('Invert', 0.6, 4), ('Equalize', 1.0, 8)], |
| 47 | + [('Color', 0.6, 4), ('Contrast', 1.0, 8)], |
| 48 | + [('Equalize', 0.8, 8), ('Equalize', 0.6, 3)], |
| 49 | + ] |
| 50 | + policy = [[SubPolicy(*args) for args in subpolicy] for subpolicy in policy] |
| 51 | + return policy |
| 52 | + |
| 53 | + |
| 54 | +class AutoAugment(): |
| 55 | + """Auto Augment |
| 56 | + Randomly choose a tuple of augment ops from a list of policy |
| 57 | + Then apply the tuple of augment ops to input image |
| 58 | + """ |
| 59 | + def __init__(self, policy): |
| 60 | + self.policy = policy |
| 61 | + |
| 62 | + def __call__(self, image, policy_idx=None): |
| 63 | + if policy_idx is None: |
| 64 | + policy_idx = random.randint(0, len(self.policy)-1) |
| 65 | + |
| 66 | + sub_policy = self.policy[policy_idx] |
| 67 | + for op in sub_policy: |
| 68 | + image = op(image) |
| 69 | + return image |
| 70 | + |
| 71 | + |
| 72 | +class SubPolicy: |
| 73 | + """Subpolicy |
| 74 | + Read augment name and magnitude, apply augment with probability |
| 75 | + Args: |
| 76 | + op_name: str, augment operation name |
| 77 | + prob: float, if prob > random prob, apply augment |
| 78 | + magnitude_idx: int, index of magnitude in preset magnitude ranges |
| 79 | + """ |
| 80 | + def __init__(self, op_name, prob, magnitude_idx): |
| 81 | + # ranges of operations' magnitude |
| 82 | + ranges = { |
| 83 | + 'ShearX': np.linspace(0, 0.3, 10), # [-0.3, 0.3] (by random negative) |
| 84 | + 'ShearY': np.linspace(0, 0.3, 10), # [-0.3, 0.3] (by random negative) |
| 85 | + 'TranslateX': np.linspace(0, 150 / 331, 10), #[-0.45, 0.45] (by random negative) |
| 86 | + 'TranslateY': np.linspace(0, 150 / 331, 10), #[-0.45, 0.45] (by random negative) |
| 87 | + 'Rotate': np.linspace(0, 30, 10), #[-30, 30] (by random negative) |
| 88 | + 'Color': np.linspace(0, 0.9, 10), #[-0.9, 0.9] (by random negative) |
| 89 | + 'Posterize': np.round(np.linspace(8, 4, 10), 0).astype(np.int), #[0, 4] |
| 90 | + 'Solarize': np.linspace(256, 0, 10), #[0, 256] |
| 91 | + 'Contrast': np.linspace(0, 0.9, 10), #[-0.9, 0.9] (by random negative) |
| 92 | + 'Sharpness': np.linspace(0, 0.9, 10), #[-0.9, 0.9] (by random negative) |
| 93 | + 'Brightness': np.linspace(0, 0.9, 10), #[-0.9, 0.9] (by random negative) |
| 94 | + 'AutoContrast': [0] * 10, # no range |
| 95 | + 'Equalize': [0] * 10, # no range |
| 96 | + 'Invert': [0] * 10, # no range |
| 97 | + } |
| 98 | + |
| 99 | + # augmentation operations |
| 100 | + # Lambda is not pickleable for DDP |
| 101 | + #image_ops = { |
| 102 | + # 'ShearX': lambda image, magnitude: shear_x(image, magnitude), |
| 103 | + # 'ShearY': lambda image, magnitude: shear_y(image, magnitude), |
| 104 | + # 'TranslateX': lambda image, magnitude: translate_x(image, magnitude), |
| 105 | + # 'TranslateY': lambda image, magnitude: translate_y(image, magnitude), |
| 106 | + # 'Rotate': lambda image, magnitude: rotate(image, magnitude), |
| 107 | + # 'AutoContrast': lambda image, magnitude: auto_contrast(image, magnitude), |
| 108 | + # 'Invert': lambda image, magnitude: invert(image, magnitude), |
| 109 | + # 'Equalize': lambda image, magnitude: equalize(image, magnitude), |
| 110 | + # 'Solarize': lambda image, magnitude: solarize(image, magnitude), |
| 111 | + # 'Posterize': lambda image, magnitude: posterize(image, magnitude), |
| 112 | + # 'Contrast': lambda image, magnitude: contrast(image, magnitude), |
| 113 | + # 'Color': lambda image, magnitude: color(image, magnitude), |
| 114 | + # 'Brightness': lambda image, magnitude: brightness(image, magnitude), |
| 115 | + # 'Sharpness': lambda image, magnitude: sharpness(image, magnitude), |
| 116 | + #} |
| 117 | + image_ops = { |
| 118 | + 'ShearX': shear_x, |
| 119 | + 'ShearY': shear_y, |
| 120 | + 'TranslateX': translate_x_relative, |
| 121 | + 'TranslateY': translate_y_relative, |
| 122 | + 'Rotate': rotate, |
| 123 | + 'AutoContrast': auto_contrast, |
| 124 | + 'Invert': invert, |
| 125 | + 'Equalize': equalize, |
| 126 | + 'Solarize': solarize, |
| 127 | + 'Posterize': posterize, |
| 128 | + 'Contrast': contrast, |
| 129 | + 'Color': color, |
| 130 | + 'Brightness': brightness, |
| 131 | + 'Sharpness': sharpness, |
| 132 | + } |
| 133 | + |
| 134 | + self.prob = prob |
| 135 | + self.magnitude = ranges[op_name][magnitude_idx] |
| 136 | + self.op = image_ops[op_name] |
| 137 | + |
| 138 | + def __call__(self, image): |
| 139 | + if self.prob > random.random(): |
| 140 | + image = self.op(image, self.magnitude) |
| 141 | + return image |
| 142 | + |
| 143 | + |
| 144 | +# PIL Image transforms |
| 145 | +# https://pillow.readthedocs.io/en/stable/reference/Image.html#PIL.Image.Image.transform |
| 146 | +def shear_x(image, magnitude, fillcolor=(128, 128, 128)): |
| 147 | + factor = magnitude * random.choice([-1, 1]) # random negative |
| 148 | + return image.transform(image.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), fillcolor=fillcolor) |
| 149 | + |
| 150 | + |
| 151 | +def shear_y(image, magnitude, fillcolor=(128, 128, 128)): |
| 152 | + factor = magnitude * random.choice([-1, 1]) # random negative |
| 153 | + return image.transform(image.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), fillcolor=fillcolor) |
| 154 | + |
| 155 | + |
| 156 | +def translate_x_relative(image, magnitude, fillcolor=(128, 128, 128)): |
| 157 | + pixels = magnitude * image.size[0] |
| 158 | + pixels = pixels * random.choice([-1, 1]) # random negative |
| 159 | + return image.transform(image.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), fillcolor=fillcolor) |
| 160 | + |
| 161 | + |
| 162 | +def translate_y_relative(image, magnitude, fillcolor=(128, 128, 128)): |
| 163 | + pixels = magnitude * image.size[0] |
| 164 | + pixels = pixels * random.choice([-1, 1]) # random negative |
| 165 | + return image.transform(image.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), fillcolor=fillcolor) |
| 166 | + |
| 167 | + |
| 168 | +def translate_x_absolute(image, magnitude, fillcolor=(128, 128, 128)): |
| 169 | + magnitude = magnitude * random.choice([-1, 1]) # random negative |
| 170 | + return image.transform(image.size, Image.AFFINE, (1, 0, magnitude, 0, 1, 0), fillcolor=fillcolor) |
| 171 | + |
| 172 | + |
| 173 | +def translate_y_absolute(image, magnitude, fillcolor=(128, 128, 128)): |
| 174 | + magnitude = magnitude * random.choice([-1, 1]) # random negative |
| 175 | + return image.transform(image.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude), fillcolor=fillcolor) |
| 176 | + |
| 177 | + |
| 178 | +def rotate(image, magnitude): |
| 179 | + rot = image.convert("RGBA").rotate(magnitude) |
| 180 | + return Image.composite(rot, |
| 181 | + Image.new('RGBA', rot.size, (128, ) * 4), |
| 182 | + rot).convert(image.mode) |
| 183 | + |
| 184 | + |
| 185 | +def auto_contrast(image, magnitude=None): |
| 186 | + return ImageOps.autocontrast(image) |
| 187 | + |
| 188 | + |
| 189 | +def invert(image, magnitude=None): |
| 190 | + return ImageOps.invert(image) |
| 191 | + |
| 192 | + |
| 193 | +def equalize(image, magnitude=None): |
| 194 | + return ImageOps.equalize(image) |
| 195 | + |
| 196 | + |
| 197 | +def solarize(image, magnitude): |
| 198 | + return ImageOps.solarize(image, magnitude) |
| 199 | + |
| 200 | + |
| 201 | +def posterize(image, magnitude): |
| 202 | + return ImageOps.posterize(image, magnitude) |
| 203 | + |
| 204 | + |
| 205 | +def contrast(image, magnitude): |
| 206 | + magnitude = magnitude * random.choice([-1, 1]) # random negative |
| 207 | + return ImageEnhance.Contrast(image).enhance(1 + magnitude) |
| 208 | + |
| 209 | + |
| 210 | +def color(image, magnitude): |
| 211 | + magnitude = magnitude * random.choice([-1, 1]) # random negative |
| 212 | + return ImageEnhance.Color(image).enhance(1 + magnitude) |
| 213 | + |
| 214 | + |
| 215 | +def brightness(image, magnitude): |
| 216 | + magnitude = magnitude * random.choice([-1, 1]) # random negative |
| 217 | + return ImageEnhance.Brightness(image).enhance(1 + magnitude) |
| 218 | + |
| 219 | + |
| 220 | +def sharpness(image, magnitude): |
| 221 | + magnitude = magnitude * random.choice([-1, 1]) # random negative |
| 222 | + return ImageEnhance.Sharpness(image).enhance(1 + magnitude) |
| 223 | + |
0 commit comments