-
Notifications
You must be signed in to change notification settings - Fork 25
/
utils.py
65 lines (55 loc) · 1.96 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import math
import random
import numpy as np
import torch
from constants import *
def get_dset_path(dset_name, dset_type):
return os.path.join('datasets', dset_name, dset_type)
def relative_to_abs(rel_traj, start_pos):
rel_traj = rel_traj.permute(1, 0, 2)
displacement = torch.cumsum(rel_traj, dim=1)
start_pos = torch.unsqueeze(start_pos, dim=1)
abs_traj = displacement + start_pos
return abs_traj.permute(1, 0, 2)
def bce_loss(input, target):
neg_abs = -input.abs()
loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
return loss.mean()
def gan_g_loss(scores_fake):
y_fake = torch.ones_like(scores_fake) * random.uniform(0.7, 1.2)
return bce_loss(scores_fake, y_fake)
def gan_d_loss(scores_real, scores_fake):
y_real = torch.ones_like(scores_real) * random.uniform(0.7, 1.2)
y_fake = torch.zeros_like(scores_fake) * random.uniform(0, 0.3)
loss_real = bce_loss(scores_real, y_real)
loss_fake = bce_loss(scores_fake, y_fake)
return loss_real + loss_fake
def l2_loss(pred_traj, pred_traj_gt, mode='average'):
seq_len, batch, _ = pred_traj.size()
loss = (pred_traj_gt.permute(1, 0, 2) - pred_traj.permute(1, 0, 2))**2
if mode == 'sum':
return torch.sum(loss)
elif mode == 'average':
return torch.sum(loss) / (seq_len * batch)
elif mode == 'raw':
return loss.sum(dim=2).sum(dim=1)
def displacement_error(pred_traj, pred_traj_gt, mode='sum'):
seq_len, _, _ = pred_traj.size()
loss = pred_traj_gt.permute(1, 0, 2) - pred_traj.permute(1, 0, 2)
loss = loss**2
loss = torch.sqrt(loss.sum(dim=2)).sum(dim=1)
if mode == 'sum':
return torch.sum(loss)
elif mode == 'raw':
return loss
def final_displacement_error(
pred_pos, pred_pos_gt, mode='sum'
):
loss = pred_pos_gt - pred_pos
loss = loss**2
loss = torch.sqrt(loss.sum(dim=1))
if mode == 'raw':
return loss
else:
return torch.sum(loss)