-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit b54cf83
Showing
13 changed files
with
1,051 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
.vscode/ | ||
__pycache__ | ||
baseline/ | ||
noise/ | ||
cifar/ | ||
*.sh | ||
scratchpad.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Attribution-based Confidence metric for deep neural networks | ||
## Decription | ||
|
||
This repository provides a PyTorch implementation of the abc metric as decribed in this [paper](https://proceedings.neurips.cc/paper/2019/file/bc1ad6e8f86c42a371aff945535baebb-Paper.pdf) | ||
|
||
## Installation | ||
|
||
1. Download or clone the repository | ||
2. Install the requirements | ||
|
||
## Usage | ||
|
||
You can specify the directory for dataset download by setting the DATASETS_ROOT environment variable. | ||
|
||
Scripts for MNIST (with and without background noise) are provided: | ||
```bash | ||
export DATASETS_ROOT="/tmp" | ||
python mnist_baseline.py | ||
python mnist_noise.py | ||
``` | ||
|
||
A script for Cifar10 is also provided: | ||
```bash | ||
export DATASETS_ROOT="/tmp" | ||
python cifar10.py | ||
``` | ||
|
||
The abc metric is tested using rotated data or alpha blending between two random samples. | ||
The displayed metrics are the average abc score, the average abc score for correctly classified samples and the average abc score for misclassified samples. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
from typing import Tuple | ||
import torch | ||
import torch.nn as nn | ||
from torch.nn.functional import fold | ||
from captum.attr import IntegratedGradients | ||
|
||
|
||
def attributions( | ||
model: nn.Module, | ||
samples: torch.Tensor, | ||
n_steps: int = 50, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
"""Compute attributions for given samples using integrated gradients | ||
Args: | ||
model (nn.Module): Trained model | ||
samples (torch.Tensor): Samples | ||
n_steps (int, optional): Discretization step for integrated gradient. Defaults to 50. | ||
Returns: | ||
Tuple[torch.Tensor, torch.Tensor]: model predictions, attributions | ||
""" | ||
baseline = torch.zeros_like(samples) | ||
model.eval() | ||
with torch.no_grad(): | ||
predictions = model(samples).squeeze().argmax(dim=1) | ||
|
||
integratedGrads = IntegratedGradients(model.forward, False) | ||
attrs = integratedGrads.attribute(samples, baseline, predictions, n_steps=n_steps) | ||
return predictions, attrs | ||
|
||
|
||
def abc_metric( | ||
model: nn.Module, samples: torch.Tensor, targets: torch.Tensor, metricParams: dict | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
"""Compute the ABC score on a batch of samples. Absolute attributions are first divided by pixel values. | ||
These raw sensitivity images are then split into small patches. Each pixel is then assigned a probability | ||
(that sums up to 1 inside every patch). The patches are reassembled to generate a full probability image. | ||
The probability image gives Bernouilli parameters for conformity assessment over a given number of samples | ||
for abc metric assessment. | ||
The overall average abc score is returned, along with average scores for correctly classified samples and | ||
misclassified samples respectively. | ||
Args: | ||
model (nn.Module): Trained model | ||
samples (torch.Tensor): Samples | ||
targets (torch.Tensor): True targets | ||
metricParams (dict): ABC metric computation parameters | ||
Returns: | ||
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: overall score, score for correct classifications, | ||
score for incorrect classifications | ||
""" | ||
n_steps = metricParams["n_steps"] | ||
minPixelValue = metricParams["minPixelValue"] | ||
minProb = metricParams["minProb"] | ||
nConform = metricParams["nConform"] | ||
assert minPixelValue > 0 | ||
assert 0 <= minProb <= 1 | ||
|
||
batchSize, channels, height, width = samples.shape | ||
|
||
predictions, attrs = attributions(model, samples, n_steps=n_steps) | ||
|
||
filteredSamples = samples.clamp( | ||
min=minPixelValue | ||
) # force samples to have minPixelValue as minimum | ||
ratio = ( | ||
torch.abs(attrs / filteredSamples) | ||
.sum(dim=1) | ||
.reshape((batchSize, 1, height, width)) | ||
) # raw attribution images | ||
|
||
pSize = metricParams["patchSize"] # patch size (square patches only) | ||
|
||
ratioPatches = ratio.unfold(2, pSize, pSize).unfold( | ||
3, pSize, pSize | ||
) # splitting into patches | ||
ratioSum = ( | ||
ratioPatches.sum((4, 5)) | ||
.reshape((batchSize, 1, height // pSize, width // pSize, 1, 1)) | ||
.expand_as(ratioPatches) | ||
) # normalize by patch | ||
probsPatches = ratioPatches / ratioSum # probability patches | ||
|
||
# reassemble patches into images | ||
probsPatches = ( | ||
probsPatches.reshape( | ||
(batchSize, 1, height // pSize, width // pSize, pSize**2) | ||
) | ||
.permute(0, 1, 4, 2, 3) | ||
.squeeze(1) | ||
.reshape(batchSize, pSize**2, -1) | ||
) | ||
probs = fold( | ||
probsPatches, (height, width), kernel_size=pSize, stride=pSize | ||
) # probability images | ||
|
||
scores = [] | ||
scoresCorrect = [] | ||
scoresIncorrect = [] | ||
for sampleIdx, sample in enumerate(samples): | ||
fullMask = probs[sampleIdx].squeeze() | ||
conformBatchList = [ | ||
# sample images using the probability images | ||
sample * ~(torch.bernoulli(fullMask.clamp_min(minProb)).bool()) | ||
for _ in range(nConform) | ||
] | ||
|
||
conformBatch = torch.stack(conformBatchList, dim=0) # conform images batch | ||
|
||
with torch.no_grad(): | ||
conformPredictions = model(conformBatch).argmax( | ||
dim=1, keepdim=True | ||
) # compute new predictions | ||
predictedLabels = torch.ones_like(conformPredictions) * predictions[sampleIdx] | ||
score = ( | ||
torch.eq(predictedLabels, conformPredictions).sum() / nConform | ||
) # abc score | ||
scores.append(score) | ||
if predictions[sampleIdx] == targets[sampleIdx]: | ||
scoresCorrect.append(score) | ||
else: | ||
scoresIncorrect.append(score) | ||
|
||
return ( | ||
torch.stack(scores, dim=0), | ||
torch.stack(scoresCorrect, dim=0) if len(scoresCorrect) > 0 else None, | ||
torch.stack(scoresIncorrect, dim=0) if len(scoresIncorrect) > 0 else None, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
from argparse import ArgumentError | ||
import logging | ||
|
||
from runner import run | ||
from models import ResNet | ||
from datasets import cifar10Loader | ||
from easydict import EasyDict | ||
|
||
from train_utils import checkSavedModel | ||
|
||
logger = logging.getLogger() | ||
|
||
params = EasyDict() | ||
params.savepoint = "cifar" | ||
params.modelType = "ResNet" | ||
params.resume = True | ||
params.name = "cifar10" | ||
params.epochs = 50 | ||
params.lr = 0.001 | ||
params.batchSize = 16 | ||
params.schedule = [20, 30, 40, 45] # learning rate schedule | ||
params.gamma = 0.5 | ||
params.alpha = 0.01 | ||
params.angle = 0 | ||
params.evals = ["alphaBlending", "rotation"] | ||
params.metricParams = { | ||
"n_steps": 50, # integrated gradients discretization steps | ||
"minPixelValue": 1e-5, # minimal pixel value | ||
"minProb": 0.0, # minimal pixel switching probability | ||
"patchSize": 4, # patch size | ||
"nConform": 50, # conformance samples to generate | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
logger.info(params) | ||
if params.modelType == "ResNet": | ||
model = ResNet(3) | ||
else: | ||
raise ArgumentError(f"Invalid model type: {params.modelType}") | ||
|
||
params.modelIsTrained = checkSavedModel(params, model) | ||
|
||
run(params, model, cifar10Loader) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
from typing import Iterable, Callable, Tuple | ||
import os | ||
import torch | ||
from torch.utils.data import TensorDataset | ||
from torchvision import datasets | ||
import torchvision.transforms as transforms | ||
|
||
|
||
def _getRoot(rootDir: str = None) -> str: | ||
if rootDir is None: | ||
root = os.environ.get("DATASETS_ROOT") | ||
if root is None: | ||
root = "_data" | ||
else: | ||
root = rootDir | ||
return root | ||
|
||
|
||
def getMNIST( | ||
batchSize: int, | ||
dataDir: str = None, | ||
trainTransforms: Iterable[Callable] = None, | ||
testTransforms: Iterable[Callable] = None, | ||
) -> Tuple[TensorDataset, TensorDataset]: | ||
|
||
trainDataset = datasets.MNIST( | ||
_getRoot(dataDir), | ||
train=True, | ||
download=True, | ||
transform=trainTransforms, | ||
target_transform=transforms.Lambda(lambda y: torch.tensor(y)), | ||
) | ||
testDataset = datasets.MNIST( | ||
_getRoot(dataDir), | ||
train=False, | ||
transform=testTransforms, | ||
target_transform=transforms.Lambda(lambda y: torch.tensor(y)), | ||
) | ||
|
||
train_loader = torch.utils.data.DataLoader( | ||
trainDataset, | ||
batch_size=batchSize, | ||
shuffle=True, | ||
drop_last=True, | ||
) | ||
|
||
test_loader = torch.utils.data.DataLoader( | ||
testDataset, | ||
batch_size=batchSize, | ||
shuffle=True, | ||
drop_last=True, | ||
) | ||
return train_loader, test_loader | ||
|
||
|
||
mnistBaselineTrainTransforms = transforms.Compose( | ||
[ | ||
transforms.Resize(32), | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.RandomRotation(10), | ||
transforms.ToTensor(), | ||
] | ||
) | ||
|
||
mnistBaselineTestTransforms = transforms.Compose( | ||
[ | ||
transforms.Resize(32), | ||
transforms.ToTensor(), | ||
] | ||
) | ||
|
||
mnistNoiseTrainTransforms = transforms.Compose( | ||
[ | ||
transforms.Resize(32), | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.RandomRotation(10), | ||
transforms.ToTensor(), | ||
transforms.Lambda(lambda x: x + 0.1 * (0.5 + 0.5 * torch.randn_like(x))), | ||
] | ||
) | ||
|
||
mnistNoiseTestTransforms = transforms.Compose( | ||
[ | ||
transforms.Resize(32), | ||
transforms.ToTensor(), | ||
transforms.Lambda(lambda x: x + 0.1 * (0.5 + 0.5 * torch.randn_like(x))), | ||
] | ||
) | ||
|
||
|
||
def mnistBaselineLoader(batchSize: int): | ||
return getMNIST( | ||
batchSize, | ||
trainTransforms=mnistBaselineTrainTransforms, | ||
testTransforms=mnistBaselineTestTransforms, | ||
) | ||
|
||
|
||
def mnistNoiseLoader(batchSize: int): | ||
return getMNIST( | ||
batchSize, | ||
trainTransforms=mnistNoiseTrainTransforms, | ||
testTransforms=mnistNoiseTestTransforms, | ||
) | ||
|
||
|
||
def getCifar10( | ||
batchSize: int, | ||
dataDir: str = None, | ||
trainTransforms: Iterable[Callable] = None, | ||
testTransforms: Iterable[Callable] = None, | ||
) -> tuple: | ||
|
||
trainDataset = datasets.CIFAR10( | ||
_getRoot(dataDir), | ||
train=True, | ||
download=True, | ||
transform=trainTransforms, | ||
target_transform=transforms.Lambda(lambda y: torch.tensor(y)), | ||
) | ||
|
||
testDataset = datasets.CIFAR10( | ||
_getRoot(dataDir), | ||
train=False, | ||
transform=testTransforms, | ||
target_transform=transforms.Lambda(lambda y: torch.tensor(y)), | ||
) | ||
|
||
train_loader = torch.utils.data.DataLoader( | ||
trainDataset, | ||
batch_size=batchSize, | ||
shuffle=True, | ||
drop_last=True, | ||
) | ||
|
||
test_loader = torch.utils.data.DataLoader( | ||
testDataset, | ||
batch_size=batchSize, | ||
shuffle=True, | ||
drop_last=True, | ||
) | ||
return train_loader, test_loader | ||
|
||
|
||
cifarTrainTransforms = transforms.Compose( | ||
[ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.RandomRotation(10), | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | ||
] | ||
) | ||
|
||
cifarTestTransforms = testTransforms = transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), | ||
] | ||
) | ||
|
||
|
||
def cifar10Loader(batchSize: int): | ||
return getCifar10( | ||
batchSize, | ||
trainTransforms=cifarTrainTransforms, | ||
testTransforms=cifarTestTransforms, | ||
) |
Oops, something went wrong.