-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathlosses_test.py
85 lines (67 loc) · 2.98 KB
/
losses_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""Tests the `losses` module."""
import tensorflow as tf
import losses
class PerceptualLossTest(tf.test.TestCase):
def test_identical_inputs(self):
loss = losses.PerceptualLoss()
images = tf.random.uniform((2, 192, 256, 3))
self.assertAllClose(loss(images, images), 0.0)
def test_different_inputs(self):
loss = losses.PerceptualLoss()
image_1 = tf.zeros((2, 192, 256, 3))
image_2 = tf.random.uniform((2, 192, 256, 3))
self.assertAllGreater(loss(image_1, image_2), 1.0)
def test_similar_vs_different_inputs(self):
loss = losses.PerceptualLoss()
pure_bright = tf.ones((3, 256, 256, 3)) * tf.constant([0.9, 0.7, 0.7])
pure_dark = tf.ones((3, 256, 256, 3)) * tf.constant([0.5, 0.2, 0.2])
speckles = tf.random.uniform((3, 256, 256, 3))
self.assertAllGreater(
loss(pure_bright, speckles), loss(pure_bright, pure_dark))
class CompositeLossTest(tf.test.TestCase):
def test_l1_with_weight(self):
composite = losses.CompositeLoss()
composite.add_loss('l1', 2.0)
y_true = tf.constant(0.3, shape=(2, 192, 256, 3), dtype=tf.float32)
y_pred = tf.constant(0.5, shape=(2, 192, 256, 3), dtype=tf.float32)
self.assertAllClose(
composite(y_true, y_pred),
tf.reduce_mean(tf.abs(y_true - y_pred)) * 2.0)
def test_l1_l2_different_weights(self):
composite = losses.CompositeLoss()
composite.add_loss('L1', 1.0)
composite.add_loss('L2', 0.5)
y_true = tf.constant(127, shape=(2, 192, 256, 3), dtype=tf.int32)
y_pred = tf.constant(215, shape=(2, 192, 256, 3), dtype=tf.int32)
l1 = tf.cast(tf.reduce_mean(tf.abs(y_true - y_pred)), tf.float32)
l2 = tf.cast(tf.reduce_mean(tf.square(y_true - y_pred)), tf.float32)
self.assertAllClose(composite(y_true, y_pred), l1 * 1.0 + l2 * 0.5)
def test_composite_loss_equals_sum_of_components(self):
composite = losses.CompositeLoss()
mae = tf.keras.losses.MAE
vgg = losses.PerceptualLoss()
composite.add_loss(mae, 1.0)
composite.add_loss(vgg, 2.0)
y_true = tf.random.uniform((1, 192, 256, 3))
y_pred = tf.random.uniform((1, 192, 256, 3))
loss_value = composite(y_true, y_pred)
mae_loss_value = tf.math.reduce_mean(mae(y_true, y_pred))
vgg_loss_value = vgg(y_true, y_pred)
self.assertAllClose(loss_value, mae_loss_value * 1.0 + vgg_loss_value * 2.0)
def test_duplicate_component_raises_error(self):
composite = losses.CompositeLoss()
composite.add_loss('l1', 1.0)
with self.assertRaisesRegex(ValueError, 'exist'):
composite.add_loss('l1', 2.0)
def test_call_before_adding_component_raises_error(self):
composite = losses.CompositeLoss()
y_true = tf.random.uniform((1, 192, 256, 3))
y_pred = tf.random.uniform((1, 192, 256, 3))
with self.assertRaises(AssertionError):
composite(y_true, y_pred)
def test_invalid_weight(self):
composite = losses.CompositeLoss()
with self.assertRaisesRegex(ValueError, r'-1\.0'):
composite.add_loss('l2', -1.0)
if __name__ == '__main__':
tf.test.main()