# Copyright (c) 2021 PPViT Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ This module implements SETR Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers <https://arxiv.org/pdf/2012.15840.pdf> """ import paddle import paddle.nn as nn from src.models.backbones import ViT_MLA, VisualTransformer from src.models.decoders import VIT_MLAHead, VIT_MLA_AUXIHead, VisionTransformerUpHead from src.utils import load_pretrained_model class SETR(nn.Layer): """ SETR SEgmentation TRansformer (SETR) has three diffrent decoder designs to perform pixl-level segmentation. The variants of SETR includes SETR_MLA, SETR_PUP, and SETR_Naive. Attributes: encoder: A backbone network for extract features from image. auxi_head: A boolena indicating if we employ the auxilary segmentation head. decoder_type: Type of decoder. decoder: A decoder module for semantic segmentation. """ def __init__(self, config): super(SETR, self).__init__() if config.MODEL.ENCODER.TYPE == "ViT_MLA": self.encoder = ViT_MLA(config) elif config.MODEL.ENCODER.TYPE == "ViT": self.encoder = VisualTransformer(config) self.auxi_head = config.MODEL.AUX.AUXIHEAD self.decoder_type = config.MODEL.DECODER_TYPE if self.decoder_type == "VIT_MLAHead": self.decoder = VIT_MLAHead( config.MODEL.MLA.MLA_CHANNELS, config.MODEL.MLA.MLAHEAD_CHANNELS, config.DATA.NUM_CLASSES, config.MODEL.MLA.MLAHEAD_ALIGN_CORNERS) self.auxi_head = config.MODEL.AUX.AUXIHEAD if self.auxi_head == True: self.aux_decoder2 = VIT_MLA_AUXIHead( config.MODEL.MLA.MLA_CHANNELS, config.DATA.NUM_CLASSES, config.MODEL.AUX.AUXHEAD_ALIGN_CORNERS) self.aux_decoder3 = VIT_MLA_AUXIHead( config.MODEL.MLA.MLA_CHANNELS, config.DATA.NUM_CLASSES, config.MODEL.AUX.AUXHEAD_ALIGN_CORNERS) self.aux_decoder4 = VIT_MLA_AUXIHead( config.MODEL.MLA.MLA_CHANNELS, config.DATA.NUM_CLASSES, config.MODEL.AUX.AUXHEAD_ALIGN_CORNERS) self.aux_decoder5 = VIT_MLA_AUXIHead( config.MODEL.MLA.MLA_CHANNELS, config.DATA.NUM_CLASSES, config.MODEL.AUX.AUXHEAD_ALIGN_CORNERS) elif (self.decoder_type == "PUP_VisionTransformerUpHead" or self.decoder_type == "Naive_VisionTransformerUpHead"): self.decoder = VisionTransformerUpHead( config.MODEL.PUP.INPUT_CHANNEL, config.MODEL.PUP.NUM_CONV, config.MODEL.PUP.NUM_UPSAMPLE_LAYER, config.MODEL.PUP.CONV3x3_CONV1x1, config.MODEL.PUP.ALIGN_CORNERS, config.DATA.NUM_CLASSES) if self.auxi_head == True: self.aux_decoder2 = VisionTransformerUpHead( config.MODEL.AUXPUP.INPUT_CHANNEL, config.MODEL.AUXPUP.NUM_CONV, config.MODEL.AUXPUP.NUM_UPSAMPLE_LAYER, config.MODEL.AUXPUP.CONV3x3_CONV1x1, config.MODEL.AUXPUP.ALIGN_CORNERS, config.DATA.NUM_CLASSES) self.aux_decoder3 = VisionTransformerUpHead( config.MODEL.AUXPUP.INPUT_CHANNEL, config.MODEL.AUXPUP.NUM_CONV, config.MODEL.AUXPUP.NUM_UPSAMPLE_LAYER, config.MODEL.AUXPUP.CONV3x3_CONV1x1, config.MODEL.AUXPUP.ALIGN_CORNERS, config.DATA.NUM_CLASSES) self.aux_decoder4 = VisionTransformerUpHead( config.MODEL.AUXPUP.INPUT_CHANNEL, config.MODEL.AUXPUP.NUM_CONV, config.MODEL.AUXPUP.NUM_UPSAMPLE_LAYER, config.MODEL.AUXPUP.CONV3x3_CONV1x1, config.MODEL.AUXPUP.ALIGN_CORNERS, config.DATA.NUM_CLASSES) if self.decoder_type == "PUP_VisionTransformerUpHead": self.aux_decoder5 = VisionTransformerUpHead( config.MODEL.AUXPUP.INPUT_CHANNEL, config.MODEL.AUXPUP.NUM_CONV, config.MODEL.AUXPUP.NUM_UPSAMPLE_LAYER, config.MODEL.AUXPUP.CONV3x3_CONV1x1, config.MODEL.AUXPUP.ALIGN_CORNERS, config.DATA.NUM_CLASSES) self.init__decoder_lr_coef(config) def init__decoder_lr_coef(self, config): #print("self.decoder.sublayers(): ", self.decoder.sublayers()) for sublayer in self.decoder.sublayers(): #print("F sublayer: ", sublayer) if isinstance(sublayer, nn.Conv2D): #print("sublayer: ", sublayer) sublayer.weight.optimize_attr['learning_rate'] = config.TRAIN.DECODER_LR_COEF if sublayer.bias is not None: sublayer.bias.optimize_attr['learning_rate'] = config.TRAIN.DECODER_LR_COEF if (isinstance(sublayer, nn.SyncBatchNorm) or isinstance(sublayer, nn.BatchNorm2D) or isinstance(sublayer,nn.LayerNorm)): #print("SyncBN, BatchNorm2D, or LayerNorm") #print("sublayer: ", sublayer) sublayer.weight.optimize_attr['learning_rate'] = config.TRAIN.DECODER_LR_COEF sublayer.bias.optimize_attr['learning_rate'] = config.TRAIN.DECODER_LR_COEF if self.auxi_head == True: sublayers = [] # list of list sublayers.append(self.aux_decoder2.sublayers()) sublayers.append(self.aux_decoder3.sublayers()) sublayers.append(self.aux_decoder4.sublayers()) if self.decoder_type == "PUP_VisionTransformerUpHead": sublayers.append(self.aux_decoder5.sublayers()) #print("self.aux_decoders.sublayers(): ", sublayers) for sublayer_list in sublayers: for sublayer in sublayer_list: if isinstance(sublayer, nn.Conv2D): #print("sublayer: ", sublayer) sublayer.weight.optimize_attr['learning_rate'] = config.TRAIN.DECODER_LR_COEF if sublayer.bias is not None: sublayer.bias.optimize_attr['learning_rate'] = config.TRAIN.DECODER_LR_COEF def forward(self, imgs): # imgs.shapes: (B,3,H,W) p2, p3, p4, p5 = self.encoder(imgs) preds = [] if self.decoder_type == "VIT_MLAHead": pred = self.decoder(p2, p3, p4, p5) elif (self.decoder_type == "PUP_VisionTransformerUpHead" or self.decoder_type == "Naive_VisionTransformerUpHead"): pred = self.decoder(p5) preds.append(pred) if self.auxi_head == True: preds.append(self.aux_decoder2(p2)) preds.append(self.aux_decoder3(p3)) preds.append(self.aux_decoder4(p4)) if self.decoder_type == "PUP_VisionTransformerUpHead": preds.append(self.aux_decoder5(p5)) return preds