forked from shibing624/text2vec
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
xuming06
committed
Nov 12, 2019
1 parent
a64f643
commit 466e994
Showing
22 changed files
with
1,402 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
numpy | ||
jieba | ||
scipy | ||
scikit-learn | ||
keras-bert==0.80.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,13 +3,25 @@ | |
@author:XuMing([email protected]) | ||
@description: | ||
""" | ||
|
||
import os | ||
from pathlib import Path | ||
|
||
from simtext.similarity.similarity import Similarity | ||
from simtext.utils.logger import logger | ||
|
||
os.environ['TF_KERAS'] = '1' | ||
import keras_bert | ||
|
||
USER_DIR = Path.expanduser(Path('~')).joinpath('.simtext') | ||
if not USER_DIR.exists(): | ||
logger.info('make dir:%s' % USER_DIR) | ||
USER_DIR.mkdir() | ||
USER_DATA_DIR = USER_DIR.joinpath('datasets') | ||
if not USER_DATA_DIR.exists(): | ||
USER_DATA_DIR.mkdir() | ||
USER_TUNED_MODELS_DIR = USER_DIR.joinpath('tuned_models') | ||
USER_BERT_MODEL_DIR = USER_DIR.joinpath('bert_model') | ||
|
||
custom_objects = keras_bert.get_custom_objects() | ||
|
||
sim = Similarity('bert') | ||
similarity_score = sim.similarity_score |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author:XuMing([email protected]) | ||
@description: | ||
""" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
@author:XuMing([email protected]) | ||
@description: | ||
""" | ||
|
||
import codecs | ||
import os | ||
from typing import Union, Optional, Any, List, Tuple | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from simtext.embeddings.embedding import Embedding | ||
from simtext.processors.base_processor import BaseProcessor | ||
from simtext.utils.logger import get_logger | ||
from simtext.utils.non_masking_layer import NonMaskingLayer | ||
|
||
os.environ['TF_KERAS'] = '1' | ||
import keras_bert | ||
|
||
logger = get_logger(__name__) | ||
|
||
|
||
class BERTEmbedding(Embedding): | ||
"""Pre-trained BERT embedding""" | ||
|
||
def info(self): | ||
info = super(BERTEmbedding, self).info() | ||
info['config'] = { | ||
'model_folder': self.model_folder, | ||
'sequence_length': self.sequence_length | ||
} | ||
return info | ||
|
||
def __init__(self, | ||
model_folder: str, | ||
layer_nums: int = 4, | ||
trainable: bool = False, | ||
sequence_length: Union[str, int] = 'auto', | ||
processor: Optional[BaseProcessor] = None): | ||
""" | ||
Args: | ||
model_folder: | ||
layer_nums: number of layers whose outputs will be concatenated into a single tensor, | ||
default `4`, output the last 4 hidden layers as the thesis suggested | ||
trainable: whether if the model is trainable, default `False` and set it to `True` | ||
for fine-tune this embedding layer during your training | ||
sequence_length: | ||
processor: | ||
""" | ||
self.trainable = trainable | ||
# Do not need to train the whole bert model if just to use its feature output | ||
self.training = False | ||
self.layer_nums = layer_nums | ||
if isinstance(sequence_length, tuple): | ||
raise ValueError('BERT embedding only accept `int` type `sequence_length`') | ||
|
||
if sequence_length == 'variable': | ||
raise ValueError('BERT embedding only accept sequences in equal length') | ||
|
||
super(BERTEmbedding, self).__init__(sequence_length=sequence_length, | ||
embedding_size=0, | ||
processor=processor) | ||
|
||
self.processor.token_pad = '[PAD]' | ||
self.processor.token_unk = '[UNK]' | ||
self.processor.token_bos = '[CLS]' | ||
self.processor.token_eos = '[SEP]' | ||
|
||
self.processor.add_bos_eos = True | ||
|
||
self.model_folder = model_folder | ||
self._build_token2idx_from_bert() | ||
self._build_model() | ||
|
||
def _build_token2idx_from_bert(self): | ||
dict_path = os.path.join(self.model_folder, 'vocab.txt') | ||
logger.debug('load vocab.txt from %s' % self.model_folder) | ||
token2idx = {} | ||
with codecs.open(dict_path, 'r', encoding='utf-8') as f: | ||
for line in f: | ||
token = line.strip() | ||
token2idx[token] = len(token2idx) | ||
|
||
self.bert_token2idx = token2idx | ||
self.tokenizer = keras_bert.Tokenizer(token2idx) | ||
self.processor.token2idx = self.bert_token2idx | ||
self.processor.idx2token = dict([(value, key) for key, value in token2idx.items()]) | ||
|
||
def _build_model(self, **kwargs): | ||
if self.embed_model is None: | ||
seq_len = self.sequence_length | ||
if isinstance(seq_len, tuple): | ||
seq_len = seq_len[0] | ||
if isinstance(seq_len, str): | ||
logger.warning(f"Model will be built until sequence length is determined") | ||
return | ||
config_path = os.path.join(self.model_folder, 'bert_config.json') | ||
check_point_path = os.path.join(self.model_folder, 'bert_model.ckpt') | ||
logger.debug('load bert model from:%s' % check_point_path) | ||
bert_model = keras_bert.load_trained_model_from_checkpoint(config_path, | ||
check_point_path, | ||
seq_len=seq_len, | ||
output_layer_num=self.layer_nums, | ||
training=self.training, | ||
trainable=self.trainable) | ||
|
||
self._model = tf.keras.Model(bert_model.inputs, bert_model.output) | ||
bert_seq_len = int(bert_model.output.shape[1]) | ||
if bert_seq_len < seq_len: | ||
logger.warning(f"Sequence length limit set to {bert_seq_len} by pre-trained model") | ||
self.sequence_length = bert_seq_len | ||
self.embedding_size = int(bert_model.output.shape[-1]) | ||
output_features = NonMaskingLayer()(bert_model.output) | ||
self.embed_model = tf.keras.Model(bert_model.inputs, output_features) | ||
logger.warning(f'seq_len: {self.sequence_length}') | ||
|
||
def analyze_corpus(self, | ||
x: Union[Tuple[List[List[str]], ...], List[List[str]]], | ||
y: Union[List[List[Any]], List[Any]]): | ||
""" | ||
Prepare embedding layer and pre-processor for labeling task | ||
Args: | ||
x: | ||
y: | ||
Returns: | ||
""" | ||
if len(self.processor.token2idx) == 0: | ||
self._build_token2idx_from_bert() | ||
super(BERTEmbedding, self).analyze_corpus(x, y) | ||
|
||
def embed(self, | ||
sentence_list: Union[Tuple[List[List[str]], ...], List[List[str]]], | ||
debug: bool = False) -> np.ndarray: | ||
""" | ||
batch embed sentences | ||
Args: | ||
sentence_list: Sentence list to embed | ||
debug: show debug log | ||
Returns: | ||
vectorized sentence list | ||
print(token, predicts[i].tolist()[:4]) | ||
[CLS] [0.24250675737857819, 0.04605229198932648, ...] | ||
from [0.2858668565750122, 0.12927496433258057, ...] | ||
that [-0.7514970302581787, 0.14548861980438232, ...] | ||
day [0.32245880365371704, -0.043174318969249725, ...] | ||
... | ||
""" | ||
if self.embed_model is None: | ||
raise ValueError('need to build model for embed sentence') | ||
|
||
tensor_x = self.process_x_dataset(sentence_list) | ||
if debug: | ||
logger.debug(f'sentence tensor: {tensor_x}') | ||
embed_results = self.embed_model.predict(tensor_x) | ||
return embed_results | ||
|
||
def process_x_dataset(self, | ||
data: Union[Tuple[List[List[str]], ...], List[List[str]]], | ||
subset: Optional[List[int]] = None) -> Tuple[np.ndarray, ...]: | ||
""" | ||
batch process feature data while training | ||
Args: | ||
data: target dataset | ||
subset: subset index list | ||
Returns: | ||
vectorized feature tensor | ||
""" | ||
x1 = None | ||
if isinstance(data, tuple): | ||
if len(data) == 2: | ||
x0 = self.processor.process_x_dataset(data[0], self.sequence_length, subset) | ||
x1 = self.processor.process_x_dataset(data[1], self.sequence_length, subset) | ||
else: | ||
x0 = self.processor.process_x_dataset(data[0], self.sequence_length, subset) | ||
else: | ||
x0 = self.processor.process_x_dataset(data, self.sequence_length, subset) | ||
if x1 is None: | ||
x1 = np.zeros(x0.shape, dtype=np.int32) | ||
return x0, x1 | ||
|
||
|
Oops, something went wrong.