From 4c457e99248acf5e0d9384013a480d0632bf9877 Mon Sep 17 00:00:00 2001 From: ChenXin Date: Mon, 26 Oct 2020 13:22:47 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B0=86=20nltk=20=E4=BB=8E=E4=BE=9D=E8=B5=96?= =?UTF-8?q?=E4=B8=AD=E5=88=A0=E9=99=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/io/pipe/classification.py | 124 +++++++++++++++--------------- requirements.txt | 1 - 2 files changed, 64 insertions(+), 61 deletions(-) diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py index c59ffe5d..9475a092 100644 --- a/fastNLP/io/pipe/classification.py +++ b/fastNLP/io/pipe/classification.py @@ -17,7 +17,11 @@ import re import warnings -from nltk import Tree +try: + from nltk import Tree +except: + # only nltk in some versions can run + pass from .pipe import Pipe from .utils import get_tokenizer, _indexize, _add_words_field, _add_chars_field, _granularize @@ -32,12 +36,12 @@ class CLSBasePipe(Pipe): - - def __init__(self, lower: bool=False, tokenizer: str='spacy', lang='en'): + + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', lang='en'): super().__init__() self.lower = lower self.tokenizer = get_tokenizer(tokenizer, lang=lang) - + def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): r""" 将DataBundle中的数据进行tokenize @@ -50,9 +54,9 @@ def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): new_field_name = new_field_name or field_name for name, dataset in data_bundle.datasets.items(): dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) - + return data_bundle - + def process(self, data_bundle: DataBundle): r""" 传入的DataSet应该具备如下的结构 @@ -73,15 +77,15 @@ def process(self, data_bundle: DataBundle): data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) # 建立词表并index data_bundle = _indexize(data_bundle=data_bundle) - + for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.INPUT) - + data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) data_bundle.set_target(Const.TARGET) - + return data_bundle - + def process_from_file(self, paths) -> DataBundle: r""" 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` @@ -151,7 +155,7 @@ def process(self, data_bundle): """ if self.tag_map is not None: data_bundle = _granularize(data_bundle, self.tag_map) - + data_bundle = super().process(data_bundle) return data_bundle @@ -231,7 +235,7 @@ class AGsNewsPipe(CLSBasePipe): +-------------+-----------+--------+-------+---------+ """ - + def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): r""" @@ -239,7 +243,7 @@ def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 """ super().__init__(lower=lower, tokenizer=tokenizer, lang='en') - + def process_from_file(self, paths=None): r""" :param str paths: @@ -272,7 +276,7 @@ class DBPediaPipe(CLSBasePipe): +-------------+-----------+--------+-------+---------+ """ - + def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): r""" @@ -280,7 +284,7 @@ def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 """ super().__init__(lower=lower, tokenizer=tokenizer, lang='en') - + def process_from_file(self, paths=None): r""" :param str paths: @@ -369,7 +373,7 @@ def process(self, data_bundle: DataBundle): instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label()) ds.append(instance) data_bundle.set_dataset(ds, name) - + # 根据granularity设置tag data_bundle = _granularize(data_bundle, tag_map=self.tag_map) @@ -525,6 +529,7 @@ class ChnSentiCorpPipe(Pipe): +-------------+-----------+--------+-------+---------+ """ + def __init__(self, bigrams=False, trigrams=False): r""" @@ -536,10 +541,10 @@ def __init__(self, bigrams=False, trigrams=False): data_bundle.get_vocab('trigrams')获取. """ super().__init__() - + self.bigrams = bigrams self.trigrams = trigrams - + def _tokenize(self, data_bundle): r""" 将DataSet中的"复旦大学"拆分为["复", "旦", "大", "学"]. 未来可以通过扩展这个函数实现分词。 @@ -549,8 +554,8 @@ def _tokenize(self, data_bundle): """ data_bundle.apply_field(list, field_name=Const.CHAR_INPUT, new_field_name=Const.CHAR_INPUT) return data_bundle - - def process(self, data_bundle:DataBundle): + + def process(self, data_bundle: DataBundle): r""" 可以处理的DataSet应该具备以下的field @@ -565,9 +570,9 @@ def process(self, data_bundle:DataBundle): :return: """ _add_chars_field(data_bundle, lower=False) - + data_bundle = self._tokenize(data_bundle) - + input_field_names = [Const.CHAR_INPUT] if self.bigrams: for name, dataset in data_bundle.iter_datasets(): @@ -580,21 +585,21 @@ def process(self, data_bundle:DataBundle): zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], field_name=Const.CHAR_INPUT, new_field_name='trigrams') input_field_names.append('trigrams') - + # index _indexize(data_bundle, input_field_names, Const.TARGET) - + input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names target_fields = [Const.TARGET] - + for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(Const.CHAR_INPUT) - + data_bundle.set_input(*input_fields) data_bundle.set_target(*target_fields) - + return data_bundle - + def process_from_file(self, paths=None): r""" @@ -604,7 +609,7 @@ def process_from_file(self, paths=None): # 读取数据 data_bundle = ChnSentiCorpLoader().load(paths) data_bundle = self.process(data_bundle) - + return data_bundle @@ -637,26 +642,26 @@ class THUCNewsPipe(CLSBasePipe): 。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 data_bundle.get_vocab('trigrams')获取. """ - + def __init__(self, bigrams=False, trigrams=False): super().__init__() - + self.bigrams = bigrams self.trigrams = trigrams - + def _chracter_split(self, sent): return list(sent) # return [w for w in sent] - + def _raw_split(self, sent): return sent.split() - + def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): new_field_name = new_field_name or field_name for name, dataset in data_bundle.datasets.items(): dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) return data_bundle - + def process(self, data_bundle: DataBundle): r""" 可处理的DataSet应具备如下的field @@ -673,14 +678,14 @@ def process(self, data_bundle: DataBundle): # 根据granularity设置tag tag_map = {'体育': 0, '财经': 1, '房产': 2, '家居': 3, '教育': 4, '科技': 5, '时尚': 6, '时政': 7, '游戏': 8, '娱乐': 9} data_bundle = _granularize(data_bundle=data_bundle, tag_map=tag_map) - + # clean,lower - + # CWS(tokenize) data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars') - + input_field_names = [Const.CHAR_INPUT] - + # n-grams if self.bigrams: for name, dataset in data_bundle.iter_datasets(): @@ -693,22 +698,22 @@ def process(self, data_bundle: DataBundle): zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], field_name=Const.CHAR_INPUT, new_field_name='trigrams') input_field_names.append('trigrams') - + # index data_bundle = _indexize(data_bundle=data_bundle, input_field_names=Const.CHAR_INPUT) - + # add length for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(field_name=Const.CHAR_INPUT, new_field_name=Const.INPUT_LEN) - + input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names target_fields = [Const.TARGET] - + data_bundle.set_input(*input_fields) data_bundle.set_target(*target_fields) - + return data_bundle - + def process_from_file(self, paths=None): r""" :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 @@ -749,22 +754,22 @@ class WeiboSenti100kPipe(CLSBasePipe): 。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 data_bundle.get_vocab('trigrams')获取. """ - + def __init__(self, bigrams=False, trigrams=False): super().__init__() - + self.bigrams = bigrams self.trigrams = trigrams - + def _chracter_split(self, sent): return list(sent) - + def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): new_field_name = new_field_name or field_name for name, dataset in data_bundle.datasets.items(): dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) return data_bundle - + def process(self, data_bundle: DataBundle): r""" 可处理的DataSet应具备以下的field @@ -779,12 +784,12 @@ def process(self, data_bundle: DataBundle): :return: """ # clean,lower - + # CWS(tokenize) data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars') - + input_field_names = [Const.CHAR_INPUT] - + # n-grams if self.bigrams: for name, dataset in data_bundle.iter_datasets(): @@ -797,22 +802,22 @@ def process(self, data_bundle: DataBundle): zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)], field_name=Const.CHAR_INPUT, new_field_name='trigrams') input_field_names.append('trigrams') - + # index data_bundle = _indexize(data_bundle=data_bundle, input_field_names='chars') - + # add length for name, dataset in data_bundle.datasets.items(): dataset.add_seq_len(field_name=Const.CHAR_INPUT, new_field_name=Const.INPUT_LEN) - + input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names target_fields = [Const.TARGET] - + data_bundle.set_input(*input_fields) data_bundle.set_target(*target_fields) - + return data_bundle - + def process_from_file(self, paths=None): r""" :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 @@ -822,4 +827,3 @@ def process_from_file(self, paths=None): data_bundle = data_loader.load(paths) data_bundle = self.process(data_bundle) return data_bundle - diff --git a/requirements.txt b/requirements.txt index 242301be..81fb307c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ numpy>=1.14.2 torch>=1.0.0 tqdm>=4.28.1 -nltk>=3.4.1 prettytable>=0.7.2 requests spacy