Skip to content

Commit

Permalink
fix some more bugs introduced by package
Browse files Browse the repository at this point in the history
  • Loading branch information
lukas-blecher committed Apr 13, 2022
1 parent 243908c commit 4056b86
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 25 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@ Always double check the result carefully. You can try to redo the prediction wit
1. First we need to combine the images with their ground truth labels. I wrote a dataset class (which needs further improving) that saves the relative paths to the images with the LaTeX code they were rendered with. To generate the dataset pickle file run

```
python -m pix2tex.dataset.dataset --equations path_to_textfile --images path_to_images --tokenizer dataset/tokenizer.json --out dataset.pkl
python -m pix2tex.dataset.dataset --equations path_to_textfile --images path_to_images --out dataset.pkl
```
To use your own tokenizer pass it via `--tokenizer` (See below).

You can find my generated training data on the [Google Drive](https://drive.google.com/drive/folders/13CA4vAmOmD_I_dSbvLp-Lf0s6KiaNfuO) as well (formulae.zip - images, math.txt - labels). Repeat the step for the validation and test data. All use the same label text file.

2. Edit the `data` (and `valdata`) entry in the config file to the newly generated `.pkl` file. Change other hyperparameters if you want to. See `settings/config.yaml` for a template.
2. Edit the `data` (and `valdata`) entry in the config file to the newly generated `.pkl` file. Change other hyperparameters if you want to. See `pix2tex/model/settings/config.yaml` for a template.
3. Now for the actual training run
```
python -m pix2tex.train --config path_to_config_file
Expand Down
24 changes: 20 additions & 4 deletions pix2tex/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from tempfile import tempdir
import albumentations as alb
from albumentations.pytorch import ToTensorV2
import torch
Expand All @@ -15,6 +16,7 @@
from transformers import PreTrainedTokenizerFast
from tqdm.auto import tqdm

from pix2tex.utils.utils import in_model_path

train_transform = alb.Compose(
[
Expand Down Expand Up @@ -188,6 +190,11 @@ def load(self, filename, args=[]):
Args:
filename (str): Path to dataset
"""
if not os.path.exists(filename):
with in_model_path():
tempf = os.path.join('..', filename)
if os.path.exists(tempf):
filename = os.path.realpath(tempf)
with open(filename, 'rb') as file:
x = pickle.load(file)
return x
Expand All @@ -201,7 +208,7 @@ def combine(self, x):
for key in x.data.keys():
if key in self.data.keys():
self.data[key].extend(x.data[key])
self.data[key]=list(set(self.data[key]))
self.data[key] = list(set(self.data[key]))
else:
self.data[key] = x.data[key]
self._get_size()
Expand Down Expand Up @@ -230,6 +237,12 @@ def update(self, **kwargs):
if self.min_dimensions[0] <= k[0] <= self.max_dimensions[0] and self.min_dimensions[1] <= k[1] <= self.max_dimensions[1]:
temp[k] = self.data[k]
self.data = temp
if 'tokenizer' in kwargs:
tokenizer_file = kwargs['tokenizer']
if not os.path.exists(tokenizer_file):
with in_model_path():
tokenizer_file = os.path.realpath(tokenizer_file)
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file)
self._get_size()
iter(self)

Expand All @@ -251,13 +264,16 @@ def generate_tokenizer(equations, output, vocab_size):
parser.add_argument('-i', '--images', type=str, nargs='+', default=None, help='Image folders')
parser.add_argument('-e', '--equations', type=str, nargs='+', default=None, help='equations text files')
parser.add_argument('-t', '--tokenizer', default=None, help='Pretrained tokenizer file')
parser.add_argument('-o', '--out', required=True, help='output file')
parser.add_argument('-o', '--out', type=str, required=True, help='output file')
parser.add_argument('-s', '--vocab-size', default=8000, type=int, help='vocabulary size when training a tokenizer')
args = parser.parse_args()
if args.images is None and args.equations is not None and args.tokenizer is None:
if args.tokenizer is None:
with in_model_path():
args.tokenizer = os.path.realpath(os.path.join('dataset', 'tokenizer.json'))
if args.images is None and args.equations is not None:
print('Generate tokenizer')
generate_tokenizer(args.equations, args.out, args.vocab_size)
elif args.images is not None and args.equations is not None and args.tokenizer is not None:
elif args.images is not None and args.equations is not None:
print('Generate dataset')
dataset = None
for images, equations in zip(args.images, args.equations):
Expand Down
17 changes: 10 additions & 7 deletions pix2tex/eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from pix2tex.dataset.dataset import Im2LatexDataset
import os
import sys
import argparse
import logging
import yaml
Expand Down Expand Up @@ -90,8 +88,8 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Test model')
parser.add_argument('--config', default='settings/config.yaml', help='path to yaml config file', type=argparse.FileType('r'))
parser.add_argument('-c', '--checkpoint', default='checkpoints/weights.pth', type=str, help='path to model checkpoint')
parser.add_argument('--config', default=None, help='path to yaml config file', type=str)
parser.add_argument('-c', '--checkpoint', default=None, type=str, help='path to model checkpoint')
parser.add_argument('-d', '--data', default='dataset/data/val.pkl', type=str, help='Path to Dataset pkl file')
parser.add_argument('--no-cuda', action='store_true', help='Use CPU')
parser.add_argument('-b', '--batchsize', type=int, default=10, help='Batch size')
Expand All @@ -100,7 +98,10 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i
parser.add_argument('-n', '--num-batches', type=int, default=None, help='how many batches to evaluate on. Defaults to None (all)')

parsed_args = parser.parse_args()
with parsed_args.config as f:
if parsed_args.config is None:
with in_model_path():
parsed_args.config = os.path.realpath('settings/config.yaml')
with open(parsed_args.config, 'r') as f:
params = yaml.load(f, Loader=yaml.FullLoader)
args = parse_args(Munch(params))
args.testbatchsize = parsed_args.batchsize
Expand All @@ -109,8 +110,10 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i
logging.getLogger().setLevel(logging.DEBUG if parsed_args.debug else logging.WARNING)
seed_everything(args.seed if 'seed' in args else 42)
model = get_model(args)
if parsed_args.checkpoint is not None:
model.load_state_dict(torch.load(parsed_args.checkpoint, args.device))
if parsed_args.checkpoint is None:
with in_model_path():
parsed_args.checkpoint = os.path.realpath('checkpoints/weights.pth')
model.load_state_dict(torch.load(parsed_args.checkpoint, args.device))
dataset = Im2LatexDataset().load(parsed_args.data)
valargs = args.copy()
valargs.update(batchsize=args.testbatchsize, keep_smaller_batches=True, test=True)
Expand Down
13 changes: 5 additions & 8 deletions pix2tex/train.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
from pix2tex.dataset.dataset import Im2LatexDataset
import os
import sys
import argparse
import logging
import yaml

import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from munch import Munch
from tqdm.auto import tqdm
import wandb
Expand Down Expand Up @@ -72,14 +68,15 @@ def save_models(e):

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train model')
parser.add_argument('--config', default='settings/debug.yaml', help='path to yaml config file', type=argparse.FileType('r'))
parser.add_argument('-d', '--data', default='dataset/data/train.pkl', type=str, help='Path to Dataset pkl file')
parser.add_argument('--config', default=None, help='path to yaml config file', type=str)
parser.add_argument('--no_cuda', action='store_true', help='Use CPU')
parser.add_argument('--debug', action='store_true', help='DEBUG')
parser.add_argument('--resume', help='path to checkpoint folder', action='store_true')

parsed_args = parser.parse_args()
with parsed_args.config as f:
if parsed_args.config is None:
with in_model_path():
parsed_args.config = os.path.realpath('settings/debug.yaml')
with open(parsed_args.config, 'r') as f:
params = yaml.load(f, Loader=yaml.FullLoader)
args = parse_args(Munch(params), **vars(parsed_args))
logging.getLogger().setLevel(logging.DEBUG if parsed_args.debug else logging.WARNING)
Expand Down
7 changes: 5 additions & 2 deletions pix2tex/train_resizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,18 @@ def train_epoch(sched=None):

if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train size classification model')
parser.add_argument('--config', default='settings/debug.yaml', help='path to yaml config file', type=argparse.FileType('r'))
parser.add_argument('--config', default=None, help='path to yaml config file', type=str)
parser.add_argument('--no_cuda', action='store_true', help='Use CPU')
parser.add_argument('--lr', type=float, default=5e-4, help='learning rate')
parser.add_argument('--resume', help='path to checkpoint folder', type=str, default='')
parser.add_argument('--out', type=str, default='checkpoints/image_resizer.pth', help='output destination for trained model')
parser.add_argument('--num_epochs', type=int, default=10, help='number of epochs to train')
parser.add_argument('--batchsize', type=int, default=10)
parsed_args = parser.parse_args()
with parsed_args.config as f:
if parsed_args.config is None:
with in_model_path():
parsed_args.config = os.path.realpath('settings/debug.yaml')
with open(parsed_args.config, 'r') as f:
params = yaml.load(f, Loader=yaml.FullLoader)
args = parse_args(Munch(params), **vars(parsed_args))
args.update(**vars(parsed_args))
Expand Down
2 changes: 1 addition & 1 deletion pix2tex/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def num_model_params(model):
def in_model_path():
from importlib.resources import path
with path('pix2tex', 'model') as model_path:
os.chdir(model_path)
saved = os.getcwd()
os.chdir(model_path)
try:
yield
finally:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setuptools.setup(
name='pix2tex',
version='0.0.6',
version='0.0.8',
description="pix2tex: Using a ViT to convert images of equations into LaTeX code.",
long_description=long_description,
long_description_content_type='text/markdown',
Expand Down

0 comments on commit 4056b86

Please sign in to comment.