Skip to content

Commit

Permalink
Support BERT-based models in release version
Browse files Browse the repository at this point in the history
The new model format is set up to allow multilingual models and
perform part-of-speech tagging as part of the parser instead of
relying on an external tagger.
  • Loading branch information
nikitakit committed Dec 31, 2018
1 parent 35fa5a5 commit 0f4439f
Show file tree
Hide file tree
Showing 6 changed files with 499 additions and 46 deletions.
155 changes: 130 additions & 25 deletions benepar/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import numpy as np
import os
import sys
import json
import codecs

from . import chart_decoder
from .downloader import load_model
from .bert_tokenization import BertTokenizer

#%%

Expand All @@ -23,12 +26,25 @@
"[": "-LSB-",
"]": "-RSB-"}

PTB_TOKEN_UNESCAPE = {"-LRB-": "(",
BERT_TOKEN_MAPPING = {"-LRB-": "(",
"-RRB-": ")",
"-LCB-": "{",
"-RCB-": "}",
"-LSB-": "[",
"-RSB-": "]"}
"-RSB-": "]",
"``": '"',
"''": '"',
"`": "'",
"“": '"',
"”": '"',
"‘": "'",
"’": "'",
"«": '"',
"»": '"',
"„": '"',
"‹": "'",
"›": "'",
}

# Label vocab is made immutable because it is potentially exposed to users
# through the spacy plugin
Expand Down Expand Up @@ -147,6 +163,7 @@
('ADJP', 'ADVP'))

SENTENCE_MAX_LEN = 300
BERT_MAX_LEN = 512

#%%
class BaseParser(object):
Expand All @@ -155,24 +172,60 @@ def __init__(self, name, batch_size=64):

with self._graph.as_default():
if isinstance(name, str) and '/' not in name:
graph_def = tf.GraphDef.FromString(load_model(name))
model = load_model(name)
elif not os.path.exists(name):
raise Exception("Argument is neither a valid module name nor a path to an existing file: {}".format(name))
raise Exception("Argument is neither a valid module name nor a path to an existing file/folder: {}".format(name))
else:
with open(name, 'rb') as f:
graph_def = tf.GraphDef.FromString(f.read())
if not os.path.isdir(name):
with open(name, 'rb') as f:
model = f.read()
else:
model = {}
with open(os.path.join(name, 'meta.json')) as f:
model['meta'] = json.load(f)
with open(os.path.join(name, 'model.pb'), 'rb') as f:
model['model'] = f.read()
with codecs.open(os.path.join(name, 'vocab.txt'), encoding='utf-8') as f:
model['vocab'] = f.read()

if isinstance(model, dict):
graph_def = tf.GraphDef.FromString(model['model'])
else:
graph_def = tf.GraphDef.FromString(model)
tf.import_graph_def(graph_def, name='')

self._sess = tf.Session(graph=self._graph)
self._chars = self._graph.get_tensor_by_name('chars:0')
self._charts = self._graph.get_tensor_by_name('charts:0')
if not isinstance(model, dict):
# Older model format (for ELMo-based models)
self._chars = self._graph.get_tensor_by_name('chars:0')
self._charts = self._graph.get_tensor_by_name('charts:0')
self._label_vocab = LABEL_VOCAB
self._language_code = 'en'
self._provides_tags = False
self._make_feed_dict = self._make_feed_dict_elmo
else:
# Newer model format (for BERT-based models)
meta = model['meta']
# Label vocab is made immutable because it is potentially exposed to
# users through the spacy plugin
self._label_vocab = tuple([tuple(label) for label in meta['label_vocab']])
self._language_code = meta['language_code']
self._provides_tags = meta['provides_tags']

self._input_ids = self._graph.get_tensor_by_name('input_ids:0')
self._word_end_mask = self._graph.get_tensor_by_name('word_end_mask:0')
self._charts = self._graph.get_tensor_by_name('charts:0')
if self._provides_tags:
self._tag_vocab = meta['tag_vocab']
self._tags = self._graph.get_tensor_by_name('tags:0')

self._bert_tokenizer = BertTokenizer(
model['vocab'], do_lower_case=meta['bert_do_lower_case'])
self._make_feed_dict = self._make_feed_dict_bert

# TODO(nikita): move this out of the source code
self._label_vocab = LABEL_VOCAB
self.batch_size = batch_size

def _charify(self, sentences):
def _make_feed_dict_elmo(self, sentences):
padded_len = max([len(sentence) + 2 for sentence in sentences])
if padded_len > SENTENCE_MAX_LEN:
raise ValueError("Sentence of length {} exceeds the maximum supported length of {}".format(
Expand Down Expand Up @@ -203,19 +256,71 @@ def _charify(self, sentences):
# sentence, which we don't have because batch_size=1
all_chars[snum, :len(sentence)+2,:] += 1

return all_chars
return {self._chars: all_chars}

def _make_feed_dict_bert(self, sentences):
all_input_ids = np.zeros((len(sentences), BERT_MAX_LEN), dtype=int)
all_word_end_mask = np.zeros((len(sentences), BERT_MAX_LEN), dtype=int)

def _make_charts(self, sentences):
inp_val = self._charify(sentences)
out_val = self._sess.run(self._charts, {self._chars: inp_val})
subword_max_len = 0
for snum, sentence in enumerate(sentences):
chart_size = len(sentence) + 1
chart = out_val[snum,:chart_size,:chart_size,:]
yield chart
tokens = []
word_end_mask = []

tokens.append("[CLS]")
word_end_mask.append(1)

cleaned_words = []
for word in sentence:
word = BERT_TOKEN_MAPPING.get(word, word)
# BERT is pre-trained with a tokenizer that doesn't split off
# n't as its own token
if word == "n't" and cleaned_words:
cleaned_words[-1] = cleaned_words[-1] + "n"
word = "'t"
cleaned_words.append(word)

for word in cleaned_words:
word_tokens = self._bert_tokenizer.tokenize(word)
for _ in range(len(word_tokens)):
word_end_mask.append(0)
word_end_mask[-1] = 1
tokens.extend(word_tokens)
tokens.append("[SEP]")
word_end_mask.append(1)

def _make_parse_raw(self, sentence):
chart_np = list(self._make_charts([sentence]))[0]
return chart_decoder.decode(chart_np)
input_ids = self._bert_tokenizer.convert_tokens_to_ids(tokens)
if len(sentence) + 2 > SENTENCE_MAX_LEN or len(input_ids) > BERT_MAX_LEN:
raise ValueError("Sentence of length {} is too long to be parsed".format(
len(sentence)))

subword_max_len = max(subword_max_len, len(input_ids))

all_input_ids[snum, :len(input_ids)] = input_ids
all_word_end_mask[snum, :len(word_end_mask)] = word_end_mask

all_input_ids = all_input_ids[:, :subword_max_len]
all_word_end_mask = all_word_end_mask[:, :subword_max_len]

return {
self._input_ids: all_input_ids,
self._word_end_mask: all_word_end_mask
}

def _make_charts_and_tags(self, sentences):
feed_dict = self._make_feed_dict(sentences)
if self._provides_tags:
charts_val, tags_val = self._sess.run((self._charts, self._tags), feed_dict)
else:
charts_val = self._sess.run(self._charts, feed_dict)
for snum, sentence in enumerate(sentences):
chart_size = len(sentence) + 1
chart = charts_val[snum,:chart_size,:chart_size,:]
if self._provides_tags:
tags = tags_val[snum,1:chart_size]
else:
tags = None
yield chart, tags

def _batched_parsed_raw(self, sentence_data_pairs):
batch_sentences = []
Expand All @@ -224,10 +329,10 @@ def _batched_parsed_raw(self, sentence_data_pairs):
batch_sentences.append(sentence)
batch_data.append(datum)
if len(batch_sentences) >= self.batch_size:
for chart_np, datum in zip(self._make_charts(batch_sentences), batch_data):
yield chart_decoder.decode(chart_np), datum
for (chart_np, tags_np), datum in zip(self._make_charts_and_tags(batch_sentences), batch_data):
yield chart_decoder.decode(chart_np), tags_np, datum
batch_sentences = []
batch_data = []
if batch_sentences:
for chart_np, datum in zip(self._make_charts(batch_sentences), batch_data):
yield chart_decoder.decode(chart_np), datum
for (chart_np, tags_np), datum in zip(self._make_charts_and_tags(batch_sentences), batch_data):
yield chart_decoder.decode(chart_np), tags_np, datum
Loading

0 comments on commit 0f4439f

Please sign in to comment.