Skip to content

Commit e69f327

Browse files
authored
Style changes and minor fixes to Insights (#68)
* Style changes and minor fixes to Insights * fix tests and suppress warning
1 parent 4a197f7 commit e69f327

File tree

5 files changed

+62
-40
lines changed

5 files changed

+62
-40
lines changed

captum/insights/api.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _calculate_attribution(
6262
) -> Tensor:
6363
ig = IntegratedGradients(net)
6464
# TODO support multiple baselines
65-
label = None if label is None or len(label.shape) == 0 else label
65+
label = None if label is None or label.nelement() == 0 else label
6666
attr_ig, _ = ig.attribute(
6767
data,
6868
baselines=baselines[0],
@@ -124,7 +124,7 @@ def _calculate_net_contrib(self, attrs_per_input_feature: List[Tensor]):
124124
if norm > 0:
125125
net_contrib /= norm
126126

127-
return net_contrib
127+
return net_contrib.tolist()
128128

129129
def visualize(self) -> List[VisualizationOutput]:
130130
batch_data = next(self.dataset)
@@ -169,11 +169,11 @@ def visualize(self) -> List[VisualizationOutput]:
169169

170170
label = batch_data.labels[i]
171171

172-
if len(outputs) == 1:
172+
if outputs.nelement() == 1:
173173
scores = outputs
174174
predicted = scores.round().to(torch.int)
175175
else:
176-
scores, predicted = outputs.topk(min(4, len(outputs)))
176+
scores, predicted = outputs.topk(min(4, outputs.shape[-1]))
177177

178178
scores = scores.cpu().squeeze(0)
179179
predicted = predicted.cpu().squeeze(0)

captum/insights/features.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import base64
2+
import warnings
23
from collections import namedtuple
34
from io import BytesIO
45
from typing import Callable, List, Optional, Union
@@ -12,7 +13,10 @@
1213

1314
def _convert_figure_base64(fig):
1415
buff = BytesIO()
15-
fig.savefig(buff, format="png", pad_inches=0.0)
16+
with warnings.catch_warnings():
17+
warnings.simplefilter("ignore")
18+
fig.tight_layout() # removes padding
19+
fig.savefig(buff, format="png")
1620
base64img = base64.b64encode(buff.getvalue()).decode("utf-8")
1721
return base64img
1822

captum/insights/frontend/src/App.css

+15-4
Original file line numberDiff line numberDiff line change
@@ -170,16 +170,27 @@ button {
170170
}
171171

172172
.gallery__item__image img {
173-
height: 130px;
173+
height: 200px;
174174
width: auto;
175175
}
176176

177+
.gallery__item__description {
178+
text-align: center;
179+
}
180+
177181
.bar-chart__group {
178-
padding: 8px 0;
182+
padding: 2px 0;
183+
display: flex;
179184
}
185+
180186
.bar-chart__group__bar {
181-
height: 10px;
182-
border-radius: 4px;
187+
width: 10px;
188+
border-radius: 2px;
189+
flex-shrink: 0;
190+
}
191+
192+
.bar-chart__group__title {
193+
padding-left: 8px;
183194
}
184195

185196
.percentage-blue {

captum/insights/frontend/src/App.js

+21-13
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,9 @@ function ImageFeature(props) {
160160
<div className="gallery__item__image">
161161
<img src={"data:image/png;base64," + props.data.modified} />
162162
</div>
163-
<div className="gallery__item__description">Gradient Overlay</div>
163+
<div className="gallery__item__description">
164+
Attribution Magnitude
165+
</div>
164166
</div>
165167
</div>
166168
</div>
@@ -206,18 +208,24 @@ function Feature(props) {
206208

207209
class Contributions extends React.Component {
208210
render() {
209-
return this.props.feature_outputs.map(f => (
210-
<div className="bar-chart__group">
211-
<div
212-
className={cx({
213-
"bar-chart__group__bar": true,
214-
[getPercentageColor(f.contribution)]: true
215-
})}
216-
width={f.contribution + "%"}
217-
/>
218-
<div className="bar-chart__group__title">{f.name}</div>
219-
</div>
220-
));
211+
return this.props.feature_outputs.map(f => {
212+
// pad bar height so features with 0 contribution can still be seen
213+
// in graph
214+
const contribution = f.contribution * 100;
215+
const bar_height = contribution > 10 ? contribution : contribution + 10;
216+
return (
217+
<div className="bar-chart__group">
218+
<div
219+
className={cx([
220+
"bar-chart__group__bar",
221+
getPercentageColor(contribution)
222+
])}
223+
style={{ height: bar_height + "px" }}
224+
/>
225+
<div className="bar-chart__group__title">{f.name}</div>
226+
</div>
227+
);
228+
});
221229
}
222230
}
223231

tests/insights/test_contribution.py

+17-18
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from __future__ import print_function
2-
from typing import Callable, List, Optional, Union
32

4-
import torch
5-
import torch.nn as nn
3+
import unittest
4+
from typing import Callable, List, Optional, Union
65

76
from captum.insights.api import AttributionVisualizer, Data
8-
from captum.insights.features import ImageFeature, BaseFeature, FeatureOutput
9-
7+
from captum.insights.features import BaseFeature, FeatureOutput, ImageFeature
108
from tests.attr.helpers.utils import BaseTest
119

10+
import torch
11+
import torch.nn as nn
12+
1213

1314
class RealFeature(BaseFeature):
1415
def __init__(
@@ -97,15 +98,15 @@ def forward(self, img, misc):
9798

9899

99100
def _labelled_img_data(num_samples=10, width=8, height=8, depth=3, num_labels=10):
100-
for i in range(num_samples):
101+
for _ in range(num_samples):
101102
yield torch.empty(depth, height, width).uniform_(0, 1), torch.randint(
102103
num_labels, (1,)
103104
)
104105

105106

106107
def _multi_modal_data(img_dataset, feature_size=256):
107108
def misc_data(length, feature_size=None):
108-
for i in range(length):
109+
for _ in range(length):
109110
yield torch.randn(feature_size)
110111

111112
misc_dataset = misc_data(length=len(img_dataset), feature_size=feature_size)
@@ -165,11 +166,8 @@ def test_one_feature(self):
165166
outputs = visualizer.visualize()
166167

167168
for output in outputs:
168-
contribs = torch.stack(
169-
[feature.contribution for feature in output.feature_outputs]
170-
)
171-
total_contrib = torch.sum(torch.abs(contribs))
172-
self.assertAlmostEqual(total_contrib.item(), 1.0, places=6)
169+
total_contrib = sum(abs(f.contribution) for f in output.feature_outputs)
170+
self.assertAlmostEqual(total_contrib, 1.0, places=6)
173171

174172
def test_multi_features(self):
175173
batch_size = 2
@@ -183,7 +181,7 @@ def test_multi_features(self):
183181
img_dataset=img_dataset, feature_size=misc_feature_size
184182
)
185183
# NOTE: using DataLoader to batch the inputs since
186-
# AttributionVisualizer requires the input to be of size `B x ...`
184+
# AttributionVisualizer requires the input to be of size `N x ...`
187185
data_loader = torch.utils.data.DataLoader(
188186
list(dataset), batch_size=batch_size, shuffle=False, num_workers=0
189187
)
@@ -211,13 +209,14 @@ def test_multi_features(self):
211209
outputs = visualizer.visualize()
212210

213211
for output in outputs:
214-
contribs = torch.stack(
215-
[feature.contribution for feature in output.feature_outputs]
216-
)
217-
total_contrib = torch.sum(torch.abs(contribs))
218-
self.assertAlmostEqual(total_contrib.item(), 1.0, places=6)
212+
total_contrib = sum(abs(f.contribution) for f in output.feature_outputs)
213+
self.assertAlmostEqual(total_contrib, 1.0, places=6)
219214

220215
# TODO: add test for multiple models (related to TODO in captum/insights/api.py)
221216
#
222217
# TODO: add test to make the attribs == 0 -- error occurs
223218
# I know (through manual testing) that this breaks some existing code
219+
220+
221+
if __name__ == "__main__":
222+
unittest.main()

0 commit comments

Comments
 (0)