-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathtest_mplug_owl.py
87 lines (73 loc) · 3.56 KB
/
test_mplug_owl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
from .mplug_owl.processing_mplug_owl import MplugOwlProcessor, MplugOwlImageProcessor
from .mplug_owl.modeling_mplug_owl import MplugOwlForConditionalGeneration
from .mplug_owl.tokenization_mplug_owl import MplugOwlTokenizer
from transformers import AutoTokenizer
from . import get_image
prompt_template = '''
The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
Human: <image>
Human: {}
AI:
'''
generate_kwargs = {
'do_sample': True,
'top_k': 5,
'max_length': 512
}
class TestMplugOwl:
def __init__(self, device=None):
model_path='MAGAer13/mplug-owl-llama-7b'
self.model = MplugOwlForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float32)
self.image_processor = MplugOwlImageProcessor.from_pretrained(model_path)
# self.tokenizer = MplugOwlTokenizer.from_pretrained(model_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.processor = MplugOwlProcessor(self.image_processor, self.tokenizer)
# import re
# target_modules=r'.*language_model.*\.(q_proj|v_proj)' # refer to the pipeline/train.py in mPLUG-Owl repo
# all_param, trainable_params = 0, 0
# for name, param in self.model.named_parameters():
# all_param += param.numel()
# if re.match(target_modules, name):
# trainable_params += param.numel()
# print(all_param)
# print(trainable_params)
# exit(0)
self.model.eval()
if device is not None:
self.move_to_device(device)
def move_to_device(self, device=None):
if device is not None and 'cuda' in device.type:
self.device = device
if torch.cuda.is_bf16_supported():
self.dtype = torch.bfloat16
else:
self.dtype = torch.float16
else:
self.device = 'cpu'
self.dtype = torch.float32
self.model.to(device=self.device, dtype=self.dtype)
@torch.no_grad()
def generate(self, image, question):
prompts = [prompt_template.format(question)]
image = get_image(image)
inputs = self.processor(text=prompts, images=[image], return_tensors='pt')
inputs = {k: v.to(self.device, dtype=self.dtype) if v.dtype == torch.float else v for k, v in inputs.items()}
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
res = self.model.generate(**inputs, **generate_kwargs)
generated_text = self.tokenizer.decode(res.tolist()[0], skip_special_tokens=True)
return generated_text
@torch.no_grad()
def batch_generate(self, image_list, question_list):
images = [get_image(image) for image in image_list]
images = [self.image_processor(image, return_tensors='pt').pixel_values for image in images]
images = torch.cat(images, dim=0).to(self.device, dtype=self.dtype)
prompts = [prompt_template.format(question) for question in question_list]
inputs = self.processor(text=prompts)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
inputs["pixel_values"] = images
with torch.no_grad():
res = self.model.generate(**inputs, **generate_kwargs)
outputs = [self.tokenizer.decode(output, skip_special_tokens=True) for output in res.tolist()]
return outputs