Skip to content

Commit

Permalink
make ResNet work as picklable model
Browse files Browse the repository at this point in the history
  • Loading branch information
goodfeli committed Oct 11, 2018
1 parent 0bb4250 commit 1d42834
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 14 deletions.
63 changes: 50 additions & 13 deletions cleverhans/model_zoo/madry_lab_challenges/cifar10_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,40 @@ def get_output_shape(self):
class ResNet(NoRefModel):
"""ResNet model."""

def __init__(self, layers, input_shape):
def __init__(self, layers, input_shape, scope=None):
"""ResNet constructor.
Args:
layers: a list of layers in CleverHans format
:param layers: a list of layers in CleverHans format
each with set_input_shape() and fprop() methods.
input_shape: 4-tuple describing input shape (e.g None, 32, 32, 3)
:param input_shape: 4-tuple describing input shape (e.g None, 32, 32, 3)
:param scope: string name of scope for Variables
This works in two ways.
If scope is None, the variables are not put in a scope, and the
model is compatible with Saver.restore from the public downloads
for the CIFAR10 Challenge.
If the scope is a string, then Saver.restore won't work, but the
model functions as a picklable NoRefModels that finds its variables
based on the scope.
"""
super(ResNet, self).__init__('', 10, {}, True)
with tf.variable_scope(self.scope):
super(ResNet, self).__init__(scope, 10, {}, scope is not None)
if scope is None:
before = list(tf.trainable_variables())
before_vars = list(tf.global_variables())
self.build(layers, input_shape)
after = list(tf.trainable_variables())
after_vars = list(tf.global_variables())
self.params = [param for param in after if param not in before]
self.vars = [var for var in after_vars if var not in before_vars]
else:
with tf.variable_scope(self.scope):
self.build(layers, input_shape)

def get_vars(self):
if hasattr(self, "vars"):
return self.vars
return super(ResNet, self).get_vars()

def build(self, layers, input_shape):
self.layer_names = []
self.layers = layers
self.input_shape = input_shape
Expand All @@ -57,8 +81,16 @@ def __init__(self, layers, input_shape):
def make_input_placeholder(self):
return tf.placeholder(tf.float32, (None, 32, 32, 3))

def make_label_placeholder(self):
return tf.placeholder(tf.float32, (None, 10))

def fprop(self, x, set_ref=False):
with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE):
if self.scope is not None:
with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE):
return self._fprop(x, set_ref)
return self._prop(x, set_ref)

def _fprop(self, x, set_ref=False):
states = []
for layer in self.layers:
if set_ref:
Expand Down Expand Up @@ -170,18 +202,23 @@ def __init__(self, num_hid):
def set_input_shape(self, input_shape):
batch_size, dim = input_shape
self.input_shape = [batch_size, dim]
self.dim = dim
self.output_shape = [batch_size, self.num_hid]
self.make_vars()

def make_vars(self):
with tf.variable_scope('logit', reuse=tf.AUTO_REUSE):
self.w = tf.get_variable(
'DW', [dim, self.num_hid],
w = tf.get_variable(
'DW', [self.dim, self.num_hid],
initializer=tf.initializers.variance_scaling(
distribution='uniform'))
self.b = tf.get_variable('biases', [self.num_hid],
b = tf.get_variable('biases', [self.num_hid],
initializer=tf.initializers.constant())
return w, b

def fprop(self, x):
return tf.nn.xw_plus_b(x, self.w, self.b)
w, b = self.make_vars()
return tf.nn.xw_plus_b(x, w, b)


def _batch_norm(name, x):
Expand Down Expand Up @@ -292,12 +329,12 @@ def fprop(self, x):
return tf.reshape(x, [-1, self.output_width])


def make_wresnet(nb_classes=10, input_shape=(None, 32, 32, 3)):
def make_wresnet(nb_classes=10, input_shape=(None, 32, 32, 3), scope=None):
layers = [Input(),
Conv2D(), # the whole ResNet is basically created in this layer
Flatten(),
Linear(nb_classes),
Softmax()]

model = ResNet(layers, input_shape)
model = ResNet(layers, input_shape, scope)
return model
2 changes: 1 addition & 1 deletion cleverhans/serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class NoRefModel(Model):
"""
A Model that can be pickled because it contains no references to any
Variables (e.g. it identifies Variables only by name).
The Model must be able to find all of its Variables via get_params
The Model must be able to find all of its Variables via get_vars
for them to be pickled.
Note that NoRefModel may have different Variable names after it is
restored, e.g. if the unpickling is run with a different enclosing
Expand Down

0 comments on commit 1d42834

Please sign in to comment.