Skip to content

Commit db360db

Browse files
committed
Updated visualizations
1 parent e0a8322 commit db360db

File tree

4 files changed

+838
-196
lines changed

4 files changed

+838
-196
lines changed

captum/attr/_utils/visualization.py

+267-48
Original file line numberDiff line numberDiff line change
@@ -2,78 +2,297 @@
22

33
from enum import Enum
44
import numpy as np
5+
import warnings
56

7+
from IPython.core.display import display, HTML
68

7-
class VisualizeMethod(Enum):
9+
import matplotlib.pyplot as plt
10+
from matplotlib.colors import LinearSegmentedColormap
11+
from mpl_toolkits.axes_grid1 import make_axes_locatable
12+
13+
14+
class ImageVisualizeMethod(Enum):
815
heat_map = 1
9-
masked_image = 2
10-
alpha_scaled = 3
11-
blended_heat_map = 4
16+
blended_heat_map = 2
17+
original_image = 3
18+
masked_image = 4
19+
alpha_scaling = 5
1220

1321
class VisualizeSign(Enum):
1422
positive = 1
1523
absolute_value = 2
1624
negative = 3
1725
all = 4
1826

19-
green = [0, 255, 0]
20-
red = [255, 0, 0]
21-
blue = [0, 0, 255]
27+
def _prepare_image(attr_visual):
28+
return np.clip(attr_visual.astype(int), 0, 255)
2229

30+
def _normalize_scale(attr, scale_factor):
31+
if abs(scale_factor) < 1e-5:
32+
warnings.warn("Attempting to normalize by value approximately 0, skipping normalization. This likely means that attribution values are all close to 0.")
33+
attr_norm = attr / scale_factor
34+
return np.clip(attr_norm, -1, 1)
2335

24-
def visualize_image_attr(attr, original_image=None, channel_transpose=False, sign="positive", method="heat_map",outlier_perc=2):
25-
if channel_transpose:
26-
attr = np.transpose(attr, (1,2,0))
27-
if original_image is not None:
28-
original_image = np.transpose(original_image, (1,2,0))
36+
def _cumulative_sum_threshold(values, percentile):
37+
sorted_vals = np.sort(values.flatten())
38+
cum_sums = np.cumsum(sorted_vals)
39+
threshold_id = np.where(cum_sums >= cum_sums[-1] * 0.01 * percentile)[0][0]
40+
return sorted_vals[threshold_id]
2941

30-
# Combine RGB attribution channels
42+
def _normalize_image_attr(attr, sign, outlier_perc=2):
3143
attr_combined = np.sum(attr, axis=2)
32-
heat_map_color = None
33-
34-
if VisualizeMethod[method] == VisualizeMethod.blended_heat_map:
35-
assert original_image is not None, "Image must be provided for blended heat map."
36-
return (0.6 * np.expand_dims(np.mean(original_image, axis=2),axis=2) + 0.4 * visualize_image_attr(attr=attr, original_image=original_image, channel_transpose=False, sign=sign, method="heat_map",outlier_perc=outlier_perc)).astype(int)
37-
38-
if VisualizeSign[sign] == VisualizeSign.all:
39-
assert VisualizeMethod[method] == VisualizeMethod.heat_map, "Heat Map is the only supported visualization approach for both positive and negative attribution."
40-
return visualize_image_attr(attr=attr, original_image=original_image, channel_transpose=False, sign="positive", method=method,outlier_perc=outlier_perc) + visualize_image_attr(attr=attr, original_image=original_image, channel_transpose=False, sign="negative", method=method,outlier_perc=outlier_perc)
41-
4244
# Choose appropriate signed values and rescale, removing given outlier percentage.
43-
if VisualizeSign[sign] == VisualizeSign.positive:
45+
if VisualizeSign[sign] == VisualizeSign.all:
46+
attr_combined = _normalize_scale(attr_combined, _cumulative_sum_threshold(np.abs(attr_combined), 100 - outlier_perc))
47+
elif VisualizeSign[sign] == VisualizeSign.positive:
4448
attr_combined = (attr_combined > 0) * attr_combined
45-
attr_combined = attr_combined / np.percentile(attr_combined, 100 - outlier_perc)
46-
attr_combined[attr_combined > 1] = 1
47-
heat_map_color = green
49+
attr_combined = _normalize_scale(attr_combined, _cumulative_sum_threshold(attr_combined, 100 - outlier_perc))
4850
elif VisualizeSign[sign] == VisualizeSign.negative:
4951
attr_combined = (attr_combined < 0) * attr_combined
50-
attr_combined = attr_combined / np.percentile(attr_combined, outlier_perc)
51-
attr_combined[attr_combined > 1] = 1
52-
heat_map_color = red
52+
attr_combined = _normalize_scale(attr_combined, _cumulative_sum_threshold(attr_combined, outlier_perc))
5353
elif VisualizeSign[sign] == VisualizeSign.absolute_value:
5454
attr_combined = np.abs(attr_combined)
55-
attr_combined = attr_combined / np.percentile(attr_combined, 100 - outlier_perc)
56-
attr_combined[attr_combined > 1] = 1
57-
heat_map_color = blue
55+
attr_combined = _normalize_scale(attr_combined, _cumulative_sum_threshold(attr_combined, 100 - outlier_perc))
5856
else:
5957
raise AssertionError("Visualize Sign type is not valid.")
58+
return attr_combined
59+
60+
def visualize_image_attr(attr, original_image=None,method="heat_map",sign="absolute_value", plt_fig_axis=None,outlier_perc=2,cmap=None,alpha_overlay=0.5,show_colorbar=False, title = None, show_figure=True):
61+
r"""
62+
Visualizes attribution for a given image by normalizing attribution values
63+
of the desired sign (positive, negative, absolute value, or all) and displaying
64+
them using the desired mode in a matplotlib figure.
6065
61-
# Apply chosen visualization method.
62-
if VisualizeMethod[method] == VisualizeMethod.heat_map:
63-
return (np.expand_dims(attr_combined, 2) * heat_map_color).astype(int)
64-
elif VisualizeMethod[method] == VisualizeMethod.masked_image:
65-
assert original_image is not None, "Image must be provided for masking."
66-
assert np.shape(original_image)[:-1] == np.shape(attr_combined), "Image dimensions do not match attribution dimensions for masking."
67-
return (np.expand_dims(attr_combined, 2) * original_image).astype(int)
68-
elif VisualizeMethod[method] == VisualizeMethod.alpha_scaled:
69-
assert original_image is not None, "Image must be provided for masking."
70-
assert np.shape(original_image)[:-1] == np.shape(attr_combined), "Image dimensions do not match attribution dimensions for adding alpha channel."
71-
# Concatenate alpha channel and return
72-
return np.concatenate((original_image, (255*np.expand_dims(attr_combined, 2)).astype(int)), axis=2)
73-
elif VisualizeMethod[method] == VisualizeMethod.blended_heat_map:
74-
return np.concatenate((original_image, (255*np.expand_dims(attr_combined, 2)).astype(int)), axis=2)
66+
Args:
67+
68+
attr (numpy.array): Numpy array corresponding to attributions to be
69+
visualized. Shape must be in the form (H, W, C), with
70+
channels as last dimension. Shape must also match that of
71+
the original image if provided.
72+
original_image (numpy.array): Numpy array corresponding to original
73+
image. Shape must be in the form (H, W, C), with
74+
channels as the last dimension. Image can be provided either
75+
with values in range 0-1 or 0-255. This is a necessary
76+
argument for any visualization method which utilizes
77+
the original image.
78+
Default: None
79+
method (string): Chosen method for visualizing attribution. Supported
80+
options are:
81+
1. `heat_map` - Display heat map of chosen attributions
82+
2. `blended_heat_map` - Overlay heat map over greyscale
83+
version of original image. Parameter alpha_overlay
84+
corresponds to alpha of heat map.
85+
3. `original_image` - Only display original image.
86+
4. `masked_image` - Mask image (pixel-wise multiply)
87+
by normalized attribution values.
88+
5. `alpha_scaling` - Sets alpha channel of each pixel
89+
to be equal to normalized attribution value.
90+
Default: `heat_map`
91+
sign (string): Chosen sign of attributions to visualize. Supported
92+
options are:
93+
1. `positive` - Displays only positive pixel attributions.
94+
2. `absolute_value` - Displays absolute value of attributions.
95+
3. `negative` - Displays only negative pixel attributions.
96+
4. `all` - Displays both positive and negative attribution
97+
values. This is not supported for `masked_image` or
98+
`alpha_scaling` modes, since signed information cannot
99+
be represented in these modes.
100+
Default: `absolute_value`
101+
plt_fig_axis (tuple, optional): Tuple of matplotlib.pyplot.figure and axis
102+
on which to visualize. If None is provided, then a new figure
103+
and axis are created.
104+
Default: None
105+
outlier_perc (float, optional): Top attribution values which correspond
106+
to a total of outlier_perc percentage of the total attribution
107+
are set to 1 and scaling is performed using the minimum of
108+
these values. For sign=`all`, outliers and scale value are
109+
computed using absolute value of attributions.
110+
Default: 2
111+
cmap (string, optional): String corresponding to desired colormap for
112+
heatmap visualization. This defaults to "Reds" for negative
113+
sign, "Blues" for absolute value, "Greens" for positive sign,
114+
and a spectrum from red to green for all. Note that this
115+
argument is only used for visualizations displaying heatmaps.
116+
Default: None
117+
alpha_overlay (float, optional): Alpha to set for heatmap when using
118+
`blended_heat_map` visualization mode, which overlays the
119+
heat map over the greyscaled original image.
120+
Default: 0.5
121+
show_colorbar (boolean, optional): Displays colorbar for heatmap below
122+
the visualization. If given method does not use a heatmap,
123+
then a colormap axis is created and hidden. This is
124+
necessary for appropriate alignment when visualizing
125+
multiple plots, some with heatmaps and some without.
126+
Default: False
127+
title (string, optional): Title string for plot. If None, no title is
128+
set.
129+
Default: None
130+
show_figure (boolean, optional): If true, calls plt.show to render
131+
plot after creating figure.
132+
133+
134+
Return:
135+
136+
figure (matplotlib.pyplot.figure): Figure object on which visualization
137+
is created. If plt_fig_axis argument is given, this is the
138+
same figure provided.
139+
axis (matplotlib.pyplot.axis): Axis object on which visualization
140+
is created. If plt_fig_axis argument is given, this is the
141+
same axis provided.
142+
143+
Examples::
144+
145+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
146+
>>> # and returns an Nx10 tensor of class probabilities.
147+
>>> net = ImageClassifier()
148+
>>> ig = IntegratedGradients(net)
149+
>>> # Computes integrated gradients for class 3 for a given image .
150+
>>> attribution, delta = ig.attribute(orig_image, target=3)
151+
>>> # Displays blended heat map visualization of computed attributions.
152+
>>> _ = visualize_image_attr(attribution, orig_image, "blended_heat_map")
153+
"""
154+
# Create plot if figure, axis not provided
155+
if plt_fig_axis is not None:
156+
plt_fig, plt_axis = plt_fig_axis
75157
else:
76-
raise AssertionError("Visualize Method type is not valid.")
158+
plt_fig, plt_axis = plt.subplots()
159+
160+
if np.max(original_image) <= 1.0:
161+
original_image = _prepare_image(original_image * 255)
162+
163+
# Remove ticks and tick labels from plot.
164+
plt_axis.xaxis.set_ticks_position('none')
165+
plt_axis.yaxis.set_ticks_position('none')
166+
plt_axis.set_yticklabels([])
167+
plt_axis.set_xticklabels([])
168+
169+
heat_map = None
170+
# Show original image
171+
if ImageVisualizeMethod[method] == ImageVisualizeMethod.original_image:
172+
plt_axis.imshow(original_image)
173+
else:
174+
# Choose appropriate signed attributions and normalize.
175+
norm_attr = _normalize_image_attr(attr, sign, outlier_perc)
176+
177+
# Set default colormap and bounds based on sign.
178+
if VisualizeSign[sign] == VisualizeSign.all:
179+
default_cmap = LinearSegmentedColormap.from_list("RdWhGn",["red","white","green"])
180+
vmin, vmax = -1, 1
181+
elif VisualizeSign[sign] == VisualizeSign.positive:
182+
default_cmap = "Greens"
183+
vmin, vmax = 0, 1
184+
elif VisualizeSign[sign] == VisualizeSign.negative:
185+
default_cmap = "Reds"
186+
vmin, vmax = 0, 1
187+
elif VisualizeSign[sign] == VisualizeSign.absolute_value:
188+
default_cmap = "Blues"
189+
vmin, vmax = 0, 1
190+
else:
191+
raise AssertionError("Visualize Sign type is not valid.")
192+
cmap = cmap if cmap is not None else default_cmap
193+
194+
# Show appropriate image visualization.
195+
if ImageVisualizeMethod[method] == ImageVisualizeMethod.heat_map:
196+
heat_map = plt_axis.imshow(norm_attr,cmap=cmap,vmin=vmin, vmax=vmax)
197+
elif ImageVisualizeMethod[method] == ImageVisualizeMethod.blended_heat_map:
198+
plt_axis.imshow(np.mean(original_image, axis=2),cmap="gray")
199+
heat_map = plt_axis.imshow(norm_attr, cmap=cmap, vmin=vmin, vmax=vmax, alpha=alpha_overlay)
200+
elif ImageVisualizeMethod[method] == ImageVisualizeMethod.masked_image:
201+
assert VisualizeSign[sign] != VisualizeSign.all, "Cannot display masked image with both positive and negative attributions, choose a different sign option."
202+
plt_axis.imshow(_prepare_image(original_image*np.expand_dims(norm_attr,2)))
203+
elif ImageVisualizeMethod[method] == ImageVisualizeMethod.alpha_scaling:
204+
assert VisualizeSign[sign] != VisualizeSign.all, "Cannot display alpha scaling with both positive and negative attributions, choose a different sign option."
205+
plt_axis.imshow(np.concatenate([original_image, _prepare_image(np.expand_dims(norm_attr,2)*255)],axis=2),cmap="gray")
206+
else:
207+
raise AssertionError("Visualize Method type is not valid.")
208+
209+
# Add colorbar. If given method is not a heatmap and no colormap is relevant,
210+
# then a colormap axis is created and hidden. This is necessary for appropriate
211+
# alignment when visualizing multiple plots, some with heatmaps and some
212+
# without.
213+
if show_colorbar:
214+
axis_separator = make_axes_locatable(plt_axis)
215+
colorbar_axis = axis_separator.append_axes("bottom", size="5%", pad=0.1)
216+
if heat_map:
217+
plt_fig.colorbar(heat_map, orientation="horizontal",cax=colorbar_axis)
218+
else:
219+
colorbar_axis.axis('off')
220+
if title:
221+
plt_axis.set_title(title)
222+
223+
if show_figure:
224+
plt.show()
225+
226+
return plt_fig, plt_axis
227+
228+
def visualize_image_attr_multiple(methods, signs, attr, original_image=None, titles = None, figsize=(8, 6), show_figure=True, **kwargs):
229+
r"""
230+
Visualizes attribution using multiple visualization methods displayed
231+
in a 1 x k grid, where k is the number of desired visualizations.
232+
233+
Args:
234+
235+
methods (list of strings): List of strings of length k, defining method
236+
for each visualization. Each method must be a valid
237+
string argument for method to visualize_image_attr.
238+
signs (list of strings): List of strings of length k, defining signs for
239+
each visualization. Each sign must be a valid
240+
string argument for sign to visualize_image_attr.
241+
attr (numpy.array): Numpy array corresponding to attributions to be
242+
visualized. Shape must be in the form (H, W, C), with
243+
channels as last dimension. Shape must also match that of
244+
the original image if provided.
245+
original_image (numpy.array): Numpy array corresponding to original
246+
image. Shape must be in the form (H, W, C), with
247+
channels as the last dimension. Image can be provided either
248+
with values in range 0-1 or 0-255. This is a necessary
249+
argument for any visualization method which utilizes
250+
the original image.
251+
Default: None
252+
titles (list of strings): List of strings of length k, providing a
253+
title string for each plot. If None is provided, no titles
254+
are added to subplots.
255+
Default: None
256+
figsize (tuple, optional): Size of figure created.
257+
Default: (8, 6)
258+
show_figure (boolean, optional): If true, calls plt.show to render
259+
plot after creating figure.
260+
**kwargs (Any, optional): Any additional arguments which will be passed
261+
to every individual visualization. Such arguments include
262+
`show_colorbar`, `alpha_overlay`, `cmap`, etc.
263+
264+
265+
Return:
266+
267+
figure (matplotlib.pyplot.figure): Figure object on which visualization
268+
is created. If plt_fig_axis argument is given, this is the
269+
same figure provided.
270+
axis (matplotlib.pyplot.axis): Axis object on which visualization
271+
is created. If plt_fig_axis argument is given, this is the
272+
same axis provided.
273+
274+
Examples::
275+
276+
>>> # ImageClassifier takes a single input tensor of images Nx3x32x32,
277+
>>> # and returns an Nx10 tensor of class probabilities.
278+
>>> net = ImageClassifier()
279+
>>> ig = IntegratedGradients(net)
280+
>>> # Computes integrated gradients for class 3 for a given image .
281+
>>> attribution, delta = ig.attribute(orig_image, target=3)
282+
>>> # Displays original image and heat map visualization of
283+
>>> # computed attributions side by side.
284+
>>> _ = visualize_mutliple_image_attr(["original_image", "heat_map"],
285+
>>> ["all", "positive"], attribution, orig_image)
286+
"""
287+
assert len(methods) == len(signs), "Methods and signs array lengths must match."
288+
plt_fig = plt.figure(figsize=figsize)
289+
plt_axis = plt_fig.subplots(1, len(methods))
290+
for i in range(len(methods)):
291+
visualize_image_attr(attr, original_image=original_image, method=methods[i], sign=signs[i], plt_fig_axis=(plt_fig, plt_axis[i]), show_figure=False, title=titles[i] if titles else None, **kwargs)
292+
plt_fig.tight_layout()
293+
if show_figure:
294+
plt.show()
295+
return plt_fig, plt_axis
77296

78297
# These visualization methods are for text and are partially copied from
79298
# experiments conducted by Davide Testuggine at Facebook.

0 commit comments

Comments
 (0)