forked from datamol-io/graphium
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_ipu_losses.py
158 lines (138 loc) · 6.82 KB
/
test_ipu_losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import unittest as ut
import torch
from torch.nn import BCELoss, MSELoss, L1Loss, BCEWithLogitsLoss
from copy import deepcopy
import pytest
from graphium.ipu.ipu_losses import BCELossIPU, MSELossIPU, L1LossIPU, BCEWithLogitsLossIPU, HybridCELossIPU
from graphium.trainer.losses import HybridCELoss
@pytest.mark.ipu
class test_Losses(ut.TestCase):
torch.manual_seed(42)
preds = torch.rand((100, 10), dtype=torch.float32)
target = torch.rand((100, 10), dtype=torch.float32)
th = 0.7
nan_th = 0.2
preds_greater = preds > th
target_greater = (target > th).to(torch.float32)
target_greater_nan = deepcopy(target_greater)
is_nan = target < nan_th
target_greater_nan[target < nan_th] = torch.nan
target_nan = deepcopy(target)
target_nan[target < nan_th] = torch.nan
def test_bce(self):
preds = deepcopy(self.preds)
target = deepcopy(self.target_greater)
target_nan = deepcopy(self.target_greater_nan)
# Regular loss
loss_true = BCELoss()(preds, target)
loss_ipu = BCELossIPU()(preds, target)
self.assertFalse(loss_true.isnan(), "Regular BCELoss is NaN")
self.assertAlmostEqual(
loss_true.item(), loss_ipu.item(), places=6, msg="Regular BCELoss is different"
)
# Weighted loss
weight = torch.rand(preds.shape[1], dtype=torch.float32)
loss_true = BCELoss(weight=weight)(preds, target)
loss_ipu = BCELossIPU(weight=weight)(preds, target)
self.assertFalse(loss_true.isnan(), "Regular BCELoss is NaN")
self.assertAlmostEqual(loss_true.item(), loss_ipu.item(), msg="Weighted BCELoss is different")
# Regular loss with NaNs in target
not_nan = ~target_nan.isnan()
loss_true = BCELoss()(preds[not_nan], target[not_nan])
loss_ipu = BCELossIPU()(preds, target_nan)
self.assertFalse(loss_true.isnan(), "Regular BCELoss with target_nan is NaN")
self.assertFalse(loss_ipu.isnan(), "Regular BCELossIPU with target_nan is NaN")
self.assertAlmostEqual(
loss_true.item(), loss_ipu.item(), places=6, msg="Regular BCELoss with NaN is different"
)
# Weighted loss with NaNs in target
not_nan = ~target_nan.isnan()
weight = torch.rand(preds.shape, dtype=torch.float32)
loss_true = BCELoss(weight=weight[not_nan])(preds[not_nan], target_nan[not_nan])
loss_ipu = BCELossIPU(weight=weight)(preds, target_nan)
self.assertFalse(loss_true.isnan(), "Weighted BCELoss with target_nan is NaN")
self.assertFalse(loss_ipu.isnan(), "Weighted BCELossIPU with target_nan is NaN")
self.assertAlmostEqual(
loss_true.item(), loss_ipu.item(), places=6, msg="Weighted BCELoss with NaN is different"
)
def test_mse(self):
preds = deepcopy(self.preds)
target = deepcopy(self.target)
target_nan = deepcopy(self.target_nan)
# Regular loss
loss_true = MSELoss()(preds, target)
loss_ipu = MSELossIPU()(preds, target)
self.assertFalse(loss_true.isnan(), "Regular MSELoss is NaN")
self.assertAlmostEqual(
loss_true.item(), loss_ipu.item(), places=6, msg="Regular MSELoss is different"
)
# Regular loss with NaNs in target
not_nan = ~target_nan.isnan()
loss_true = MSELoss()(preds[not_nan], target[not_nan])
loss_ipu = MSELossIPU()(preds, target_nan)
self.assertFalse(loss_true.isnan(), "Regular MSELoss with target_nan is NaN")
self.assertFalse(loss_ipu.isnan(), "Regular MSELossIPU with target_nan is NaN")
self.assertAlmostEqual(
loss_true.item(), loss_ipu.item(), places=6, msg="Regular MSELoss with NaN is different"
)
def test_l1(self):
preds = deepcopy(self.preds)
target = deepcopy(self.target)
target_nan = deepcopy(self.target_nan)
# Regular loss
loss_true = L1Loss()(preds, target)
loss_ipu = L1LossIPU()(preds, target)
self.assertFalse(loss_true.isnan(), "Regular MAELoss is NaN")
self.assertAlmostEqual(
loss_true.item(), loss_ipu.item(), places=6, msg="Regular MAELoss is different"
)
# Regular loss with NaNs in target
not_nan = ~target_nan.isnan()
loss_true = L1Loss()(preds[not_nan], target[not_nan])
loss_ipu = L1LossIPU()(preds, target_nan)
self.assertFalse(loss_true.isnan(), "Regular MAELoss with target_nan is NaN")
self.assertFalse(loss_ipu.isnan(), "Regular MAELossIPU with target_nan is NaN")
self.assertAlmostEqual(
loss_true.item(), loss_ipu.item(), places=6, msg="Regular MAELoss with NaN is different"
)
def test_bce_logits(self):
preds = deepcopy(self.preds)
target = deepcopy(self.target_greater)
target_nan = deepcopy(self.target_greater_nan)
# Regular loss
loss_true = BCEWithLogitsLoss()(preds, target)
loss_ipu = BCEWithLogitsLossIPU()(preds, target)
self.assertFalse(loss_true.isnan(), "Regular BCEWithLogitsLoss is NaN")
self.assertAlmostEqual(
loss_true.item(), loss_ipu.item(), places=6, msg="Regular BCEWithLogitsLoss is different"
)
# Weighted loss
weight = torch.rand(preds.shape[1], dtype=torch.float32)
loss_true = BCEWithLogitsLoss(weight=weight)(preds, target)
loss_ipu = BCEWithLogitsLossIPU(weight=weight)(preds, target)
self.assertFalse(loss_true.isnan(), "Regular BCEWithLogitsLoss is NaN")
self.assertAlmostEqual(
loss_true.item(), loss_ipu.item(), msg="Weighted BCEWithLogitsLoss is different"
)
# Regular loss with NaNs in target
not_nan = ~target_nan.isnan()
loss_true = BCEWithLogitsLoss()(preds[not_nan], target[not_nan])
loss_ipu = BCEWithLogitsLossIPU()(preds, target_nan)
self.assertFalse(loss_true.isnan(), "Regular test_bce_logits with target_nan is NaN")
self.assertFalse(loss_ipu.isnan(), "Regular test_bce_logits with target_nan is NaN")
self.assertAlmostEqual(
loss_true.item(), loss_ipu.item(), places=6, msg="Regular BCELoss with NaN is different"
)
# Weighted loss with NaNs in target
not_nan = ~target_nan.isnan()
weight = torch.rand(preds.shape, dtype=torch.float32)
loss_true = BCEWithLogitsLoss(weight=weight[not_nan])(preds[not_nan], target_nan[not_nan])
loss_ipu = BCEWithLogitsLossIPU(weight=weight)(preds, target_nan)
self.assertFalse(loss_true.isnan(), "Weighted test_bce_logits with target_nan is NaN")
self.assertFalse(loss_ipu.isnan(), "Weighted test_bce_logits with target_nan is NaN")
self.assertAlmostEqual(
loss_true.item(),
loss_ipu.item(),
places=6,
msg="Weighted BCEWithLogitsLoss with NaN is different",
)