-
Notifications
You must be signed in to change notification settings - Fork 248
/
Copy pathproto.py
90 lines (74 loc) · 3.18 KB
/
proto.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from torch.optim import Optimizer
from torch.nn import Module
from typing import Callable
from few_shot.utils import pairwise_distances
def proto_net_episode(model: Module,
optimiser: Optimizer,
loss_fn: Callable,
x: torch.Tensor,
y: torch.Tensor,
n_shot: int,
k_way: int,
q_queries: int,
distance: str,
train: bool):
"""Performs a single training episode for a Prototypical Network.
# Arguments
model: Prototypical Network to be trained.
optimiser: Optimiser to calculate gradient step
loss_fn: Loss function to calculate between predictions and outputs. Should be cross-entropy
x: Input samples of few shot classification task
y: Input labels of few shot classification task
n_shot: Number of examples per class in the support set
k_way: Number of classes in the few shot classification task
q_queries: Number of examples per class in the query set
distance: Distance metric to use when calculating distance between class prototypes and queries
train: Whether (True) or not (False) to perform a parameter update
# Returns
loss: Loss of the Prototypical Network on this task
y_pred: Predicted class probabilities for the query set on this task
"""
if train:
# Zero gradients
model.train()
optimiser.zero_grad()
else:
model.eval()
# Embed all samples
embeddings = model(x)
# Samples are ordered by the NShotWrapper class as follows:
# k lots of n support samples from a particular class
# k lots of q query samples from those classes
support = embeddings[:n_shot*k_way]
queries = embeddings[n_shot*k_way:]
prototypes = compute_prototypes(support, k_way, n_shot)
# Calculate squared distances between all queries and all prototypes
# Output should have shape (q_queries * k_way, k_way) = (num_queries, k_way)
distances = pairwise_distances(queries, prototypes, distance)
# Calculate log p_{phi} (y = k | x)
log_p_y = (-distances).log_softmax(dim=1)
loss = loss_fn(log_p_y, y)
# Prediction probabilities are softmax over distances
y_pred = (-distances).softmax(dim=1)
if train:
# Take gradient step
loss.backward()
optimiser.step()
else:
pass
return loss, y_pred
def compute_prototypes(support: torch.Tensor, k: int, n: int) -> torch.Tensor:
"""Compute class prototypes from support samples.
# Arguments
support: torch.Tensor. Tensor of shape (n * k, d) where d is the embedding
dimension.
k: int. "k-way" i.e. number of classes in the classification task
n: int. "n-shot" of the classification task
# Returns
class_prototypes: Prototypes aka mean embeddings for each class
"""
# Reshape so the first dimension indexes by class then take the mean
# along that dimension to generate the "prototypes" for each class
class_prototypes = support.reshape(k, n, -1).mean(dim=1)
return class_prototypes