Skip to content

Interpretability Methods for tf.keras models with Tensorflow 2.x

License

Notifications You must be signed in to change notification settings

a143416/tf-explain

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

tf-explain

Pypi Version Build Status Documentation Status Python Versions Tensorflow Versions Code style: black

tf-explain implements interpretability methods as Tensorflow 2.x callbacks to ease neural network's understanding. See Introducing tf-explain, Interpretability for Tensorflow 2.0

Documentation: https://tf-explain.readthedocs.io

Installation

tf-explain is available on PyPi as an alpha release. To install it:

virtualenv venv -p python3.6
pip install tf-explain

tf-explain is compatible with Tensorflow 2.x. It is not declared as a dependency to let you choose between full and standalone-CPU versions. Additionally to the previous install, run:

# For CPU or GPU
pip install tensorflow==2.2.0

Opencv is also a dependency. To install it, run:

# For CPU or GPU
pip install opencv-python

Quickstart

tf-explain offers 2 ways to apply interpretability methods. The full list of methods is the Available Methods section.

On trained model

The best option is probably to load a trained model and apply the methods on it.

# Load pretrained model or your own
model = tf.keras.applications.vgg16.VGG16(weights="imagenet", include_top=True)

# Load a sample image (or multiple ones)
img = tf.keras.preprocessing.image.load_img(IMAGE_PATH, target_size=(224, 224))
img = tf.keras.preprocessing.image.img_to_array(img)
data = ([img], None)

# Start explainer
explainer = GradCAM()
grid = explainer.explain(data, model, class_index=281)  # 281 is the tabby cat index in ImageNet

explainer.save(grid, ".", "grad_cam.png")

During training

If you want to follow your model during the training, you can also use it as a Keras Callback, and see the results directly in TensorBoard.

from tf_explain.callbacks.grad_cam import GradCAMCallback

model = [...]

callbacks = [
    GradCAMCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        output_dir=output_dir,
    )
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Available Methods

  1. Activations Visualization
  2. Vanilla Gradients
  3. Gradients*Inputs
  4. Occlusion Sensitivity
  5. Grad CAM (Class Activation Maps)
  6. SmoothGrad
  7. Integrated Gradients

Activations Visualization

Visualize how a given input comes out of a specific activation layer

from tf_explain.callbacks.activations_visualization import ActivationsVisualizationCallback

model = [...]

callbacks = [
    ActivationsVisualizationCallback(
        validation_data=(x_val, y_val),
        layers_name=["activation_1"],
        output_dir=output_dir,
    ),
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Vanilla Gradients

Visualize gradients importance on input image

from tf_explain.callbacks.vanilla_gradients import VanillaGradientsCallback

model = [...]

callbacks = [
    VanillaGradientsCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        output_dir=output_dir,
    ),
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Gradients*Inputs

Variant of Vanilla Gradients ponderating gradients with input values

from tf_explain.callbacks.gradients_inputs import GradientsInputsCallback

model = [...]

callbacks = [
    GradientsInputsCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        output_dir=output_dir,
    ),
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Occlusion Sensitivity

Visualize how parts of the image affects neural network's confidence by occluding parts iteratively

from tf_explain.callbacks.occlusion_sensitivity import OcclusionSensitivityCallback

model = [...]

callbacks = [
    OcclusionSensitivityCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        patch_size=4,
        output_dir=output_dir,
    ),
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Occlusion Sensitivity for Tabby class (stripes differentiate tabby cat from other ImageNet cat classes)

Grad CAM

Visualize how parts of the image affects neural network's output by looking into the activation maps

From Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization

from tf_explain.callbacks.grad_cam import GradCAMCallback

model = [...]

callbacks = [
    GradCAMCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        output_dir=output_dir,
    )
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

SmoothGrad

Visualize stabilized gradients on the inputs towards the decision

From SmoothGrad: removing noise by adding noise

from tf_explain.callbacks.smoothgrad import SmoothGradCallback

model = [...]

callbacks = [
    SmoothGradCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        num_samples=20,
        noise=1.,
        output_dir=output_dir,
    )
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Integrated Gradients

Visualize an average of the gradients along the construction of the input towards the decision

From Axiomatic Attribution for Deep Networks

from tf_explain.callbacks.integrated_gradients import IntegratedGradientsCallback

model = [...]

callbacks = [
    IntegratedGradientsCallback(
        validation_data=(x_val, y_val),
        class_index=0,
        n_steps=20,
        output_dir=output_dir,
    )
]

model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Roadmap

Contributing

To contribute to the project, please read the dedicated section.

About

Interpretability Methods for tf.keras models with Tensorflow 2.x

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 99.8%
  • Makefile 0.2%