forked from facebookresearch/votenet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnn_distance.py
94 lines (84 loc) · 2.86 KB
/
nn_distance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
""" Chamfer distance in Pytorch.
Author: Charles R. Qi
"""
import torch
import torch.nn as nn
import numpy as np
def huber_loss(error, delta=1.0):
"""
Args:
error: Torch tensor (d1,d2,...,dk)
Returns:
loss: Torch tensor (d1,d2,...,dk)
x = error = pred - gt or dist(pred,gt)
0.5 * |x|^2 if |x|<=d
0.5 * d^2 + d * (|x|-d) if |x|>d
Ref: https://github.com/charlesq34/frustum-pointnets/blob/master/models/model_util.py
"""
abs_error = torch.abs(error)
#quadratic = torch.min(abs_error, torch.FloatTensor([delta]))
quadratic = torch.clamp(abs_error, max=delta)
linear = (abs_error - quadratic)
loss = 0.5 * quadratic**2 + delta * linear
return loss
def nn_distance(pc1, pc2, l1smooth=False, delta=1.0, l1=False):
"""
Input:
pc1: (B,N,C) torch tensor
pc2: (B,M,C) torch tensor
l1smooth: bool, whether to use l1smooth loss
delta: scalar, the delta used in l1smooth loss
Output:
dist1: (B,N) torch float32 tensor
idx1: (B,N) torch int64 tensor
dist2: (B,M) torch float32 tensor
idx2: (B,M) torch int64 tensor
"""
N = pc1.shape[1]
M = pc2.shape[1]
pc1_expand_tile = pc1.unsqueeze(2).repeat(1,1,M,1)
pc2_expand_tile = pc2.unsqueeze(1).repeat(1,N,1,1)
pc_diff = pc1_expand_tile - pc2_expand_tile
if l1smooth:
pc_dist = torch.sum(huber_loss(pc_diff, delta), dim=-1) # (B,N,M)
elif l1:
pc_dist = torch.sum(torch.abs(pc_diff), dim=-1) # (B,N,M)
else:
pc_dist = torch.sum(pc_diff**2, dim=-1) # (B,N,M)
dist1, idx1 = torch.min(pc_dist, dim=2) # (B,N)
dist2, idx2 = torch.min(pc_dist, dim=1) # (B,M)
return dist1, idx1, dist2, idx2
def demo_nn_distance():
np.random.seed(0)
pc1arr = np.random.random((1,5,3))
pc2arr = np.random.random((1,6,3))
pc1 = torch.from_numpy(pc1arr.astype(np.float32))
pc2 = torch.from_numpy(pc2arr.astype(np.float32))
dist1, idx1, dist2, idx2 = nn_distance(pc1, pc2)
print(dist1)
print(idx1)
dist = np.zeros((5,6))
for i in range(5):
for j in range(6):
dist[i,j] = np.sum((pc1arr[0,i,:] - pc2arr[0,j,:]) ** 2)
print(dist)
print('-'*30)
print('L1smooth dists:')
dist1, idx1, dist2, idx2 = nn_distance(pc1, pc2, True)
print(dist1)
print(idx1)
dist = np.zeros((5,6))
for i in range(5):
for j in range(6):
error = np.abs(pc1arr[0,i,:] - pc2arr[0,j,:])
quad = np.minimum(error, 1.0)
linear = error - quad
loss = 0.5*quad**2 + 1.0*linear
dist[i,j] = np.sum(loss)
print(dist)
if __name__ == '__main__':
demo_nn_distance()