Skip to content

Commit

Permalink
First commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ma3oun committed Mar 7, 2022
0 parents commit b54cf83
Show file tree
Hide file tree
Showing 13 changed files with 1,051 additions and 0 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.vscode/
__pycache__
baseline/
noise/
cifar/
*.sh
scratchpad.py
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Attribution-based Confidence metric for deep neural networks
## Decription

This repository provides a PyTorch implementation of the abc metric as decribed in this [paper](https://proceedings.neurips.cc/paper/2019/file/bc1ad6e8f86c42a371aff945535baebb-Paper.pdf)

## Installation

1. Download or clone the repository
2. Install the requirements

## Usage

You can specify the directory for dataset download by setting the DATASETS_ROOT environment variable.

Scripts for MNIST (with and without background noise) are provided:
```bash
export DATASETS_ROOT="/tmp"
python mnist_baseline.py
python mnist_noise.py
```

A script for Cifar10 is also provided:
```bash
export DATASETS_ROOT="/tmp"
python cifar10.py
```

The abc metric is tested using rotated data or alpha blending between two random samples.
The displayed metrics are the average abc score, the average abc score for correctly classified samples and the average abc score for misclassified samples.
130 changes: 130 additions & 0 deletions abc_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from typing import Tuple
import torch
import torch.nn as nn
from torch.nn.functional import fold
from captum.attr import IntegratedGradients


def attributions(
model: nn.Module,
samples: torch.Tensor,
n_steps: int = 50,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute attributions for given samples using integrated gradients
Args:
model (nn.Module): Trained model
samples (torch.Tensor): Samples
n_steps (int, optional): Discretization step for integrated gradient. Defaults to 50.
Returns:
Tuple[torch.Tensor, torch.Tensor]: model predictions, attributions
"""
baseline = torch.zeros_like(samples)
model.eval()
with torch.no_grad():
predictions = model(samples).squeeze().argmax(dim=1)

integratedGrads = IntegratedGradients(model.forward, False)
attrs = integratedGrads.attribute(samples, baseline, predictions, n_steps=n_steps)
return predictions, attrs


def abc_metric(
model: nn.Module, samples: torch.Tensor, targets: torch.Tensor, metricParams: dict
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute the ABC score on a batch of samples. Absolute attributions are first divided by pixel values.
These raw sensitivity images are then split into small patches. Each pixel is then assigned a probability
(that sums up to 1 inside every patch). The patches are reassembled to generate a full probability image.
The probability image gives Bernouilli parameters for conformity assessment over a given number of samples
for abc metric assessment.
The overall average abc score is returned, along with average scores for correctly classified samples and
misclassified samples respectively.
Args:
model (nn.Module): Trained model
samples (torch.Tensor): Samples
targets (torch.Tensor): True targets
metricParams (dict): ABC metric computation parameters
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: overall score, score for correct classifications,
score for incorrect classifications
"""
n_steps = metricParams["n_steps"]
minPixelValue = metricParams["minPixelValue"]
minProb = metricParams["minProb"]
nConform = metricParams["nConform"]
assert minPixelValue > 0
assert 0 <= minProb <= 1

batchSize, channels, height, width = samples.shape

predictions, attrs = attributions(model, samples, n_steps=n_steps)

filteredSamples = samples.clamp(
min=minPixelValue
) # force samples to have minPixelValue as minimum
ratio = (
torch.abs(attrs / filteredSamples)
.sum(dim=1)
.reshape((batchSize, 1, height, width))
) # raw attribution images

pSize = metricParams["patchSize"] # patch size (square patches only)

ratioPatches = ratio.unfold(2, pSize, pSize).unfold(
3, pSize, pSize
) # splitting into patches
ratioSum = (
ratioPatches.sum((4, 5))
.reshape((batchSize, 1, height // pSize, width // pSize, 1, 1))
.expand_as(ratioPatches)
) # normalize by patch
probsPatches = ratioPatches / ratioSum # probability patches

# reassemble patches into images
probsPatches = (
probsPatches.reshape(
(batchSize, 1, height // pSize, width // pSize, pSize**2)
)
.permute(0, 1, 4, 2, 3)
.squeeze(1)
.reshape(batchSize, pSize**2, -1)
)
probs = fold(
probsPatches, (height, width), kernel_size=pSize, stride=pSize
) # probability images

scores = []
scoresCorrect = []
scoresIncorrect = []
for sampleIdx, sample in enumerate(samples):
fullMask = probs[sampleIdx].squeeze()
conformBatchList = [
# sample images using the probability images
sample * ~(torch.bernoulli(fullMask.clamp_min(minProb)).bool())
for _ in range(nConform)
]

conformBatch = torch.stack(conformBatchList, dim=0) # conform images batch

with torch.no_grad():
conformPredictions = model(conformBatch).argmax(
dim=1, keepdim=True
) # compute new predictions
predictedLabels = torch.ones_like(conformPredictions) * predictions[sampleIdx]
score = (
torch.eq(predictedLabels, conformPredictions).sum() / nConform
) # abc score
scores.append(score)
if predictions[sampleIdx] == targets[sampleIdx]:
scoresCorrect.append(score)
else:
scoresIncorrect.append(score)

return (
torch.stack(scores, dim=0),
torch.stack(scoresCorrect, dim=0) if len(scoresCorrect) > 0 else None,
torch.stack(scoresIncorrect, dim=0) if len(scoresIncorrect) > 0 else None,
)
44 changes: 44 additions & 0 deletions cifar10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from argparse import ArgumentError
import logging

from runner import run
from models import ResNet
from datasets import cifar10Loader
from easydict import EasyDict

from train_utils import checkSavedModel

logger = logging.getLogger()

params = EasyDict()
params.savepoint = "cifar"
params.modelType = "ResNet"
params.resume = True
params.name = "cifar10"
params.epochs = 50
params.lr = 0.001
params.batchSize = 16
params.schedule = [20, 30, 40, 45] # learning rate schedule
params.gamma = 0.5
params.alpha = 0.01
params.angle = 0
params.evals = ["alphaBlending", "rotation"]
params.metricParams = {
"n_steps": 50, # integrated gradients discretization steps
"minPixelValue": 1e-5, # minimal pixel value
"minProb": 0.0, # minimal pixel switching probability
"patchSize": 4, # patch size
"nConform": 50, # conformance samples to generate
}


if __name__ == "__main__":
logger.info(params)
if params.modelType == "ResNet":
model = ResNet(3)
else:
raise ArgumentError(f"Invalid model type: {params.modelType}")

params.modelIsTrained = checkSavedModel(params, model)

run(params, model, cifar10Loader)
170 changes: 170 additions & 0 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
from typing import Iterable, Callable, Tuple
import os
import torch
from torch.utils.data import TensorDataset
from torchvision import datasets
import torchvision.transforms as transforms


def _getRoot(rootDir: str = None) -> str:
if rootDir is None:
root = os.environ.get("DATASETS_ROOT")
if root is None:
root = "_data"
else:
root = rootDir
return root


def getMNIST(
batchSize: int,
dataDir: str = None,
trainTransforms: Iterable[Callable] = None,
testTransforms: Iterable[Callable] = None,
) -> Tuple[TensorDataset, TensorDataset]:

trainDataset = datasets.MNIST(
_getRoot(dataDir),
train=True,
download=True,
transform=trainTransforms,
target_transform=transforms.Lambda(lambda y: torch.tensor(y)),
)
testDataset = datasets.MNIST(
_getRoot(dataDir),
train=False,
transform=testTransforms,
target_transform=transforms.Lambda(lambda y: torch.tensor(y)),
)

train_loader = torch.utils.data.DataLoader(
trainDataset,
batch_size=batchSize,
shuffle=True,
drop_last=True,
)

test_loader = torch.utils.data.DataLoader(
testDataset,
batch_size=batchSize,
shuffle=True,
drop_last=True,
)
return train_loader, test_loader


mnistBaselineTrainTransforms = transforms.Compose(
[
transforms.Resize(32),
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
]
)

mnistBaselineTestTransforms = transforms.Compose(
[
transforms.Resize(32),
transforms.ToTensor(),
]
)

mnistNoiseTrainTransforms = transforms.Compose(
[
transforms.Resize(32),
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Lambda(lambda x: x + 0.1 * (0.5 + 0.5 * torch.randn_like(x))),
]
)

mnistNoiseTestTransforms = transforms.Compose(
[
transforms.Resize(32),
transforms.ToTensor(),
transforms.Lambda(lambda x: x + 0.1 * (0.5 + 0.5 * torch.randn_like(x))),
]
)


def mnistBaselineLoader(batchSize: int):
return getMNIST(
batchSize,
trainTransforms=mnistBaselineTrainTransforms,
testTransforms=mnistBaselineTestTransforms,
)


def mnistNoiseLoader(batchSize: int):
return getMNIST(
batchSize,
trainTransforms=mnistNoiseTrainTransforms,
testTransforms=mnistNoiseTestTransforms,
)


def getCifar10(
batchSize: int,
dataDir: str = None,
trainTransforms: Iterable[Callable] = None,
testTransforms: Iterable[Callable] = None,
) -> tuple:

trainDataset = datasets.CIFAR10(
_getRoot(dataDir),
train=True,
download=True,
transform=trainTransforms,
target_transform=transforms.Lambda(lambda y: torch.tensor(y)),
)

testDataset = datasets.CIFAR10(
_getRoot(dataDir),
train=False,
transform=testTransforms,
target_transform=transforms.Lambda(lambda y: torch.tensor(y)),
)

train_loader = torch.utils.data.DataLoader(
trainDataset,
batch_size=batchSize,
shuffle=True,
drop_last=True,
)

test_loader = torch.utils.data.DataLoader(
testDataset,
batch_size=batchSize,
shuffle=True,
drop_last=True,
)
return train_loader, test_loader


cifarTrainTransforms = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)

cifarTestTransforms = testTransforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
)


def cifar10Loader(batchSize: int):
return getCifar10(
batchSize,
trainTransforms=cifarTrainTransforms,
testTransforms=cifarTestTransforms,
)
Loading

0 comments on commit b54cf83

Please sign in to comment.