forked from tensorflow/privacy
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
A callback and a function to be called in the end of training for ker…
…as to perform membership inference attack. PiperOrigin-RevId: 323805663
1 parent
dcbfaa3
commit cea9e01
Showing
3 changed files
with
284 additions
and
0 deletions.
There are no files selected for viewing
108 changes: 108 additions & 0 deletions
108
tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright 2020, The TensorFlow Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Lint as: python3 | ||
"""A callback and a function in keras for membership inference attack.""" | ||
|
||
from absl import logging | ||
|
||
import tensorflow.compat.v1 as tf | ||
|
||
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia | ||
from tensorflow_privacy.privacy.membership_inference_attack.utils import log_loss | ||
from tensorflow_privacy.privacy.membership_inference_attack.utils import write_to_tensorboard | ||
|
||
|
||
def calculate_losses(model, data, labels): | ||
"""Calculate losses of model prediction on data, provided true labels. | ||
Args: | ||
model: model to make prediction | ||
data: samples | ||
labels: true labels of samples (integer valued) | ||
Returns: | ||
preds: probability vector of each sample | ||
loss: cross entropy loss of each sample | ||
""" | ||
pred = model.predict(data) | ||
loss = log_loss(labels, pred) | ||
return pred, loss | ||
|
||
|
||
class MembershipInferenceCallback(tf.keras.callbacks.Callback): | ||
"""Callback to perform membership inference attack on epoch end.""" | ||
|
||
def __init__(self, in_train, out_train, attack_classifiers, | ||
tensorboard_dir=None): | ||
"""Initalizes the callback. | ||
Args: | ||
in_train: (in_training samples, in_training labels) | ||
out_train: (out_training samples, out_training labels) | ||
attack_classifiers: a list of classifiers to be used by attacker, must be | ||
a subset of ['lr', 'mlp', 'rf', 'knn'] | ||
tensorboard_dir: directory for tensorboard summary | ||
""" | ||
self._in_train_data, self._in_train_labels = in_train | ||
self._out_train_data, self._out_train_labels = out_train | ||
self._attack_classifiers = attack_classifiers | ||
# Setup tensorboard writer if tensorboard_dir is specified | ||
if tensorboard_dir: | ||
with tf.Graph().as_default(): | ||
self._writer = tf.summary.FileWriter(tensorboard_dir) | ||
logging.info('Will write to tensorboard.') | ||
else: | ||
self._writer = None | ||
|
||
def on_epoch_end(self, epoch, logs=None): | ||
results = run_attack_on_keras_model( | ||
self.model, | ||
(self._in_train_data, self._in_train_labels), | ||
(self._out_train_data, self._out_train_labels), | ||
self._attack_classifiers) | ||
print('all_thresh_loss_advantage', results['all_thresh_loss_advantage']) | ||
logging.info(results) | ||
|
||
# Write to tensorboard if tensorboard_dir is specified | ||
write_to_tensorboard(self._writer, ['attack advantage'], | ||
[results['all_thresh_loss_advantage']], epoch) | ||
|
||
|
||
def run_attack_on_keras_model(model, in_train, out_train, attack_classifiers): | ||
"""Performs the attack on a trained model. | ||
Args: | ||
model: model to be tested | ||
in_train: a (in_training samples, in_training labels) tuple | ||
out_train: a (out_training samples, out_training labels) tuple | ||
attack_classifiers: a list of classifiers to be used by attacker, must be | ||
a subset of ['lr', 'mlp', 'rf', 'knn'] | ||
Returns: | ||
Results of the attack | ||
""" | ||
in_train_data, in_train_labels = in_train | ||
out_train_data, out_train_labels = out_train | ||
|
||
# Compute predictions and losses | ||
in_train_pred, in_train_loss = calculate_losses(model, in_train_data, | ||
in_train_labels) | ||
out_train_pred, out_train_loss = calculate_losses(model, out_train_data, | ||
out_train_labels) | ||
results = mia.run_all_attacks(in_train_loss, out_train_loss, | ||
in_train_pred, out_train_pred, | ||
in_train_labels, out_train_labels, | ||
attack_classifiers=attack_classifiers) | ||
return results | ||
|
104 changes: 104 additions & 0 deletions
104
tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_example.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright 2020, The TensorFlow Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Lint as: python3 | ||
"""An example for using keras_evaluation.""" | ||
|
||
from absl import app | ||
from absl import flags | ||
|
||
import numpy as np | ||
import tensorflow.compat.v1 as tf | ||
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import MembershipInferenceCallback | ||
from tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation import run_attack_on_keras_model | ||
|
||
GradientDescentOptimizer = tf.train.GradientDescentOptimizer | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
flags.DEFINE_float('learning_rate', .15, 'Learning rate for training') | ||
flags.DEFINE_integer('batch_size', 256, 'Batch size') | ||
flags.DEFINE_integer('epochs', 10, 'Number of epochs') | ||
flags.DEFINE_string('model_dir', None, 'Model directory.') | ||
|
||
|
||
def cnn_model(): | ||
"""Define a CNN model.""" | ||
model = tf.keras.Sequential([ | ||
tf.keras.layers.Conv2D(16, 8, strides=2, padding='same', | ||
activation='relu', input_shape=(28, 28, 1)), | ||
tf.keras.layers.MaxPool2D(2, 1), | ||
tf.keras.layers.Conv2D(32, 4, strides=2, padding='valid', | ||
activation='relu'), | ||
tf.keras.layers.MaxPool2D(2, 1), | ||
tf.keras.layers.Flatten(), | ||
tf.keras.layers.Dense(32, activation='relu'), | ||
tf.keras.layers.Dense(10) | ||
]) | ||
return model | ||
|
||
|
||
def load_mnist(): | ||
"""Loads MNIST and preprocesses to combine training and validation data.""" | ||
(train_data, | ||
train_labels), (test_data, | ||
test_labels) = tf.keras.datasets.mnist.load_data() | ||
|
||
train_data = np.array(train_data, dtype=np.float32) / 255 | ||
test_data = np.array(test_data, dtype=np.float32) / 255 | ||
|
||
train_data = train_data.reshape((train_data.shape[0], 28, 28, 1)) | ||
test_data = test_data.reshape((test_data.shape[0], 28, 28, 1)) | ||
|
||
train_labels = np.array(train_labels, dtype=np.int32) | ||
test_labels = np.array(test_labels, dtype=np.int32) | ||
|
||
return train_data, train_labels, test_data, test_labels | ||
|
||
|
||
def main(unused_argv): | ||
# Load training and test data. | ||
train_data, train_labels, test_data, test_labels = load_mnist() | ||
|
||
# Get model, optimizer and specify loss. | ||
model = cnn_model() | ||
optimizer = GradientDescentOptimizer(learning_rate=FLAGS.learning_rate) | ||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) | ||
model.compile(optimizer=optimizer, loss=loss, metrics=['accuracy']) | ||
|
||
# Get callback for membership inference attack. | ||
mia_callback = MembershipInferenceCallback((train_data, train_labels), | ||
(test_data, test_labels), | ||
[], | ||
FLAGS.model_dir) | ||
|
||
# Train model with Keras | ||
model.fit(train_data, train_labels, | ||
epochs=FLAGS.epochs, | ||
validation_data=(test_data, test_labels), | ||
batch_size=FLAGS.batch_size, | ||
callbacks=[mia_callback], | ||
verbose=2) | ||
|
||
print('End of training attack') | ||
attack_results = run_attack_on_keras_model(model, | ||
(train_data, train_labels), | ||
(test_data, test_labels), | ||
[]) | ||
print('all_thresh_loss_advantage', | ||
attack_results['all_thresh_loss_advantage']) | ||
|
||
|
||
if __name__ == '__main__': | ||
app.run(main) |
72 changes: 72 additions & 0 deletions
72
tensorflow_privacy/privacy/membership_inference_attack/keras_evaluation_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Copyright 2020, The TensorFlow Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# Lint as: python3 | ||
"""Tests for tensorflow_privacy.privacy.membership_inference_attack.keras_evaluation.""" | ||
|
||
from absl.testing import absltest | ||
|
||
import numpy as np | ||
import tensorflow.compat.v1 as tf | ||
|
||
from tensorflow_privacy.privacy.membership_inference_attack import keras_evaluation | ||
|
||
|
||
class UtilsTest(absltest.TestCase): | ||
|
||
def __init__(self, methodname): | ||
"""Initialize the test class.""" | ||
super().__init__(methodname) | ||
|
||
self.ntrain, self.ntest = 50, 100 | ||
self.nclass = 5 | ||
self.ndim = 10 | ||
|
||
# Generate random training and test data | ||
self.train_data = np.random.rand(self.ntrain, self.ndim) | ||
self.test_data = np.random.rand(self.ntest, self.ndim) | ||
self.train_labels = np.random.randint(self.nclass, size=self.ntrain) | ||
self.test_labels = np.random.randint(self.nclass, size=self.ntest) | ||
|
||
self.model = tf.keras.Sequential([tf.keras.layers.Dense(self.nclass)]) | ||
|
||
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) | ||
self.model.compile(optimizer='Adam', loss=loss, metrics=['accuracy']) | ||
|
||
def test_calculate_losses(self): | ||
"""Test calculating the loss.""" | ||
pred, loss = keras_evaluation.calculate_losses(self.model, self.train_data, | ||
self.train_labels) | ||
self.assertEqual(pred.shape, (self.ntrain, self.nclass)) | ||
self.assertEqual(loss.shape, (self.ntrain,)) | ||
|
||
pred, loss = keras_evaluation.calculate_losses(self.model, self.test_data, | ||
self.test_labels) | ||
self.assertEqual(pred.shape, (self.ntest, self.nclass)) | ||
self.assertEqual(loss.shape, (self.ntest,)) | ||
|
||
def test_run_attack_on_keras_model(self): | ||
"""Test the attack.""" | ||
results = keras_evaluation.run_attack_on_keras_model( | ||
self.model, | ||
(self.train_data, self.train_labels), | ||
(self.test_data, self.test_labels), | ||
[]) | ||
self.assertIsInstance(results, dict) | ||
self.assertIn('all_thresh_loss_auc', results) | ||
self.assertIn('all_thresh_loss_advantage', results) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |