Skip to content

Commit d84bcba

Browse files
committed
update
1 parent ec4adc0 commit d84bcba

File tree

668 files changed

+52968
-1
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

668 files changed

+52968
-1
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,4 @@ If you find our paper and/or code helpful, please consider citing:
9999
year={2024},
100100
organization={IEEE}
101101
}
102-
```
102+
```

app/__init__.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""
2+
# Copyright (c) 2022, salesforce.com, inc.
3+
# All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6+
"""
7+
8+
from PIL import Image
9+
import requests
10+
11+
import streamlit as st
12+
import torch
13+
14+
15+
@st.cache()
16+
def load_demo_image():
17+
img_url = (
18+
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
19+
)
20+
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
21+
return raw_image
22+
23+
# lyz modifies cuda
24+
25+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26+
27+
# cache_root = "/export/home/.cache/lavis/"
28+
29+
cache_root = "/data/xcg/lavis_data/.cache/"
30+

app/calculate_coco_features.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""
2+
# Copyright (c) 2022, salesforce.com, inc.
3+
# All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6+
"""
7+
8+
from PIL import Image
9+
import requests
10+
import torch
11+
12+
import os
13+
14+
from lavis.common.registry import registry
15+
from lavis.processors import *
16+
from lavis.models import *
17+
from lavis.common.utils import build_default_model
18+
19+
# lyz modifies cuda
20+
21+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22+
23+
24+
def load_demo_image():
25+
img_url = (
26+
"https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
27+
)
28+
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
29+
30+
return raw_image
31+
32+
33+
def read_img(filepath):
34+
raw_image = Image.open(filepath).convert("RGB")
35+
36+
return raw_image
37+
38+
39+
# model
40+
model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth"
41+
feature_extractor = BlipFeatureExtractor(pretrained=model_url)
42+
43+
feature_extractor.eval()
44+
feature_extractor = feature_extractor.to(device)
45+
46+
# preprocessors
47+
vis_processor = BlipImageEvalProcessor(image_size=224)
48+
text_processor = BlipCaptionProcessor()
49+
50+
# files to process
51+
# file_root = "/export/home/.cache/lavis/coco/images/val2014"
52+
# file_root = "/export/home/.cache/lavis/coco/images/train2014"
53+
file_root = "/data/xcg/lavis_data/coco/images/train2014"
54+
filepaths = os.listdir(file_root)
55+
56+
print(len(filepaths))
57+
58+
caption = "dummy"
59+
60+
path2feat = dict()
61+
bsz = 256
62+
63+
images_in_batch = []
64+
filepaths_in_batch = []
65+
66+
for i, filename in enumerate(filepaths):
67+
if i % bsz == 0 and i > 0:
68+
images_in_batch = torch.cat(images_in_batch, dim=0).to(device)
69+
with torch.no_grad():
70+
image_features = feature_extractor(
71+
images_in_batch, caption, mode="image", normalized=True
72+
)[:, 0]
73+
74+
for filepath, image_feat in zip(filepaths_in_batch, image_features):
75+
path2feat[os.path.basename(filepath)] = image_feat.detach().cpu()
76+
77+
images_in_batch = []
78+
filepaths_in_batch = []
79+
80+
print(len(path2feat), image_features.shape)
81+
else:
82+
filepath = os.path.join(file_root, filename)
83+
84+
image = read_img(filepath)
85+
image = vis_processor(image).unsqueeze(0)
86+
87+
images_in_batch.append(image)
88+
filepaths_in_batch.append(filepath)
89+
90+
torch.save(path2feat, "path2feat_coco_train2014.pth")

app/caption.py

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""
2+
# Copyright (c) 2022, salesforce.com, inc.
3+
# All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6+
"""
7+
8+
import streamlit as st
9+
from app import device, load_demo_image
10+
from app.utils import load_model_cache
11+
from lavis.processors import load_processor
12+
from PIL import Image
13+
14+
15+
def app():
16+
# ===== layout =====
17+
model_type = st.sidebar.selectbox("Model:", ["BLIP_base", "BLIP_large"])
18+
19+
sampling_method = st.sidebar.selectbox(
20+
"Sampling method:", ["Beam search", "Nucleus sampling"]
21+
)
22+
23+
st.markdown(
24+
"<h1 style='text-align: center;'>Image Description Generation</h1>",
25+
unsafe_allow_html=True,
26+
)
27+
28+
instructions = """Try the provided image or upload your own:"""
29+
file = st.file_uploader(instructions)
30+
31+
use_beam = sampling_method == "Beam search"
32+
33+
col1, col2 = st.columns(2)
34+
35+
if file:
36+
raw_img = Image.open(file).convert("RGB")
37+
else:
38+
raw_img = load_demo_image()
39+
40+
col1.header("Image")
41+
42+
w, h = raw_img.size
43+
scaling_factor = 720 / w
44+
resized_image = raw_img.resize((int(w * scaling_factor), int(h * scaling_factor)))
45+
46+
col1.image(resized_image, use_column_width=True)
47+
col2.header("Description")
48+
49+
cap_button = st.button("Generate")
50+
51+
# ==== event ====
52+
vis_processor = load_processor("blip_image_eval").build(image_size=384)
53+
54+
if cap_button:
55+
if model_type.startswith("BLIP"):
56+
blip_type = model_type.split("_")[1].lower()
57+
model = load_model_cache(
58+
"blip_caption",
59+
model_type=f"{blip_type}_coco",
60+
is_eval=True,
61+
device=device,
62+
)
63+
64+
img = vis_processor(raw_img).unsqueeze(0).to(device)
65+
captions = generate_caption(
66+
model=model, image=img, use_nucleus_sampling=not use_beam
67+
)
68+
69+
col2.write("\n\n".join(captions), use_column_width=True)
70+
71+
72+
def generate_caption(
73+
model, image, use_nucleus_sampling=False, num_beams=3, max_length=40, min_length=5
74+
):
75+
samples = {"image": image}
76+
77+
captions = []
78+
if use_nucleus_sampling:
79+
for _ in range(5):
80+
caption = model.generate(
81+
samples,
82+
use_nucleus_sampling=True,
83+
max_length=max_length,
84+
min_length=min_length,
85+
top_p=0.9,
86+
)
87+
captions.append(caption[0])
88+
else:
89+
caption = model.generate(
90+
samples,
91+
use_nucleus_sampling=False,
92+
num_beams=num_beams,
93+
max_length=max_length,
94+
min_length=min_length,
95+
)
96+
captions.append(caption[0])
97+
98+
return captions

0 commit comments

Comments
 (0)