-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathmodeling_flamingo.py
939 lines (797 loc) · 34.4 KB
/
modeling_flamingo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
import random
from dataclasses import dataclass
from typing import Optional, Callable
import torch
import torch.nn as nn
from transformers.modeling_utils import PreTrainedModel
from transformers.models.auto import AutoModelForCausalLM, AutoTokenizer, AutoModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer
from einops import rearrange, repeat
from accelerate.hooks import add_hook_to_module, AlignDevicesHook
from .configuration_flamingo import FlamingoConfig
__KNOWN_DECODER_LAYERS_ATTR_NAMES = {
"opt": "model.decoder.layers",
"gptneo": "transformer.h",
"gptj": "transformer.h",
"gpt-j": "transformer.h",
"pythia": "gpt_neox.layers",
"llama": "model.layers",
}
def _infer_decoder_layers_attr_name(model: nn.Module):
for k in __KNOWN_DECODER_LAYERS_ATTR_NAMES:
if k.lower() in model.__class__.__name__.lower():
return __KNOWN_DECODER_LAYERS_ATTR_NAMES[k]
raise ValueError(
f"We require the attribute name for the nn.ModuleList in the decoder storing the transformer block layers. Please supply this string manually."
)
def extend_instance(obj, mixin):
"""Apply mixins to a class instance after creation"""
base_cls = obj.__class__
base_cls_name = obj.__class__.__name__
obj.__class__ = type(
base_cls_name, (mixin, base_cls), {}
) # mixin needs to go first for our forward() logic to work
def getattr_recursive(obj, att):
"""
Return nested attribute of obj
Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
"""
if att == "":
return obj
i = att.find(".")
if i < 0:
return getattr(obj, att)
else:
return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
def setattr_recursive(obj, att, val):
"""
Set nested attribute of obj
Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
"""
if "." in att:
obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
setattr(obj, att.split(".")[-1], val)
def exists(val):
return val is not None
class FlamingoPerceiverBlock(nn.Module):
def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8, mult: int = 4):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
inner_dim = dim_head * heads
ff_dim = dim * mult
self.norm_media = nn.LayerNorm(dim)
self.norm_latents = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
self.feed_forward = nn.ModuleList(
[
nn.LayerNorm(dim),
nn.Linear(dim, ff_dim, bias=False),
nn.GELU(),
nn.Linear(ff_dim, dim, bias=False),
]
)
def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): image features
shape (b, T, n1, D)
latent (torch.Tensor): latent features
shape (b, T, n2, D)
"""
x = self.norm_media(x)
residual_latents = latents
latents = self.norm_latents(latents)
h = self.heads
q = self.to_q(latents)
kv_input = torch.cat((x, latents), dim=-2)
k, v = self.to_kv(kv_input).chunk(2, dim=-1)
q = rearrange(q, "b t n (h d) -> b h t n d", h=h)
k = rearrange(k, "b t n (h d) -> b h t n d", h=h)
v = rearrange(v, "b t n (h d) -> b h t n d", h=h)
q = q * self.scale
# attention
sim = torch.einsum("... i d, ... j d -> ... i j", q, k)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = torch.einsum("... i j, ... j d -> ... i d", attn, v)
out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
out = self.to_out(out) + residual_latents
residual_out = out
for layer in self.feed_forward:
out = layer(out)
return out + residual_out
class FlamingoPerceiverResampler(nn.Module):
def __init__(
self,
*,
dim: int,
depth: int = 6,
dim_head: int = 64,
heads: int = 8,
num_latents: int = 64,
max_num_media: Optional[int] = None,
max_num_frames: Optional[int] = None,
ff_mult: int = 4,
):
super().__init__()
self.latents = nn.Parameter(torch.randn(num_latents, dim))
self.frame_embs = (
nn.Parameter(torch.randn(max_num_frames, dim))
if exists(max_num_frames)
else None
)
self.media_time_embs = (
nn.Parameter(torch.randn(max_num_media, 1, dim))
if exists(max_num_media)
else None
)
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(
FlamingoPerceiverBlock(
dim=dim, dim_head=dim_head, heads=heads, mult=ff_mult
)
)
self.norm = nn.LayerNorm(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): image features
shape (b, T, F, v, D)
Returns:
shape (b, T, n, D) where n is self.num_latents
"""
b, T, F, v = x.shape[:4]
# frame and media time embeddings
if exists(self.frame_embs):
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
x = x + frame_embs
x = rearrange(
x, "b T F v d -> b T (F v) d"
) # flatten the frame and spatial dimensions
if exists(self.media_time_embs):
x = x + self.media_time_embs[:T]
# blocks
latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
for block in self.layers:
latents = block(x, latents)
return self.norm(latents)
class FlamingoMaskedCrossAttention(nn.Module):
def __init__(
self,
*,
dim: int,
dim_visual: int,
dim_head: int = 64,
heads: int = 8,
only_attend_immediate_media: bool = True,
only_attend_previous: bool = True,
):
super().__init__()
self.scale = dim_head**-0.5
self.heads = heads
inner_dim = dim_head * heads
self.norm = nn.LayerNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_kv = nn.Linear(dim_visual, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
# whether for text to only attend to immediate preceding image, or all previous images
self.only_attend_immediate_media = only_attend_immediate_media
self.only_attend_previous = only_attend_previous
def forward(
self,
x: torch.Tensor,
media: torch.Tensor,
media_locations: Optional[torch.BoolTensor] = None,
attend_previous: bool = True,
) -> torch.Tensor:
"""
Args:
x (torch.Tensor): text features
shape (B, T_txt, D_txt)
media (torch.Tensor): image features
shape (B, T_img, n, D_img) where n is the dim of the latents
media_locations: boolean mask identifying the media tokens in x
shape (B, T_txt)
attend_previous: bool
If false, ignores immediately preceding image and starts attending when following image
"""
_, T_img, n = media.shape[:3]
h = self.heads
x = self.norm(x)
q = self.to_q(x)
media = rearrange(media, "b t n d -> b (t n) d")
k, v = self.to_kv(media).chunk(2, dim=-1)
q = rearrange(q, "b n (h d) -> b h n d", h=h)
k = rearrange(k, "b n (h d) -> b h n d", h=h)
v = rearrange(v, "b n (h d) -> b h n d", h=h)
q = q * self.scale
sim = torch.einsum("... i d, ... j d -> ... i j", q, k)
if exists(media_locations):
# at each boolean of True, increment the time counter (relative to media time)
text_time = media_locations.cumsum(dim=-1)
media_time = torch.arange(T_img, device=x.device) + 1
if not attend_previous:
text_time[~media_locations] += 1
# make sure max is still the number of images in the sequence
text_time[
text_time
> repeat(
torch.count_nonzero(media_locations, dim=1),
"b -> b i",
i=text_time.shape[1],
)
] = 0
# text time must equal media time if only attending to most immediate image
# otherwise, as long as text time is greater than media time (if attending to all previous images / media)
mask_op = torch.eq if self.only_attend_immediate_media else torch.ge
text_to_media_mask = mask_op(
rearrange(text_time, "b i -> b 1 i 1"),
repeat(media_time, "j -> 1 1 1 (j n)", n=n),
)
sim = sim.masked_fill(~text_to_media_mask, -torch.finfo(sim.dtype).max)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
if exists(media_locations) and self.only_attend_immediate_media:
# any text without a preceding media needs to have attention zeroed out
text_without_media_mask = text_time == 0
text_without_media_mask = rearrange(
text_without_media_mask, "b i -> b 1 i 1"
)
attn = attn.masked_fill(text_without_media_mask, 0.0)
out = torch.einsum("... i j, ... j d -> ... i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class FlamingoGatedCrossAttentionBlock(nn.Module):
def __init__(
self,
*,
dim: int,
dim_visual: int,
dim_head: int = 64,
heads: int = 8,
ff_mult: int = 4,
only_attend_immediate_media: bool = True,
only_attend_previous: bool = True,
):
super().__init__()
self.attn = FlamingoMaskedCrossAttention(
dim=dim,
dim_visual=dim_visual,
dim_head=dim_head,
heads=heads,
only_attend_immediate_media=only_attend_immediate_media,
only_attend_previous=only_attend_previous,
)
self.attn_gate = nn.Parameter(torch.tensor([0.0]))
self.feed_forward = nn.ModuleList(
[
nn.LayerNorm(dim),
nn.Linear(dim, dim * ff_mult, bias=False),
nn.GELU(),
nn.Linear(dim * ff_mult, dim, bias=False),
]
)
self.ff_gate = nn.Parameter(torch.tensor([0.0]))
def forward(
self,
x: torch.Tensor,
media: torch.Tensor,
media_locations: Optional[torch.BoolTensor] = None,
attend_previous: bool = True,
) -> torch.Tensor:
x = (
self.attn(
x,
media,
media_locations=media_locations,
attend_previous=attend_previous,
)
* self.attn_gate.tanh()
+ x
)
residual_x = x
for ff in self.feed_forward:
x = ff(x)
x = x * self.ff_gate.tanh() + residual_x
return x
class FlamingoLayer(nn.Module):
def __init__(self, gated_cross_attn_layer: nn.Module, decoder_layer: nn.Module):
super().__init__()
self.gated_cross_attn_layer = gated_cross_attn_layer
self.decoder_layer = decoder_layer
self.vis_x = None
self.media_locations = None
def is_conditioned(self) -> bool:
"""Check whether the layer is conditioned."""
return self.vis_x is not None
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
def condition_vis_x(self, vis_x) -> None:
self.vis_x = vis_x
def condition_media_locations(self, media_locations) -> None:
self.media_locations = media_locations
def condition_attend_previous(self, attend_previous) -> None:
self.attend_previous = attend_previous
def forward(
self,
lang_x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**decoder_layer_kwargs,
):
if self.gated_cross_attn_layer is None:
return self.decoder_layer(
lang_x, attention_mask=attention_mask, **decoder_layer_kwargs
)
if self.vis_x is None:
raise ValueError("vis_x must be conditioned before forward pass")
if self.media_locations is None:
raise ValueError("media_locations must be conditioned before forward pass")
lang_x = self.gated_cross_attn_layer(
lang_x,
self.vis_x,
media_locations=self.media_locations,
attend_previous=self.attend_previous,
)
lang_x = self.decoder_layer(
lang_x, attention_mask=attention_mask, **decoder_layer_kwargs
)
return lang_x
class FlamingoLMMixin(nn.Module):
"""
Mixin to add cross-attention layers to a language model.
"""
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
self.decoder_layers_attr_name = decoder_layers_attr_name
def _get_decoder_layers(self):
return getattr_recursive(self, self.decoder_layers_attr_name)
def _set_decoder_layers(self, value):
setattr_recursive(self, self.decoder_layers_attr_name, value)
def init_flamingo(
self,
media_token_id: int,
vis_hidden_size: int,
cross_attn_every_n_layers: int,
use_media_placement_augmentation: bool,
only_attend_previous: bool,
):
"""
Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
"""
gated_cross_attn_layers = nn.ModuleList(
[
FlamingoGatedCrossAttentionBlock(
dim=self.config.hidden_size,
dim_visual=vis_hidden_size,
only_attend_previous=only_attend_previous,
)
if (layer_idx + 1) % cross_attn_every_n_layers == 0
else None
for layer_idx, _ in enumerate(self._get_decoder_layers())
]
)
self._set_decoder_layers(
nn.ModuleList(
[
FlamingoLayer(gated_cross_attn_layer, decoder_layer)
for gated_cross_attn_layer, decoder_layer in zip(
gated_cross_attn_layers, self._get_decoder_layers()
)
]
)
)
self.media_token_id = media_token_id
self.use_media_placement_augmentation = use_media_placement_augmentation
self.only_attend_previous = only_attend_previous
self.initialized_flamingo = True
def forward(self, *input, **kwargs):
"""Condition the Flamingo layers on the media locations before forward()"""
if not self.initialized_flamingo:
raise ValueError(
"Flamingo layers are not initialized. Please call `init_flamingo` first."
)
input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0]
media_locations = input_ids == self.media_token_id
# IMPORTANT: Force `attend_previous` to True when we place training data as <image>caption<|endofchunk|>
# attend_previous = (
# (random.random() < 0.5) if self.use_media_placement_augmentation else False
# )
attend_previous = self.only_attend_previous
for layer in self.get_decoder().layers:
layer.condition_media_locations(media_locations)
layer.condition_attend_previous(attend_previous)
return super().forward(
*input, **kwargs
) # Call the other parent's forward method
def is_conditioned(self) -> bool:
"""Check whether all decoder layers are already conditioned."""
return all(l.is_conditioned() for l in self._get_decoder_layers())
def clear_conditioned_layers(self) -> None:
for layer in self._get_decoder_layers():
layer.condition_vis_x(None)
layer.condition_media_locations(None)
layer.condition_attend_previous(None)
class FlamingoPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
config_class = FlamingoConfig
base_model_prefix = "flamingo"
supports_gradient_checkpointing = True
_no_split_modules = ["FlamingoPerceiverBlock", "CLIPEncoderLayer", "FlamingoLayer"]
def _init_weights(self, module):
"""Flamingo requires no specific initialization"""
return super()._init_weights(module)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, FlamingoModel):
module.gradient_checkpointing = value
class FlamingoModel(FlamingoPreTrainedModel):
config_class = FlamingoConfig
def __init__(
self,
config: FlamingoConfig,
):
super().__init__(config)
# TODO: hardcode right because autoXXX is too slow
# lang_encoder = AutoModelForCausalLM.from_config(config.text_config)
# text_tokenizer = AutoTokenizer.from_pretrained(config.text_config._name_or_path)
# vision_encoder = AutoModel.from_config(config.vision_config).vision_model
text_tokenizer = LlamaTokenizer.from_pretrained(
config.text_config._name_or_path
)
lang_encoder = LlamaForCausalLM(config=config.text_config)
vision_encoder = CLIPVisionModel(config=config.vision_config)
text_tokenizer.add_special_tokens(
{"additional_special_tokens": ["<|endofchunk|>", "<image>"]}
)
if text_tokenizer.pad_token is None:
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
self.text_tokenizer = text_tokenizer
self.eoc_token_id = text_tokenizer.encode("<|endofchunk|>")[-1]
self.media_token_id = text_tokenizer.encode("<image>")[-1]
extend_instance(lang_encoder, FlamingoLMMixin)
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
lang_encoder.resize_token_embeddings(len(text_tokenizer))
self.lang_encoder = lang_encoder
self.cross_attn_every_n_layers = config.cross_attn_every_n_layers
self.use_media_placement_augmentation = config.use_media_placement_augmentation
self.only_attend_previous = config.only_attend_previous
vision_encoder.output_tokens = True
self.vision_encoder = vision_encoder
self.vis_dim = 1024
self.perceiver = FlamingoPerceiverResampler(dim=self.vis_dim)
self.lang_encoder.init_flamingo(
media_token_id=self.media_token_id,
vis_hidden_size=self.vis_dim,
cross_attn_every_n_layers=self.cross_attn_every_n_layers,
use_media_placement_augmentation=self.use_media_placement_augmentation,
only_attend_previous=self.only_attend_previous,
)
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.lang_encoder.get_input_embeddings()
def set_input_embeddings(self, new_embeddings):
self.lang_encoder.set_input_embeddings(new_embeddings)
def get_output_embeddings(self) -> nn.Module:
return self.lang_encoder.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.lang_encoder.set_output_embeddings(new_embeddings)
def get_image_encoder(self) -> nn.Module:
return self.vision_encoder
def get_lang_encoder(self) -> nn.Module:
return self.lang_encoder
def tie_weights(self):
return super().tie_weights()
def init_weights(self):
# Freeze all parameters in vision encoder
for param in self.vision_encoder.parameters():
param.requires_grad = False
# Freeze all parameters in lang encoders except gated_cross_attn_layers
for name, param in self.lang_encoder.named_parameters():
if "gated_cross_attn_layer" not in name:
param.requires_grad = False
# Unfreeze LM input embeddings
self.lang_encoder.get_input_embeddings().requires_grad_(True)
self.lang_encoder.lm_head.requires_grad_(True)
def forward(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cached_vision_x: bool = False,
clear_conditioned_layers: bool = True,
past_key_values: Optional[torch.Tensor] = None,
use_cache: bool = False,
**kwargs,
) -> CausalLMOutputWithPast:
"""
Forward pass of Flamingo.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W) with F=1
lang_x (torch.Tensor): Language input ids
shape (B, T_txt)
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
labels (torch.Tensor, optional): Labels. Defaults to None.
clear_conditioned_layers: if True, clear the conditioned layers
once the foward pass is completed. Set this to false if the
same set of images will be reused in another subsequent
forward pass.
past_key_values: pre-computed values to pass to language model.
See past_key_values documentation in Hugging Face
CausalLM models.
use_cache: whether to use cached key values. See use_cache
documentation in Hugging Face CausalLM models.
"""
assert (
vision_x is not None
) or use_cached_vision_x, (
"Must provide either vision_x or use_cached_vision_x to True."
)
if use_cached_vision_x:
# Case: use cached; vision_x should be cached and other
# vision-related inputs should not be provided.
assert (
vision_x is None
), "Expect vision_x to be None when use_cached_vision_x is True."
assert self.lang_encoder.is_conditioned()
else:
# Case: do not use caching (i.e. this is a standard forward pass);
self._encode_vision_x(vision_x=vision_x)
output = self.lang_encoder(
input_ids=lang_x,
attention_mask=attention_mask,
labels=labels,
past_key_values=past_key_values,
use_cache=use_cache,
**kwargs,
)
if clear_conditioned_layers:
self.lang_encoder.clear_conditioned_layers()
return output
def _encode_vision_x(self, vision_x: torch.Tensor):
"""
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
Images in the same chunk are collated along T_img, and frames are collated along F
Currently only F=1 is supported (single-frame videos)
rearrange code based on https://github.com/dhansmair/flamingo-mini
"""
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
b, T, F = vision_x.shape[:3]
assert F == 1, "Only single frame supported"
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
with torch.no_grad():
vision_x = self.vision_encoder(vision_x)[0][:, 1:, :]
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
vision_x = self.perceiver(vision_x) # reshapes to (b, T, n, d)
for layer in self.lang_encoder._get_decoder_layers():
layer.condition_vis_x(vision_x)
class FlamingoForConditionalGeneration(FlamingoPreTrainedModel):
config_class = FlamingoConfig
def __init__(
self,
config: FlamingoConfig,
):
super().__init__(config)
# TODO: hardcode right because autoXXX is too slow
# vision_encoder = AutoModel.from_config(config.vision_config).vision_model
# lang_encoder = AutoModelForCausalLM.from_config(config.text_config)
# text_tokenizer = AutoTokenizer.from_pretrained(config.text_config._name_or_path)
text_tokenizer = LlamaTokenizer.from_pretrained(
config.text_config._name_or_path
)
lang_encoder = LlamaForCausalLM(config=config.text_config)
vision_encoder = CLIPVisionModel(config=config.vision_config)
text_tokenizer.add_special_tokens(
{"additional_special_tokens": ["<|endofchunk|>", "<image>"]}
)
if text_tokenizer.pad_token is None:
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
self.text_tokenizer = text_tokenizer
self.eoc_token_id = text_tokenizer.encode("<|endofchunk|>")[-1]
self.media_token_id = text_tokenizer.encode("<image>")[-1]
extend_instance(lang_encoder, FlamingoLMMixin)
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
lang_encoder.resize_token_embeddings(len(text_tokenizer))
self.lang_encoder = lang_encoder
self.cross_attn_every_n_layers = config.cross_attn_every_n_layers
self.use_media_placement_augmentation = config.use_media_placement_augmentation
self.only_attend_previous = config.only_attend_previous
vision_encoder.output_tokens = True
self.vision_encoder = vision_encoder
self.vis_dim = 1024
self.perceiver = FlamingoPerceiverResampler(dim=self.vis_dim)
self.lang_encoder.init_flamingo(
media_token_id=self.media_token_id,
vis_hidden_size=self.vis_dim,
cross_attn_every_n_layers=self.cross_attn_every_n_layers,
use_media_placement_augmentation=self.use_media_placement_augmentation,
only_attend_previous=self.only_attend_previous,
)
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.lang_encoder.get_input_embeddings()
def set_input_embeddings(self, new_embeddings):
self.lang_encoder.set_input_embeddings(new_embeddings)
def get_output_embeddings(self) -> nn.Module:
return self.lang_encoder.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.lang_encoder.set_output_embeddings(new_embeddings)
def get_image_encoder(self) -> nn.Module:
return self.vision_encoder
def get_lang_encoder(self) -> nn.Module:
return self.lang_encoder
def init_weights(self):
# Freeze all parameters in vision encoder
for param in self.vision_encoder.parameters():
param.requires_grad = False
# Freeze all parameters in lang encoders except gated_cross_attn_layers
for name, param in self.lang_encoder.named_parameters():
if "gated_cross_attn_layer" not in name:
param.requires_grad = False
# Unfreeze LM input embeddings
self.lang_encoder.get_input_embeddings().requires_grad_(True)
self.lang_encoder.lm_head.requires_grad_(True)
def forward(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cached_vision_x: bool = False,
clear_conditioned_layers: bool = True,
past_key_values: Optional[torch.Tensor] = None,
use_cache: bool = False,
**kwargs,
) -> CausalLMOutputWithPast:
"""
Forward pass of Flamingo.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W) with F=1
lang_x (torch.Tensor): Language input ids
shape (B, T_txt)
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
labels (torch.Tensor, optional): Labels. Defaults to None.
clear_conditioned_layers: if True, clear the conditioned layers
once the foward pass is completed. Set this to false if the
same set of images will be reused in another subsequent
forward pass.
past_key_values: pre-computed values to pass to language model.
See past_key_values documentation in Hugging Face
CausalLM models.
use_cache: whether to use cached key values. See use_cache
documentation in Hugging Face CausalLM models.
"""
assert (
vision_x is not None
) or use_cached_vision_x, (
"Must provide either vision_x or use_cached_vision_x to True."
)
if use_cached_vision_x:
# Case: use cached; vision_x should be cached and other
# vision-related inputs should not be provided.
assert (
vision_x is None
), "Expect vision_x to be None when use_cached_vision_x is True."
assert self.lang_encoder.is_conditioned()
else:
# Case: do not use caching (i.e. this is a standard forward pass);
self._encode_vision_x(vision_x=vision_x)
output = self.lang_encoder(
input_ids=lang_x,
attention_mask=attention_mask,
labels=labels,
past_key_values=past_key_values,
use_cache=use_cache,
**kwargs,
)
if clear_conditioned_layers:
self.lang_encoder.clear_conditioned_layers()
return output
def _encode_vision_x(self, vision_x: torch.Tensor):
"""
Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
Images in the same chunk are collated along T_img, and frames are collated along F
Currently only F=1 is supported (single-frame videos)
rearrange code based on https://github.com/dhansmair/flamingo-mini
"""
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
b, T, F = vision_x.shape[:3]
assert F == 1, "Only single frame supported"
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
with torch.no_grad():
vision_x = self.vision_encoder(vision_x)[0][:, 1:, :]
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
dtype = self.lang_encoder.lm_head.weight.dtype
vision_x = self.perceiver(vision_x.to(self.lang_encoder.device, dtype=dtype)) # reshapes to (b, T, n, d)
for layer in self.lang_encoder._get_decoder_layers():
layer.condition_vis_x(vision_x)
@torch.no_grad()
def generate(
self,
vision_x: torch.Tensor,
lang_x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
num_beams: int = 1,
max_new_tokens: Optional[int] = None,
temperature: float = 1.0,
top_k: int = 0,
top_p: float = 1.0,
no_repeat_ngram_size: int = 0,
prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], list[int]]
] = None,
length_penalty: float = 1.0,
num_return_sequences: int = 1,
do_sample: bool = False,
early_stopping: bool = False,
**kwargs,
):
"""
Generate text conditioned on vision and language inputs.
Args:
vision_x (torch.Tensor): Vision input
shape (B, T_img, F, C, H, W)
images in the same chunk are collated along T_img, and frames are collated along F
currently only F=1 is supported (single-frame videos)
lang_x (torch.Tensor): Language input
shape (B, T_txt)
max_length (int, optional): Maximum length of the output. Defaults to None.
attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
num_beams (int, optional): Number of beams. Defaults to 1.
max_new_tokens (int, optional): Maximum new tokens. Defaults to None.
temperature (float, optional): Temperature. Defaults to 1.0.
top_k (int, optional): Top k. Defaults to 0.
top_p (float, optional): Top p. Defaults to 1.0.
no_repeat_ngram_size (int, optional): No repeat ngram size. Defaults to 0.
length_penalty (float, optional): Length penalty. Defaults to 1.0.
num_return_sequences (int, optional): Number of return sequences. Defaults to 1.
do_sample (bool, optional): Do sample. Defaults to False.
early_stopping (bool, optional): Early stopping. Defaults to False.
Returns:
torch.Tensor: lang_x with generated tokens appended to it
"""
if hasattr(self, "_hf_hook"):
# add a hook to make sure that the output of lang_encoder is mapped to the same device as the lang_x
hook = AlignDevicesHook(
execution_device=lang_x.device,
io_same_device=True,
place_submodules=False,
)
add_hook_to_module(self.lang_encoder, hook)
if num_beams > 1:
vision_x = vision_x.repeat_interleave(num_beams, dim=0)
self._encode_vision_x(vision_x=vision_x)
output = self.lang_encoder.generate(
lang_x,
attention_mask=attention_mask,
eos_token_id=self.eoc_token_id,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
no_repeat_ngram_size=no_repeat_ngram_size,
length_penalty=length_penalty,
num_return_sequences=num_return_sequences,
do_sample=do_sample,
early_stopping=early_stopping,
**kwargs,
)
self.lang_encoder.clear_conditioned_layers()
return output