Skip to content

Commit

Permalink
Replaces predict with predict_proba.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 326227257
tensorflower-gardener committed Aug 12, 2020
1 parent 59192e6 commit f8515df
Showing 2 changed files with 7 additions and 26 deletions.
31 changes: 6 additions & 25 deletions tensorflow_privacy/privacy/membership_inference_attack/models.py
Original file line number Diff line number Diff line change
@@ -115,8 +115,13 @@ def predict(self, input_features):
Args:
input_features : A vector of features with the same semantics as x_train
passed to train_model.
Returns:
An array of probabilities denoting whether the example belongs to test.
"""
raise NotImplementedError()
if self.model is None:
raise AssertionError(
'Model not trained yet. Please call train_model first.')
return self.model.predict_proba(input_features)[:, 1]


class LogisticRegressionAttacker(TrainedAttacker):
@@ -132,12 +137,6 @@ def train_model(self, input_features, is_training_labels):
model.fit(input_features, is_training_labels)
self.model = model

def predict(self, input_features):
if self.model is None:
raise AssertionError(
'Model not trained yet. Please call train_model first.')
return self.model.predict(input_features)


class MultilayerPerceptronAttacker(TrainedAttacker):
"""Multilayer perceptron attacker."""
@@ -155,12 +154,6 @@ def train_model(self, input_features, is_training_labels):
model.fit(input_features, is_training_labels)
self.model = model

def predict(self, input_features):
if self.model is None:
raise AssertionError(
'Model not trained yet. Please call train_model first.')
return self.model.predict(input_features)


class RandomForestAttacker(TrainedAttacker):
"""Random forest attacker."""
@@ -182,12 +175,6 @@ def train_model(self, input_features, is_training_labels):
model.fit(input_features, is_training_labels)
self.model = model

def predict(self, input_features):
if self.model is None:
raise AssertionError(
'Model not trained yet. Please call train_model first.')
return self.model.predict(input_features)


class KNearestNeighborsAttacker(TrainedAttacker):
"""K nearest neighbor attacker."""
@@ -201,9 +188,3 @@ def train_model(self, input_features, is_training_labels):
knn_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0)
model.fit(input_features, is_training_labels)
self.model = model

def predict(self, input_features):
if self.model is None:
raise AssertionError(
'Model not trained yet. Please call train_model first.')
return self.model.predict(input_features)
Original file line number Diff line number Diff line change
@@ -26,7 +26,7 @@ class TrainedAttackerTest(absltest.TestCase):
def test_base_attacker_train_and_predict(self):
base_attacker = models.TrainedAttacker()
self.assertRaises(NotImplementedError, base_attacker.train_model, [], [])
self.assertRaises(NotImplementedError, base_attacker.predict, [])
self.assertRaises(AssertionError, base_attacker.predict, [])

def test_predict_before_training(self):
lr_attacker = models.LogisticRegressionAttacker()

0 comments on commit f8515df

Please sign in to comment.