Skip to content

Commit

Permalink
add (normed) edit distance evaluation metric
Browse files Browse the repository at this point in the history
  • Loading branch information
lukas-blecher committed May 5, 2021
1 parent 95a9fae commit c7e0486
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
35 changes: 23 additions & 12 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from munch import Munch
from tqdm.auto import tqdm
import wandb
from Levenshtein import distance

from models import get_model, Model
from utils import *
Expand Down Expand Up @@ -44,7 +45,9 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i
"""
assert len(dataset) > 0
device = args.device
bleus = []
log = {}
bleus, edit_dists = [], []
bleu_score, edit_distance = 0, 1
pbar = tqdm(enumerate(iter(dataset)), total=len(dataset))
for i, (seq, im) in pbar:
if seq is None or im is None:
Expand All @@ -53,27 +56,35 @@ def evaluate(model: Model, dataset: Im2LatexDataset, args: Munch, num_batches: i
encoded = model.encoder(im.to(device))
#loss = decoder(tgt_seq, mask=tgt_mask, context=encoded)
dec = model.decoder.generate(torch.LongTensor([args.bos_token]*len(encoded))[:, None].to(device), args.max_seq_len,
eos_token=args.pad_token, context=encoded, temperature=(args.temperature if 'temperature' in args else 1))
eos_token=args.pad_token, context=encoded, temperature=args.get('temperature', .2))
pred = detokenize(dec, dataset.tokenizer)
truth = detokenize(seq['input_ids'], dataset.tokenizer)
bleus.append(metrics.bleu_score(pred, [alternatives(x) for x in truth]))
pbar.set_description('BLEU: %.3f +/- %.3f' % (np.mean(bleus), np.std(bleus)))
for predi, truthi in zip(token2str(dec, dataset.tokenizer), token2str(seq['input_ids'], dataset.tokenizer)):
ts = post_process(truthi)
edit_dists.append(distance(post_process(predi), ts)/len(ts))
pbar.set_description('BLEU: %.3f, ED: %.2e' % (np.mean(bleus), np.mean(edit_dists)))
if num_batches is not None and i >= num_batches:
break
if len(bleus) > 0:
bleu_score = np.mean(bleus)
log[name+'/bleu'] = bleu_score
if len(edit_dists) > 0:
edit_distance = np.mean(edit_dists)
log[name+'/edit_distance'] = edit_distance
if args.wandb:
# samples
pred = token2str(dec, dataset.tokenizer)
truth = token2str(seq['input_ids'], dataset.tokenizer)
if args.wandb:
table = wandb.Table(columns=["Truth", "Prediction"])
for k in range(min([len(pred), args.test_samples])):
table.add_data(post_process(truth[k]), post_process(pred[k]))
wandb.log({name+'/examples': table, name+'/bleu': bleu_score})
else:
print('\n%s\n%s' % (truth, pred))
print('BLEU: %.2f' % bleu_score)
return bleu_score
table = wandb.Table(columns=["Truth", "Prediction"])
for k in range(min([len(pred), args.test_samples])):
table.add_data(post_process(truth[k]), post_process(pred[k]))
log[name+'/examples'] = table
wandb.log(log)
else:
print('\n%s\n%s' % (truth, pred))
print('BLEU: %.2f' % bleu_score)
return bleu_score, edit_distance


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ torchtext>=0.6.0
albumentations>=0.5.2
pandas>=1.0.0
timm==0.4.5
python-Levenshtein>=0.12.2

0 comments on commit c7e0486

Please sign in to comment.