Skip to content

Commit 9e001ee

Browse files
committed
Make the logic more understandable via DRY
1 parent fc6c42e commit 9e001ee

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

keras/models.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,13 @@ def weighted(y_true, y_pred, weights, mask=None):
7575
filtered_y_pred = y_pred[weights.nonzero()[:-1]]
7676
filtered_weights = weights[weights.nonzero()]
7777
obj_output = fn(filtered_y_true, filtered_y_pred)
78+
weighted = filtered_weights * obj_output
7879
if mask is None:
7980
# Instead of calling mean() here, we divide by the sum of filtered_weights.
80-
return (filtered_weights.flatten() * obj_output.flatten()).sum() / filtered_weights.sum()
81+
return weighted.sum() / filtered_weights.sum()
8182
else:
82-
# We assume the time index to be masked is axis=1
8383
filtered_mask = mask[weights.nonzero()[:-1]]
84-
wc = filtered_weights * obj_output
85-
# Divide by mask.sum() here not filtered_mask.sum() since otherwise interactions
86-
# between sample_weight and masks cause issues.
87-
return (wc * filtered_mask).sum() / (filtered_mask * filtered_weights).sum()
84+
return weighted.sum() / (filtered_mask * filtered_weights).sum()
8885
return weighted
8986

9087

0 commit comments

Comments
 (0)