-
Notifications
You must be signed in to change notification settings - Fork 513
/
Copy pathtest_sample_gradient.py
144 lines (126 loc) · 5.87 KB
/
test_sample_gradient.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#!/usr/bin/env python3
import unittest
from typing import Callable, Tuple
import torch
from captum._utils.gradient import apply_gradient_requirements
from captum._utils.sample_gradient import (
_reset_sample_grads,
SampleGradientWrapper,
SUPPORTED_MODULES,
)
from packaging import version
from tests.helpers.basic import assertTensorAlmostEqual, BaseTest
from tests.helpers.basic_models import (
BasicModel_ConvNet_One_Conv,
BasicModel_ConvNetWithPaddingDilation,
BasicModel_MultiLayer,
)
from torch import Tensor
from torch.nn import Module
class Test(BaseTest):
def test_sample_grads_linear_sum(self) -> None:
model = BasicModel_MultiLayer(multi_input_module=True)
inp = (torch.randn(6, 3), torch.randn(6, 3))
self._compare_sample_grads_per_sample(model, inp, lambda x: torch.sum(x), "sum")
def test_sample_grads_linear_mean(self) -> None:
model = BasicModel_MultiLayer(multi_input_module=True)
inp = (20 * torch.randn(6, 3),)
self._compare_sample_grads_per_sample(model, inp, lambda x: torch.mean(x))
def test_sample_grads_conv_sum(self) -> None:
model = BasicModel_ConvNet_One_Conv()
inp = (123 * torch.randn(6, 1, 4, 4),)
self._compare_sample_grads_per_sample(model, inp, lambda x: torch.sum(x), "sum")
def test_sample_grads_conv_mean_multi_inp(self) -> None:
model = BasicModel_ConvNet_One_Conv()
inp = (20 * torch.randn(6, 1, 4, 4), 9 * torch.randn(6, 1, 4, 4))
self._compare_sample_grads_per_sample(model, inp, lambda x: torch.mean(x))
def test_sample_grads_modified_conv_mean(self) -> None:
if version.parse(torch.__version__) < version.parse("1.8.0"):
raise unittest.SkipTest(
"Skipping sample gradient test with 3D linear module"
"since torch version < 1.8"
)
model = BasicModel_ConvNetWithPaddingDilation()
inp = (20 * torch.randn(6, 1, 5, 5),)
self._compare_sample_grads_per_sample(
model, inp, lambda x: torch.mean(x), "mean"
)
def test_sample_grads_modified_conv_sum(self) -> None:
if version.parse(torch.__version__) < version.parse("1.8.0"):
raise unittest.SkipTest(
"Skipping sample gradient test with 3D linear module"
"since torch version < 1.8"
)
model = BasicModel_ConvNetWithPaddingDilation()
inp = (20 * torch.randn(6, 1, 5, 5),)
self._compare_sample_grads_per_sample(model, inp, lambda x: torch.sum(x), "sum")
def _compare_sample_grads_per_sample(
self,
model: Module,
inputs: Tuple[Tensor, ...],
loss_fn: Callable,
loss_type: str = "mean",
):
wrapper = SampleGradientWrapper(model)
wrapper.add_hooks()
apply_gradient_requirements(inputs)
out = model(*inputs)
wrapper.compute_param_sample_gradients(loss_fn(out), loss_type)
batch_size = inputs[0].shape[0]
for i in range(batch_size):
model.zero_grad()
single_inp = tuple(inp[i : i + 1] for inp in inputs)
out = model(*single_inp)
loss_fn(out).backward()
for layer in model.modules():
if isinstance(layer, tuple(SUPPORTED_MODULES.keys())):
assertTensorAlmostEqual(
self,
layer.weight.grad,
layer.weight.sample_grad[i], # type: ignore
mode="max",
)
assertTensorAlmostEqual(
self,
layer.bias.grad,
layer.bias.sample_grad[i], # type: ignore
mode="max",
)
def test_sample_grads_layer_modules(self):
"""
tests that if `layer_modules` argument is specified for `SampleGradientWrapper`
that only per-sample gradients for the specified layers are calculated
"""
model = BasicModel_ConvNet_One_Conv()
inp = (20 * torch.randn(6, 1, 4, 4), 9 * torch.randn(6, 1, 4, 4))
# possible candidates for `layer_modules`, which are the modules whose
# parameters we want to compute sample grads for
layer_moduless = [[model.conv1], [model.fc1], [model.conv1, model.fc1]]
# hard coded all modules we want to check
all_modules = [model.conv1, model.fc1]
for layer_modules in layer_moduless:
# we will call the wrapper multiple times, so should reset each time
for module in all_modules:
_reset_sample_grads(module)
# compute sample grads
wrapper = SampleGradientWrapper(model, layer_modules)
wrapper.add_hooks()
apply_gradient_requirements(inp)
out = model(*inp)
wrapper.compute_param_sample_gradients(torch.sum(out), "sum")
for module in all_modules:
if module in layer_modules:
# If we calculated the sample grads for the layer, none
# of its parameters' `sample_grad` attributes` would be an int,
# since even though they were all set to 0 in beginning of loop
# computing sample grads would override that 0.
# So, check that we did calculate sample grads for the desired
# layers via the above checking approach.
for parameter in module.parameters():
assert not isinstance(parameter.sample_grad, int)
else:
# For the layers we do not want sample grads for, their
# `sample_grad` should still be 0, since they should not have been
# over-written.
for parameter in module.parameters():
assert parameter.sample_grad == 0