Skip to content

Commit

Permalink
Neuron Aggregation (pytorch#495)
Browse files Browse the repository at this point in the history
Summary:
This adds support for neuron aggregation, neuron_selector can be a function which returns a custom aggregate of a layer's neurons for all neuron methods other than neuron conductance, which has dependence on output gradients. The neuron_index argument was renamed, and a deprecation decorator was added to provide a warning for usage of the old parameter as a keyword argument. This decorator can be removed prior to the 0.4.0 release.

Documentation of the new callable functionality has been added to NeuronDeepLift, this documentation will be propagated to other relevant methods after review.

Pull Request resolved: pytorch#495

Reviewed By: miguelmartin75

Differential Revision: D24346065

Pulled By: vivekmig

fbshipit-source-id: c3853e19256de4c8c32a8ff615965bf513a5cd22
  • Loading branch information
vivekmig authored and facebook-github-bot committed Oct 29, 2020
1 parent 31f266d commit aae5228
Show file tree
Hide file tree
Showing 22 changed files with 568 additions and 221 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ In this case, we choose to analyze the first neuron in the linear layer.

```python
nc = NeuronConductance(model, model.lin1)
attributions = nc.attribute(input, neuron_index=1, target=0)
attributions = nc.attribute(input, neuron_selector=1, target=0)
print('Neuron Attributions:', attributions)
```
Output
Expand Down
27 changes: 27 additions & 0 deletions captum/_utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,15 @@ def _select_targets(output: Tensor, target: TargetType) -> Tensor:
raise AssertionError("Target type %r is not valid." % target)


def _contains_slice(target: Union[int, Tuple[Union[int, slice], ...]]) -> bool:
if isinstance(target, tuple):
for index in target:
if isinstance(index, slice):
return True
return False
return isinstance(target, slice)


def _verify_select_column(
output: Tensor, target: Union[int, Tuple[Union[int, slice], ...]]
) -> Tensor:
Expand All @@ -427,6 +436,24 @@ def _verify_select_column(
return output[(slice(None), *target)]


def _verify_select_neuron(
layer_output: Tuple[Tensor, ...],
selector: Union[int, Tuple[Union[int, slice], ...], Callable],
) -> Tensor:
if callable(selector):
return selector(layer_output if len(layer_output) > 1 else layer_output[0])

assert len(layer_output) == 1, (
"Cannot select neuron index from layer with multiple tensors,"
"consider providing a neuron selector function instead."
)

selected_neurons = _verify_select_column(layer_output[0], selector)
if _contains_slice(selector):
return selected_neurons.reshape(selected_neurons.shape[0], -1).sum(1)
return selected_neurons


def _extract_device(
module: Module,
hook_inputs: Union[None, Tensor, Tuple[Tensor, ...]],
Expand Down
54 changes: 28 additions & 26 deletions captum/_utils/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch import Tensor, device
from torch.nn import Module

from .common import _reduce_list, _run_forward, _sort_key_list, _verify_select_column
from .common import _reduce_list, _run_forward, _sort_key_list, _verify_select_neuron
from .typing import (
Literal,
ModuleOrModuleList,
Expand Down Expand Up @@ -125,22 +125,20 @@ def _neuron_gradients(
inputs: Union[Tensor, Tuple[Tensor, ...]],
saved_layer: Dict[device, Tuple[Tensor, ...]],
key_list: List[device],
gradient_neuron_index: Union[int, Tuple[Union[int, slice], ...]],
gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
) -> Tuple[Tensor, ...]:
with torch.autograd.set_grad_enabled(True):
gradient_tensors = []
for key in key_list:
assert (
len(saved_layer[key]) == 1
), "Cannot compute neuron gradients for layer with multiple tensors."
current_out_tensor = _verify_select_column(
saved_layer[key][0], gradient_neuron_index
current_out_tensor = _verify_select_neuron(
saved_layer[key], gradient_neuron_selector
)
gradient_tensors.append(
torch.autograd.grad(
torch.unbind(current_out_tensor),
torch.unbind(current_out_tensor)
if current_out_tensor.numel() > 1
else current_out_tensor,
inputs,
grad_outputs=torch.unbind(torch.ones_like(current_out_tensor)),
)
)
_total_gradients = _reduce_list(gradient_tensors, sum)
Expand Down Expand Up @@ -187,7 +185,7 @@ def _forward_layer_eval(
inputs,
layer,
additional_forward_args=additional_forward_args,
gradient_neuron_index=None,
gradient_neuron_selector=None,
grad_enabled=grad_enabled,
device_ids=device_ids,
attribute_to_layer_input=attribute_to_layer_input,
Expand Down Expand Up @@ -369,7 +367,7 @@ def _forward_layer_eval_with_neuron_grads(
layer: Module,
additional_forward_args: Any = None,
*,
gradient_neuron_index: Union[int, Tuple[Union[int, slice], ...]],
gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
grad_enabled: bool = False,
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
Expand All @@ -383,7 +381,7 @@ def _forward_layer_eval_with_neuron_grads(
inputs: Union[Tensor, Tuple[Tensor, ...]],
layer: Module,
additional_forward_args: Any = None,
gradient_neuron_index: None = None,
gradient_neuron_selector: None = None,
grad_enabled: bool = False,
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
Expand All @@ -397,7 +395,7 @@ def _forward_layer_eval_with_neuron_grads(
inputs: Union[Tensor, Tuple[Tensor, ...]],
layer: List[Module],
additional_forward_args: Any = None,
gradient_neuron_index: None = None,
gradient_neuron_selector: None = None,
grad_enabled: bool = False,
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
Expand All @@ -410,7 +408,9 @@ def _forward_layer_eval_with_neuron_grads(
inputs: Union[Tensor, Tuple[Tensor, ...]],
layer: ModuleOrModuleList,
additional_forward_args: Any = None,
gradient_neuron_index: Union[None, int, Tuple[Union[int, slice], ...]] = None,
gradient_neuron_selector: Union[
None, int, Tuple[Union[int, slice], ...], Callable
] = None,
grad_enabled: bool = False,
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
Expand All @@ -421,7 +421,7 @@ def _forward_layer_eval_with_neuron_grads(
]:
"""
This method computes forward evaluation for a particular layer using a
forward hook. If a gradient_neuron_index is provided, then gradients with
forward hook. If a gradient_neuron_selector is provided, then gradients with
respect to that neuron in the layer output are also returned.
These functionalities are combined due to the behavior of DataParallel models
Expand All @@ -435,7 +435,7 @@ def _forward_layer_eval_with_neuron_grads(
evals in a dictionary protected by a lock, analogous to the gather implementation
for the core PyTorch DataParallel implementation.
"""
grad_enabled = True if gradient_neuron_index is not None or grad_enabled else False
grad_enabled = True if gradient_neuron_selector is not None else grad_enabled

with torch.autograd.set_grad_enabled(grad_enabled):
saved_layer = _forward_layer_distributed_eval(
Expand All @@ -450,12 +450,12 @@ def _forward_layer_eval_with_neuron_grads(
# key_list is a list of devices in appropriate ordering for concatenation.
# If only one key exists (standard model), key list simply has one element.
key_list = _sort_key_list(list(next(iter(saved_layer.values())).keys()), device_ids)
if gradient_neuron_index is not None:
if gradient_neuron_selector is not None:
assert isinstance(
layer, Module
), "Cannot compute neuron gradients for multiple layers simultaneously!"
inp_grads = _neuron_gradients(
inputs, saved_layer[layer], key_list, gradient_neuron_index
inputs, saved_layer[layer], key_list, gradient_neuron_selector
)
return (
_gather_distributed_tensors(saved_layer[layer], key_list=key_list),
Expand All @@ -479,7 +479,7 @@ def compute_layer_gradients_and_eval(
target_ind: TargetType = None,
additional_forward_args: Any = None,
*,
gradient_neuron_index: Union[int, Tuple[int, ...]],
gradient_neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
output_fn: Union[None, Callable] = None,
Expand All @@ -494,7 +494,7 @@ def compute_layer_gradients_and_eval(
inputs: Union[Tensor, Tuple[Tensor, ...]],
target_ind: TargetType = None,
additional_forward_args: Any = None,
gradient_neuron_index: None = None,
gradient_neuron_selector: None = None,
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
output_fn: Union[None, Callable] = None,
Expand All @@ -509,7 +509,7 @@ def compute_layer_gradients_and_eval(
inputs: Union[Tensor, Tuple[Tensor, ...]],
target_ind: TargetType = None,
additional_forward_args: Any = None,
gradient_neuron_index: None = None,
gradient_neuron_selector: None = None,
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
output_fn: Union[None, Callable] = None,
Expand All @@ -523,7 +523,9 @@ def compute_layer_gradients_and_eval(
inputs: Union[Tensor, Tuple[Tensor, ...]],
target_ind: TargetType = None,
additional_forward_args: Any = None,
gradient_neuron_index: Union[None, int, Tuple[int, ...]] = None,
gradient_neuron_selector: Union[
None, int, Tuple[Union[int, slice], ...], Callable
] = None,
device_ids: Union[None, List[int]] = None,
attribute_to_layer_input: bool = False,
output_fn: Union[None, Callable] = None,
Expand Down Expand Up @@ -659,12 +661,12 @@ def compute_layer_gradients_and_eval(
if isinstance(layer, Module):
layer_grads = all_grads[0]

if gradient_neuron_index is not None:
if gradient_neuron_selector is not None:
assert isinstance(
layer, Module
), "Cannot compute neuron gradients for multiple layers simultaneously!"
inp_grads = _neuron_gradients(
inputs, saved_layer[layer], key_list, gradient_neuron_index
inputs, saved_layer[layer], key_list, gradient_neuron_selector
)
return (
cast(Tuple[Tensor, ...], layer_grads),
Expand All @@ -676,7 +678,7 @@ def compute_layer_gradients_and_eval(

def construct_neuron_grad_fn(
layer: Module,
neuron_index: Union[int, Tuple[Union[int, slice], ...]],
neuron_selector: Union[int, Tuple[Union[int, slice], ...], Callable],
device_ids: Union[None, List[int]] = None,
attribute_to_neuron_input: bool = False,
) -> Callable:
Expand All @@ -691,7 +693,7 @@ def grad_fn(
inputs,
layer,
additional_forward_args,
gradient_neuron_index=neuron_index,
gradient_neuron_selector=neuron_selector,
device_ids=device_ids,
attribute_to_layer_input=attribute_to_neuron_input,
)
Expand Down
24 changes: 15 additions & 9 deletions captum/attr/_core/neuron/neuron_conductance.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
from ..._utils.approximation_methods import approximation_parameters
from ..._utils.attribution import GradientAttribution, NeuronAttribution
from ..._utils.batching import _batch_attribution
from ..._utils.common import _format_input_baseline, _reshape_and_sum, _validate_input
from ..._utils.common import (
_format_input_baseline,
_reshape_and_sum,
_validate_input,
neuron_index_deprecation_decorator,
)


class NeuronConductance(NeuronAttribution, GradientAttribution):
Expand All @@ -46,7 +51,7 @@ def __init__(
modification of it
layer (torch.nn.Module): Layer for which neuron attributions are computed.
Attributions for a particular neuron in the input or output
of this layer are computed using the argument neuron_index
of this layer are computed using the argument neuron_selector
in the attribute method.
Currently, only layers with a single tensor input or output
are supported.
Expand Down Expand Up @@ -85,10 +90,11 @@ def __init__(
self._multiply_by_inputs = multiply_by_inputs

@log_usage()
@neuron_index_deprecation_decorator
def attribute(
self,
inputs: TensorOrTupleOfTensorsGeneric,
neuron_index: Union[int, Tuple[int, ...]],
neuron_selector: Union[int, Tuple[int, ...]],
baselines: BaselineType = None,
target: TargetType = None,
additional_forward_args: Any = None,
Expand All @@ -108,7 +114,7 @@ def attribute(
that for all given input tensors, dimension 0 corresponds
to the number of examples, and if multiple input tensors
are provided, the examples must be aligned appropriately.
neuron_index (int or tuple): Index of neuron in output of given
neuron_selector (int or tuple): Index of neuron in output of given
layer for which attribution is desired. Length of
this tuple must be one less than the number of
dimensions in the output of the given layer (since
Expand Down Expand Up @@ -260,7 +266,7 @@ def attribute(
n_steps,
inputs=inputs,
baselines=baselines,
neuron_index=neuron_index,
neuron_selector=neuron_selector,
target=target,
additional_forward_args=additional_forward_args,
method=method,
Expand All @@ -269,7 +275,7 @@ def attribute(
else:
attrs = self._attribute(
inputs=inputs,
neuron_index=neuron_index,
neuron_selector=neuron_selector,
baselines=baselines,
target=target,
additional_forward_args=additional_forward_args,
Expand All @@ -282,7 +288,7 @@ def attribute(
def _attribute(
self,
inputs: Tuple[Tensor, ...],
neuron_index: Union[int, Tuple[int, ...]],
neuron_selector: Union[int, Tuple[int, ...]],
baselines: Tuple[Union[Tensor, int, float], ...],
target: TargetType = None,
additional_forward_args: Any = None,
Expand Down Expand Up @@ -333,7 +339,7 @@ def _attribute(
inputs=scaled_features_tpl,
target_ind=expanded_target,
additional_forward_args=input_additional_args,
gradient_neuron_index=neuron_index,
gradient_neuron_selector=neuron_selector,
device_ids=self.device_ids,
attribute_to_layer_input=attribute_to_neuron_input,
)
Expand All @@ -348,7 +354,7 @@ def _attribute(
# Multiplies by appropriate gradient of output with respect to hidden neurons
# mid_grads is a 1D Tensor of length num_steps*internal_batch_size,
# containing mid layer gradient for each input step.
mid_grads = _verify_select_column(layer_gradients, neuron_index)
mid_grads = _verify_select_column(layer_gradients, neuron_selector)

scaled_input_gradients = tuple(
input_grad
Expand Down
Loading

0 comments on commit aae5228

Please sign in to comment.