-
Notifications
You must be signed in to change notification settings - Fork 248
/
Copy pathtest_utils.py
139 lines (118 loc) · 5.4 KB
/
test_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import unittest
import numpy as np
import torch
from torch.nn.modules.distance import CosineSimilarity, PairwiseDistance
from few_shot.utils import *
from config import PATH
class TestDistance(unittest.TestCase):
def test_query_support_distances(self):
# Create some dummy data with easily verifiable distances
q = 1 # 1 query per class
k = 3 # 3 way classification
d = 2 # embedding dimension of two
query = torch.zeros([q * k, d], dtype=torch.double)
query[0] = torch.Tensor([0, 0])
query[1] = torch.Tensor([0, 1])
query[2] = torch.Tensor([1, 0])
support = torch.zeros([k, d], dtype=torch.double)
support[0] = torch.Tensor([1, 1])
support[1] = torch.Tensor([1, 2])
support[2] = torch.Tensor([2, 2])
distances = pairwise_distances(query, support, 'l2')
self.assertEqual(
distances.shape, (q * k, k),
'Output should have shape (q * k, k).'
)
# Calculate squared distances by iterating through all query-support pairs
for i, q_ in enumerate(query):
for j, s_ in enumerate(support):
self.assertEqual(
(q_ - s_).pow(2).sum(),
distances[i, j].item(),
'The jth column of the ith row should be the squared distance between the '
'ith query sample and the kth query sample'
)
# Create some dummy data with easily verifiable distances
q = 1 # 1 query per class
k = 3 # 3 way classification
d = 2 # embedding dimension of two
query = torch.zeros([q * k, d], dtype=torch.double)
query[0] = torch.Tensor([1, 0])
query[1] = torch.Tensor([0, 1])
query[2] = torch.Tensor([1, 1])
support = torch.zeros([k, d], dtype=torch.double)
support[0] = torch.Tensor([1, 1])
support[1] = torch.Tensor([-1, -1])
support[2] = torch.Tensor([0, 2])
distances = pairwise_distances(query, support, 'cosine')
# Calculate distances by iterating through all query-support pairs
for i, q_ in enumerate(query):
for j, s_ in enumerate(support):
self.assertTrue(
torch.isclose(1-CosineSimilarity(dim=0)(q_, s_), distances[i, j], atol=2e-8),
'The jth column of the ith row should be the squared distance between the '
'ith query sample and the kth query sample'
)
def test_no_nans_on_zero_vectors(self):
"""Cosine distance calculation involves a divide-through by vector magnitude which
can divide by zeros to occur.
"""
# Create some dummy data with easily verifiable distances
q = 1 # 1 query per class
k = 3 # 3 way classification
d = 2 # embedding dimension of two
query = torch.zeros([q * k, d], dtype=torch.double)
query[0] = torch.Tensor([0, 0]) # First query sample is all zeros
query[1] = torch.Tensor([0, 1])
query[2] = torch.Tensor([1, 1])
support = torch.zeros([k, d], dtype=torch.double)
support[0] = torch.Tensor([1, 1])
support[1] = torch.Tensor([-1, -1])
support[2] = torch.Tensor([0, 0]) # Third support sample is all zeros
distances = pairwise_distances(query, support, 'cosine')
self.assertTrue(torch.isnan(distances).sum() == 0, 'Cosine distances between 0-vectors should not be nan')
class TestAutogradGraphRetrieval(unittest.TestCase):
def test_retrieval(self):
"""Create a simple autograd graph and check that the output is what is expected"""
x = torch.ones(2, 2, requires_grad=True)
y = x + 2
# The operation on the next line will create two edges because the y variable is
# y variable is used twice.
z = y * y
out = z.mean()
nodes, edges = autograd_graph(out)
# This is quite a brittle test as it will break if the names of the autograd Functions change
# TODO: Less brittle test
expected_nodes = [
'MeanBackward1',
'MulBackward0',
'AddBackward0',
'AccumulateGrad',
]
self.assertEqual(
set(expected_nodes),
set(n.__class__.__name__ for n in nodes),
'autograd_graph() must return all nodes in the autograd graph.'
)
# Check for the existence of the expected edges
expected_edges = [
('MulBackward0', 'MeanBackward1'), # z = y * y, out = z.mean()
('AddBackward0', 'MulBackward0'), # y = x + 2, z = y * y
('AccumulateGrad', 'AddBackward0'), # x = torch.ones(2, 2, requires_grad=True), y = x + 2
]
for e in edges:
self.assertIn(
(e[0].__class__.__name__, e[1].__class__.__name__),
expected_edges,
'autograd_graph() must return all edges in the autograd graph.'
)
# Check for two edges between the AddBackward node and the ThMulBackward node
num_y_squared_edges = 0
for e in edges:
if e[0].__class__.__name__ == 'AddBackward0' and e[1].__class__.__name__ == 'MulBackward0':
num_y_squared_edges += 1
self.assertEqual(
num_y_squared_edges,
2,
'autograd_graph() must return multiple edges between nodes if they exist.'
)