-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy path__init__.py
80 lines (68 loc) · 2.51 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gc
import torch
import numpy as np
from PIL import Image
DATA_DIR = '/root/VLP_web_data'
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
def get_image(image):
if type(image) is str:
try:
return Image.open(image).convert("RGB")
except Exception as e:
print(f"Fail to read image: {image}")
exit(-1)
elif type(image) is Image.Image:
return image
else:
raise NotImplementedError(f"Invalid type of Image: {type(image)}")
def get_BGR_image(image):
image = get_image(image)
image = np.array(image)[:, :, ::-1]
image = Image.fromarray(np.uint8(image))
return image
def get_model(model_name, device=None):
if model_name == 'blip2':
from .test_blip2 import TestBlip2
return TestBlip2(device)
elif model_name == 'minigpt4':
from .test_minigpt4 import TestMiniGPT4
return TestMiniGPT4(device)
elif model_name == 'owl':
from .test_mplug_owl import TestMplugOwl
return TestMplugOwl(device)
elif model_name == 'otter':
from .test_otter import TestOtter
return TestOtter(device)
elif model_name == 'instruct_blip':
from .test_instructblip import TestInstructBLIP
return TestInstructBLIP(device)
elif model_name == 'vpgtrans':
from .test_vpgtrans import TestVPGTrans
return TestVPGTrans(device)
elif model_name == 'llava':
from .test_llava import TestLLaVA
return TestLLaVA(device)
elif model_name == 'llama_adapter_v2':
from .test_llama_adapter_v2 import TestLLamaAdapterV2, TestLLamaAdapterV2_web
return TestLLamaAdapterV2(device)
else:
raise ValueError(f"Invalid model_name: {model_name}")
def get_device_name(device: torch.device):
return f"{device.type}{'' if device.index is None else ':' + str(device.index)}"
@torch.inference_mode()
def generate_stream(model, text, image, device=None, keep_in_device=False):
image = np.array(image, dtype='uint8')
image = Image.fromarray(image.astype('uint8')).convert('RGB')
if device != model.device:
model.move_to_device(device)
output = model.generate(image, text, device, keep_in_device)
if not keep_in_device:
model.move_to_device(None)
print(f"{'#' * 20} Model out: {output}")
gc.collect()
torch.cuda.empty_cache()
yield output