Skip to content

Commit

Permalink
refactor: add type hint
Browse files Browse the repository at this point in the history
  • Loading branch information
englishbook committed Feb 16, 2020
1 parent 07beb21 commit cdb9718
Show file tree
Hide file tree
Showing 9 changed files with 587 additions and 291 deletions.
166 changes: 108 additions & 58 deletions fancy_nlp/applications/spm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# -*- coding: utf-8 -*-

from typing import List, Optional, Union, Tuple, Dict, Any

from absl import logging
from keras.models import model_from_json
from keras.utils import get_file
import numpy as np

from fancy_nlp.preprocessors import SPMPreprocessor
from fancy_nlp.models.spm import *
Expand All @@ -15,7 +16,7 @@
class SPM(object):
"""SPM application"""

def __init__(self, use_pretrained=False):
def __init__(self, use_pretrained: bool = False) -> None:
self.preprocessor = None
self.model = None
self.trainer = None
Expand All @@ -25,36 +26,35 @@ def __init__(self, use_pretrained=False):
self.load_pretrained_model()

def fit(self,
train_data,
train_labels,
valid_data=None,
valid_labels=None,
spm_model_type='siamese_cnn',
use_word=True,
external_word_dict=None,
word_embed_type='word2vec',
word_embed_dim=300,
word_embed_trainable=True,
use_char=False,
char_embed_type='word2vec',
char_embed_dim=300,
char_embed_trainable=True,
use_bert=False,
bert_vocab_file=None,
bert_config_file=None,
bert_checkpoint_file=None,
bert_trainable=False,
label_dict_file=None,
max_len=None,
max_word_len=None,
optimizer='adam',
batch_size=32,
epochs=50,
callback_list=None,
checkpoint_dir=None,
model_name=None,
load_swa_model=False,
**kwargs):
train_data: Tuple[List[str], List[str]],
train_labels: List[str],
valid_data: Optional[Tuple[List[str], List[str]]] = None,
valid_labels: Optional[List[str]] = None,
spm_model_type: str = 'siamese_cnn',
use_word: bool = True,
external_word_dict: List[str] = None,
word_embed_type: Optional[str] = 'word2vec',
word_embed_dim: int = 300,
word_embed_trainable: bool = True,
use_char: bool = False,
char_embed_type: Optional[str] = 'word2vec',
char_embed_dim: int = 300,
char_embed_trainable: bool = True,
use_bert: bool = False,
bert_vocab_file: Optional[str] = None,
bert_config_file: Optional = None,
bert_checkpoint_file: Optional = None,
bert_trainable: bool = False,
max_len: Optional[int] = None,
max_word_len: Optional[int] = None,
optimizer: Union[str, tf.keras.optimizers.Optimizer] = 'adam',
batch_size: int = 32,
epochs: int = 50,
callback_list: Optional[List[str]] = None,
checkpoint_dir: Optional = None,
model_name: Optional = None,
load_swa_model: bool = False,
**kwargs) -> None:
"""Train spm model using provided data
Args:
Expand All @@ -71,22 +71,19 @@ def fit(self,
word_embed_dim: int, dimensionality of word embedding
word_embed_trainable: boolean, whether to update word embedding during training
use_char: boolean, whether to use char as input
char_embed_type: similar as 'word_embed_type'
char_embed_dim: similar as 'word_embed_dim'
char_embed_trainable: similar as 'word_embed_trainable'
char_embed_type: str, similar as 'word_embed_type'
char_embed_dim: int, similar as 'word_embed_dim'
char_embed_trainable: boolean, similar as 'word_embed_trainable'
use_bert: boolean, whether to use bert embedding as input
bert_vocab_file: str, path to bert's vocabulary file
bert_config_file: str, path to bert's configuration file
bert_checkpoint_file: str, path to bert's checkpoint file
bert_trainable: boolean, whether to update bert during training
use_bert_model: boolean, whether to use traditional bert model which combines two sentences
as one input
label_dict_file: a file with two columns separated by tab, the first column is raw
label name, and the second column is the corresponding name which is
meaningful
max_len: int, max sequence length. If None, we dynamically use the max length of one batch
as max_len. However, max_len must be provided when using bert as input.
max_word_len: int max word length. If None, we dynamically use the max word length of one
max_word_len: int, max word length. If None, we dynamically use the max word length of one
batch as max_word_len.
optimizer: str or instance of `keras.optimizers.Optimizer`, indicating the optimizer to
use during training
Expand All @@ -109,6 +106,7 @@ def fit(self,
"""
use_bert_model = True if spm_model_type == 'bert' else False

# data preprocessing
self.preprocessor = SPMPreprocessor(train_data=train_data,
train_labels=train_labels,
use_word=use_word,
Expand All @@ -121,10 +119,10 @@ def fit(self,
char_embed_dim=char_embed_dim,
word_embed_type=word_embed_type,
word_embed_dim=word_embed_dim,
label_dict_file=label_dict_file,
max_len=max_len,
max_word_len=max_word_len)

# build model
self.model = self.get_spm_model(spm_model_type=spm_model_type,
num_class=self.preprocessor.num_class,
use_word=use_word,
Expand All @@ -146,6 +144,7 @@ def fit(self,
optimizer=optimizer,
**kwargs)

# build swa model
if 'swa' in callback_list:
swa_model = self.get_spm_model(spm_model_type=spm_model_type,
num_class=self.preprocessor.num_class,
Expand All @@ -170,18 +169,20 @@ def fit(self,
else:
swa_model = None

# train model
self.trainer = SPMTrainer(self.model, self.preprocessor)
self.trainer.train_generator(train_data, train_labels, valid_data, valid_labels,
batch_size, epochs, callback_list, checkpoint_dir, model_name,
swa_model, load_swa_model)

# predict model
self.predictor = SPMPredictor(self.model, self.preprocessor)

if valid_data is not None and valid_labels is not None:
logging.info('Evaluating on validation data...')
self.score(valid_data, valid_labels)

def score(self, valid_data, valid_labels):
def score(self, valid_data: Tuple[List[str], List[str]], valid_labels: List[str]) -> float:
"""Return the f1 score of the model over validation data
Args:
Expand All @@ -196,7 +197,7 @@ def score(self, valid_data, valid_labels):
else:
logging.fatal('Trainer is None! Call fit() or load() to get trainer.')

def predict(self, test_text):
def predict(self, test_text: Tuple[str, str]) -> str:
"""Return prediction of the model for test data
Args:
Expand All @@ -210,7 +211,7 @@ def predict(self, test_text):
else:
logging.fatal('Predictor is None! Call fit() or load() to get predictor.')

def predict_batch(self, test_texts):
def predict_batch(self, test_texts: Tuple[List[str], List[str]]) -> List[str]:
"""Return predictions of the model for test data
Args:
Expand All @@ -224,7 +225,7 @@ def predict_batch(self, test_texts):
else:
logging.fatal('Predictor is None! Call fit() or load() to get predictor.')

def analyze(self, text):
def analyze(self, text: Tuple[str, str]) -> Tuple[str, np.ndarray]:
"""Analyze text and return matching result with probability.
Args:
Expand All @@ -237,7 +238,7 @@ def analyze(self, text):
else:
logging.fatal('Predictor is None! Call fit() or load() to get predictor.')

def analyze_batch(self, texts):
def analyze_batch(self, texts: Tuple[List[str], List[str]]) -> List[Tuple[str, np.ndarray]]:
"""Analyze text and return matching result with probability.
Args:
Expand All @@ -250,7 +251,10 @@ def analyze_batch(self, texts):
else:
logging.fatal('Predictor is None! Call fit() or load() to get predictor.')

def save(self, preprocessor_file, json_file, weights_file=None):
def save(self,
preprocessor_file: str,
json_file: str,
weights_file: Optional[str] = None) -> None:
"""save spm application
Args:
Expand All @@ -273,7 +277,11 @@ def save(self, preprocessor_file, json_file, weights_file=None):
self.model.save_weights(weights_file)
logging.info('Save model weights to {}'.format(weights_file))

def load(self, preprocessor_file, json_file, weights_file, custom_objects=None):
def load(self,
preprocessor_file: str,
json_file: str,
weights_file: str,
custom_objects: Optional[Dict[str, Any]] = None) -> None:
"""load spm application
Args:
Expand All @@ -291,7 +299,7 @@ def load(self, preprocessor_file, json_file, weights_file, custom_objects=None):
custom_objects = custom_objects or {}
custom_objects.update(get_custom_objects())
with open(json_file, 'r') as reader:
self.model = model_from_json(reader.read(), custom_objects=custom_objects)
self.model = tf.keras.models.model_from_json(reader.read(), custom_objects=custom_objects)
logging.info('Load model architecture from {}'.format(json_file))

self.model.load_weights(weights_file)
Expand All @@ -301,11 +309,53 @@ def load(self, preprocessor_file, json_file, weights_file, custom_objects=None):
self.predictor = SPMPredictor(self.model, self.preprocessor)

@staticmethod
def get_spm_model(spm_model_type, num_class, use_word, word_embeddings, word_vocab_size,
word_embed_dim, word_embed_trainable, use_char, char_embeddings,
char_vocab_size, char_embed_dim, char_embed_trainable, use_bert,
bert_config_file, bert_checkpoint_file, bert_trainable,
max_len, max_word_len, optimizer, **kwargs):
def get_spm_model(spm_model_type: str,
num_class: int,
use_word: bool,
word_embeddings: Optional[np.ndarray],
word_vocab_size: int,
word_embed_dim: int,
word_embed_trainable: bool,
use_char: bool,
char_embeddings: Optional[np.ndarray],
char_vocab_size: int,
char_embed_dim: int,
char_embed_trainable: bool,
use_bert: bool,
bert_config_file: Optional[str],
bert_checkpoint_file: Optional[str],
bert_trainable: bool,
max_len: Optional[int],
max_word_len: Optional[int],
optimizer: Union[str, tf.keras.optimizers.Optimizer],
**kwargs) -> tf.keras.models.Model:
"""build spm models by model_type
Args:
spm_model_type: str, which spm model to use
num_class: int: the number of classification class
use_word: boolean, whether to use word embedding as input
word_embeddings: np.ndarray, word embeddings
word_vocab_size: int, the number of words in vocabulary
word_embed_dim: int, dimensionality of word embedding
word_embed_trainable: boolean, whether to update word embedding during training
use_char: boolean, whether to use char as input
char_embeddings: ndarray, char_embeddings
char_vocab_size: int, the number of chars in vocabulary
char_embed_dim: int, dimensionality of char embedding
char_embed_trainable: boolean, similar as 'word_embed_trainable'
use_bert: boolean, whether to use bert embedding as input
bert_config_file: str, path to bert's configuration file
bert_checkpoint_file: str, path to bert's checkpoint file
bert_trainable: boolean, whether to update bert during training
max_len: int, max sequence length. If None, we dynamically use the max length of one batch
as max_len. However, max_len must be provided when using bert as input.
max_word_len: int, max word length. If None, we dynamically use the max word length of one
batch as max_word_len.
optimizer: str or instance of `keras.optimizers.Optimizer`, indicating the optimizer to
use during training
**kwargs: other argument for building spm model, such as "rnn_units", "fc_dim" etc.
"""
spm_model_all = {'siamese_cnn': SiameseCNN,
'siamese_bilstm': SiameseBiLSTM,
'siamese_bigru': SiameseBiGRU,
Expand Down Expand Up @@ -340,18 +390,18 @@ def get_spm_model(spm_model_type, num_class, use_word, word_embeddings, word_voc
return spm_model.build_model()

# todo: 重新训练模型
def load_pretrained_model(self):
def load_pretrained_model(self) -> None:
cache_subdir = 'pretrained_models'

preprocessor_file = get_file(
preprocessor_file = tf.keras.utils.get_file(
fname='webank_spm_siamese_cnn_word_preprocessor.pkl',
origin=MODEL_STORAGE_PREFIX + 'webank_spm_siamese_cnn_word_preprocessor.pkl',
cache_subdir=cache_subdir, cache_dir=CACHE_DIR)
json_file = get_file(
json_file = tf.keras.utils.get_file(
fname='webank_spm_siamese_cnn_word.json',
origin=MODEL_STORAGE_PREFIX + 'webank_spm_siamese_cnn_word.json',
cache_subdir=cache_subdir, cache_dir=CACHE_DIR)
weights_file = get_file(
weights_file = tf.keras.utils.get_file(
fname='webank_spm_siamese_cnn_word.hdf5',
origin=MODEL_STORAGE_PREFIX + 'webank_spm_siamese_cnn_word.hdf5',
cache_subdir=cache_subdir, cache_dir=CACHE_DIR)
Expand Down
9 changes: 6 additions & 3 deletions fancy_nlp/callbacks/metrics.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# -*- coding: utf-8 -*-

from typing import List
from typing import List, Tuple

import tensorflow as tf
from seqeval import metrics
from sklearn.metrics import f1_score, precision_score, recall_score, classification_report
import numpy as np

from fancy_nlp.preprocessors import NERPreprocessor
from fancy_nlp.preprocessors import NERPreprocessor, SPMPreprocessor


class NERMetric(tf.keras.callbacks.Callback):
Expand Down Expand Up @@ -92,7 +92,10 @@ class SPMMetric(tf.keras.callbacks.Callback):
"""
callback for evaluating spm model
"""
def __init__(self, preprocessor, valid_data, valid_labels):
def __init__(self,
preprocessor: SPMPreprocessor,
valid_data: Tuple[List[str], List[str]],
valid_labels: List[str]) -> None:
"""
Args:
preprocessor: `SPMPreprocessor` instance to help prepare input for spm model
Expand Down
Loading

0 comments on commit cdb9718

Please sign in to comment.