# Copyright (c) 2021 PPViT Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """utils for ViT Contains AverageMeter for monitoring, get_exclude_from_decay_fn for training and WarmupCosineScheduler for training """ import math from paddle.optimizer.lr import LRScheduler import numpy as np class MyPrint(): """" Print tensor and its shape, used for debug """ def __init__(self): self.cnt = 0 def myprint(self, prefix, var, cnt=None, save=None): """print tensor and its shape, optionly save to npy Args: prefix: str, print info in 1st and last lines var: Tensor, tensor needs to print cnt: int, if self.cnt is exceed this value, print will stop save: str, file name (should end with .npy) to save the tensor, if None no save """ if cnt is None or self.cnt < cnt: print(f'------------ {prefix} ---------------') print(var.shape, var) print(f'------------ END {prefix} ---------------') if save is not None: var = var.numpy() with open(save,'wb') as ofile: np.save(ofile, var) self.cnt += 1 class AverageMeter(): """ Meter for monitoring losses""" def __init__(self): self.avg = 0 self.sum = 0 self.cnt = 0 self.val = 0 self.reset() def reset(self): """reset all values to zeros""" self.avg = 0 self.sum = 0 self.cnt = 0 self.val = 0 def update(self, val, n=1): """update avg by val and n, where val is the avg of n values""" self.val = val self.sum += val * n self.cnt += n self.avg = self.sum / self.cnt def get_exclude_from_weight_decay_fn(exclude_list=[]): """ Set params with no weight decay during the training For certain params, e.g., positional encoding in ViT, weight decay may not needed during the learning, this method is used to find these params. Args: exclude_list: a list of params names which need to exclude from weight decay. Returns: exclude_from_weight_decay_fn: a function returns True if param will be excluded from weight decay """ if len(exclude_list) == 0: exclude_from_weight_decay_fn = None else: def exclude_fn(param): for name in exclude_list: if param.endswith(name): return False return True exclude_from_weight_decay_fn = exclude_fn return exclude_from_weight_decay_fn class WarmupCosineScheduler(LRScheduler): """Warmup Cosine Scheduler First apply linear warmup, then apply cosine decay schedule. Linearly increase learning rate from "warmup_start_lr" to "start_lr" over "warmup_epochs" Cosinely decrease learning rate from "start_lr" to "end_lr" over remaining "total_epochs - warmup_epochs" Attributes: learning_rate: the starting learning rate (without warmup), not used here! warmup_start_lr: warmup starting learning rate start_lr: the starting learning rate (without warmup) end_lr: the ending learning rate after whole loop warmup_epochs: # of epochs for warmup total_epochs: # of total epochs (include warmup) """ def __init__(self, learning_rate, warmup_start_lr, start_lr, end_lr, warmup_epochs, total_epochs, cycles=0.5, last_epoch=-1, verbose=False): """init WarmupCosineScheduler """ self.warmup_epochs = warmup_epochs self.total_epochs = total_epochs self.warmup_start_lr = warmup_start_lr self.start_lr = start_lr self.end_lr = end_lr self.cycles = cycles super(WarmupCosineScheduler, self).__init__(learning_rate, last_epoch, verbose) def get_lr(self): """ return lr value """ if self.last_epoch < self.warmup_epochs: val = (self.start_lr - self.warmup_start_lr) * float( self.last_epoch)/float(self.warmup_epochs) + self.warmup_start_lr return val progress = float(self.last_epoch - self.warmup_epochs) / float( max(1, self.total_epochs - self.warmup_epochs)) val = max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) val = max(0.0, val * (self.start_lr - self.end_lr) + self.end_lr) return val