PyTorch implementation of NoPeekNN, currently trained on MNIST. We implement a black box model inversion attack to validate the protection afforded by NoPeek.
NoPeekNN is an extension to SplitNNs to preserve privacy. While SplitNNs do not send raw data to a potentially untrustworthy central server, it has been shown that raw data can be reverse engineered from the model. NoPeekNN attempts to limit this by training the model to produce an intermediate data representation (sent between model parts) as distinct from the input data as possible while retaining the information necessary to successfully complete the task.
This is achieved by adding a term to the loss function which minimises distance covariance between the input and intermediate data.
This code has been written in python 3.7
and PyTorch 1.5.0
,
however other versions may work.
If using conda, run
conda env create -f environment.yml
to create an environment,
nopeek
,
with all the necessary packages.
Run
conda env create -f environment-lock.yml
to create an environment with the exact package versions used to develop this code.
Run python main.py --nopeek_weight 0.1
to train a SplitNN model
on MNIST
with a weighting of 0.1
for NoPeek loss.
See other optional arguments with python main.py --help
.
NoPeek loss is computationally demanding and scales with the size of a data batch, so it is recommended to stick to small (<64) batch sizes.