Skip to content

Commit

Permalink
fix merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
matsuyamax committed Oct 4, 2015
2 parents 19c736a + 5bab11e commit 61d76d4
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 35 deletions.
3 changes: 2 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ python:
- "3.4"
# command to install dependencies
install:
- conda install --yes python=$TRAVIS_PYTHON_VERSION numpy scipy matplotlib pandas pytest h5py theano
- conda install --yes python=$TRAVIS_PYTHON_VERSION numpy scipy matplotlib pandas pytest h5py
- pip install pytest-cov python-coveralls
- pip install git+git://github.com/Theano/Theano.git
# command to run tests
script:
- PYTHONPATH=$PWD:$PYTHONPATH py.test -v --cov-report term-missing --cov keras tests/
Expand Down
5 changes: 4 additions & 1 deletion docs/sources/callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ The `logs` dictionary will contain keys for quantities relevant to the current b
keras.callbacks.ModelCheckpoint(filepath, verbose=0, save_best_only=False)
```

Save the model after every epoch. If `save_best_only=True`, the latest best model according to the validation loss will not be overwritten.
Save the model after every epoch. If `save_best_only=True`, the latest best model according to the validation loss will not be overwritten.
`filepath` can contain named formatting options, which will be filled the value of `epoch` and keys in `logs` (passed in `on_epoch_end`).

For example: if `filepath` is `weights.{epoch:02d}-{val_loss:.2f}.hdf5`, then multiple files will be save with the epoch number and the validation loss.


```python
Expand Down
2 changes: 1 addition & 1 deletion keras/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def softplus(x):


def relu(x):
return (x + abs(x)) / 2.0
return T.nnet.relu(x)


def tanh(x):
Expand Down
9 changes: 5 additions & 4 deletions keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(self, filepath, monitor='val_loss', verbose=0, save_best_only=False
self.best = np.Inf

def on_epoch_end(self, epoch, logs={}):
filepath = self.filepath.format(epoch=epoch, **logs)
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
Expand All @@ -191,16 +192,16 @@ def on_epoch_end(self, epoch, logs={}):
if current < self.best:
if self.verbose > 0:
print("Epoch %05d: %s improved from %0.5f to %0.5f, saving model to %s"
% (epoch, self.monitor, self.best, current, self.filepath))
% (epoch, self.monitor, self.best, current, filepath))
self.best = current
self.model.save_weights(self.filepath, overwrite=True)
self.model.save_weights(filepath, overwrite=True)
else:
if self.verbose > 0:
print("Epoch %05d: %s did not improve" % (epoch, self.monitor))
else:
if self.verbose > 0:
print("Epoch %05d: saving model to %s" % (epoch, self.filepath))
self.model.save_weights(self.filepath, overwrite=True)
print("Epoch %05d: saving model to %s" % (epoch, filepath))
self.model.save_weights(filepath, overwrite=True)


class EarlyStopping(Callback):
Expand Down
20 changes: 11 additions & 9 deletions keras/layers/advanced_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def __init__(self, alpha=0.3):

def get_output(self, train):
X = self.get_input(train)
return ((X + abs(X)) / 2.0) + self.alpha * ((X - abs(X)) / 2.0)
return T.nnet.relu(X, self.alpha)

def get_config(self):
return {"name": self.__class__.__name__,
Expand All @@ -37,8 +37,8 @@ def __init__(self, input_shape, init='zero', weights=None):

def get_output(self, train):
X = self.get_input(train)
pos = ((X + abs(X)) / 2.0)
neg = self.alphas * ((X - abs(X)) / 2.0)
pos = T.nnet.relu(X)
neg = self.alphas * (X - abs(X)) * 0.5
return pos + neg

def get_config(self):
Expand Down Expand Up @@ -78,6 +78,7 @@ def get_config(self):
"alpha_init": self.alpha_init,
"beta_init": self.beta_init}


class ThresholdedLinear(MaskedLayer):
'''
Thresholded Linear Activation
Expand All @@ -89,14 +90,15 @@ class ThresholdedLinear(MaskedLayer):
def __init__(self, theta=1.0):
super(ThresholdedLinear, self).__init__()
self.theta = theta

def get_output(self, train):
X = self.get_input(train)
return T.switch( abs(X) < self.theta, 0, X )
return T.switch(abs(X) < self.theta, 0, X)

def get_config(self):
return {"name": self.__class__.__name__,
"theta": self.theta}
"theta": self.theta}


class ThresholdedReLu(MaskedLayer):
'''
Expand All @@ -109,11 +111,11 @@ class ThresholdedReLu(MaskedLayer):
def __init__(self, theta=1.0):
super(ThresholdedReLu, self).__init__()
self.theta = theta

def get_output(self, train):
X = self.get_input(train)
return T.switch( X > self.theta, X, 0 )
return T.switch(X > self.theta, X, 0)

def get_config(self):
return {"name": self.__class__.__name__,
"theta": self.theta}
"theta": self.theta}
31 changes: 23 additions & 8 deletions keras/layers/convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,29 @@ def get_output(self, train=False):
X = T.reshape(X, (X.shape[0], X.shape[1], X.shape[2], 1)).dimshuffle(0, 2, 1, 3)

border_mode = self.border_mode
if border_mode == 'same':
border_mode = 'full'
assert self.subsample == (1, 1)

conv_out = T.nnet.conv.conv2d(X, self.W, border_mode=border_mode, subsample=self.subsample)
if self.border_mode == 'same':
shift_x = (self.filter_length - 1) // 2
conv_out = conv_out[:, :, shift_x:X.shape[2] + shift_x, :]
if on_gpu() and dnn.dnn_available():
if border_mode == 'same':
assert(self.subsample_length == 1)
pad_x = (self.filter_length - self.subsample_length) // 2
conv_out = dnn.dnn_conv(img=X,
kerns=self.W,
border_mode=(pad_x, 0))
else:
conv_out = dnn.dnn_conv(img=X,
kerns=self.W,
border_mode=border_mode,
subsample=self.subsample)
else:
if border_mode == 'same':
assert(self.subsample_length == 1)
border_mode = 'full'

conv_out = T.nnet.conv.conv2d(X, self.W,
border_mode=border_mode,
subsample=self.subsample)
if self.border_mode == 'same':
shift_x = (self.filter_length - 1) // 2
conv_out = conv_out[:, :, shift_x:X.shape[2] + shift_x, :]

output = self.activation(conv_out + self.b.dimshuffle('x', 0, 'x', 'x'))
output = T.reshape(output, (output.shape[0], output.shape[1], output.shape[2])).dimshuffle(0, 2, 1)
Expand Down
6 changes: 4 additions & 2 deletions keras/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,12 @@ class Reshape(Layer):
Can't be used as first layer in a model (no fixed input!)
First dimension is assumed to be nb_samples.
'''
def __init__(self, *dims):
def __init__(self, *dims, **kwargs):
super(Reshape, self).__init__()
if type(dims[0]) in [list, tuple]:
if len(dims) > 0 and type(dims[0]) in [list, tuple]:
dims = dims[0]
if len(dims) == 0 and 'dims' in kwargs:
dims = kwargs['dims']
self.dims = tuple(dims)

@property
Expand Down
8 changes: 4 additions & 4 deletions keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,17 +312,17 @@ def get_config(self, verbose=0):
pp.pprint(config)
return config

def to_yaml(self):
def to_yaml(self, **kwargs):
# dump model configuration to yaml string
import yaml
config = self.get_config()
return yaml.dump(config)
return yaml.dump(config, **kwargs)

def to_json(self):
def to_json(self, **kwargs):
# dump model configuration to json string
import json
config = self.get_config()
return json.dumps(config)
return json.dumps(config, **kwargs)


class Sequential(Model, containers.Sequential):
Expand Down
20 changes: 15 additions & 5 deletions tests/manual/check_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,22 @@ def on_train_end(self, logs={}):
import warnings
warnings.filterwarnings('error')
try:
passed = False
# this should issue a warning
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, show_accuracy=True, verbose=0, callbacks =[checkpointer])
except:
print("Tests passed")
import sys
sys.exit(0)

raise Exception("Modelcheckpoint tests did not pass")
passed = True
if not passed:
raise Exception("Modelcheckpoint tests did not pass")

print("Test model checkpointer with pattern")
filename = "model_weights.{epoch:04d}.hdf5"
f = os.path.join(path, filename)
nb_epoch = 3
checkpointer = cbks.ModelCheckpoint(f)
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, verbose=0, callbacks=[checkpointer])
for i in range(nb_epoch):
if not os.path.isfile(f.format(epoch=i)):
raise Exception("Model weights were not saved separately for each epoch")

print("Tests passed")

0 comments on commit 61d76d4

Please sign in to comment.