Skip to content

Commit 2547a44

Browse files
committed
backward_hook for Grad-CAM
1 parent 4512276 commit 2547a44

10 files changed

+182
-0
lines changed

Code/4_viewer/6_hook_for_grad_cam.py

+182
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# coding: utf-8
2+
"""
3+
通过实现Grad-CAM学习module中的forward_hook和backward_hook函数
4+
"""
5+
import cv2
6+
import os
7+
import numpy as np
8+
from PIL import Image
9+
import torch
10+
import torch.nn as nn
11+
import torch.nn.functional as F
12+
import torchvision.transforms as transforms
13+
14+
15+
class Net(nn.Module):
16+
def __init__(self):
17+
super(Net, self).__init__()
18+
self.conv1 = nn.Conv2d(3, 6, 5)
19+
self.pool1 = nn.MaxPool2d(2, 2)
20+
self.conv2 = nn.Conv2d(6, 16, 5)
21+
self.pool2 = nn.MaxPool2d(2, 2)
22+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
23+
self.fc2 = nn.Linear(120, 84)
24+
self.fc3 = nn.Linear(84, 10)
25+
26+
def forward(self, x):
27+
x = self.pool1(F.relu(self.conv1(x)))
28+
x = self.pool1(F.relu(self.conv2(x)))
29+
x = x.view(-1, 16 * 5 * 5)
30+
x = F.relu(self.fc1(x))
31+
x = F.relu(self.fc2(x))
32+
x = self.fc3(x)
33+
return x
34+
35+
36+
def img_transform(img_in, transform):
37+
"""
38+
将img进行预处理,并转换成模型输入所需的形式—— B*C*H*W
39+
:param img_roi: np.array
40+
:return:
41+
"""
42+
img = img_in.copy()
43+
img = Image.fromarray(np.uint8(img))
44+
img = transform(img)
45+
img = img.unsqueeze(0) # C*H*W --> B*C*H*W
46+
return img
47+
48+
49+
def img_preprocess(img_in):
50+
"""
51+
读取图片,转为模型可读的形式
52+
:param img_in: ndarray, [H, W, C]
53+
:return: PIL.image
54+
"""
55+
img = img_in.copy()
56+
img = cv2.resize(img,(32, 32))
57+
img = img[:, :, ::-1] # BGR --> RGB
58+
transform = transforms.Compose([
59+
transforms.ToTensor(),
60+
transforms.Normalize([0.4948052, 0.48568845, 0.44682974], [0.24580306, 0.24236229, 0.2603115])
61+
])
62+
img_input = img_transform(img, transform)
63+
return img_input
64+
65+
66+
def backward_hook(module, grad_in, grad_out):
67+
grad_block.append(grad_out[0].detach())
68+
69+
70+
def farward_hook(module, input, output):
71+
fmap_block.append(output)
72+
73+
74+
def show_cam_on_image(img, mask, out_dir):
75+
heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
76+
heatmap = np.float32(heatmap) / 255
77+
cam = heatmap + np.float32(img)
78+
cam = cam / np.max(cam)
79+
80+
path_cam_img = os.path.join(out_dir, "cam.jpg")
81+
path_raw_img = os.path.join(out_dir, "raw.jpg")
82+
if not os.path.exists(out_dir):
83+
os.makedirs(out_dir)
84+
cv2.imwrite(path_cam_img, np.uint8(255 * cam))
85+
cv2.imwrite(path_raw_img, np.uint8(255 * img))
86+
87+
88+
def comp_class_vec(ouput_vec, index=None):
89+
"""
90+
计算类向量
91+
:param ouput_vec: tensor
92+
:param index: int,指定类别
93+
:return: tensor
94+
"""
95+
if not index:
96+
index = np.argmax(ouput_vec.cpu().data.numpy())
97+
index = index[np.newaxis, np.newaxis]
98+
index = torch.from_numpy(index)
99+
one_hot = torch.zeros(1, 10).scatter_(1, index, 1)
100+
one_hot.requires_grad = True
101+
class_vec = torch.sum(one_hot * output) # one_hot = 11.8605
102+
103+
return class_vec
104+
105+
106+
def gen_cam(feature_map, grads):
107+
"""
108+
依据梯度和特征图,生成cam
109+
:param feature_map: np.array, in [C, H, W]
110+
:param grads: np.array, in [C, H, W]
111+
:return: np.array, [H, W]
112+
"""
113+
cam = np.zeros(feature_map.shape[1:], dtype=np.float32) # cam shape (H, W)
114+
115+
weights = np.mean(grads, axis=(1, 2)) #
116+
117+
for i, w in enumerate(weights):
118+
cam += w * feature_map[i, :, :]
119+
120+
cam = np.maximum(cam, 0)
121+
cam = cv2.resize(cam, (32, 32))
122+
cam -= np.min(cam)
123+
cam /= np.max(cam)
124+
125+
return cam
126+
127+
128+
if __name__ == '__main__':
129+
130+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
131+
path_img = os.path.join(BASE_DIR, "../../Data/cam_img/", "test_img_1.png")
132+
path_net = os.path.join(BASE_DIR, "../../Data/", "net_params_72p.pkl")
133+
output_dir = os.path.join(BASE_DIR, "../../Result/backward_hook_cam/")
134+
135+
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
136+
fmap_block = list()
137+
grad_block = list()
138+
139+
# 图片读取;网络加载
140+
img = cv2.imread(path_img, 1) # H*W*C
141+
img_input = img_preprocess(img)
142+
net = Net()
143+
net.load_state_dict(torch.load(path_net))
144+
145+
# 注册hook
146+
net.conv2.register_forward_hook(farward_hook)
147+
net.conv2.register_backward_hook(backward_hook)
148+
149+
# forward
150+
output = net(img_input)
151+
idx = np.argmax(output.cpu().data.numpy())
152+
print("predict: {}".format(classes[idx]))
153+
154+
# backward
155+
net.zero_grad()
156+
class_loss = comp_class_vec(output)
157+
class_loss.backward()
158+
159+
# 生成cam
160+
grads_val = grad_block[0].cpu().data.numpy().squeeze()
161+
fmap = fmap_block[0].cpu().data.numpy().squeeze()
162+
cam = gen_cam(fmap, grads_val)
163+
164+
# 保存cam图片
165+
img_show = np.float32(cv2.resize(img, (32, 32))) / 255
166+
show_cam_on_image(img_show, cam, output_dir)
167+
168+
169+
170+
171+
172+
173+
174+
175+
176+
177+
178+
179+
180+
181+
182+

Data/cam_img/.DS_Store

6 KB
Binary file not shown.

Data/cam_img/test_img_1.png

7.68 KB
Loading

Data/cam_img/test_img_2.png

7.93 KB
Loading

Data/cam_img/test_img_3.png

30.2 KB
Loading

Data/cam_img/test_img_4.png

10.3 KB
Loading

Data/cam_img/test_img_5.png

25.7 KB
Loading

Data/cam_img/test_img_6.png

16.5 KB
Loading

Data/cam_img/test_img_7.png

18.2 KB
Loading

Data/cam_img/test_img_8.png

23 KB
Loading

0 commit comments

Comments
 (0)