Skip to content

Commit

Permalink
Revert "Fix sample_weight and class_weight in validation"
Browse files Browse the repository at this point in the history
This reverts commit 9773e81.
  • Loading branch information
fchollet committed Aug 26, 2015
1 parent 9773e81 commit 3c4f0ac
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 47 deletions.
66 changes: 22 additions & 44 deletions keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_function_name(o):

class Model(object):
def _fit(self, f, ins, out_labels=[], batch_size=128, nb_epoch=100, verbose=1, callbacks=[],
val_f=None, val_ins=None, shuffle=True, metrics=[]):
validation_split=0., val_f=None, val_ins=None, shuffle=True, metrics=[]):
'''
Abstract fit function for f(*ins). Assume that f returns a list, labelled by out_labels.
'''
Expand All @@ -169,6 +169,13 @@ def _fit(self, f, ins, out_labels=[], batch_size=128, nb_epoch=100, verbose=1, c
do_validation = True
if verbose:
print("Train on %d samples, validate on %d samples" % (len(ins[0]), len(val_ins[0])))
else:
if 0 < validation_split < 1:
do_validation = True
split_at = int(len(ins[0]) * (1 - validation_split))
(ins, val_ins) = (slice_X(ins, 0, split_at), slice_X(ins, split_at))
if verbose:
print("Train on %d samples, validate on %d samples" % (len(ins[0]), len(val_ins[0])))

nb_train_sample = len(ins[0])
index_array = np.arange(nb_train_sample)
Expand Down Expand Up @@ -444,6 +451,7 @@ def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=1, callbacks=[],

X = standardize_X(X)
y = standardize_y(y)
sample_weight = standardize_weights(y, class_weight=class_weight, sample_weight=sample_weight)

val_f = None
val_ins = None
Expand All @@ -453,27 +461,14 @@ def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=1, callbacks=[],
else:
val_f = self._test
if validation_data:
if len(validation_data) == 2:
try:
X_val, y_val = validation_data
sample_weight_val = np.ones(y_val.shape[:-1] + (1,))
elif len(validation_data) == 3:
X_val, y_val, sample_weight_val = validation_data
else:
raise Exception("Invalid format for validation data; provide a tuple (X_val, y_val) or (X_val, y_val, sample_weight). \
except:
raise Exception("Invalid format for validation data; provide a tuple (X_val, y_val). \
X_val may be a numpy array or a list of numpy arrays depending on your model input.")
X_val = standardize_X(X_val)
y_val = standardize_y(y_val)
val_ins = X_val + [y_val, sample_weight_val]

elif 0 < validation_split < 1:
split_at = int(len(X[0]) * (1 - validation_split))
X, X_val = (slice_X(X, 0, split_at), slice_X(X, split_at))
y, y_val = (slice_X(y, 0, split_at), slice_X(y, split_at))
if sample_weight:
sample_weight, sample_weight_val = (slice_X(sample_weight, 0, split_at), slice_X(sample_weight, split_at))
else:
sample_weight_val = np.ones(y_val.shape[:-1] + (1,))
val_ins = X_val + [y_val, sample_weight_val]
val_ins = X_val + [y_val, np.ones(y_val.shape[:-1] + (1,))]

if show_accuracy:
f = self._train_with_acc
Expand All @@ -482,12 +477,11 @@ def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=1, callbacks=[],
f = self._train
out_labels = ['loss']

sample_weight = standardize_weights(y, class_weight=class_weight, sample_weight=sample_weight)
ins = X + [y, sample_weight]
metrics = ['loss', 'acc', 'val_loss', 'val_acc']
return self._fit(f, ins, out_labels=out_labels, batch_size=batch_size, nb_epoch=nb_epoch,
verbose=verbose, callbacks=callbacks,
val_f=val_f, val_ins=val_ins,
validation_split=validation_split, val_f=val_f, val_ins=val_ins,
shuffle=shuffle, metrics=metrics)

def predict(self, X, batch_size=128, verbose=0):
Expand Down Expand Up @@ -630,8 +624,8 @@ def train_on_batch(self, data, class_weight={}, sample_weight={}):

def test_on_batch(self, data, sample_weight={}):
# data is a dictionary mapping input names to arrays
sample_weight = [standardize_weights(data[name],
sample_weight=sample_weight.get(name)) for name in self.output_order]
sample_weight = [standardize_weights(data[name]) for name in self.output_order]

ins = [data[name] for name in self.input_order] + [standardize_y(data[name]) for name in self.output_order] + sample_weight
return self._test(*ins)

Expand All @@ -642,46 +636,30 @@ def predict_on_batch(self, data):

def fit(self, data, batch_size=128, nb_epoch=100, verbose=1, callbacks=[],
validation_split=0., validation_data=None, shuffle=True, class_weight={}, sample_weight={}):
X = [data[name] for name in self.input_order]
y = [standardize_y(data[name]) for name in self.output_order]
sample_weight_list = [standardize_weights(data[name],
sample_weight=sample_weight.get(name)) for name in self.output_order]
class_weight_list = [class_weight.get(name) for name in self.output_order]
sample_weight = [standardize_weights(data[name],
sample_weight=sample_weight.get(name),
class_weight=class_weight.get(name)) for name in self.output_order]
ins = [data[name] for name in self.input_order] + [standardize_y(data[name]) for name in self.output_order] + sample_weight

val_f = None
val_ins = None
if validation_data or validation_split:
val_f = self._test
if validation_data:
# can't use sample weights with validation data at this point
sample_weight = [standardize_weights(validation_data[name]) for name in self.output_order]
val_ins = [validation_data[name] for name in self.input_order] + [standardize_y(validation_data[name]) for name in self.output_order] + sample_weight

elif 0 < validation_split < 1:
split_at = int(len(X[0]) * (1 - validation_split))
X, X_val = (slice_X(X, 0, split_at), slice_X(X, split_at))
y, y_val = (slice_X(y, 0, split_at), slice_X(y, split_at))
sample_weight_list, sample_weight_list_val = (slice_X(sample_weight_list, 0, split_at), slice_X(sample_weight_list, split_at))
val_ins = X_val + y_val + sample_weight_val

f = self._train
out_labels = ['loss']
metrics = ['loss', 'val_loss']

sample_weight_list = [standardize_weights(y[i],
sample_weight=sample_weight_list[i],
class_weight=class_weight_list[i]) for i in range(len(self.output_order))]
ins = X + y + sample_weight_list

history = self._fit(f, ins, out_labels=out_labels, batch_size=batch_size, nb_epoch=nb_epoch,
verbose=verbose, callbacks=callbacks,
val_f=val_f, val_ins=val_ins,
validation_split=validation_split, val_f=val_f, val_ins=val_ins,
shuffle=shuffle, metrics=metrics)
return history

def evaluate(self, data, batch_size=128, verbose=0, sample_weight={}):
sample_weight = [standardize_weights(data[name],
sample_weight=sample_weight.get(name)) for name in self.output_order]
sample_weight = [standardize_weights(data[name], sample_weight.get(name)) for name in self.output_order]

ins = [data[name] for name in self.input_order] + [standardize_y(data[name]) for name in self.output_order] + sample_weight
outs = self._test_loop(self._test, ins, batch_size, verbose)
Expand Down
5 changes: 2 additions & 3 deletions tests/auto/test_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,6 @@ def test_2o_1i_sample_weights(self):

weights1 = np.random.uniform(size=y_train.shape[0])
weights2 = np.random.uniform(size=y2_train.shape[0])
weights1_test = np.random.uniform(size=y_test.shape[0])
weights2_test = np.random.uniform(size=y2_test.shape[0])

graph.compile('rmsprop', {'output1': 'mse', 'output2': 'mse'})

Expand All @@ -155,11 +153,12 @@ def test_2o_1i_sample_weights(self):
assert(type(out == dict))
assert(len(out) == 2)
loss = graph.test_on_batch({'input1': X_test, 'output1': y_test, 'output2': y2_test},
sample_weight={'output1': weights1_test, 'output2': weights2_test})
sample_weight={'output1': weights1, 'output2': weights2})
loss = graph.train_on_batch({'input1': X_train, 'output1': y_train, 'output2': y2_train},
sample_weight={'output1': weights1, 'output2': weights2})
loss = graph.evaluate({'input1': X_train, 'output1': y_train, 'output2': y2_train},
sample_weight={'output1': weights1, 'output2': weights2})
print(loss)

def test_recursive(self):
print('test layer-like API')
Expand Down

0 comments on commit 3c4f0ac

Please sign in to comment.