Skip to content

Commit

Permalink
Merge pull request cleverhans-lab#356 from goodfeli/dtype
Browse files Browse the repository at this point in the history
support float64
  • Loading branch information
npapernot authored Feb 7, 2018
2 parents 6e59120 + a89f8c8 commit 1a3cf27
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion cleverhans/attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def generate(self, x, **kwargs):

# Normalize current gradient and add it to the accumulated gradient
red_ind = list(xrange(1, len(grad.get_shape())))
avoid_zero_div = 1e-12
avoid_zero_div = tf.cast(1e-12, grad.dtype)
grad = grad / tf.maximum(avoid_zero_div,
tf.reduce_mean(tf.abs(grad),
red_ind,
Expand Down
10 changes: 8 additions & 2 deletions tests_tf/test_attacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ class SimpleModel(Model):
"""

def get_logits(self, x):
W1 = tf.constant([[1.5, .3], [-2, 0.3]], dtype=tf.float32)
W1 = tf.constant([[1.5, .3], [-2, 0.3]], dtype=x.dtype)
h1 = tf.nn.sigmoid(tf.matmul(x, W1))
W2 = tf.constant([[-2.4, 1.2], [0.5, -2.3]], dtype=tf.float32)
W2 = tf.constant([[-2.4, 1.2], [0.5, -2.3]], dtype=x.dtype)

res = tf.matmul(h1, W2)
return res

Expand Down Expand Up @@ -174,6 +175,11 @@ def test_generate_np_gives_adversarial_example_l1(self):
def test_generate_np_gives_adversarial_example_l2(self):
self.help_generate_np_gives_adversarial_example(2)

def test_generate_respects_dtype(self):
x = tf.placeholder(dtype=tf.float64, shape=(100, 2))
x_adv = self.attack.generate(x)
self.assertEqual(x_adv.dtype, tf.float64)

def test_targeted_generate_np_gives_adversarial_example(self):
x_val = np.random.rand(100, 2)
x_val = np.array(x_val, dtype=np.float32)
Expand Down

0 comments on commit 1a3cf27

Please sign in to comment.