-
Notifications
You must be signed in to change notification settings - Fork 248
/
Copy pathtest_matching.py
64 lines (50 loc) · 1.93 KB
/
test_matching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import unittest
import torch
from torch.utils.data import DataLoader
from few_shot.core import NShotTaskSampler
from few_shot.datasets import DummyDataset
from few_shot.matching import matching_net_predictions
from few_shot.utils import pairwise_distances
class TestMatchingNets(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dataset = DummyDataset(samples_per_class=1000, n_classes=20)
def _test_n_k_q_combination(self, n, k, q):
n_shot_taskloader = DataLoader(self.dataset,
batch_sampler=NShotTaskSampler(self.dataset, 100, n, k, q))
# Load a single n-shot, k-way task
for batch in n_shot_taskloader:
x, y = batch
break
# Take just dummy label features and a little bit of noise
# So distances are never 0
support = x[:n * k, 1:]
queries = x[n * k:, 1:]
support += torch.rand_like(support)
queries += torch.rand_like(queries)
distances = pairwise_distances(queries, support, 'cosine')
# Calculate "attention" as softmax over distances
attention = (-distances).softmax(dim=1).cuda()
y_pred = matching_net_predictions(attention, n, k, q)
self.assertEqual(
y_pred.shape,
(q * k, k),
'Matching Network predictions must have shape (q * k, k).'
)
y_pred_sum = y_pred.sum(dim=1)
self.assertTrue(
torch.all(
torch.isclose(y_pred_sum, torch.ones_like(y_pred_sum).double())
),
'Matching Network predictions probabilities must sum to 1 for each '
'query sample.'
)
def test_matching_net_predictions(self):
test_combinations = [
(1, 5, 5),
(5, 5, 5),
(1, 20, 5),
(5, 20, 5)
]
for n, k, q in test_combinations:
self._test_n_k_q_combination(n, k, q)