Skip to content

Commit

Permalink
examples compat to ERNIE tiny
Browse files Browse the repository at this point in the history
  • Loading branch information
Meiyim committed Nov 20, 2019
1 parent 72e2123 commit f889492
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 21 deletions.
46 changes: 44 additions & 2 deletions ernie/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from propeller import log
import itertools
from propeller.paddle.data import Dataset
import pickle

import six

Expand Down Expand Up @@ -101,7 +102,7 @@ def __call__(self, sen):


class CharTokenizer(object):
def __init__(self, vocab, lower=True):
def __init__(self, vocab, lower=True, sentencepiece_style_vocab=False):
"""
char tokenizer (wordpiece english)
normed txt(space seperated or not) => list of word-piece
Expand All @@ -110,6 +111,7 @@ def __init__(self, vocab, lower=True):
#self.pat = re.compile(r'([,.!?\u3002\uff1b\uff0c\uff1a\u201c\u201d\uff08\uff09\u3001\uff1f\u300a\u300b]|[\u4e00-\u9fa5]|[a-zA-Z0-9]+)')
self.pat = re.compile(r'([a-zA-Z0-9]+|\S)')
self.lower = lower
self.sentencepiece_style_vocab = sentencepiece_style_vocab

def __call__(self, sen):
if len(sen) == 0:
Expand All @@ -119,11 +121,51 @@ def __call__(self, sen):
sen = sen.lower()
res = []
for match in self.pat.finditer(sen):
words, _ = wordpiece(match.group(0), vocab=self.vocab, unk_token='[UNK]')
words, _ = wordpiece(match.group(0), vocab=self.vocab, unk_token='[UNK]', sentencepiece_style_vocab=self.sentencepiece_style_vocab)
res.extend(words)
return res


class WSSPTokenizer(object):
def __init__(self, sp_model_dir, word_dict, ws=True, lower=True):
self.ws = ws
self.lower = lower
self.dict = pickle.load(open(word_dict, 'rb'), encoding='utf8')
import sentencepiece as spm
self.sp_model = spm.SentencePieceProcessor()
self.window_size = 5
self.sp_model.Load(sp_model_dir)

def cut(self, chars):
words = []
idx = 0
while idx < len(chars):
matched = False
for i in range(self.window_size, 0, -1):
cand = chars[idx: idx+i]
if cand in self.dict:
words.append(cand)
matched = True
break
if not matched:
i = 1
words.append(chars[idx])
idx += i
return words

def __call__(self, sen):
sen = sen.decode('utf8')
if self.ws:
sen = [s for s in self.cut(sen) if s != ' ']
else:
sen = sen.split(' ')
if self.lower:
sen = [s.lower() for s in sen]
sen = ' '.join(sen)
ret = self.sp_model.EncodeAsPieces(sen)
return ret


def build_2_pair(seg_a, seg_b, max_seqlen, cls_id, sep_id):
token_type_a = np.ones_like(seg_a, dtype=np.int64) * 0
token_type_b = np.ones_like(seg_b, dtype=np.int64) * 1
Expand Down
13 changes: 10 additions & 3 deletions example/finetune_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def forward(self, features):
pos_ids = L.cast(pos_ids, 'int64')
pos_ids.stop_gradient = True
input_mask.stop_gradient = True
task_ids = L.zeros_like(src_ids) + self.hparam.task_id #this shit wont use at the moment
task_ids = L.zeros_like(src_ids) + self.hparam.task_id
task_ids.stop_gradient = True

ernie = ErnieModel(
Expand Down Expand Up @@ -128,6 +128,8 @@ def metrics(self, predictions, label):
parser.add_argument('--vocab_file', type=str, required=True)
parser.add_argument('--do_predict', action='store_true')
parser.add_argument('--warm_start_from', type=str)
parser.add_argument('--sentence_piece_model', type=str, default=None)
parser.add_argument('--word_dict', type=str, default=None)
args = parser.parse_args()
run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args)
Expand All @@ -138,7 +140,12 @@ def metrics(self, predictions, label):
cls_id = vocab['[CLS]']
unk_id = vocab['[UNK]']

tokenizer = utils.data.CharTokenizer(vocab.keys())
if args.sentence_piece_model is not None:
if args.word_dict is None:
raise ValueError('--word_dict no specified in subword Model')
tokenizer = utils.data.WSSPTokenizer(args.sentence_piece_model, args.word_dict, ws=True, lower=True)
else:
tokenizer = utils.data.CharTokenizer(vocab.keys())

def tokenizer_func(inputs):
'''avoid pickle error'''
Expand Down Expand Up @@ -179,7 +186,7 @@ def after(sentence, segments, label):
dev_ds.data_shapes = shapes
dev_ds.data_types = types

varname_to_warmstart = re.compile('encoder.*|pooled.*|.*embedding|pre_encoder_.*')
varname_to_warmstart = re.compile(r'^encoder.*[wb]_0$|^.*embedding$|^.*bias$|^.*scale$|^pooled_fc.[wb]_0$')
warm_start_dir = args.warm_start_from
ws = propeller.WarmStartSetting(
predicate_fn=lambda v: varname_to_warmstart.match(v.name) and os.path.exists(os.path.join(warm_start_dir, v.name)),
Expand Down
29 changes: 15 additions & 14 deletions example/finetune_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

from model.ernie import ErnieModel
from optimization import optimization
import tokenization
import utils.data

from propeller import log
Expand Down Expand Up @@ -121,7 +120,7 @@ def metrics(self, predictions, label):
def make_sequence_label_dataset(name, input_files, label_list, tokenizer, batch_size, max_seqlen, is_train):
label_map = {v: i for i, v in enumerate(label_list)}
no_entity_id = label_map['O']
delimiter = ''
delimiter = b''

def read_bio_data(filename):
ds = propeller.data.Dataset.from_file(filename)
Expand All @@ -132,10 +131,10 @@ def gen():
while 1:
line = next(iterator)
cols = line.rstrip(b'\n').split(b'\t')
tokens = cols[0].split(delimiter)
labels = cols[1].split(delimiter)
if len(cols) != 2:
continue
tokens = tokenization.convert_to_unicode(cols[0]).split(delimiter)
labels = tokenization.convert_to_unicode(cols[1]).split(delimiter)
if len(tokens) != len(labels) or len(tokens) == 0:
continue
yield [tokens, labels]
Expand All @@ -151,7 +150,8 @@ def gen():
ret_tokens = []
ret_labels = []
for token, label in zip(tokens, labels):
sub_token = tokenizer.tokenize(token)
sub_token = tokenizer(token)
label = label.decode('utf8')
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
Expand Down Expand Up @@ -179,7 +179,7 @@ def gen():
labels = labels[: max_seqlen - 2]

tokens = ['[CLS]'] + tokens + ['[SEP]']
token_ids = tokenizer.convert_tokens_to_ids(tokens)
token_ids = [vocab[t] for t in tokens]
label_ids = [no_entity_id] + [label_map[x] for x in labels] + [no_entity_id]
token_type_ids = [0] * len(token_ids)
input_seqlen = len(token_ids)
Expand Down Expand Up @@ -211,7 +211,7 @@ def after(*features):


def make_sequence_label_dataset_from_stdin(name, tokenizer, batch_size, max_seqlen):
delimiter = ''
delimiter = b''

def stdin_gen():
if six.PY3:
Expand All @@ -232,9 +232,9 @@ def gen():
while 1:
line, = next(iterator)
cols = line.rstrip(b'\n').split(b'\t')
tokens = cols[0].split(delimiter)
if len(cols) != 1:
continue
tokens = tokenization.convert_to_unicode(cols[0]).split(delimiter)
if len(tokens) == 0:
continue
yield tokens,
Expand All @@ -247,7 +247,7 @@ def gen():
tokens, = next(iterator)
ret_tokens = []
for token in tokens:
sub_token = tokenizer.tokenize(token)
sub_token = tokenizer(token)
if len(sub_token) == 0:
continue
ret_tokens.extend(sub_token)
Expand All @@ -266,7 +266,7 @@ def gen():
tokens = tokens[: max_seqlen - 2]

tokens = ['[CLS]'] + tokens + ['[SEP]']
token_ids = tokenizer.convert_tokens_to_ids(tokens)
token_ids = [vocab[t] for t in tokens]
token_type_ids = [0] * len(token_ids)
input_seqlen = len(token_ids)

Expand Down Expand Up @@ -296,13 +296,15 @@ def after(*features):
parser.add_argument('--data_dir', type=str, required=True)
parser.add_argument('--vocab_file', type=str, required=True)
parser.add_argument('--do_predict', action='store_true')
parser.add_argument('--use_sentence_piece_vocab', action='store_true')
parser.add_argument('--warm_start_from', type=str)
args = parser.parse_args()
run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args)

tokenizer = tokenization.FullTokenizer(args.vocab_file)
vocab = tokenizer.vocab

vocab = {j.strip().split('\t')[0]: i for i, j in enumerate(open(args.vocab_file, 'r', encoding='utf8'))}
tokenizer = utils.data.CharTokenizer(vocab, sentencepiece_style_vocab=args.use_sentence_piece_vocab)
sep_id = vocab['[SEP]']
cls_id = vocab['[CLS]']
unk_id = vocab['[UNK]']
Expand Down Expand Up @@ -358,7 +360,7 @@ def after(*features):
from_dir=warm_start_dir
)

best_exporter = propeller.train.exporter.BestExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1'])
best_exporter = propeller.train.exporter.BestInferenceModelExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1'])
propeller.train.train_and_eval(
model_class_or_model_fn=SequenceLabelErnieModel,
params=hparams,
Expand Down Expand Up @@ -387,7 +389,6 @@ def after(*features):
predict_ds.data_types = types

rev_label_map = {i: v for i, v in enumerate(label_list)}
best_exporter = propeller.train.exporter.BestExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1'])
learner = propeller.Learner(SequenceLabelErnieModel, run_config, hparams)
for pred, _ in learner.predict(predict_ds, ckpt=-1):
pred_str = ' '.join([rev_label_map[idx] for idx in np.argmax(pred, 1).tolist()])
Expand Down
8 changes: 6 additions & 2 deletions example/finetune_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def backward(self, loss):
parser.add_argument('--data_dir', type=str, required=True)
parser.add_argument('--warm_start_from', type=str)
parser.add_argument('--sentence_piece_model', type=str, default=None)
parser.add_argument('--word_dict', type=str, default=None)
args = parser.parse_args()
run_config = propeller.parse_runconfig(args)
hparams = propeller.parse_hparam(args)
Expand All @@ -157,7 +158,9 @@ def backward(self, loss):
unk_id = vocab['[UNK]']

if args.sentence_piece_model is not None:
tokenizer = utils.data.JBSPTokenizer(args.sentence_piece_model, jb=True, lower=True)
if args.word_dict is None:
raise ValueError('--word_dict no specified in subword Model')
tokenizer = utils.data.WSSPTokenizer(args.sentence_piece_model, args.word_dict, ws=True, lower=True)
else:
tokenizer = utils.data.CharTokenizer(vocab.keys())

Expand Down Expand Up @@ -218,7 +221,7 @@ def after(sentence, segments, qid, label):
from_dir=warm_start_dir
)

best_exporter = propeller.train.exporter.BestExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1'])
best_exporter = propeller.train.exporter.BestInferenceModelExporter(os.path.join(run_config.model_dir, 'best'), cmp_fn=lambda old, new: new['dev']['f1'] > old['dev']['f1'])
propeller.train_and_eval(
model_class_or_model_fn=RankingErnieModel,
params=hparams,
Expand Down Expand Up @@ -258,6 +261,7 @@ def after(sentence, segments, qid):
est = propeller.Learner(RankingErnieModel, run_config, hparams)
for qid, res in est.predict(predict_ds, ckpt=-1):
print('%d\t%d\t%.5f\t%.5f' % (qid[0], np.argmax(res), res[0], res[1]))

#for i in predict_ds:
# sen = i[0]
# for ss in np.squeeze(sen):
Expand Down

0 comments on commit f889492

Please sign in to comment.