Skip to content

Commit

Permalink
multi-gpu train(data-parallelism) + pure ViT
Browse files Browse the repository at this point in the history
  • Loading branch information
TITC committed May 16, 2022
1 parent 720978d commit dffa9f9
Show file tree
Hide file tree
Showing 9 changed files with 214 additions and 16 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,5 @@ pix2tex/model/checkpoints/**
!**/.gitkeep
.vscode
.DS_Store
test/*

2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Welcome to LaTeX-OCR's documentation!
pix2tex.gui
pix2tex.api
pix2tex.dataset
pix2tex.models
pix2tex.structures.hybrid
pix2tex.utils


Expand Down
4 changes: 2 additions & 2 deletions docs/pix2tex.models.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pix2tex.models package
pix2tex.structures.hybrid package
======================

.. automodule:: pix2tex.models
.. automodule:: pix2tex.structures.hybrid
:members:
:no-undoc-members:
:show-inheritance:
2 changes: 1 addition & 1 deletion pix2tex/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from timm.models.layers import StdConv2dSame

from pix2tex.dataset.latex2png import tex2pil
from pix2tex.models import get_model
from pix2tex.structures.hybrid import get_model
from pix2tex.utils import *
from pix2tex.model.checkpoints.get_latest_checkpoint import download_checkpoints

Expand Down
5 changes: 3 additions & 2 deletions pix2tex/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import wandb
from Levenshtein import distance

from pix2tex.models import get_model, Model
from pix2tex.structures.hybrid import get_model, Model
from pix2tex.utils import *


Expand Down Expand Up @@ -52,7 +52,8 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i
continue
encoded = model.encoder(im.to(device))
#loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)
dec = model.decoder.generate(torch.LongTensor([args.bos_token]*len(encoded))[:, None].to(device), args.max_seq_len,
generate = model.decoder.module.generate if torch.cuda.device_count() > 1 else model.decoder.generate
dec = generate(torch.LongTensor([args.bos_token]*len(encoded))[:, None].to(device), args.max_seq_len,
eos_token=args.pad_token, context=encoded, temperature=args.get('temperature', .2))
pred = detokenize(dec, dataset.tokenizer)
truth = detokenize(seq['input_ids'], dataset.tokenizer)
Expand Down
49 changes: 49 additions & 0 deletions pix2tex/model/settings/config-vit.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
betas:
- 0.9
- 0.999
batchsize: 64
bos_token: 1
channels: 1
data: dataset/data/train.pkl
debug: false
decoder_args:
attn_on_attn: true
cross_attend: true
ff_glu: true
rel_pos_bias: false
use_scalenorm: false
dim: 256
encoder_depth: 4
eos_token: 2
epochs: 10
gamma: 0.9995
heads: 8
id: null
load_chkpt: null
lr: 0.0005
lr_step: 30
max_height: 192
max_seq_len: 512
max_width: 672
min_height: 32
min_width: 32
micro_batchsize: 64
model_path: checkpoints_add
name: pix2tex-vit
num_layers: 4
num_tokens: 8000
optimizer: Adam
output_path: outputs
pad: false
pad_token: 0
patch_size: 16
sample_freq: 1000
save_freq: 5
scheduler: StepLR
seed: 42
temperature: 0.2
test_samples: 5
testbatchsize: 20
tokenizer: dataset/tokenizer.json
valbatches: 100
valdata: dataset/data/val.pkl
16 changes: 13 additions & 3 deletions pix2tex/models.py → pix2tex/structures/hybrid.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# taken and modified from https://github.com/lukas-blecher/LaTeX-OCR/blob/720978d8c469780ed070d041d5795c55b705ac1b/pix2tex/models.py
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -126,7 +127,7 @@ def embed_layer(**x):
depth=args.encoder_depth,
num_heads=args.heads,
embed_layer=embed_layer
).to(args.device)
)

decoder = CustomARWrapper(
TransformerWrapper(
Expand All @@ -139,10 +140,19 @@ def embed_layer(**x):
**args.decoder_args
)),
pad_value=args.pad_token
).to(args.device)
)
#to device
available_gpus = torch.cuda.device_count()
if available_gpus > 1:
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)
encoder.to(args.device)
decoder.to(args.device)
if 'wandb' in args and args.wandb:
import wandb
wandb.watch((encoder, decoder.net.attn_layers))
en_attn_layers = encoder.module.attn_layers if available_gpus > 1 else encoder.attn_layers
de_attn_layers = decoder.module.net.attn_layers if available_gpus > 1 else decoder.net.attn_layers
wandb.watch((en_attn_layers, de_attn_layers))
model = Model(encoder, decoder, args)
if training:
# check if largest batch can be handled by system
Expand Down
129 changes: 129 additions & 0 deletions pix2tex/structures/vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# taken and modified from https://github.com/lukas-blecher/LaTeX-OCR/blob/844bc219a9469fa7e9dfc8626f74a705bd194d69/models.py

import torch
import torch.nn as nn
import torch.nn.functional as F

from x_transformers import *
from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
from einops import rearrange, repeat


class ViTransformerWrapper(nn.Module):
def __init__(
self,
*,
max_width,
max_height,
patch_size,
attn_layers,
channels=1,
num_classes=None,
dropout=0.,
emb_dropout=0.
):
super().__init__()
assert isinstance(
attn_layers, Encoder), 'attention layers must be an Encoder'
assert max_width % patch_size == 0 and max_height % patch_size == 0, 'image dimensions must be divisible by the patch size'
dim = attn_layers.dim
num_patches = (max_width // patch_size)*(max_height // patch_size)
patch_dim = channels * patch_size ** 2

self.patch_size = patch_size
self.max_width = max_width
self.max_height = max_height

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(emb_dropout)

self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
#self.mlp_head = FeedForward(dim, dim_out = num_classes, dropout = dropout) if exists(num_classes) else None

def forward(self, img, **kwargs):
p = self.patch_size

x = rearrange(
img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape

cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
x = torch.cat((cls_tokens, x), dim=1)
h, w = torch.tensor(img.shape[2:])//p
pos_emb_ind = repeat(torch.arange(
h)*(self.max_width//p-w), 'h -> (h w)', w=w)+torch.arange(h*w)
pos_emb_ind = torch.cat((torch.zeros(1), pos_emb_ind+1), dim=0).long()
x += self.pos_embedding[:, pos_emb_ind]
x = self.dropout(x)

x = self.attn_layers(x, **kwargs)
x = self.norm(x)

return x


class Model(nn.Module):
def __init__(self, encoder: Encoder, decoder: AutoregressiveWrapper, args):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.args = args

def forward(self, x: torch.Tensor):
return self.decoder.generate(torch.LongTensor([self.args.bos_token]*len(x)).to(x.device), self.args.max_seq_len, eos_token=self.args.eos_token, context=self.encoder(x))


def get_model(args, training=False):
encoder = ViTransformerWrapper(
max_width=args.max_width,
max_height=args.max_height,
channels=args.channels,
patch_size=args.patch_size,
attn_layers=Encoder(
dim=args.dim,
depth=args.num_layers,
heads=args.heads,
)
)

decoder = AutoregressiveWrapper(
TransformerWrapper(
num_tokens=args.num_tokens,
max_seq_len=args.max_seq_len,
attn_layers=Decoder(
dim=args.dim,
depth=args.num_layers,
heads=args.heads,
cross_attend=True
)),
pad_value=args.pad_token
)
available_gpus = torch.cuda.device_count()
if available_gpus > 1:
encoder = nn.DataParallel(encoder)
decoder = nn.DataParallel(decoder)
encoder.to(args.device)
decoder.to(args.device)
if args.wandb:
import wandb
en_attn_layers = encoder.module.attn_layers if available_gpus > 1 else encoder.attn_layers
de_attn_layers = decoder.module.net.attn_layers if available_gpus > 1 else decoder.net.attn_layers
wandb.watch((en_attn_layers, de_attn_layers))
model = Model(encoder, decoder, args)
if training:
# check if largest batch can be handled by system
batchsize = args.batchsize if args.get(
'micro_batchsize', -1) == -1 else args.micro_batchsize
im = torch.empty(batchsize, args.channels, args.max_height,
args.min_height, device=args.device).float()
seq = torch.randint(0, args.num_tokens, (batchsize,
args.max_seq_len), device=args.device).long()
decoder(seq, context=encoder(im)).sum().backward()
model.zero_grad()
torch.cuda.empty_cache()
del im, seq
return model
21 changes: 14 additions & 7 deletions pix2tex/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
import wandb

from pix2tex.eval import evaluate
from pix2tex.models import get_model
from pix2tex.utils import *
from pix2tex.structures.hybrid import get_model
# from pix2tex.utils import *
from pix2tex.utils import in_model_path, parse_args, seed_everything, get_optimizer, get_scheduler



def train(args):
Expand All @@ -26,12 +28,12 @@ def train(args):
if args.load_chkpt is not None:
model.load_state_dict(torch.load(args.load_chkpt, map_location=device))
encoder, decoder = model.encoder, model.decoder

max_bleu, max_token_acc = 0, 0
out_path = os.path.join(args.model_path, args.name)
os.makedirs(out_path, exist_ok=True)

def save_models(e):
torch.save(model.state_dict(), os.path.join(out_path, '%s_e%02d.pth' % (args.name, e+1)))
def save_models(e, step=0):
torch.save(model.state_dict(), os.path.join(out_path, '%s_e%02d_step$02d.pth' % (args.name, e+1, step)))
yaml.dump(dict(args), open(os.path.join(out_path, 'config.yaml'), 'w+'))

opt = get_optimizer(args.optimizer)(model.parameters(), args.lr, betas=args.betas)
Expand All @@ -40,6 +42,7 @@ def save_models(e):
microbatch = args.get('micro_batchsize', -1)
if microbatch == -1:
microbatch = args.batchsize

try:
for e in range(args.epoch, args.epochs):
args.epoch = e
Expand All @@ -52,7 +55,8 @@ def save_models(e):
tgt_seq, tgt_mask = seq['input_ids'][j:j+microbatch].to(device), seq['attention_mask'][j:j+microbatch].bool().to(device)
encoded = encoder(im[j:j+microbatch].to(device))
loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)*microbatch/args.batchsize
loss.backward()
# loss.backward()
loss.mean().backward()# data parallism loss is a vector
total_loss += loss.item()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
opt.step()
Expand All @@ -61,7 +65,10 @@ def save_models(e):
if args.wandb:
wandb.log({'train/loss': total_loss})
if (i+1) % args.sample_freq == 0:
evaluate(model, valdataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val')
bleu_score, edit_distance, token_accuracy = evaluate(model, valdataloader, args, num_batches=int(args.valbatches*e/args.epochs), name='val')
if bleu_score > max_bleu and token_accuracy > max_token_acc:
max_bleu, max_token_acc = bleu_score, token_accuracy
save_models(e, step=i+1)
if (e+1) % args.save_freq == 0:
save_models(e)
if args.wandb:
Expand Down

0 comments on commit dffa9f9

Please sign in to comment.