Skip to content

Commit

Permalink
implemented ALBERT sequence labeling model
Browse files Browse the repository at this point in the history
  • Loading branch information
hankcs committed Jan 4, 2020
1 parent 353db5b commit 81f9d88
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 29 deletions.
9 changes: 4 additions & 5 deletions hanlp/common/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def evaluate(self, input_path: str, save_dir=None, output=False, batch_size=128,
mode='w')
tst_data = self.transform.file_to_dataset(input_path, batch_size=batch_size)
samples = size_of_dataset(tst_data)
num_batches = math.ceil(samples / batch_size)
if warm_up:
self.model.predict_on_batch(tst_data.take(1))
if output:
Expand All @@ -104,7 +105,7 @@ def evaluate(self, input_path: str, save_dir=None, output=False, batch_size=128,
else:
raise RuntimeError('output ({}) must be of type bool or str'.format(repr(output)))
timer = Timer()
loss, score, output = self.evaluate_dataset(tst_data, callbacks, output)
loss, score, output = self.evaluate_dataset(tst_data, callbacks, output, num_batches)
delta_time = timer.stop()
speed = samples / delta_time.delta_seconds

Expand All @@ -126,14 +127,12 @@ def evaluate(self, input_path: str, save_dir=None, output=False, batch_size=128,
if output:
logger.info('Saving output to {}'.format(output))
with open(output, 'w', encoding='utf-8') as out:
num_batches = math.ceil(samples / batch_size)

self.evaluate_output(tst_data, out, num_batches, self.model.metrics)

return loss, score, speed

def evaluate_dataset(self, tst_data, callbacks, output):
loss, score = self.model.evaluate(tst_data, callbacks=callbacks)
def evaluate_dataset(self, tst_data, callbacks, output, num_batches):
loss, score = self.model.evaluate(tst_data, callbacks=callbacks, steps=num_batches)
return loss, score, output

def evaluate_output(self, tst_data, out, num_batches, metrics: List[tf.keras.metrics.Metric]):
Expand Down
8 changes: 8 additions & 0 deletions hanlp/common/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ def file_to_dataset(self, filepath: str, gold=True, map_x=None, map_y=None, batc
"""

# debug
# for sample in self.file_to_samples(filepath):
# pass

def generator():
inputs = self.file_to_inputs(filepath, gold)
samples = self.inputs_to_samples(inputs, gold)
Expand All @@ -142,6 +146,10 @@ def generator():
def inputs_to_dataset(self, inputs, gold=False, map_x=None, map_y=None, batch_size=32, shuffle=None, repeat=None,
drop_remainder=False,
prefetch=1, cache=False, **kwargs) -> tf.data.Dataset:
# debug
# for sample in self.inputs_to_samples(inputs):
# pass

def generator():
samples = self.inputs_to_samples(inputs, gold)
yield from samples
Expand Down
55 changes: 52 additions & 3 deletions hanlp/components/taggers/transformers/transformer_tagger.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-29 13:55
import glob
import logging
import math
import os

import tensorflow as tf
from bert import bert_models_google

from hanlp.components.taggers.tagger import TaggerComponent
from hanlp.components.taggers.transformers.metrics import MaskedSparseCategoricalAccuracy
from hanlp.components.taggers.transformers.transformer_transform import TransformerTransform
from hanlp.layers.transformers import AutoTokenizer, TFAutoModel, TFPreTrainedModel, PreTrainedTokenizer, TFAlbertModel, \
BertTokenizer
from hanlp.losses.sparse_categorical_crossentropy import MaskedSparseCategoricalCrossentropyOverBatchFirstDim
BertTokenizer, albert_models_google
from hanlp.layers.transformers.loader import load_stock_weights
from hanlp.losses.sparse_categorical_crossentropy import MaskedSparseCategoricalCrossentropyOverBatchFirstDim, \
SparseCategoricalCrossentropyOverBatchFirstDim
from hanlp.optimizers.adamw import create_optimizer
from hanlp.utils.io_util import get_resource
from hanlp.utils.util import merge_locals_kwargs
import bert


class TransformerTaggingModel(tf.keras.Model):
Expand All @@ -33,7 +39,47 @@ def __init__(self, transform: TransformerTransform = None) -> None:
super().__init__(transform)
self.transform: TransformerTransform = transform

def build_model(self, transformer, max_seq_length, **kwargs) -> tf.keras.Model:
def build_model(self, transformer, max_seq_length, implementation, **kwargs) -> tf.keras.Model:
if implementation == 'bert-for-tf2':
if transformer in albert_models_google:
from bert.tokenization.albert_tokenization import FullTokenizer
model_url = albert_models_google[transformer]
albert = True
elif transformer in bert_models_google:
from bert.tokenization.bert_tokenization import FullTokenizer
model_url = bert_models_google[transformer]
albert = False
else:
raise ValueError(
f'Unknown model {transformer}, available ones: {bert_models_google.keys() + albert_models_google.keys()}')
albert_dir = get_resource(model_url)
vocab = glob.glob(os.path.join(albert_dir, '*vocab*.txt'))
assert len(vocab) == 1, 'No vocab found or unambiguous vocabs found'
vocab = vocab[0]
# noinspection PyTypeChecker
self.transform.tokenizer = FullTokenizer(vocab_file=vocab)
bert_params = bert.params_from_pretrained_ckpt(albert_dir)
l_bert = bert.BertModelLayer.from_params(bert_params, name="bert")
l_input_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype='int32', name="input_ids")
l_mask_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype='int32', name="mask_ids")
l_token_type_ids = tf.keras.layers.Input(shape=(max_seq_length,), dtype='int32', name="token_type_ids")
output = l_bert([l_input_ids, l_token_type_ids], mask=l_mask_ids)
output = tf.keras.layers.Dropout(bert_params.hidden_dropout, name='hidden_dropout')(output)
logits = tf.keras.layers.Dense(len(self.transform.tag_vocab),
kernel_initializer=tf.keras.initializers.TruncatedNormal(
bert_params.initializer_range))(output)
model = tf.keras.Model(inputs=[l_input_ids, l_mask_ids, l_token_type_ids], outputs=logits)
model.build(input_shape=(None, max_seq_length))
ckpt = glob.glob(os.path.join(albert_dir, '*.index'))
assert ckpt, f'No checkpoint found under {albert_dir}'
ckpt, _ = os.path.splitext(ckpt[0])
if albert:
skipped_weight_value_tuples = load_stock_weights(l_bert, ckpt)
else:
skipped_weight_value_tuples = bert.load_bert_weights(l_bert, ckpt)
assert 0 == len(skipped_weight_value_tuples), 'failed to load pretrained model'
return model

tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(transformer)
self.transform.tokenizer = tokenizer
transformer: TFPreTrainedModel = TFAutoModel.from_pretrained(transformer, name=os.path.basename(transformer))
Expand Down Expand Up @@ -64,6 +110,7 @@ def fit(self, trn_data, dev_data, save_dir,
batch_size=32,
epochs=3,
metrics='accuracy',
implementation='transformers',
run_eagerly=False,
logger=None,
verbose=True,
Expand Down Expand Up @@ -112,6 +159,8 @@ def train_loop(self, trn_data, dev_data, epochs, num_examples, train_steps_per_e
return history

def build_loss(self, loss, **kwargs):
if self.config.implementation == 'bert-for-tf2':
return SparseCategoricalCrossentropyOverBatchFirstDim()
return MaskedSparseCategoricalCrossentropyOverBatchFirstDim()

def build_metrics(self, metrics, logger: logging.Logger, **kwargs):
Expand Down
55 changes: 36 additions & 19 deletions hanlp/components/taggers/transformers/transformer_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ def tokenizer(self):
@tokenizer.setter
def tokenizer(self, tokenizer):
self._tokenizer = tokenizer
self.special_token_ids = tf.constant(self.tokenizer.all_special_ids, dtype=tf.int32)
if self.config.implementation == 'bert-for-tf2':
self.special_token_ids = tf.constant([tokenizer.vocab[token] for token in ['[PAD]', '[CLS]', '[SEP]']], dtype=tf.int32)
else:
self.special_token_ids = tf.constant(self.tokenizer.all_special_ids, dtype=tf.int32)

def fit(self, trn_path: str, **kwargs) -> int:
self.tag_vocab = Vocab(unk_token=None)
Expand All @@ -49,42 +52,56 @@ def create_types_shapes_values(self) -> Tuple[Tuple, Tuple, Tuple]:
values = (0, 0, 0), self.tag_vocab.pad_idx
return types, shapes, values

def lock_vocabs(self):
super().lock_vocabs()

def inputs_to_samples(self, inputs, gold=False):
max_seq_length = self.config.get('max_seq_length', 128)
tokenizer = self._tokenizer
config = self.transformer_config
if self.config.implementation == 'bert-for-tf2':
xlnet = False
roberta = False
pad_token = '[PAD]'
cls_token = '[CLS]'
sep_token = '[SEP]'
else:
config = self.transformer_config
xlnet = config_is(config, 'xlnet')
roberta = config_is(config, 'roberta')
pad_token = tokenizer.pad_token
cls_token = tokenizer.cls_token
sep_token = tokenizer.sep_token

pad_label_idx = self.tag_vocab.pad_idx
pad_token = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0]
pad_token = tokenizer.convert_tokens_to_ids([pad_token])[0]
for sample in inputs:
if gold:
words, tags = sample
else:
words, tags = sample, [self.tag_vocab.pad_token] * len(sample)

input_ids, input_mask, segment_ids, label_ids = convert_examples_to_features(words, tags,
self.tag_vocab.token_to_idx,
max_seq_length, tokenizer,
cls_token_at_end=config_is(
config,
'xlnet'),
cls_token_at_end=xlnet,
# xlnet has a cls token at the end
cls_token=tokenizer.cls_token,
cls_token_segment_id=2 if config_is(
config,
'xlnet') else 0,
sep_token=tokenizer.sep_token,
sep_token_extra=config_is(
config,
'roberta'),
cls_token=cls_token,
cls_token_segment_id=2 if xlnet else 0,
sep_token=sep_token,
sep_token_extra=roberta,
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
pad_on_left=config_is(config,
'xlnet'),
pad_on_left=xlnet,
# pad on the left for xlnet
pad_token=pad_token,
pad_token_segment_id=4 if config_is(
config,
'xlnet') else 0,
pad_token_segment_id=4 if xlnet else 0,
pad_token_label_id=pad_label_idx)

if None in input_ids:
print(input_ids)
if None in input_mask:
print(input_mask)
if None in segment_ids:
print(input_mask)
yield (input_ids, input_mask, segment_ids), label_ids

def x_to_idx(self, x) -> Union[tf.Tensor, Tuple]:
Expand Down
9 changes: 8 additions & 1 deletion hanlp/layers/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@

# mute transformers
import logging
import os

logging.getLogger('transformers.file_utils').setLevel(logging.ERROR)
logging.getLogger('transformers.tokenization_utils').setLevel(logging.ERROR)
logging.getLogger('transformers.configuration_utils').setLevel(logging.ERROR)
logging.getLogger('transformers.modeling_tf_utils').setLevel(logging.ERROR)
import os

os.environ["USE_TORCH"] = 'NO' # saves time loading transformers
from transformers import BertTokenizer, TFBertForSequenceClassification, BertConfig, PretrainedConfig, TFAutoModel, \
AutoConfig, AutoTokenizer, PreTrainedTokenizer, TFPreTrainedModel, TFAlbertModel

albert_models_google = {
'albert_base_zh': 'https://storage.googleapis.com/albert_models/albert_base_zh.tar.gz',
'albert_large_zh': 'https://storage.googleapis.com/albert_models/albert_large_zh.tar.gz',
'albert_xlarge_zh': 'https://storage.googleapis.com/albert_models/albert_xlarge_zh.tar.gz',
'albert_xxlarge_zh': 'https://storage.googleapis.com/albert_models/albert_xxlarge_zh.tar.gz',
}
63 changes: 63 additions & 0 deletions hanlp/layers/transformers/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2020-01-04 06:05
import tensorflow as tf
from bert import BertModelLayer
from bert.loader import _checkpoint_exists, bert_prefix
from bert.loader_albert import map_to_tfhub_albert_variable_name
from tensorflow import keras


def load_stock_weights(bert: BertModelLayer, ckpt_path):
"""
Use this method to load the weights from a pre-trained BERT checkpoint into a bert layer.
:param bert: a BertModelLayer instance within a built keras model.
:param ckpt_path: checkpoint path, i.e. `uncased_L-12_H-768_A-12/bert_model.ckpt` or `albert_base_zh/albert_model.ckpt`
:return: list of weights with mismatched shapes. This can be used to extend
the segment/token_type embeddings.
"""
assert isinstance(bert, BertModelLayer), "Expecting a BertModelLayer instance as first argument"
assert _checkpoint_exists(ckpt_path), "Checkpoint does not exist: {}".format(ckpt_path)
ckpt_reader = tf.train.load_checkpoint(ckpt_path)

stock_weights = set(ckpt_reader.get_variable_to_dtype_map().keys())

prefix = bert_prefix(bert)

loaded_weights = set()
skip_count = 0
weight_value_tuples = []
skipped_weight_value_tuples = []

bert_params = bert.weights
param_values = keras.backend.batch_get_value(bert.weights)
for ndx, (param_value, param) in enumerate(zip(param_values, bert_params)):
stock_name = map_to_tfhub_albert_variable_name(param.name, prefix)

if ckpt_reader.has_tensor(stock_name):
ckpt_value = ckpt_reader.get_tensor(stock_name)

if param_value.shape != ckpt_value.shape:
print("loader: Skipping weight:[{}] as the weight shape:[{}] is not compatible "
"with the checkpoint:[{}] shape:{}".format(param.name, param.shape,
stock_name, ckpt_value.shape))
skipped_weight_value_tuples.append((param, ckpt_value))
continue

weight_value_tuples.append((param, ckpt_value))
loaded_weights.add(stock_name)
else:
print("loader: No value for:[{}], i.e.:[{}] in:[{}]".format(param.name, stock_name, ckpt_path))
skip_count += 1
keras.backend.batch_set_value(weight_value_tuples)

print("Done loading {} BERT weights from: {} into {} (prefix:{}). "
"Count of weights not found in the checkpoint was: [{}]. "
"Count of weights with mismatched shape: [{}]".format(
len(weight_value_tuples), ckpt_path, bert, prefix, skip_count, len(skipped_weight_value_tuples)))

print("Unused weights from checkpoint:",
"\n\t" + "\n\t".join(sorted(stock_weights.difference(loaded_weights))))

return skipped_weight_value_tuples # (bert_weight, value_from_ckpt)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
keywords='corpus,machine-learning,NLU,NLP',
packages=find_packages(exclude=['docs', 'tests*']),
include_package_data=True,
install_requires=['tensorflow==2.1.0rc2', 'fasttext==0.9.1', 'transformers==2.3.0'],
install_requires=['tensorflow==2.1.0rc2', 'fasttext==0.9.1', 'transformers==2.3.0', 'bert-for-tf2==0.12.7'],
python_requires='>=3.6',
# entry_points={
# 'console_scripts': [
Expand Down
19 changes: 19 additions & 0 deletions tests/train/zh/train_msra_ner_albert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# -*- coding:utf-8 -*-
# Author: hankcs
# Date: 2019-12-28 23:15
from hanlp.components.ner import TransformerNamedEntityRecognizer
from hanlp.components.taggers.transformers.transformer_tagger import TransformerTagger
from hanlp.datasets.ner.msra import MSRA_NER_TRAIN, MSRA_NER_VALID, MSRA_NER_TEST
from tests import cdroot

cdroot()
recognizer = TransformerNamedEntityRecognizer()
save_dir = 'data/model/ner/ner_albert_base_msra'
# recognizer.fit(MSRA_NER_TRAIN, MSRA_NER_VALID, save_dir, transformer='albert_base_zh',
# implementation='bert-for-tf2',
# learning_rate=5e-5,
# metrics='f1')
recognizer.load(save_dir)
print(recognizer.predict(list('上海华安工业(集团)公司董事长谭旭光和秘书张晚霞来到美国纽约现代艺术博物馆参观。')))
recognizer.evaluate(MSRA_NER_TEST, save_dir=save_dir)
print(f'Model saved in {save_dir}')

0 comments on commit 81f9d88

Please sign in to comment.