Skip to content

Commit

Permalink
added predicate input to demo
Browse files Browse the repository at this point in the history
  • Loading branch information
t-li committed Oct 25, 2020
1 parent 49e54c7 commit bb170b4
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 43 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,5 +246,5 @@ F1 scores with ``*``: trained and evaluated without gold predicate (i.e. ``--use

# TODO
- [x] Upload more models to HuggingFace hub
- [ ] extend demo interface to accept predicate
- [x] Extend demo interface to accept predicate
- [x] Make a separate predicate classifier
64 changes: 48 additions & 16 deletions hf/demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import sys
import argparse
import spacy
import h5py
import numpy as np
import torch
Expand All @@ -12,7 +11,12 @@
from hf.roberta_for_srl import *
import traceback

spacy_nlp = spacy.load('en')
#import spacy
#spacy_nlp = spacy.load('en')
# use nltk instead as it has better token-char mapping
import nltk
from nltk.tokenize import TreebankWordTokenizer
tb_tokenizer = TreebankWordTokenizer()


parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
Expand All @@ -33,19 +37,28 @@
parser.add_argument('--num_frame', help="The number of frame for each proposition", type=int, default=38)


def process(opt, tokenizer, seq):
def process(opt, tokenizer, seq, predicates):
bos_tok, eos_tok = get_special_tokens(tokenizer)
ws = spacy_nlp(seq)
sent_subtoks = [tokenizer.tokenize(t.text) for t in ws]

char_spans = list(tb_tokenizer.span_tokenize(seq))
orig_toks = [seq[s:e] for s, e in char_spans]

v_label = [next((i for i, span in enumerate(char_spans) if span == (seq.index(p), seq.index(p)+len(p))), None) for p in predicates if p in seq]
v_label = [i for i in v_label if i is not None]

if len(v_label) != len(predicates):
print('valid predicates: ', ','.join([orig_toks[i] for i in v_label]))

sent_subtoks = [tokenizer.tokenize(t) for t in orig_toks]
tok_l = [len(subtoks) for subtoks in sent_subtoks]
toks = [p for subtoks in sent_subtoks for p in subtoks] # flatterning
orig_toks = [t.text for t in ws]

# pad for CLS and SEP
CLS, SEP = tokenizer.cls_token, tokenizer.sep_token
toks = [CLS] + toks + [SEP]
tok_l = [1] + tok_l + [1]
orig_toks = [CLS] + orig_toks + [SEP]
v_label = [l+1 for l in v_label]

tok_idx = np.array(tokenizer.convert_tokens_to_ids(toks), dtype=int)

Expand All @@ -60,7 +73,7 @@ def process(opt, tokenizer, seq):
acc += l
sub2tok_idx = pad(sub2tok_idx, len(tok_idx), [-1 for _ in range(opt.max_num_subtok)])
sub2tok_idx = np.array(sub2tok_idx, dtype=int)
return tok_idx, sub2tok_idx, toks, orig_toks
return tok_idx, sub2tok_idx, toks, orig_toks, v_label


def fix_opt(opt):
Expand All @@ -71,6 +84,7 @@ def fix_opt(opt):
opt.param_init_type = 'xavier_uniform'
return opt


def pretty_print_pred(opt, shared, m, pred_idx):
batch_l = shared.batch_l
orig_l = shared.orig_seq_l
Expand All @@ -86,17 +100,23 @@ def pretty_print_pred(opt, shared, m, pred_idx):
return pred_log


def run(opt, shared, m, tokenizer, seq):
tok_idx, sub2tok_idx, toks, orig_toks = process(opt, tokenizer, seq)
def run(opt, shared, m, tokenizer, seq, predicates=[]):
tok_idx, sub2tok_idx, toks, orig_toks, v_label = process(opt, tokenizer, seq, predicates)

m.update_context(orig_seq_l=to_device(torch.tensor([len(orig_toks)]).int(), opt.gpuid),
sub2tok_idx=to_device(torch.tensor([sub2tok_idx]).int(), opt.gpuid),
res_map={'orig_tok_grouped': [orig_toks]})

tok_idx = to_device(Variable(torch.tensor([tok_idx]), requires_grad=False), opt.gpuid)

if len(v_label) != 0:
v_l = to_device(torch.Tensor([len(v_label)]).long().view(1), opt.gpuid)
v_label = to_device(torch.Tensor(v_label).long().view(1, -1), opt.gpuid)
else:
v_label, v_l = None, None

with torch.no_grad():
pred_idx = m.forward(tok_idx)
pred_idx = m.forward(tok_idx, v_label, v_l)

log = pretty_print_pred(opt, shared, m, pred_idx)[0]
return orig_toks[1:-1], log
Expand All @@ -108,7 +128,7 @@ def init(opt):

opt = complete_opt(opt)

tokenizer = AutoTokenizer.from_pretrained(opt.bert_type)
tokenizer = AutoTokenizer.from_pretrained(opt.bert_type, add_special_tokens=False, use_fast=True)
m = RobertaForSRL.from_pretrained(opt.load_file, overwrite_opt = opt, shared=shared)

if opt.gpuid != -1:
Expand All @@ -122,19 +142,31 @@ def main(args):
opt, shared, m, tokenizer = init(opt)

seq = "He said he knows it."
orig_toks, log = run(opt, shared, m, tokenizer, seq)
predicates = ['said', 'knows']
#predicates = []
orig_toks, log = run(opt, shared, m, tokenizer, seq, predicates)

print('###################################')
print('Here is a sample prediction for input:')
print('>>', seq)
print('***********************************')
print('>> Input: ', seq)
print('>> Predicates: ', ','.join(predicates)) # predicates empty
print(' '.join(orig_toks))
print(log)

print('###################################')
print('# Instructions #')
print('###################################')
print('>> Enter a input senquence as prompted.')
print('>> You may also specify ground truth predicates, or leave it empty.')

while True:
try:
print('###################################')
seq = input("Enter a sequence: ")
orig_toks, log = run(opt, shared, m, tokenizer, seq)
print('***********************************')
predicates = input('Enter predicates: ')
predicates = predicates.strip().split(',')

orig_toks, log = run(opt, shared, m, tokenizer, seq, predicates)
print(' '.join(orig_toks))
print(log)

Expand Down
12 changes: 2 additions & 10 deletions hf/roberta_for_srl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,16 @@ def update_context(self, orig_seq_l, sub2tok_idx, res_map=None):
self.shared.sub2tok_idx = sub2tok_idx
self.shared.res_map = res_map

# loss context is only visible to the pipeline during loss computation (to avoid accidental contamination)
# update the contextual info of current batch for loss calculation
def update_loss_context(self, v_label, v_l, role_label, v_roleset_id):
self._loss_context.v_label = v_label
self._loss_context.v_l = v_l
self._loss_context.role_label = role_label
self._loss_context.v_roleset_id = v_roleset_id

# shared: a namespace or a Holder instance that contains information for the current input batch
# such as, predicate labels, subtok to tok index mapping, etc
def forward(self, input_ids):
def forward(self, input_ids, v_label = None, v_l = None):
self.shared.batch_l = input_ids.shape[0]
self.shared.seq_l = input_ids.shape[1]

enc = self.roberta(input_ids)[0]

log_pa, score, extra = self.classifier(enc)

pred, _ = self.crf_loss.decode(log_pa, score)
pred, _ = self.crf_loss.decode(log_pa, score, v_label, v_l)

return pred
34 changes: 18 additions & 16 deletions loss/crf_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self, opt, shared):
self.gold_log = []
self.pred_log = []

def decode(self, log_pa, score):
def decode(self, log_pa, score, v_label=None, v_l=None):
batch_l, source_l, _, _ = score.shape
orig_l = self.shared.orig_seq_l
max_orig_l = orig_l.max()
Expand All @@ -36,24 +36,26 @@ def decode(self, log_pa, score):
a_score = []
a_mask = []

v_label = to_device(torch.zeros(batch_l, max_orig_l).long(), self.opt.gpuid)
v_l = to_device(torch.zeros(batch_l).long(), self.opt.gpuid)

# use heuristic to get predicates
for i in range(batch_l):
max_v_idx = (score[i].argmax(-1) == bv_idx).diagonal().nonzero().view(-1)
# if no predicate candidate found, just take the one with the max score on B-V
if max_v_idx.numel() == 0:
max_v_idx = score[i, :, :, bv_idx].max(-1)[0].argmax(-1).view(1)
max_v_idx = max_v_idx[:max_num_v]
v_l[i] = max_v_idx.shape[0]
v_label[i, :v_l[i]] = max_v_idx

if v_label is None:
v_label = to_device(torch.zeros(batch_l, max_orig_l).long(), self.opt.gpuid)
v_l = to_device(torch.zeros(batch_l).long(), self.opt.gpuid)

# use heuristic to get predicates
for i in range(batch_l):
max_v_idx = (score[i].argmax(-1) == bv_idx).diagonal().nonzero().view(-1)
# if no predicate candidate found, just take the one with the max score on B-V
if max_v_idx.numel() == 0:
max_v_idx = score[i, :, :, bv_idx].max(-1)[0].argmax(-1).view(1)
max_v_idx = max_v_idx[:max_num_v]
v_l[i] = max_v_idx.shape[0]
v_label[i, :v_l[i]] = max_v_idx
else:
v_label = to_device(v_label, self.opt.gpuid)
v_l = to_device(v_l, self.opt.gpuid)

# pack everything into (batch_l*acc_orig_l, max_orig_l, ...)
for i in range(batch_l):
v_i = v_label[i, :v_l[i]]


a_mask_i = torch.zeros(v_l[i], max_orig_l).byte()
a_mask_i[:, :orig_l[i]] = True
a_mask.append(a_mask_i)
Expand Down

0 comments on commit bb170b4

Please sign in to comment.