-
Notifications
You must be signed in to change notification settings - Fork 513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add batch_size to noise tunnel #555
Add batch_size to noise tunnel #555
Conversation
c37830a
to
666bc5a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, this is a really useful feature :) !
captum/attr/_core/noise_tunnel.py
Outdated
@@ -79,13 +80,16 @@ def multiplies_by_inputs(self): | |||
@noise_tunnel_n_samples_deprecation_decorator | |||
def attribute( | |||
self, | |||
inputs: Union[Tensor, Tuple[Tensor, ...]], | |||
inputs: TensorOrTupleOfTensorsGeneric, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think the reason the type hint initially didn't use TensorOrTupleOfTensorsGeneric is that the output type doesn't necessarily match the input, particularly if used with layer attribution. To apply to all cases, the return would need to cover these cases as well, so would need to be something like tensor, tuple[tensor, ..], tuple[tensor, tensor], tuple[tuple[tensor], tensor],etc. But this is more readable and does cover most use-cases, so either way works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point! I assumed that, that is the reason why we don't have overrides for different attribute signatures in this case.
This method was also lacking return type hint which I added also using TensorOrTupleOfTensorsGeneric
. I was thinking that we can use TensorOrTupleOfTensorsGeneric
shortcut instead of Union[Tensor, Tuple[Tensor, ...]]
. But it looks like there needs to be an agreement between inputs and output because TensorOrTupleOfTensorsGeneric
is generics. In that case I should probably also avoid using it for the output type.
# if the algorithm supports targets, baselines and/or | ||
# additional_forward_args they will be expanded based | ||
# on the nt_samples_partition and corresponding kwargs | ||
# variables will be updated accordingly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think this came up before, but feature_mask might also need to be expanded for perturbation-based methods. Not related to this diff though, so feel free to leave that change for later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point! I'll add feature mask support here.
@@ -36,6 +40,25 @@ def _get_multiargs_basic_config() -> Tuple[ | |||
return model, inputs, grads, additional_forward_args | |||
|
|||
|
|||
def _get_multiargs_basic_config_large() -> Tuple[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: It might be good to add a GPU / DataParallel test with this functionality if possible, it should only require adding a new config similar to these existing NoiseTunnel configs with this new parameter. https://github.com/pytorch/captum/blob/master/tests/attr/helpers/test_config.py#L282
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! I'll add DP tests.
98b8abf
to
d9669ce
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NarineK has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@NarineK has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Adding support for batch_size in NoiseTunnel as proposed in: #497