forked from lukas-blecher/LaTeX-OCR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
102 lines (92 loc) · 4.53 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from pix2tex.dataset.dataset import Im2LatexDataset
import os
import argparse
import logging
import yaml
import torch
from munch import Munch
from tqdm.auto import tqdm
import wandb
import torch.nn as nn
from pix2tex.eval import evaluate
from pix2tex.models import get_model
# from pix2tex.utils import *
from pix2tex.utils import in_model_path, parse_args, seed_everything, get_optimizer, get_scheduler, gpu_memory_check
def train(args):
dataloader = Im2LatexDataset().load(args.data)
dataloader.update(**args, test=False)
valdataloader = Im2LatexDataset().load(args.valdata)
valargs = args.copy()
valargs.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True)
valdataloader.update(**valargs)
device = args.device
model = get_model(args)
if torch.cuda.is_available() and not args.no_cuda:
gpu_memory_check(model, args)
max_bleu, max_token_acc = 0, 0
out_path = os.path.join(args.model_path, args.name)
os.makedirs(out_path, exist_ok=True)
if args.load_chkpt is not None:
model.load_state_dict(torch.load(args.load_chkpt, map_location=device))
def save_models(e, step=0):
torch.save(model.state_dict(), os.path.join(out_path, '%s_e%02d_step%02d.pth' % (args.name, e+1, step)))
yaml.dump(dict(args), open(os.path.join(out_path, 'config.yaml'), 'w+'))
opt = get_optimizer(args.optimizer)(model.parameters(), args.lr, betas=args.betas)
scheduler = get_scheduler(args.scheduler)(opt, step_size=args.lr_step, gamma=args.gamma)
microbatch = args.get('micro_batchsize', -1)
if microbatch == -1:
microbatch = args.batchsize
try:
for e in range(args.epoch, args.epochs):
args.epoch = e
dset = tqdm(iter(dataloader))
for i, (seq, im) in enumerate(dset):
if seq is not None and im is not None:
opt.zero_grad()
total_loss = 0
for j in range(0, len(im), microbatch):
tgt_seq, tgt_mask = seq['input_ids'][j:j+microbatch].to(device), seq['attention_mask'][j:j+microbatch].bool().to(device)
loss = model.data_parallel(im[j:j+microbatch].to(device), device_ids=args.gpu_devices, tgt_seq=tgt_seq, mask=tgt_mask)*microbatch/args.batchsize
loss.backward() # data parallism loss is a vector
total_loss += loss.item()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
opt.step()
scheduler.step()
dset.set_description('Loss: %.4f' % total_loss)
if args.wandb:
wandb.log({'train/loss': total_loss})
if (i+1+len(dataloader)*e) % args.sample_freq == 0:
bleu_score, edit_distance, token_accuracy = evaluate(model, valdataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val')
if bleu_score > max_bleu and token_accuracy > max_token_acc:
max_bleu, max_token_acc = bleu_score, token_accuracy
save_models(e, step=i)
if (e+1) % args.save_freq == 0:
save_models(e, step=len(dataloader))
if args.wandb:
wandb.log({'train/epoch': e+1})
except KeyboardInterrupt:
if e >= 2:
save_models(e, step=i)
raise KeyboardInterrupt
save_models(e, step=len(dataloader))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train model')
parser.add_argument('--config', default=None, help='path to yaml config file', type=str)
parser.add_argument('--no_cuda', action='store_true', help='Use CPU')
parser.add_argument('--debug', action='store_true', help='DEBUG')
parser.add_argument('--resume', help='path to checkpoint folder', action='store_true')
parsed_args = parser.parse_args()
if parsed_args.config is None:
with in_model_path():
parsed_args.config = os.path.realpath('settings/debug.yaml')
with open(parsed_args.config, 'r') as f:
params = yaml.load(f, Loader=yaml.FullLoader)
args = parse_args(Munch(params), **vars(parsed_args))
logging.getLogger().setLevel(logging.DEBUG if parsed_args.debug else logging.WARNING)
seed_everything(args.seed)
if args.wandb:
if not parsed_args.resume:
args.id = wandb.util.generate_id()
wandb.init(config=dict(args), resume='allow', name=args.name, id=args.id)
args = Munch(wandb.config)
train(args)