Skip to content

Commit

Permalink
remove de-macro, improve own demacro
Browse files Browse the repository at this point in the history
  • Loading branch information
lukas-blecher committed Apr 19, 2022
1 parent e02c395 commit 59903b1
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 44 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ In order to render the math in many different fonts we use XeLaTeX, generate a
* [XeLaTeX](https://www.ctan.org/pkg/xetex)
* [ImageMagick](https://imagemagick.org/) with [Ghostscript](https://www.ghostscript.com/index.html). (for converting pdf to png)
* [Node.js](https://nodejs.org/) to run [KaTeX](https://github.com/KaTeX/KaTeX) (for normalizing Latex code)
* [`de-macro`](https://www.ctan.org/pkg/de-macro) >= 1.4 (only for parsing arxiv papers)
* Python 3.7+ & dependencies (specified in `setup.py`)

### Fonts
Expand Down
36 changes: 14 additions & 22 deletions pix2tex/dataset/arxiv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import sys
import argparse
import logging
import shutil
import subprocess
import tarfile
import tempfile
import chardet
Expand Down Expand Up @@ -50,7 +48,7 @@ def download(url, dir_path='./'):
return 0


def read_tex_files(file_path, demacro=True):
def read_tex_files(file_path):
tex = ''
try:
with tempfile.TemporaryDirectory() as tempdir:
Expand All @@ -59,18 +57,11 @@ def read_tex_files(file_path, demacro=True):
tf.extractall(tempdir)
tf.close()
texfiles = [os.path.abspath(x) for x in glob.glob(os.path.join(tempdir, '**', '*.tex'), recursive=True)]
# de-macro
if demacro:
ret = subprocess.run(['de-macro', *texfiles], cwd=tempdir, capture_output=True)
if ret.returncode == 0:
texfiles = glob.glob(os.path.join(tempdir, '**', '*-clean.tex'), recursive=True)
except tarfile.ReadError as e:
texfiles = [file_path] # [os.path.join(tempdir, file_path+'.tex')]
#shutil.move(file_path, texfiles[0])

for texfile in texfiles:
try:
tex += open(texfile, 'r', encoding=chardet.detect(open(texfile, 'br').readline())['encoding']).read()
tex += open(texfile, 'r', encoding=chardet.detect(open(texfile, 'br').readline())['encoding']).read()
except UnicodeDecodeError:
pass
tex = unfold(convert(tex))
Expand All @@ -86,32 +77,32 @@ def download_paper(arxiv_id, dir_path='./'):
return download(url, dir_path)


def read_paper(targz_path, delete=True, demacro=True):
def read_paper(targz_path, delete=True):
paper = ''
if targz_path != 0:
paper = read_tex_files(targz_path, demacro)
paper = read_tex_files(targz_path)
if delete:
os.remove(targz_path)
return paper


def parse_arxiv(id, demacro=True):
def parse_arxiv(id):
tempdir = tempfile.gettempdir()
text = read_paper(download_paper(id, tempdir), demacro=demacro)
text = read_paper(download_paper(id, tempdir))
#print(text, file=open('paper.tex', 'w'))
#linked = list(set([l for l in re.findall(arxiv_id, text)]))

return find_math(text, wiki=False), []


if __name__ == '__main__':
# logging.getLogger().setLevel(logging.DEBUG)
parser = argparse.ArgumentParser(description='Extract math from arxiv')
parser.add_argument('-m', '--mode', default='top100', choices=['top100', 'id', 'dir'],
parser.add_argument('-m', '--mode', default='top100', choices=['top100', 'ids', 'dir'],
help='Where to extract code from. top100: current 100 arxiv papers, id: specific arxiv ids. \
Usage: `python arxiv.py -m id id001 id002`, dir: a folder full of .tar.gz files. Usage: `python arxiv.py -m dir directory`')
parser.add_argument(nargs='+', dest='args', default=[])
parser.add_argument(nargs='*', dest='args', default=[])
parser.add_argument('-o', '--out', default=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'data'), help='output directory')
parser.add_argument('-d', '--no-demacro', dest='demacro', action='store_false', help='Use de-macro (Slows down extraction but improves quality)')
args = parser.parse_args()
if '.' in args.out:
args.out = os.path.dirname(args.out)
Expand All @@ -123,7 +114,7 @@ def parse_arxiv(id, demacro=True):
if args.mode == 'ids':
visited, math = recursive_search(parse_arxiv, args.args, skip=skip, unit='paper')
elif args.mode == 'top100':
url = 'https://arxiv.org/list/hep-th/2012?skip=0&show=100' # https://arxiv.org/list/hep-th/2012?skip=0&show=100
url = 'https://arxiv.org/list/physics/pastweek?skip=0&show=100' #'https://arxiv.org/list/hep-th/2203?skip=0&show=100'
ids = get_all_arxiv_ids(requests.get(url).text)
math, visited = [], ids
for id in tqdm(ids):
Expand All @@ -134,15 +125,16 @@ def parse_arxiv(id, demacro=True):
math, visited = [], []
for f in tqdm(dirs):
try:
text = read_paper(os.path.join(args.args[0], f), False, args.demacro)
text = read_paper(os.path.join(args.args[0], f), False)
math.extend(find_math(text, wiki=False))
visited.append(os.path.basename(f))
visited.append(os.path.basename(f))
except Exception as e:
logging.debug(e)
pass
else:
raise NotImplementedError

print('\n'.join(math))
sys.exit(0)
for l, name in zip([visited, math], ['visited_arxiv.txt', 'math_arxiv.txt']):
f = os.path.join(args.out, name)
if not os.path.exists(f):
Expand Down
67 changes: 51 additions & 16 deletions pix2tex/dataset/demacro.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,24 @@

import argparse
import re
from pix2tex.dataset.extract_latex import remove_labels


def main():
args = parse_command_line()
data = read(args.input)
data = convert(data)
if args.demacro:
data = unfold(data)
write(args.output, data)
data = unfold(data)
if args.output is not None:
write(args.output, data)
else:
print(data)


def parse_command_line():
parser = argparse.ArgumentParser(description='Replace \\def with \\newcommand where possible.')
parser.add_argument('input', help='TeX input file with \\def')
parser.add_argument('--output', '-o', required=True, help='TeX output file with \\newcommand')
parser.add_argument('--demacro', action='store_true', help='replace all commands with their definition')

parser.add_argument('--output', '-o', default=None, help='TeX output file with \\newcommand')
return parser.parse_args()


Expand All @@ -37,27 +38,61 @@ def convert(data):
)


def unfold(t):
cmds = re.findall(r'\\(?:re)?newcommand\*?{\\(.+?)}\s*(\[\d\])?(\[.+?\])?{(.+?)}\n', t)
cmds = sorted(cmds, key=lambda x: len(x[0]))
# print(cmds)
def bracket_replace(string: str) -> str:
'''
replaces all layered brackets with special symbols
'''
layer = 0
out = list(string)
for i, c in enumerate(out):
if c == '{':
if layer > 0:
out[i] = 'Ḋ'
layer += 1
elif c == '}':
layer -= 1
if layer > 0:
out[i] = 'Ḍ'
return ''.join(out)


def undo_bracket_replace(string):
return string.replace('Ḋ', '{').replace('Ḍ', '}')


def sweep(t, cmds):
num_matches = 0
for c in cmds:
nargs = int(c[1][1]) if c[1] != r'' else 0
# print(c)
optional = c[2] != r''
if nargs == 0:
#t = t.replace(r'\\%s' % c[0], c[-1])
t = re.sub(r'\\%s([\W_^\d])' % c[0], r'%s\1' % c[-1].replace('\\', r'\\'), t)
else:
matches = re.findall(r'(\\%s(?:\[(.+?)\])?' % c[0]+r'{(.+?)}'*(nargs-(1 if c[2] != r'' else 0))+r')', t)
# print(matches)
matches = re.findall(r'(\\%s(?:\[(.+?)\])?' % c[0]+r'{(.+?)}'*(nargs-(1 if optional else 0))+r')', t)
num_matches += len(matches)
for i, m in enumerate(matches):
r = c[-1]
if m[1] == r'':
matches[i] = (m[0], c[2][1:-1], *m[2:])
for j in range(1, nargs+1):
r = r.replace(r'#%i' % j, matches[i][j])
r = r.replace(r'#%i' % j, matches[i][j+int(not optional)])
t = t.replace(matches[i][0], r)
return t
return t, num_matches


def unfold(t):
t = remove_labels(t).replace('\n', 'Ċ')

cmds = re.findall(r'\\(?:re)?newcommand\*?{\\(.+?)}\s*(\[\d\])?(\[.+?\])?{(.+?)}Ċ', t)
cmds = sorted(cmds, key=lambda x: len(x[0]))
for _ in range(10):
# check for up to 10 nested commands
t = bracket_replace(t)
t, N = sweep(t, cmds)
t = undo_bracket_replace(t)
if N == 0:
break
return t.replace('Ċ', '\n')


def replace(match):
Expand Down
11 changes: 6 additions & 5 deletions pix2tex/dataset/extract_latex.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
displaymath = re.compile(r'(\\displaystyle)(.{%i,%i}?)(\}(?:<|"))' % (1, MAX_CHARS))
outer_whitespace = re.compile(
r'^\\,|\\,$|^~|~$|^\\ |\\ $|^\\thinspace|\\thinspace$|^\\!|\\!$|^\\:|\\:$|^\\;|\\;$|^\\enspace|\\enspace$|^\\quad|\\quad$|^\\qquad|\\qquad$|^\\hspace{[a-zA-Z0-9]+}|\\hspace{[a-zA-Z0-9]+}$|^\\hfill|\\hfill$')

label_names = [re.compile(r'\\%s\s?\{(.*?)\}' % s) for s in ['ref', 'cite', 'label', 'caption', 'eqref']]

def check_brackets(s):
a = []
Expand Down Expand Up @@ -39,17 +39,18 @@ def check_brackets(s):
else:
return s

def remove_labels(string):
for s in label_names:
string = re.sub(s, '', string)
return string

def clean_matches(matches, min_chars=MIN_CHARS):
template = r'\\%s\s?\{(.*?)\}'
sub = [re.compile(template % s) for s in ['ref', 'cite', 'label', 'caption']]
faulty = []
for i in range(len(matches)):
if 'tikz' in matches[i]: # do not support tikz at the moment
faulty.append(i)
continue
for s in sub:
matches[i] = re.sub(s, '', matches[i])
matches[i] = remove_labels(matches[i])
matches[i] = matches[i].replace('\n', '').replace(r'\notag', '').replace(r'\nonumber', '')
matches[i] = re.sub(outer_whitespace, '', matches[i])
if len(matches[i]) < min_chars:
Expand Down

0 comments on commit 59903b1

Please sign in to comment.