-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathpretrain_datasets.py
102 lines (85 loc) · 3.53 KB
/
pretrain_datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from copy import deepcopy
import os
from typing import List, Tuple
from PIL import Image
import pandas as pd
import torch
from torch.utils.data import Dataset
import tokenizers
import random
def pil_loader(path: str) -> Image.Image:
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
class MultimodalBertDataset(Dataset):
def __init__(
self,
data_root,
transform,
max_caption_length: int = 100
):
self.max_caption_length = max_caption_length
self.data_root = data_root
self.transform = transform
self.images_list, self.report_list = self.read_csv()
self.tokenizer = tokenizers.Tokenizer.from_file("mimic_wordpiece.json")
self.idxtoword = {v: k for k, v in self.tokenizer.get_vocab().items()}
def __len__(self):
return len(self.images_list)
def _random_mask(self,tokens):
masked_tokens = deepcopy(tokens)
for i in range(1, masked_tokens.shape[1]-1):
if masked_tokens[0][i] == 0:
break
if masked_tokens[0][i-1] == 3 and self.idxtoword[masked_tokens[0][i].item()][0:2] == '##':
masked_tokens[0][i] = 3
continue
if masked_tokens[0][i-1] != 3 and self.idxtoword[masked_tokens[0][i].item()][0:2] == '##':
continue
prob = random.random()
if prob < 0.5:
masked_tokens[0][i] = 3
return masked_tokens
def __getitem__(self, index):
image = pil_loader(self.images_list[index])
image = self.transform(image)
sent = self.report_list[index]
sent = '[CLS] '+ sent
self.tokenizer.enable_truncation(max_length=self.max_caption_length)
self.tokenizer.enable_padding(length=self.max_caption_length)
encoded = self.tokenizer.encode(sent)
ids = torch.tensor(encoded.ids).unsqueeze(0)
attention_mask = torch.tensor(encoded.attention_mask).unsqueeze(0)
type_ids = torch.tensor(encoded.type_ids).unsqueeze(0)
masked_ids = self._random_mask(ids)
return image, ids, attention_mask, type_ids, masked_ids
def read_csv(self):
csv_path = os.path.join(self.data_root,'training.csv')
df = pd.read_csv(csv_path,sep=',')
return df["image_path"], df["report_content"]
def collate_fn(self, instances: List[Tuple]):
image_list, ids_list, attention_mask_list, type_ids_list, masked_ids_list = [], [], [], [], []
# flattern
for b in instances:
image, ids, attention_mask, type_ids, masked_ids = b
image_list.append(image)
ids_list.append(ids)
attention_mask_list.append(attention_mask)
type_ids_list.append(type_ids)
masked_ids_list.append(masked_ids)
# stack
image_stack = torch.stack(image_list)
ids_stack = torch.stack(ids_list).squeeze()
attention_mask_stack = torch.stack(attention_mask_list).squeeze()
type_ids_stack = torch.stack(type_ids_list).squeeze()
masked_ids_stack = torch.stack(masked_ids_list).squeeze()
# sort and add to dictionary
return_dict = {
"image": image_stack,
"labels": ids_stack,
"attention_mask": attention_mask_stack,
"type_ids": type_ids_stack,
"ids": masked_ids_stack
}
return return_dict