Skip to content

Commit

Permalink
Make the logic more understandable via DRY
Browse files Browse the repository at this point in the history
  • Loading branch information
wxs committed Aug 24, 2015
1 parent fc6c42e commit 9e001ee
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,13 @@ def weighted(y_true, y_pred, weights, mask=None):
filtered_y_pred = y_pred[weights.nonzero()[:-1]]
filtered_weights = weights[weights.nonzero()]
obj_output = fn(filtered_y_true, filtered_y_pred)
weighted = filtered_weights * obj_output
if mask is None:
# Instead of calling mean() here, we divide by the sum of filtered_weights.
return (filtered_weights.flatten() * obj_output.flatten()).sum() / filtered_weights.sum()
return weighted.sum() / filtered_weights.sum()
else:
# We assume the time index to be masked is axis=1
filtered_mask = mask[weights.nonzero()[:-1]]
wc = filtered_weights * obj_output
# Divide by mask.sum() here not filtered_mask.sum() since otherwise interactions
# between sample_weight and masks cause issues.
return (wc * filtered_mask).sum() / (filtered_mask * filtered_weights).sum()
return weighted.sum() / (filtered_mask * filtered_weights).sum()
return weighted


Expand Down

0 comments on commit 9e001ee

Please sign in to comment.