-
Notifications
You must be signed in to change notification settings - Fork 248
/
Copy pathtrain.py
128 lines (102 loc) · 4.63 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
The `fit` function in this file implements a slightly modified version
of the Keras `model.fit()` API.
"""
import torch
from torch.optim import Optimizer
from torch.nn import Module
from torch.utils.data import DataLoader
from typing import Callable, List, Union
from few_shot.callbacks import DefaultCallback, ProgressBarLogger, CallbackList, Callback
from few_shot.metrics import NAMED_METRICS
def gradient_step(model: Module, optimiser: Optimizer, loss_fn: Callable, x: torch.Tensor, y: torch.Tensor, **kwargs):
"""Takes a single gradient step.
# Arguments
model: Model to be fitted
optimiser: Optimiser to calculate gradient step from loss
loss_fn: Loss function to calculate between predictions and outputs
x: Input samples
y: Input targets
"""
model.train()
optimiser.zero_grad()
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimiser.step()
return loss, y_pred
def batch_metrics(model: Module, y_pred: torch.Tensor, y: torch.Tensor, metrics: List[Union[str, Callable]],
batch_logs: dict):
"""Calculates metrics for the current training batch
# Arguments
model: Model being fit
y_pred: predictions for a particular batch
y: labels for a particular batch
batch_logs: Dictionary of logs for the current batch
"""
model.eval()
for m in metrics:
if isinstance(m, str):
batch_logs[m] = NAMED_METRICS[m](y, y_pred)
else:
# Assume metric is a callable function
batch_logs = m(y, y_pred)
return batch_logs
def fit(model: Module, optimiser: Optimizer, loss_fn: Callable, epochs: int, dataloader: DataLoader,
prepare_batch: Callable, metrics: List[Union[str, Callable]] = None, callbacks: List[Callback] = None,
verbose: bool =True, fit_function: Callable = gradient_step, fit_function_kwargs: dict = {}):
"""Function to abstract away training loop.
The benefit of this function is that allows training scripts to be much more readable and allows for easy re-use of
common training functionality provided they are written as a subclass of voicemap.Callback (following the
Keras API).
# Arguments
model: Model to be fitted.
optimiser: Optimiser to calculate gradient step from loss
loss_fn: Loss function to calculate between predictions and outputs
epochs: Number of epochs of fitting to be performed
dataloader: `torch.DataLoader` instance to fit the model to
prepare_batch: Callable to perform any desired preprocessing
metrics: Optional list of metrics to evaluate the model with
callbacks: Additional functionality to incorporate into training such as logging metrics to csv, model
checkpointing, learning rate scheduling etc... See voicemap.callbacks for more.
verbose: All print output is muted if this argument is `False`
fit_function: Function for calculating gradients. Leave as default for simple supervised training on labelled
batches. For more complex training procedures (meta-learning etc...) you will need to write your own
fit_function
fit_function_kwargs: Keyword arguments to pass to `fit_function`
"""
# Determine number of samples:
num_batches = len(dataloader)
batch_size = dataloader.batch_size
callbacks = CallbackList([DefaultCallback(), ] + (callbacks or []) + [ProgressBarLogger(), ])
callbacks.set_model(model)
callbacks.set_params({
'num_batches': num_batches,
'batch_size': batch_size,
'verbose': verbose,
'metrics': (metrics or []),
'prepare_batch': prepare_batch,
'loss_fn': loss_fn,
'optimiser': optimiser
})
if verbose:
print('Begin training...')
callbacks.on_train_begin()
for epoch in range(1, epochs+1):
callbacks.on_epoch_begin(epoch)
epoch_logs = {}
for batch_index, batch in enumerate(dataloader):
batch_logs = dict(batch=batch_index, size=(batch_size or 1))
callbacks.on_batch_begin(batch_index, batch_logs)
x, y = prepare_batch(batch)
loss, y_pred = fit_function(model, optimiser, loss_fn, x, y, **fit_function_kwargs)
batch_logs['loss'] = loss.item()
# Loops through all metrics
batch_logs = batch_metrics(model, y_pred, y, metrics, batch_logs)
callbacks.on_batch_end(batch_index, batch_logs)
# Run on epoch end
callbacks.on_epoch_end(epoch, epoch_logs)
# Run on train end
if verbose:
print('Finished.')
callbacks.on_train_end()