Skip to content

Commit

Permalink
Add part-of-speech prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitakit committed Dec 20, 2018
1 parent 15efc13 commit 89e1cde
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 22 deletions.
92 changes: 92 additions & 0 deletions EVALB/nk.prm
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Based on new.prm (and by extension COLLINS.prm)
# The only change from new.prm is increasing MAX_ERROR. The evaluation should be
# identical to the standard setup, except that evalb won't give up early for a
# parser that has just started training and does not yet produce good results.

##------------------------------------------##
## Debug mode ##
## 0: No debugging ##
## 1: print data for individual sentence ##
## 2: print detailed bracketing info ##
##------------------------------------------##
DEBUG 0

##------------------------------------------##
## MAX error ##
## Number of error to stop the process. ##
## This is useful if there could be ##
## tokanization error. ##
## The process will stop when this number##
## of errors are accumulated. ##
##------------------------------------------##
MAX_ERROR 10000

##------------------------------------------##
## Cut-off length for statistics ##
## At the end of evaluation, the ##
## statistics for the senetnces of length##
## less than or equal to this number will##
## be shown, on top of the statistics ##
## for all the sentences ##
##------------------------------------------##
CUTOFF_LEN 40

##------------------------------------------##
## unlabeled or labeled bracketing ##
## 0: unlabeled bracketing ##
## 1: labeled bracketing ##
##------------------------------------------##
LABELED 1

##------------------------------------------##
## Delete labels ##
## list of labels to be ignored. ##
## If it is a pre-terminal label, delete ##
## the word along with the brackets. ##
## If it is a non-terminal label, just ##
## delete the brackets (don't delete ##
## deildrens). ##
##------------------------------------------##
DELETE_LABEL TOP
DELETE_LABEL S1
DELETE_LABEL -NONE-
DELETE_LABEL ,
DELETE_LABEL :
DELETE_LABEL ``
DELETE_LABEL ''
DELETE_LABEL .
DELETE_LABEL ?
DELETE_LABEL !

##------------------------------------------##
## Delete labels for length calculation ##
## list of labels to be ignored for ##
## length calculation purpose ##
##------------------------------------------##
DELETE_LABEL_FOR_LENGTH -NONE-

##------------------------------------------##
## Labels to be considered for misquote ##
## (could be possesive or quote) ##
##------------------------------------------##
QUOTE_LABEL ``
QUOTE_LABEL ''
QUOTE_LABEL POS

##------------------------------------------##
## These ones are less common, but ##
## are on occasion output by parsers: ##
##------------------------------------------##
QUOTE_LABEL NN
QUOTE_LABEL CD
QUOTE_LABEL VBZ
QUOTE_LABEL :

##------------------------------------------##
## Equivalent labels, words ##
## the pairs are considered equivalent ##
## This is non-directional. ##
##------------------------------------------##
EQ_LABEL ADVP PRT

# EQ_WORD Example example
16 changes: 12 additions & 4 deletions src/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@
import trees

class FScore(object):
def __init__(self, recall, precision, fscore):
def __init__(self, recall, precision, fscore, tagging_accuracy=100):
self.recall = recall
self.precision = precision
self.fscore = fscore
self.tagging_accuracy = tagging_accuracy

def __str__(self):
return "(Recall={:.2f}, Precision={:.2f}, FScore={:.2f})".format(
self.recall, self.precision, self.fscore)
if self.tagging_accuracy < 100:
return "(Recall={:.2f}, Precision={:.2f}, FScore={:.2f}, TaggingAccuracy={:.2f})".format(
self.recall, self.precision, self.fscore, self.tagging_accuracy)
else:
return "(Recall={:.2f}, Precision={:.2f}, FScore={:.2f})".format(
self.recall, self.precision, self.fscore, self.tagging_accuracy)

def evalb(evalb_dir, gold_trees, predicted_trees, ref_gold_path=None):
assert os.path.exists(evalb_dir)
Expand All @@ -23,7 +28,7 @@ def evalb(evalb_dir, gold_trees, predicted_trees, ref_gold_path=None):
assert os.path.exists(evalb_program_path) or os.path.exists(evalb_spmrl_program_path)

if os.path.exists(evalb_program_path):
evalb_param_path = os.path.join(evalb_dir, "COLLINS.prm")
evalb_param_path = os.path.join(evalb_dir, "nk.prm")
else:
evalb_program_path = evalb_spmrl_program_path
evalb_param_path = os.path.join(evalb_dir, "spmrl.prm")
Expand Down Expand Up @@ -84,6 +89,9 @@ def evalb(evalb_dir, gold_trees, predicted_trees, ref_gold_path=None):
match = re.match(r"Bracketing FMeasure\s+=\s+(\d+\.\d+)", line)
if match:
fscore.fscore = float(match.group(1))
match = re.match(r"Tagging accuracy\s+=\s+(\d+\.\d+)", line)
if match:
fscore.tagging_accuracy = float(match.group(1))
break

success = (
Expand Down
15 changes: 13 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def make_hparams():
d_kv=64,
d_ff=2048,
d_label_hidden=250,
d_tag_hidden=250,
tag_loss_scale=5.0,

attention_dropout=0.2,
embedding_dropout=0.0,
Expand All @@ -67,6 +69,7 @@ def make_hparams():
use_elmo=False,
use_bert=False,
use_bert_only=False,
predict_tags=False,

d_char_emb=32, # A larger value may be better for use_chars_lstm

Expand Down Expand Up @@ -101,6 +104,10 @@ def run_train(args, hparams):
hparams.print()

print("Loading training trees from {}...".format(args.train_path))
if hparams.predict_tags and args.train_path.endswith('10way.clean'):
print("WARNING: The data distributed with this repository contains "
"predicted part-of-speech tags only (not gold tags!) We do not "
"recommend enabling predict_tags in this configuration.")
train_treebank = trees.load_trees(args.train_path)
if hparams.max_len_train > 0:
train_treebank = [tree for tree in train_treebank if len(list(tree.leaves())) <= hparams.max_len_train]
Expand Down Expand Up @@ -298,11 +305,15 @@ def check_dev():
batch_loss_value = 0.0
batch_trees = train_parse[start_index:start_index + args.batch_size]
batch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in batch_trees]
batch_num_tokens = sum(len(sentence) for sentence in batch_sentences)

for subbatch_sentences, subbatch_trees in parser.split_batch(batch_sentences, batch_trees, args.subbatch_max_tokens):
_, loss = parser.parse_batch(subbatch_sentences, subbatch_trees)

loss = loss / len(batch_trees)
if hparams.predict_tags:
loss = loss[0] / len(batch_trees) + loss[1] / batch_num_tokens
else:
loss = loss / len(batch_trees)
loss_value = float(loss.data.cpu().numpy())
batch_loss_value += loss_value
if loss_value > 0:
Expand Down Expand Up @@ -463,7 +474,7 @@ def run_parse(args):
sentences = input_file.readlines()
sentences = [sentence.split() for sentence in sentences]

# Parser does not do tagging, so use a dummy tag when parsing from raw text
# Tags are not available when parsing from raw text, so use a dummy tag
if 'UNK' in parser.tag_vocab.indices:
dummy_tag = 'UNK'
else:
Expand Down
65 changes: 49 additions & 16 deletions src/parse_nk.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,9 @@ def get_bert(bert_model, bert_do_lower_case):
# Avoid a hard dependency on BERT by only importing it if it's being used
from pytorch_pretrained_bert import BertTokenizer, BertModel
if bert_model.endswith('.tar.gz'):
tokenizer = BertTokenizer.from_pretrained(bert_model.replace('.tar.gz', '-vocab.txt'), bert_do_lower_case)
tokenizer = BertTokenizer.from_pretrained(bert_model.replace('.tar.gz', '-vocab.txt'), do_lower_case=bert_do_lower_case)
else:
tokenizer = BertTokenizer.from_pretrained(bert_model, bert_do_lower_case)
tokenizer = BertTokenizer.from_pretrained(bert_model, do_lower_case=bert_do_lower_case)
bert = BertModel.from_pretrained(bert_model)
return tokenizer, bert

Expand Down Expand Up @@ -731,23 +731,28 @@ def __init__(
residual_dropout=hparams.residual_dropout,
attention_dropout=hparams.attention_dropout,
)

self.f_label = nn.Sequential(
nn.Linear(hparams.d_model, hparams.d_label_hidden),
LayerNormalization(hparams.d_label_hidden),
nn.ReLU(),
nn.Linear(hparams.d_label_hidden, label_vocab.size - 1),
)
else:
self.embedding = None
self.encoder = None

self.f_label = nn.Sequential(
nn.Linear(hparams.d_model, hparams.d_label_hidden),
LayerNormalization(hparams.d_label_hidden),
self.f_label = nn.Sequential(
nn.Linear(hparams.d_model, hparams.d_label_hidden),
LayerNormalization(hparams.d_label_hidden),
nn.ReLU(),
nn.Linear(hparams.d_label_hidden, label_vocab.size - 1),
)

if hparams.predict_tags:
assert not hparams.use_tags, "use_tags and predict_tags are mutually exclusive"
self.f_tag = nn.Sequential(
nn.Linear(hparams.d_model, hparams.d_tag_hidden),
LayerNormalization(hparams.d_tag_hidden),
nn.ReLU(),
nn.Linear(hparams.d_label_hidden, label_vocab.size - 1),
nn.Linear(hparams.d_tag_hidden, tag_vocab.size),
)
self.tag_loss_scale = hparams.tag_loss_scale
else:
self.f_tag = None

if use_cuda:
self.cuda()
Expand All @@ -772,6 +777,8 @@ def from_spec(cls, spec, model):
hparams['use_bert'] = False
if 'use_bert_only' not in hparams:
hparams['use_bert_only'] = False
if 'predict_tags' not in hparams:
hparams['predict_tags'] = False

spec['hparams'] = nkutil.HParams(**hparams)
res = cls(**spec)
Expand Down Expand Up @@ -824,7 +831,7 @@ def parse_batch(self, sentences, golds=None, return_label_scores_charts=False):
batch_idxs = np.zeros(packed_len, dtype=int)
for snum, sentence in enumerate(sentences):
for (tag, word) in [(START, START)] + sentence + [(STOP, STOP)]:
tag_idxs[i] = 0 if not self.use_tags else self.tag_vocab.index_or_unk(tag, TAG_UNK)
tag_idxs[i] = 0 if (not self.use_tags and self.f_tag is None) else self.tag_vocab.index_or_unk(tag, TAG_UNK)
if word not in (START, STOP):
count = self.word_vocab.count(word)
if not count or (is_train and np.random.rand() < 1 / (1 + count)):
Expand All @@ -845,6 +852,9 @@ def parse_batch(self, sentences, golds=None, return_label_scores_charts=False):
for emb_type in self.emb_types
]

if is_train and self.f_tag is not None:
gold_tag_idxs = from_numpy(emb_idxs_map['tags'])

extra_content_annotations = None
if self.char_encoder is not None:
assert isinstance(self.char_encoder, CharacterLSTM)
Expand Down Expand Up @@ -991,6 +1001,9 @@ def parse_batch(self, sentences, golds=None, return_label_scores_charts=False):
annotations[:, 1::2],
], 1)

if self.f_tag is not None:
tag_annotations = annotations

fencepost_annotations = torch.cat([
annotations[:-1, :self.d_model//2],
annotations[1:, self.d_model//2:],
Expand All @@ -1002,6 +1015,13 @@ def parse_batch(self, sentences, golds=None, return_label_scores_charts=False):
features = self.project_bert(features)
fencepost_annotations_start = features.masked_select(all_word_start_mask.to(torch.uint8).unsqueeze(-1)).reshape(-1, features.shape[-1])
fencepost_annotations_end = features.masked_select(all_word_end_mask.to(torch.uint8).unsqueeze(-1)).reshape(-1, features.shape[-1])
if self.f_tag is not None:
tag_annotations = fencepost_annotations_end

if self.f_tag is not None:
tag_logits = self.f_tag(tag_annotations)
if is_train:
tag_loss = self.tag_loss_scale * nn.functional.cross_entropy(tag_logits, gold_tag_idxs, reduction='sum')

# Note that the subtraction above creates fenceposts at sentence
# boundaries, which are not used by our parser. Hence subtract 1
Expand All @@ -1020,8 +1040,17 @@ def parse_batch(self, sentences, golds=None, return_label_scores_charts=False):
if not is_train:
trees = []
scores = []
if self.f_tag is not None:
# Note that tag_logits includes tag predictions for start/stop tokens
tag_idxs = torch.argmax(tag_logits, -1).cpu()
per_sentence_tag_idxs = torch.split_with_sizes(tag_idxs, [len(sentence) + 2 for sentence in sentences])
per_sentence_tags = [[self.tag_vocab.value(idx) for idx in idxs[1:-1]] for idxs in per_sentence_tag_idxs]

for i, (start, end) in enumerate(zip(fp_startpoints, fp_endpoints)):
tree, score = self.parse_from_annotations(fencepost_annotations_start[start:end,:], fencepost_annotations_end[start:end,:], sentences[i], golds[i])
sentence = sentences[i]
if self.f_tag is not None:
sentence = list(zip(per_sentence_tags[i], [x[1] for x in sentence]))
tree, score = self.parse_from_annotations(fencepost_annotations_start[start:end,:], fencepost_annotations_end[start:end,:], sentence, golds[i])
trees.append(tree)
scores.append(score)
return trees, scores
Expand Down Expand Up @@ -1065,7 +1094,11 @@ def parse_batch(self, sentences, golds=None, return_label_scores_charts=False):
], 1)
cells_scores = torch.gather(cells_label_scores, 1, cells_label[:, None])
loss = cells_scores[:num_p].sum() - cells_scores[num_p:].sum() + paugment_total
return None, loss

if self.f_tag is not None:
return None, (loss, tag_loss)
else:
return None, loss

def label_scores_from_annotations(self, fencepost_annotations_start, fencepost_annotations_end):
# Note that the bias added to the final layer norm is useless because
Expand Down

0 comments on commit 89e1cde

Please sign in to comment.