-
Notifications
You must be signed in to change notification settings - Fork 248
/
Copy pathtest_proto.py
68 lines (56 loc) · 2.38 KB
/
test_proto.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import unittest
from torch.utils.data import DataLoader
from few_shot.core import NShotTaskSampler
from few_shot.datasets import DummyDataset, OmniglotDataset, MiniImageNet
from few_shot.models import get_few_shot_encoder
from few_shot.proto import compute_prototypes
class TestProtoNets(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.dataset = DummyDataset(samples_per_class=1000, n_classes=20)
def _test_n_k_q_combination(self, n, k, q):
n_shot_taskloader = DataLoader(self.dataset,
batch_sampler=NShotTaskSampler(self.dataset, 100, n, k, q))
# Load a single n-shot, k-way task
for batch in n_shot_taskloader:
x, y = batch
break
support = x[:n * k]
support_labels = y[:n * k]
prototypes = compute_prototypes(support, k, n)
# By construction the second feature of samples from the
# DummyDataset is equal to the label.
# As class prototypes are constructed from the means of the support
# set items of a particular class the value of the second feature
# of the class prototypes should be equal to the label of that class.
for i in range(k):
self.assertEqual(
support_labels[i * n],
prototypes[i, 1],
'Prototypes computed incorrectly!'
)
def test_compute_prototypes(self):
test_combinations = [
(1, 5, 5),
(5, 5, 5),
(1, 20, 5),
(5, 20, 5)
]
for n, k, q in test_combinations:
self._test_n_k_q_combination(n, k, q)
def test_create_model(self):
# Check output of encoder has shape specified in paper
encoder = get_few_shot_encoder(num_input_channels=1).float()
omniglot = OmniglotDataset('background')
self.assertEqual(
encoder(omniglot[0][0].unsqueeze(0).float()).shape[1],
64,
'Encoder network should produce 64 dimensional embeddings on Omniglot dataset.'
)
encoder = get_few_shot_encoder(num_input_channels=3).float()
omniglot = MiniImageNet('background')
self.assertEqual(
encoder(omniglot[0][0].unsqueeze(0).float()).shape[1],
1600,
'Encoder network should produce 1600 dimensional embeddings on miniImageNet dataset.'
)