-
Notifications
You must be signed in to change notification settings - Fork 55
/
trainer.py
88 lines (66 loc) · 3.01 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import torch
import numpy as np
def save(model, ckpt_num, dir_name):
os.makedirs(dir_name, exist_ok=True)
if torch.cuda.device_count() > 1:
torch.save(model.module.state_dict(), os.path.join(dir_name, 'model_%s' % ckpt_num))
else:
torch.save(model.state_dict(), os.path.join(dir_name, 'model_%s' % ckpt_num))
print('model saved!')
def fit(train_loader, model, loss_fn, optimizer, scheduler, nb_epoch,
device, log_interval, start_epoch=0, save_model_to='/tmp/save_model_to'):
"""
Loaders, model, loss function and metrics should work together for a given task,
i.e. The model should be able to process data output of loaders,
loss function should process target output of loaders and outputs from the model
Examples: Classification: batch loader, classification model, NLL loss, accuracy metric
Siamese network: Siamese loader, siamese model, contrastive loss
Online triplet learning: batch loader, embedding model, online triplet loss
"""
# Save pre-trained model
save(model, 0, save_model_to)
for epoch in range(0, start_epoch):
scheduler.step()
for epoch in range(start_epoch, nb_epoch):
scheduler.step()
# Train stage
train_loss = train_epoch(train_loader, model, loss_fn, optimizer, device, log_interval)
log_dict = {'epoch': epoch + 1,
'epoch_total': nb_epoch,
'loss': float(train_loss),
}
message = 'Epoch: {}/{}. Train set: Average loss: {:.4f}'.format(epoch + 1, nb_epoch, train_loss)
print(message)
print(log_dict)
if (epoch + 1) % 5 == 0:
save(model, epoch + 1, save_model_to)
def train_epoch(train_loader, model, loss_fn, optimizer, device, log_interval):
model.train()
total_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):
target = target if len(target) > 0 else None
if not type(data) in (tuple, list):
data = (data,)
data = tuple(d.to(device) for d in data)
if target is not None:
target = target.to(device)
optimizer.zero_grad()
if loss_fn.cross_entropy_flag:
output_embedding, output_cross_entropy = model(*data)
blended_loss, losses = loss_fn.calculate_loss(target, output_embedding, output_cross_entropy)
else:
output_embedding = model(*data)
blended_loss, losses = loss_fn.calculate_loss(target, output_embedding)
total_loss += blended_loss.item()
blended_loss.backward()
optimizer.step()
# Print log
if batch_idx % log_interval == 0:
message = 'Train: [{}/{} ({:.0f}%)]'.format(
batch_idx * len(data[0]), len(train_loader.dataset), 100. * batch_idx / len(train_loader))
for name, value in losses.items():
message += '\t{}: {:.6f}'.format(name, np.mean(value))
print(message)
total_loss /= (batch_idx + 1)
return total_loss