Skip to content

Commit

Permalink
Normalize chars before tokenizing, TxtBMESFormat supports max_seq_len
Browse files Browse the repository at this point in the history
  • Loading branch information
hankcs committed Jan 10, 2020
1 parent d43e83b commit 3a791c3
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
2 changes: 1 addition & 1 deletion hanlp/common/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __contains__(self, key: Union[str, int]):

def add(self, token: str) -> int:
assert self.mutable, 'It is not allowed to call add on an immutable Vocab'
assert isinstance(token, str), 'Token type must be str'
assert isinstance(token, str), f'Token type must be str but got {type(token)} from {token}'
assert token, 'Token must not be None or length 0'
idx = self.token_to_idx.get(token, None)
if idx is None:
Expand Down
6 changes: 4 additions & 2 deletions hanlp/components/tok.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from hanlp.transform.tsv import TSVTaggingTransform
from hanlp.transform.txt import extract_ngram_features_and_tags, bmes_to_words, TxtFormat, TxtBMESFormat
from hanlp.utils.util import merge_locals_kwargs


class BMESTokenizer(KerasComponent):

def build_metrics(self, metrics, logger: logging.Logger, **kwargs):
Expand Down Expand Up @@ -92,8 +94,8 @@ def __init__(self, transform: RNNTokenizerTransform = None) -> None:
super().__init__(transform)

def fit(self, trn_data: str, dev_data: str = None, save_dir: str = None, embeddings=100, embedding_trainable=False,
rnn_input_dropout=0.2, rnn_units=100, rnn_output_dropout=0.2, epochs=20, lower=False, logger=None,
loss: Union[tf.keras.losses.Loss, str] = None,
rnn_input_dropout=0.2, rnn_units=100, rnn_output_dropout=0.2, epochs=20, lower=False, max_seq_len=50,
logger=None, loss: Union[tf.keras.losses.Loss, str] = None,
optimizer: Union[str, tf.keras.optimizers.Optimizer] = 'adam', metrics='f1', batch_size=32,
dev_batch_size=32, lr_decay_per_epoch=None, verbose=True, **kwargs):
return super().fit(**merge_locals_kwargs(locals(), kwargs))
24 changes: 23 additions & 1 deletion hanlp/transform/txt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from hanlp.common.transform import Transform
from hanlp.common.vocab import Vocab
from hanlp.utils.io_util import get_resource
from hanlp.utils.lang.zh.char_table import CharTable


def generate_words_per_line(file_path):
Expand Down Expand Up @@ -66,6 +67,7 @@ def extract_ngram_features_and_tags(sentence, bigram_only=False, window_size=4,
"""
chars, tags = bmes_of(sentence, segmented)
chars = CharTable.normalize_chars(chars)
ret = []
ret.append(chars)
# TODO: optimize ngram generation using https://www.tensorflow.org/api_docs/python/tf/strings/ngrams
Expand Down Expand Up @@ -191,9 +193,29 @@ def file_to_inputs(self, filepath: str, gold=True):

class TxtBMESFormat(TxtFormat, ABC):
def file_to_inputs(self, filepath: str, gold=True):
max_seq_len = self.config.get('max_seq_len', False)
if max_seq_len:
delimiter = set()
delimiter.update('。!?:;、,,;!?、,')
for text in super().file_to_inputs(filepath, gold):
chars, tags = bmes_of(text, gold)
yield chars, tags
if max_seq_len and len(chars) > max_seq_len:
short_chars, short_tags = [], []
for idx, (char, tag) in enumerate(zip(chars, tags)):
short_chars.append(char)
short_tags.append(tag)
if len(short_chars) >= max_seq_len and char in delimiter:
yield short_chars, short_tags
short_chars, short_tags = [], []
if short_chars:
yield short_chars, short_tags
else:
yield chars, tags

def input_is_single_sample(self, input: Union[List[str], List[List[str]]]) -> bool:
return isinstance(input, str)

def inputs_to_samples(self, inputs, gold=False):
for chars, tags in inputs:
chars = CharTable.normalize_chars(chars)
yield chars, tags

0 comments on commit 3a791c3

Please sign in to comment.