Skip to content

Commit

Permalink
attempt to make model picklable
Browse files Browse the repository at this point in the history
  • Loading branch information
goodfeli committed Oct 10, 2018
1 parent db006e7 commit 531d219
Showing 1 changed file with 34 additions and 29 deletions.
63 changes: 34 additions & 29 deletions cleverhans/model_zoo/madry_lab_challenges/cifar10_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import numpy as np
import tensorflow as tf
from cleverhans.model import Model
from cleverhans.serial import NoRefModel


class Layer(object):
Expand All @@ -22,7 +22,7 @@ def get_output_shape(self):
return self.output_shape


class ResNet(Model):
class ResNet(NoRefModel):
"""ResNet model."""

def __init__(self, layers, input_shape):
Expand All @@ -33,36 +33,41 @@ def __init__(self, layers, input_shape):
each with set_input_shape() and fprop() methods.
input_shape: 4-tuple describing input shape (e.g None, 32, 32, 3)
"""
super(ResNet, self).__init__('', 10, {})
self.layer_names = []
self.layers = layers
self.input_shape = input_shape
if isinstance(layers[-1], Softmax):
layers[-1].name = 'probs'
layers[-2].name = 'logits'
else:
layers[-1].name = 'logits'
for i, layer in enumerate(self.layers):
if hasattr(layer, 'name'):
name = layer.name
super(ResNet, self).__init__('', 10, {}, True)
with tf.variable_scope(self.scope):
self.layer_names = []
self.layers = layers
self.input_shape = input_shape
if isinstance(layers[-1], Softmax):
layers[-1].name = 'probs'
layers[-2].name = 'logits'
else:
name = layer.__class__.__name__ + str(i)
layer.name = name
self.layer_names.append(name)
layers[-1].name = 'logits'
for i, layer in enumerate(self.layers):
if hasattr(layer, 'name'):
name = layer.name
else:
name = layer.__class__.__name__ + str(i)
layer.name = name
self.layer_names.append(name)

layer.set_input_shape(input_shape)
input_shape = layer.get_output_shape()
layer.set_input_shape(input_shape)
input_shape = layer.get_output_shape()

def make_input_placeholder(self):
return tf.placeholder(tf.float32, (None, 32, 32, 3))

def fprop(self, x, set_ref=False):
states = []
for layer in self.layers:
if set_ref:
layer.ref = x
x = layer.fprop(x)
assert x is not None
states.append(x)
states = dict(zip(self.layer_names, states))
return states
with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE):
states = []
for layer in self.layers:
if set_ref:
layer.ref = x
x = layer.fprop(x)
assert x is not None
states.append(x)
states = dict(zip(self.layer_names, states))
return states

def add_internal_summaries(self):
pass
Expand Down Expand Up @@ -287,7 +292,7 @@ def fprop(self, x):
return tf.reshape(x, [-1, self.output_width])


def make_madry_wresnet(nb_classes=10, input_shape=(None, 32, 32, 3)):
def make_wresnet(nb_classes=10, input_shape=(None, 32, 32, 3)):
layers = [Input(),
Conv2D(), # the whole ResNet is basically created in this layer
Flatten(),
Expand Down

0 comments on commit 531d219

Please sign in to comment.