|
2 | 2 |
|
3 | 3 | from enum import Enum
|
4 | 4 | import numpy as np
|
| 5 | +import warnings |
5 | 6 |
|
| 7 | +from IPython.core.display import display, HTML |
6 | 8 |
|
7 |
| -class VisualizeMethod(Enum): |
| 9 | +import matplotlib.pyplot as plt |
| 10 | +from matplotlib.colors import LinearSegmentedColormap |
| 11 | +from mpl_toolkits.axes_grid1 import make_axes_locatable |
| 12 | + |
| 13 | + |
| 14 | +class ImageVisualizeMethod(Enum): |
8 | 15 | heat_map = 1
|
9 |
| - masked_image = 2 |
10 |
| - alpha_scaled = 3 |
11 |
| - blended_heat_map = 4 |
| 16 | + blended_heat_map = 2 |
| 17 | + original_image = 3 |
| 18 | + masked_image = 4 |
| 19 | + alpha_scaling = 5 |
12 | 20 |
|
13 | 21 | class VisualizeSign(Enum):
|
14 | 22 | positive = 1
|
15 | 23 | absolute_value = 2
|
16 | 24 | negative = 3
|
17 | 25 | all = 4
|
18 | 26 |
|
19 |
| -green = [0, 255, 0] |
20 |
| -red = [255, 0, 0] |
21 |
| -blue = [0, 0, 255] |
| 27 | +def _prepare_image(attr_visual): |
| 28 | + return np.clip(attr_visual.astype(int), 0, 255) |
22 | 29 |
|
| 30 | +def _normalize_scale(attr, scale_factor): |
| 31 | + if abs(scale_factor) < 1e-5: |
| 32 | + warnings.warn("Attempting to normalize by value approximately 0, skipping normalization. This likely means that attribution values are all close to 0.") |
| 33 | + attr_norm = attr / scale_factor |
| 34 | + return np.clip(attr_norm, -1, 1) |
23 | 35 |
|
24 |
| -def visualize_image_attr(attr, original_image=None, channel_transpose=False, sign="positive", method="heat_map",outlier_perc=2): |
25 |
| - if channel_transpose: |
26 |
| - attr = np.transpose(attr, (1,2,0)) |
27 |
| - if original_image is not None: |
28 |
| - original_image = np.transpose(original_image, (1,2,0)) |
| 36 | +def _cumulative_sum_threshold(values, percentile): |
| 37 | + sorted_vals = np.sort(values.flatten()) |
| 38 | + cum_sums = np.cumsum(sorted_vals) |
| 39 | + threshold_id = np.where(cum_sums >= cum_sums[-1] * 0.01 * percentile)[0][0] |
| 40 | + return sorted_vals[threshold_id] |
29 | 41 |
|
30 |
| - # Combine RGB attribution channels |
| 42 | +def _normalize_image_attr(attr, sign, outlier_perc=2): |
31 | 43 | attr_combined = np.sum(attr, axis=2)
|
32 |
| - heat_map_color = None |
33 |
| - |
34 |
| - if VisualizeMethod[method] == VisualizeMethod.blended_heat_map: |
35 |
| - assert original_image is not None, "Image must be provided for blended heat map." |
36 |
| - return (0.6 * np.expand_dims(np.mean(original_image, axis=2),axis=2) + 0.4 * visualize_image_attr(attr=attr, original_image=original_image, channel_transpose=False, sign=sign, method="heat_map",outlier_perc=outlier_perc)).astype(int) |
37 |
| - |
38 |
| - if VisualizeSign[sign] == VisualizeSign.all: |
39 |
| - assert VisualizeMethod[method] == VisualizeMethod.heat_map, "Heat Map is the only supported visualization approach for both positive and negative attribution." |
40 |
| - return visualize_image_attr(attr=attr, original_image=original_image, channel_transpose=False, sign="positive", method=method,outlier_perc=outlier_perc) + visualize_image_attr(attr=attr, original_image=original_image, channel_transpose=False, sign="negative", method=method,outlier_perc=outlier_perc) |
41 |
| - |
42 | 44 | # Choose appropriate signed values and rescale, removing given outlier percentage.
|
43 |
| - if VisualizeSign[sign] == VisualizeSign.positive: |
| 45 | + if VisualizeSign[sign] == VisualizeSign.all: |
| 46 | + attr_combined = _normalize_scale(attr_combined, _cumulative_sum_threshold(np.abs(attr_combined), 100 - outlier_perc)) |
| 47 | + elif VisualizeSign[sign] == VisualizeSign.positive: |
44 | 48 | attr_combined = (attr_combined > 0) * attr_combined
|
45 |
| - attr_combined = attr_combined / np.percentile(attr_combined, 100 - outlier_perc) |
46 |
| - attr_combined[attr_combined > 1] = 1 |
47 |
| - heat_map_color = green |
| 49 | + attr_combined = _normalize_scale(attr_combined, _cumulative_sum_threshold(attr_combined, 100 - outlier_perc)) |
48 | 50 | elif VisualizeSign[sign] == VisualizeSign.negative:
|
49 | 51 | attr_combined = (attr_combined < 0) * attr_combined
|
50 |
| - attr_combined = attr_combined / np.percentile(attr_combined, outlier_perc) |
51 |
| - attr_combined[attr_combined > 1] = 1 |
52 |
| - heat_map_color = red |
| 52 | + attr_combined = _normalize_scale(attr_combined, _cumulative_sum_threshold(attr_combined, outlier_perc)) |
53 | 53 | elif VisualizeSign[sign] == VisualizeSign.absolute_value:
|
54 | 54 | attr_combined = np.abs(attr_combined)
|
55 |
| - attr_combined = attr_combined / np.percentile(attr_combined, 100 - outlier_perc) |
56 |
| - attr_combined[attr_combined > 1] = 1 |
57 |
| - heat_map_color = blue |
| 55 | + attr_combined = _normalize_scale(attr_combined, _cumulative_sum_threshold(attr_combined, 100 - outlier_perc)) |
58 | 56 | else:
|
59 | 57 | raise AssertionError("Visualize Sign type is not valid.")
|
| 58 | + return attr_combined |
| 59 | + |
| 60 | +def visualize_image_attr(attr, original_image=None,method="heat_map",sign="absolute_value", plt_fig_axis=None,outlier_perc=2,cmap=None,alpha_overlay=0.5,show_colorbar=False, title = None, show_figure=True): |
| 61 | + r""" |
| 62 | + Visualizes attribution for a given image by normalizing attribution values |
| 63 | + of the desired sign (positive, negative, absolute value, or all) and displaying |
| 64 | + them using the desired mode in a matplotlib figure. |
60 | 65 |
|
61 |
| - # Apply chosen visualization method. |
62 |
| - if VisualizeMethod[method] == VisualizeMethod.heat_map: |
63 |
| - return (np.expand_dims(attr_combined, 2) * heat_map_color).astype(int) |
64 |
| - elif VisualizeMethod[method] == VisualizeMethod.masked_image: |
65 |
| - assert original_image is not None, "Image must be provided for masking." |
66 |
| - assert np.shape(original_image)[:-1] == np.shape(attr_combined), "Image dimensions do not match attribution dimensions for masking." |
67 |
| - return (np.expand_dims(attr_combined, 2) * original_image).astype(int) |
68 |
| - elif VisualizeMethod[method] == VisualizeMethod.alpha_scaled: |
69 |
| - assert original_image is not None, "Image must be provided for masking." |
70 |
| - assert np.shape(original_image)[:-1] == np.shape(attr_combined), "Image dimensions do not match attribution dimensions for adding alpha channel." |
71 |
| - # Concatenate alpha channel and return |
72 |
| - return np.concatenate((original_image, (255*np.expand_dims(attr_combined, 2)).astype(int)), axis=2) |
73 |
| - elif VisualizeMethod[method] == VisualizeMethod.blended_heat_map: |
74 |
| - return np.concatenate((original_image, (255*np.expand_dims(attr_combined, 2)).astype(int)), axis=2) |
| 66 | + Args: |
| 67 | +
|
| 68 | + attr (numpy.array): Numpy array corresponding to attributions to be |
| 69 | + visualized. Shape must be in the form (H, W, C), with |
| 70 | + channels as last dimension. Shape must also match that of |
| 71 | + the original image if provided. |
| 72 | + original_image (numpy.array): Numpy array corresponding to original |
| 73 | + image. Shape must be in the form (H, W, C), with |
| 74 | + channels as the last dimension. Image can be provided either |
| 75 | + with values in range 0-1 or 0-255. This is a necessary |
| 76 | + argument for any visualization method which utilizes |
| 77 | + the original image. |
| 78 | + Default: None |
| 79 | + method (string): Chosen method for visualizing attribution. Supported |
| 80 | + options are: |
| 81 | + 1. `heat_map` - Display heat map of chosen attributions |
| 82 | + 2. `blended_heat_map` - Overlay heat map over greyscale |
| 83 | + version of original image. Parameter alpha_overlay |
| 84 | + corresponds to alpha of heat map. |
| 85 | + 3. `original_image` - Only display original image. |
| 86 | + 4. `masked_image` - Mask image (pixel-wise multiply) |
| 87 | + by normalized attribution values. |
| 88 | + 5. `alpha_scaling` - Sets alpha channel of each pixel |
| 89 | + to be equal to normalized attribution value. |
| 90 | + Default: `heat_map` |
| 91 | + sign (string): Chosen sign of attributions to visualize. Supported |
| 92 | + options are: |
| 93 | + 1. `positive` - Displays only positive pixel attributions. |
| 94 | + 2. `absolute_value` - Displays absolute value of attributions. |
| 95 | + 3. `negative` - Displays only negative pixel attributions. |
| 96 | + 4. `all` - Displays both positive and negative attribution |
| 97 | + values. This is not supported for `masked_image` or |
| 98 | + `alpha_scaling` modes, since signed information cannot |
| 99 | + be represented in these modes. |
| 100 | + Default: `absolute_value` |
| 101 | + plt_fig_axis (tuple, optional): Tuple of matplotlib.pyplot.figure and axis |
| 102 | + on which to visualize. If None is provided, then a new figure |
| 103 | + and axis are created. |
| 104 | + Default: None |
| 105 | + outlier_perc (float, optional): Top attribution values which correspond |
| 106 | + to a total of outlier_perc percentage of the total attribution |
| 107 | + are set to 1 and scaling is performed using the minimum of |
| 108 | + these values. For sign=`all`, outliers and scale value are |
| 109 | + computed using absolute value of attributions. |
| 110 | + Default: 2 |
| 111 | + cmap (string, optional): String corresponding to desired colormap for |
| 112 | + heatmap visualization. This defaults to "Reds" for negative |
| 113 | + sign, "Blues" for absolute value, "Greens" for positive sign, |
| 114 | + and a spectrum from red to green for all. Note that this |
| 115 | + argument is only used for visualizations displaying heatmaps. |
| 116 | + Default: None |
| 117 | + alpha_overlay (float, optional): Alpha to set for heatmap when using |
| 118 | + `blended_heat_map` visualization mode, which overlays the |
| 119 | + heat map over the greyscaled original image. |
| 120 | + Default: 0.5 |
| 121 | + show_colorbar (boolean, optional): Displays colorbar for heatmap below |
| 122 | + the visualization. If given method does not use a heatmap, |
| 123 | + then a colormap axis is created and hidden. This is |
| 124 | + necessary for appropriate alignment when visualizing |
| 125 | + multiple plots, some with heatmaps and some without. |
| 126 | + Default: False |
| 127 | + title (string, optional): Title string for plot. If None, no title is |
| 128 | + set. |
| 129 | + Default: None |
| 130 | + show_figure (boolean, optional): If true, calls plt.show to render |
| 131 | + plot after creating figure. |
| 132 | +
|
| 133 | +
|
| 134 | + Return: |
| 135 | +
|
| 136 | + figure (matplotlib.pyplot.figure): Figure object on which visualization |
| 137 | + is created. If plt_fig_axis argument is given, this is the |
| 138 | + same figure provided. |
| 139 | + axis (matplotlib.pyplot.axis): Axis object on which visualization |
| 140 | + is created. If plt_fig_axis argument is given, this is the |
| 141 | + same axis provided. |
| 142 | +
|
| 143 | + Examples:: |
| 144 | +
|
| 145 | + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, |
| 146 | + >>> # and returns an Nx10 tensor of class probabilities. |
| 147 | + >>> net = ImageClassifier() |
| 148 | + >>> ig = IntegratedGradients(net) |
| 149 | + >>> # Computes integrated gradients for class 3 for a given image . |
| 150 | + >>> attribution, delta = ig.attribute(orig_image, target=3) |
| 151 | + >>> # Displays blended heat map visualization of computed attributions. |
| 152 | + >>> _ = visualize_image_attr(attribution, orig_image, "blended_heat_map") |
| 153 | + """ |
| 154 | + # Create plot if figure, axis not provided |
| 155 | + if plt_fig_axis is not None: |
| 156 | + plt_fig, plt_axis = plt_fig_axis |
75 | 157 | else:
|
76 |
| - raise AssertionError("Visualize Method type is not valid.") |
| 158 | + plt_fig, plt_axis = plt.subplots() |
| 159 | + |
| 160 | + if np.max(original_image) <= 1.0: |
| 161 | + original_image = _prepare_image(original_image * 255) |
| 162 | + |
| 163 | + # Remove ticks and tick labels from plot. |
| 164 | + plt_axis.xaxis.set_ticks_position('none') |
| 165 | + plt_axis.yaxis.set_ticks_position('none') |
| 166 | + plt_axis.set_yticklabels([]) |
| 167 | + plt_axis.set_xticklabels([]) |
| 168 | + |
| 169 | + heat_map = None |
| 170 | + # Show original image |
| 171 | + if ImageVisualizeMethod[method] == ImageVisualizeMethod.original_image: |
| 172 | + plt_axis.imshow(original_image) |
| 173 | + else: |
| 174 | + # Choose appropriate signed attributions and normalize. |
| 175 | + norm_attr = _normalize_image_attr(attr, sign, outlier_perc) |
| 176 | + |
| 177 | + # Set default colormap and bounds based on sign. |
| 178 | + if VisualizeSign[sign] == VisualizeSign.all: |
| 179 | + default_cmap = LinearSegmentedColormap.from_list("RdWhGn",["red","white","green"]) |
| 180 | + vmin, vmax = -1, 1 |
| 181 | + elif VisualizeSign[sign] == VisualizeSign.positive: |
| 182 | + default_cmap = "Greens" |
| 183 | + vmin, vmax = 0, 1 |
| 184 | + elif VisualizeSign[sign] == VisualizeSign.negative: |
| 185 | + default_cmap = "Reds" |
| 186 | + vmin, vmax = 0, 1 |
| 187 | + elif VisualizeSign[sign] == VisualizeSign.absolute_value: |
| 188 | + default_cmap = "Blues" |
| 189 | + vmin, vmax = 0, 1 |
| 190 | + else: |
| 191 | + raise AssertionError("Visualize Sign type is not valid.") |
| 192 | + cmap = cmap if cmap is not None else default_cmap |
| 193 | + |
| 194 | + # Show appropriate image visualization. |
| 195 | + if ImageVisualizeMethod[method] == ImageVisualizeMethod.heat_map: |
| 196 | + heat_map = plt_axis.imshow(norm_attr,cmap=cmap,vmin=vmin, vmax=vmax) |
| 197 | + elif ImageVisualizeMethod[method] == ImageVisualizeMethod.blended_heat_map: |
| 198 | + plt_axis.imshow(np.mean(original_image, axis=2),cmap="gray") |
| 199 | + heat_map = plt_axis.imshow(norm_attr, cmap=cmap, vmin=vmin, vmax=vmax, alpha=alpha_overlay) |
| 200 | + elif ImageVisualizeMethod[method] == ImageVisualizeMethod.masked_image: |
| 201 | + assert VisualizeSign[sign] != VisualizeSign.all, "Cannot display masked image with both positive and negative attributions, choose a different sign option." |
| 202 | + plt_axis.imshow(_prepare_image(original_image*np.expand_dims(norm_attr,2))) |
| 203 | + elif ImageVisualizeMethod[method] == ImageVisualizeMethod.alpha_scaling: |
| 204 | + assert VisualizeSign[sign] != VisualizeSign.all, "Cannot display alpha scaling with both positive and negative attributions, choose a different sign option." |
| 205 | + plt_axis.imshow(np.concatenate([original_image, _prepare_image(np.expand_dims(norm_attr,2)*255)],axis=2),cmap="gray") |
| 206 | + else: |
| 207 | + raise AssertionError("Visualize Method type is not valid.") |
| 208 | + |
| 209 | + # Add colorbar. If given method is not a heatmap and no colormap is relevant, |
| 210 | + # then a colormap axis is created and hidden. This is necessary for appropriate |
| 211 | + # alignment when visualizing multiple plots, some with heatmaps and some |
| 212 | + # without. |
| 213 | + if show_colorbar: |
| 214 | + axis_separator = make_axes_locatable(plt_axis) |
| 215 | + colorbar_axis = axis_separator.append_axes("bottom", size="5%", pad=0.1) |
| 216 | + if heat_map: |
| 217 | + plt_fig.colorbar(heat_map, orientation="horizontal",cax=colorbar_axis) |
| 218 | + else: |
| 219 | + colorbar_axis.axis('off') |
| 220 | + if title: |
| 221 | + plt_axis.set_title(title) |
| 222 | + |
| 223 | + if show_figure: |
| 224 | + plt.show() |
| 225 | + |
| 226 | + return plt_fig, plt_axis |
| 227 | + |
| 228 | +def visualize_image_attr_multiple(methods, signs, attr, original_image=None, titles = None, figsize=(8, 6), show_figure=True, **kwargs): |
| 229 | + r""" |
| 230 | + Visualizes attribution using multiple visualization methods displayed |
| 231 | + in a 1 x k grid, where k is the number of desired visualizations. |
| 232 | +
|
| 233 | + Args: |
| 234 | +
|
| 235 | + methods (list of strings): List of strings of length k, defining method |
| 236 | + for each visualization. Each method must be a valid |
| 237 | + string argument for method to visualize_image_attr. |
| 238 | + signs (list of strings): List of strings of length k, defining signs for |
| 239 | + each visualization. Each sign must be a valid |
| 240 | + string argument for sign to visualize_image_attr. |
| 241 | + attr (numpy.array): Numpy array corresponding to attributions to be |
| 242 | + visualized. Shape must be in the form (H, W, C), with |
| 243 | + channels as last dimension. Shape must also match that of |
| 244 | + the original image if provided. |
| 245 | + original_image (numpy.array): Numpy array corresponding to original |
| 246 | + image. Shape must be in the form (H, W, C), with |
| 247 | + channels as the last dimension. Image can be provided either |
| 248 | + with values in range 0-1 or 0-255. This is a necessary |
| 249 | + argument for any visualization method which utilizes |
| 250 | + the original image. |
| 251 | + Default: None |
| 252 | + titles (list of strings): List of strings of length k, providing a |
| 253 | + title string for each plot. If None is provided, no titles |
| 254 | + are added to subplots. |
| 255 | + Default: None |
| 256 | + figsize (tuple, optional): Size of figure created. |
| 257 | + Default: (8, 6) |
| 258 | + show_figure (boolean, optional): If true, calls plt.show to render |
| 259 | + plot after creating figure. |
| 260 | + **kwargs (Any, optional): Any additional arguments which will be passed |
| 261 | + to every individual visualization. Such arguments include |
| 262 | + `show_colorbar`, `alpha_overlay`, `cmap`, etc. |
| 263 | +
|
| 264 | +
|
| 265 | + Return: |
| 266 | +
|
| 267 | + figure (matplotlib.pyplot.figure): Figure object on which visualization |
| 268 | + is created. If plt_fig_axis argument is given, this is the |
| 269 | + same figure provided. |
| 270 | + axis (matplotlib.pyplot.axis): Axis object on which visualization |
| 271 | + is created. If plt_fig_axis argument is given, this is the |
| 272 | + same axis provided. |
| 273 | +
|
| 274 | + Examples:: |
| 275 | +
|
| 276 | + >>> # ImageClassifier takes a single input tensor of images Nx3x32x32, |
| 277 | + >>> # and returns an Nx10 tensor of class probabilities. |
| 278 | + >>> net = ImageClassifier() |
| 279 | + >>> ig = IntegratedGradients(net) |
| 280 | + >>> # Computes integrated gradients for class 3 for a given image . |
| 281 | + >>> attribution, delta = ig.attribute(orig_image, target=3) |
| 282 | + >>> # Displays original image and heat map visualization of |
| 283 | + >>> # computed attributions side by side. |
| 284 | + >>> _ = visualize_mutliple_image_attr(["original_image", "heat_map"], |
| 285 | + >>> ["all", "positive"], attribution, orig_image) |
| 286 | + """ |
| 287 | + assert len(methods) == len(signs), "Methods and signs array lengths must match." |
| 288 | + plt_fig = plt.figure(figsize=figsize) |
| 289 | + plt_axis = plt_fig.subplots(1, len(methods)) |
| 290 | + for i in range(len(methods)): |
| 291 | + visualize_image_attr(attr, original_image=original_image, method=methods[i], sign=signs[i], plt_fig_axis=(plt_fig, plt_axis[i]), show_figure=False, title=titles[i] if titles else None, **kwargs) |
| 292 | + plt_fig.tight_layout() |
| 293 | + if show_figure: |
| 294 | + plt.show() |
| 295 | + return plt_fig, plt_axis |
77 | 296 |
|
78 | 297 | # These visualization methods are for text and are partially copied from
|
79 | 298 | # experiments conducted by Davide Testuggine at Facebook.
|
|
0 commit comments