-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathT5VisionModelFrozen.py
33 lines (26 loc) · 1.52 KB
/
T5VisionModelFrozen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from math import comb
from transformers import T5Tokenizer, T5ForConditionalGeneration
from tqdm import tqdm
import clip
import os
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from architectures.T5VisionModel import T5VisionModel
from create_mapping import CrossModalMapping
class T5VisionModelFrozen(T5VisionModel):
def __init__(self, device, vision_encoder = "ViT-B/32", T5_version = "t5-small", max_source_length = 512, max_target_length = 128, use_image_info=True, vision_checkpoint=None, mapping_checkpoint=None, retrieval_function=None, use_quantifier=True):
super().__init__(device, vision_encoder =vision_encoder, T5_version = T5_version, max_source_length = max_source_length, max_target_length = max_target_length, use_image_info=use_image_info, vision_checkpoint=vision_checkpoint, mapping_checkpoint=mapping_checkpoint, retrieval_function=retrieval_function, use_quantifier=use_quantifier)
self.T5_model.encoder.requires_grad_(False)
self.T5_model.decoder.requires_grad_(False)
self.vision_model.requires_grad_(False)
self.T5_model.shared.requires_grad_(True)
trainable_params = 0
for para in self.T5_model.parameters():
if para.requires_grad:
trainable_params += np.prod(para.size())
print(f"Freezing T5 model to have {trainable_params} trainable parameters ...")