forked from BR-IDL/PaddleViT
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathxcit.py
596 lines (518 loc) · 19.6 KB
/
xcit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
# Copyright (c) 2021 PPViT Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Implement Transformer Class for XCiT
"""
import math
from functools import partial
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from drop import DropPath
trunc_normal_ = nn.initializer.TruncatedNormal(std=0.02)
zeros_ = nn.initializer.Constant(value=0.0)
ones_ = nn.initializer.Constant(value=1.0)
class Mlp(nn.Layer):
"""MLP module
MLP using nn.Linear and activation is GELU, dropout is applied.
Ops: fc1 -> act -> dropout -> fc2 -> dropout
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Identity(nn.Layer):
"""Identity layer
The output of this layer is the input without any change.
Use this layer to avoid if condition in some forward methods
"""
def __init__(self):
super().__init__()
def forward(self, inputs):
return inputs
class PositionalEncodingFourier(nn.Layer):
"""
Positional encoding relying on a fourier kernel matching the one used in the
"Attention is all of Need" paper.
"""
def __init__(self, hidden_dim=32, dim=768, temperature=10000):
super().__init__()
self.token_projection = nn.Conv2D(hidden_dim * 2, dim, kernel_size=1)
self.scale = 2 * math.pi
self.temperature = temperature
self.hidden_dim = hidden_dim
self.dim = dim
def forward(self, B, H, W):
mask = paddle.zeros([B, H, W]).astype("bool")
not_mask = paddle.logical_not(mask)
y_embed = not_mask.cumsum(1, dtype="float32")
x_embed = not_mask.cumsum(2, dtype="float32")
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = paddle.arange(self.hidden_dim, dtype="int64")
dim_t = self.temperature ** (2 * (dim_t // 2) / self.hidden_dim)
pos_x = x_embed.unsqueeze(3) / dim_t
pos_y = y_embed.unsqueeze(3) / dim_t
pos_x = paddle.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), axis=4).flatten(3)
pos_y = paddle.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), axis=4).flatten(3)
pos = paddle.concat((pos_y, pos_x), axis=3).transpose([0, 3, 1, 2])
pos = self.token_projection(pos)
return pos
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return paddle.nn.Sequential(
nn.Conv2D(in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias_attr=False),
nn.BatchNorm2D(out_planes))
class ConvPatchEmbed(nn.Layer):
""" Image to Patch Embedding using multiple convolutional layers
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = (img_size, img_size)
patch_size = (patch_size, patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
if patch_size[0] == 16:
self.proj = paddle.nn.Sequential(
conv3x3(3, embed_dim // 8, 2),
nn.GELU(),
conv3x3(embed_dim // 8, embed_dim // 4, 2),
nn.GELU(),
conv3x3(embed_dim // 4, embed_dim // 2, 2),
nn.GELU(),
conv3x3(embed_dim // 2, embed_dim, 2),
)
elif patch_size[0] == 8:
self.proj = paddle.nn.Sequential(
conv3x3(3, embed_dim // 4, 2),
nn.GELU(),
conv3x3(embed_dim // 4, embed_dim // 2, 2),
nn.GELU(),
conv3x3(embed_dim // 2, embed_dim, 2),
)
else:
raise ValueError("For convolutional projection, patch size has to be in [8, 16]")
def forward(self, x, padding_size=None):
B, C, H, W = x.shape
x = self.proj(x)
Hp, Wp = x.shape[2], x.shape[3]
x = x.flatten(2).transpose([0, 2, 1])
return x, (Hp, Wp)
class LPI(nn.Layer):
"""
Local Patch Interaction module that allows explicit communication between tokens in 3x3 windows
to augment the implicit communcation performed by the block diagonal scatter attention.
Implemented using 2 layers of separable 3x3 convolutions with GeLU and BatchNorm2d
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
kernel_size=3):
super().__init__()
out_features = out_features or in_features
padding = kernel_size // 2
self.conv1 = paddle.nn.Conv2D(
in_features,
out_features,
kernel_size=kernel_size,
padding=padding,
groups=out_features,
)
self.act = act_layer()
self.bn = nn.BatchNorm2D(in_features)
self.conv2 = paddle.nn.Conv2D(
in_features,
out_features,
kernel_size=kernel_size,
padding=padding,
groups=out_features,
)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose([0, 2, 1]).reshape([B, C, H, W])
x = self.conv1(x)
x = self.act(x)
x = self.bn(x)
x = self.conv2(x)
x = x.reshape([B, C, N]).transpose([0, 2, 1])
return x
class ClassAttention(nn.Layer):
"""Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239
"""
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape([B, N, 3, self.num_heads, C // self.num_heads])
qkv = qkv.transpose([2, 0, 3, 1, 4])
# make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv[0], qkv[1], qkv[2]
qc = q[:, :, 0:1] # CLS token
attn_cls = (qc * k).sum(axis=-1) * self.scale
attn_cls = F.softmax(attn_cls, axis=-1)
attn_cls = self.attn_drop(attn_cls)
cls_tkn = (attn_cls.unsqueeze(2) @ v).transpose([0, 1, 3, 2]).reshape([B, 1, C])
cls_tkn = self.proj(cls_tkn)
x = paddle.concat([self.proj_drop(cls_tkn), x[:, 1:]], axis=1)
return x
class ClassAttentionBlock(nn.Layer):
"""Class Attention Layer as in CaiT https://arxiv.org/abs/2103.17239
"""
def __init__(self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
eta=None,
tokens_norm=False):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = ClassAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
# LayerScale Initialization (no layerscale when None)
if eta is not None:
self.gamma1 = paddle.create_parameter(
shape=[dim],
dtype="float32",
default_initializer=nn.initializer.Constant(value=eta),
)
self.gamma2 = paddle.create_parameter(
shape=[dim],
dtype="float32",
default_initializer=nn.initializer.Constant(value=eta),
)
else:
self.gamma1, self.gamma2 = 1.0, 1.0
# A hack for models pre-trained with layernorm over all the tokens not just the CLS
self.tokens_norm = tokens_norm
def forward(self, x, H, W, mask=None):
x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))
if self.tokens_norm:
x = self.norm2(x)
else:
x[:, 0:1] = self.norm2(x[:, 0:1])
x_res = x
cls_token = x[:, 0:1]
cls_token = self.gamma2 * self.mlp(cls_token)
x = paddle.concat([cls_token, x[:, 1:]], axis=1)
x = x_res + self.drop_path(x)
return x
class XCA(nn.Layer):
""" Cross-Covariance Attention (XCA) operation where the channels are updated using a weighted
sum. The weights are obtained from the (softmax normalized) Cross-covariance
matrix (Q^T K \\in d_h \\times d_h)
"""
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0):
super().__init__()
self.num_heads = num_heads
# self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.temperature = paddle.create_parameter(
shape=[num_heads, 1, 1], dtype="float32", default_initializer=ones_
)
self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape([B, N, 3, self.num_heads, C // self.num_heads])
qkv = qkv.transpose([2, 0, 3, 1, 4])
# make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv[0], qkv[1], qkv[2]
q = q.transpose([0, 1, 3, 2])
k = k.transpose([0, 1, 3, 2])
v = v.transpose([0, 1, 3, 2])
q = paddle.nn.functional.normalize(q, axis=-1)
k = paddle.nn.functional.normalize(k, axis=-1)
attn = (q @ k.transpose([0, 1, 3, 2])) * self.temperature
attn = F.softmax(attn, axis=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose([0, 3, 1, 2]).reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x
class XCABlock(nn.Layer):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
num_tokens=196,
eta=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = XCA(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.norm3 = norm_layer(dim)
self.local_mp = LPI(in_features=dim, act_layer=act_layer)
self.gamma1 = paddle.create_parameter(
shape=[dim],
dtype="float32",
default_initializer=nn.initializer.Constant(value=eta),
)
self.gamma2 = paddle.create_parameter(
shape=[dim],
dtype="float32",
default_initializer=nn.initializer.Constant(value=eta),
)
self.gamma3 = paddle.create_parameter(
shape=[dim],
dtype="float32",
default_initializer=nn.initializer.Constant(value=eta),
)
# self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
# self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
# self.gamma3 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
def forward(self, x, H, W):
x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.gamma3 * self.local_mp(self.norm3(x), H, W))
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
return x
class XCiT(nn.Layer):
"""
Based on timm and DeiT code bases
https://github.com/rwightman/pytorch-image-models/tree/master/timm
https://github.com/facebookresearch/deit/
"""
def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
norm_layer=partial(nn.LayerNorm, epsilon=1e-6),
cls_attn_layers=2,
use_pos=True,
patch_proj="linear",
eta=None,
tokens_norm=False):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
norm_layer: (nn.Module): normalization layer
cls_attn_layers: (int) Depth of Class attention layers
use_pos: (bool) whether to use positional encoding
eta: (float) layerscale initialization value
tokens_norm: (bool) Whether to normalize all tokens or just the cls_token in the CA
"""
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
norm_layer = norm_layer or partial(nn.LayerNorm, epsilson=1e-6)
self.patch_embed = ConvPatchEmbed(
img_size=img_size, embed_dim=embed_dim, patch_size=patch_size
)
num_patches = self.patch_embed.num_patches
# self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.cls_token = paddle.create_parameter(
shape=[1, 1, embed_dim], dtype="float32", default_initializer=trunc_normal_
)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [drop_path_rate for i in range(depth)]
self.blocks = nn.LayerList(
[
XCABlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
num_tokens=num_patches,
eta=eta,
)
for i in range(depth)
]
)
self.cls_attn_blocks = nn.LayerList(
[
ClassAttentionBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
norm_layer=norm_layer,
eta=eta,
tokens_norm=tokens_norm,
)
for i in range(cls_attn_layers)
]
)
self.norm = norm_layer(embed_dim)
self.pos_embeder = PositionalEncodingFourier(dim=embed_dim)
self.use_pos = use_pos
self.head = (
nn.Linear(self.num_features, num_classes) if num_classes > 0 else Identity()
)
# Classifier head
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
zeros_(m.bias)
elif isinstance(m, nn.LayerNorm):
zeros_(m.bias)
ones_(m.weight)
def forward_features(self, x):
B = x.shape[0]
x, (Hp, Wp) = self.patch_embed(x)
if self.use_pos:
pos_encoding = (
self.pos_embeder(B, Hp, Wp)
.reshape([B, -1, x.shape[1]])
.transpose([0, 2, 1])
)
x = x + pos_encoding
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x, Hp, Wp)
cls_tokens = self.cls_token.expand([B, -1, -1])
x = paddle.concat((cls_tokens, x), axis=1)
for blk in self.cls_attn_blocks:
x = blk(x, Hp, Wp)
x = self.norm(x)[:, 0]
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def build_xcit(config):
model = XCiT(
img_size=config.DATA.IMAGE_SIZE,
patch_size=config.MODEL.TRANS.PATCH_SIZE,
embed_dim=config.MODEL.TRANS.EMBED_DIM,
num_classes=config.MODEL.NUM_CLASSES,
depth=config.MODEL.TRANS.DEPTH,
num_heads=config.MODEL.TRANS.NUM_HEADS,
eta=config.MODEL.TRANS.ETA,
tokens_norm=config.MODEL.TRANS.TOKENS_NORM,
)
return model