forked from olly-styles/Multiple-Object-Forecasting
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrainer.py
88 lines (71 loc) · 3.26 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import numpy as np
import metrics
import torch
import torch.nn as nn
from tqdm import tqdm
def train_seqseq(encoder, decoder, device, train_loader, encoder_optimizer, decoder_optimizer, epoch, loss_function, learning_rate):
encoder.train()
decoder.train()
total_loss = 0
ades = []
fdes = []
for batch_idx, data in enumerate(tqdm(train_loader)):
features, labels, dtp_features = data['features'].to(
device), data['labels'].to(device), data['dtp_features'].to(device)
features = features.float()
labels = labels.float()
dtp_features = dtp_features.float()
context = encoder(features)
output = decoder(context, dtp_features, val=False)
loss = loss_function(output, labels)
ades.append(list(metrics.calc_ade(output.cpu().detach().numpy(), labels.cpu().detach().numpy(), return_mean=False)))
fdes.append(list(metrics.calc_fde(output.cpu().detach().numpy(),
labels.cpu().detach().numpy(), 60, return_mean=False)))
# Backward and optimize
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()
loss.backward()
encoder_optimizer.step()
decoder_optimizer.step()
# Clip gradients
nn.utils.clip_grad_norm(decoder.parameters(), 1)
for p in decoder.parameters():
p.data.add_(-learning_rate, p.grad.data)
total_loss += loss
# Flatten lists
ades = [item for sublist in ades for item in sublist]
fdes = [item for sublist in fdes for item in sublist]
print('Train ADE: ', np.round(np.mean(ades), 1))
print('Train FDE: ', np.round(np.mean(fdes), 1))
print('Train loss: ', total_loss.cpu().detach().numpy())
def test_seqseq(encoder, decoder, device, test_loader, loss_function, return_predictions=False, phase='Val'):
encoder.eval()
decoder.eval()
ades = []
fdes = []
outputs = np.array([])
targets = np.array([])
with torch.no_grad():
for batch_idx, data in enumerate(tqdm(test_loader)):
# if batch_idx == 5:
# break
features, labels, dtp_features = data['features'].to(device), data['labels'].to(
device), data['dtp_features'].to(device)
features = features.float()
labels = labels.float()
dtp_features = dtp_features.float()
context = encoder(features, val=True)
output = decoder(context, dtp_features, val=True)
ades.append(list(metrics.calc_ade(output.cpu().numpy(),
labels.cpu().numpy(), return_mean=False)))
fdes.append(list(metrics.calc_fde(output.cpu().numpy(),
labels.cpu().numpy(), 60, return_mean=False)))
if return_predictions:
outputs = np.append(outputs, output.cpu().numpy())
targets = np.append(targets, labels.cpu().numpy())
# Flatten lists
ades = [item for sublist in ades for item in sublist]
fdes = [item for sublist in fdes for item in sublist]
print(phase + ' ADE: ' + str(np.round(np.mean(ades), 1)))
print(phase + ' FDE: ' + str(np.round(np.mean(fdes), 1)))
return outputs, targets, np.mean(ades), np.mean(fdes)