Skip to content

Commit

Permalink
将 nltk 从依赖中删除
Browse files Browse the repository at this point in the history
  • Loading branch information
WillQvQ committed Oct 26, 2020
1 parent 13480d6 commit 4c457e9
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 61 deletions.
124 changes: 64 additions & 60 deletions fastNLP/io/pipe/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
import re
import warnings

from nltk import Tree
try:
from nltk import Tree
except:
# only nltk in some versions can run
pass

from .pipe import Pipe
from .utils import get_tokenizer, _indexize, _add_words_field, _add_chars_field, _granularize
Expand All @@ -32,12 +36,12 @@


class CLSBasePipe(Pipe):

def __init__(self, lower: bool=False, tokenizer: str='spacy', lang='en'):
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', lang='en'):
super().__init__()
self.lower = lower
self.tokenizer = get_tokenizer(tokenizer, lang=lang)

def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None):
r"""
将DataBundle中的数据进行tokenize
Expand All @@ -50,9 +54,9 @@ def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None):
new_field_name = new_field_name or field_name
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name)

return data_bundle

def process(self, data_bundle: DataBundle):
r"""
传入的DataSet应该具备如下的结构
Expand All @@ -73,15 +77,15 @@ def process(self, data_bundle: DataBundle):
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)
# 建立词表并index
data_bundle = _indexize(data_bundle=data_bundle)

for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT)

data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET)

return data_bundle

def process_from_file(self, paths) -> DataBundle:
r"""
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`
Expand Down Expand Up @@ -151,7 +155,7 @@ def process(self, data_bundle):
"""
if self.tag_map is not None:
data_bundle = _granularize(data_bundle, self.tag_map)

data_bundle = super().process(data_bundle)

return data_bundle
Expand Down Expand Up @@ -231,15 +235,15 @@ class AGsNewsPipe(CLSBasePipe):
+-------------+-----------+--------+-------+---------+
"""

def __init__(self, lower: bool = False, tokenizer: str = 'spacy'):
r"""
:param bool lower: 是否对输入进行小写化。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
"""
super().__init__(lower=lower, tokenizer=tokenizer, lang='en')

def process_from_file(self, paths=None):
r"""
:param str paths:
Expand Down Expand Up @@ -272,15 +276,15 @@ class DBPediaPipe(CLSBasePipe):
+-------------+-----------+--------+-------+---------+
"""

def __init__(self, lower: bool = False, tokenizer: str = 'spacy'):
r"""
:param bool lower: 是否对输入进行小写化。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
"""
super().__init__(lower=lower, tokenizer=tokenizer, lang='en')

def process_from_file(self, paths=None):
r"""
:param str paths:
Expand Down Expand Up @@ -369,7 +373,7 @@ def process(self, data_bundle: DataBundle):
instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label())
ds.append(instance)
data_bundle.set_dataset(ds, name)

# 根据granularity设置tag
data_bundle = _granularize(data_bundle, tag_map=self.tag_map)

Expand Down Expand Up @@ -525,6 +529,7 @@ class ChnSentiCorpPipe(Pipe):
+-------------+-----------+--------+-------+---------+
"""

def __init__(self, bigrams=False, trigrams=False):
r"""
Expand All @@ -536,10 +541,10 @@ def __init__(self, bigrams=False, trigrams=False):
data_bundle.get_vocab('trigrams')获取.
"""
super().__init__()

self.bigrams = bigrams
self.trigrams = trigrams

def _tokenize(self, data_bundle):
r"""
将DataSet中的"复旦大学"拆分为["复", "旦", "大", "学"]. 未来可以通过扩展这个函数实现分词。
Expand All @@ -549,8 +554,8 @@ def _tokenize(self, data_bundle):
"""
data_bundle.apply_field(list, field_name=Const.CHAR_INPUT, new_field_name=Const.CHAR_INPUT)
return data_bundle

def process(self, data_bundle:DataBundle):
def process(self, data_bundle: DataBundle):
r"""
可以处理的DataSet应该具备以下的field
Expand All @@ -565,9 +570,9 @@ def process(self, data_bundle:DataBundle):
:return:
"""
_add_chars_field(data_bundle, lower=False)

data_bundle = self._tokenize(data_bundle)

input_field_names = [Const.CHAR_INPUT]
if self.bigrams:
for name, dataset in data_bundle.iter_datasets():
Expand All @@ -580,21 +585,21 @@ def process(self, data_bundle:DataBundle):
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)],
field_name=Const.CHAR_INPUT, new_field_name='trigrams')
input_field_names.append('trigrams')

# index
_indexize(data_bundle, input_field_names, Const.TARGET)

input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names
target_fields = [Const.TARGET]

for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.CHAR_INPUT)

data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields)

return data_bundle

def process_from_file(self, paths=None):
r"""
Expand All @@ -604,7 +609,7 @@ def process_from_file(self, paths=None):
# 读取数据
data_bundle = ChnSentiCorpLoader().load(paths)
data_bundle = self.process(data_bundle)

return data_bundle


Expand Down Expand Up @@ -637,26 +642,26 @@ class THUCNewsPipe(CLSBasePipe):
。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过
data_bundle.get_vocab('trigrams')获取.
"""

def __init__(self, bigrams=False, trigrams=False):
super().__init__()

self.bigrams = bigrams
self.trigrams = trigrams

def _chracter_split(self, sent):
return list(sent)
# return [w for w in sent]

def _raw_split(self, sent):
return sent.split()

def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None):
new_field_name = new_field_name or field_name
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name)
return data_bundle

def process(self, data_bundle: DataBundle):
r"""
可处理的DataSet应具备如下的field
Expand All @@ -673,14 +678,14 @@ def process(self, data_bundle: DataBundle):
# 根据granularity设置tag
tag_map = {'体育': 0, '财经': 1, '房产': 2, '家居': 3, '教育': 4, '科技': 5, '时尚': 6, '时政': 7, '游戏': 8, '娱乐': 9}
data_bundle = _granularize(data_bundle=data_bundle, tag_map=tag_map)

# clean,lower

# CWS(tokenize)
data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars')

input_field_names = [Const.CHAR_INPUT]

# n-grams
if self.bigrams:
for name, dataset in data_bundle.iter_datasets():
Expand All @@ -693,22 +698,22 @@ def process(self, data_bundle: DataBundle):
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)],
field_name=Const.CHAR_INPUT, new_field_name='trigrams')
input_field_names.append('trigrams')

# index
data_bundle = _indexize(data_bundle=data_bundle, input_field_names=Const.CHAR_INPUT)

# add length
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(field_name=Const.CHAR_INPUT, new_field_name=Const.INPUT_LEN)

input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names
target_fields = [Const.TARGET]

data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields)

return data_bundle

def process_from_file(self, paths=None):
r"""
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。
Expand Down Expand Up @@ -749,22 +754,22 @@ class WeiboSenti100kPipe(CLSBasePipe):
。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过
data_bundle.get_vocab('trigrams')获取.
"""

def __init__(self, bigrams=False, trigrams=False):
super().__init__()

self.bigrams = bigrams
self.trigrams = trigrams

def _chracter_split(self, sent):
return list(sent)

def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None):
new_field_name = new_field_name or field_name
for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name)
return data_bundle

def process(self, data_bundle: DataBundle):
r"""
可处理的DataSet应具备以下的field
Expand All @@ -779,12 +784,12 @@ def process(self, data_bundle: DataBundle):
:return:
"""
# clean,lower

# CWS(tokenize)
data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars')

input_field_names = [Const.CHAR_INPUT]

# n-grams
if self.bigrams:
for name, dataset in data_bundle.iter_datasets():
Expand All @@ -797,22 +802,22 @@ def process(self, data_bundle: DataBundle):
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)],
field_name=Const.CHAR_INPUT, new_field_name='trigrams')
input_field_names.append('trigrams')

# index
data_bundle = _indexize(data_bundle=data_bundle, input_field_names='chars')

# add length
for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(field_name=Const.CHAR_INPUT, new_field_name=Const.INPUT_LEN)

input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names
target_fields = [Const.TARGET]

data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields)

return data_bundle

def process_from_file(self, paths=None):
r"""
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。
Expand All @@ -822,4 +827,3 @@ def process_from_file(self, paths=None):
data_bundle = data_loader.load(paths)
data_bundle = self.process(data_bundle)
return data_bundle

1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
numpy>=1.14.2
torch>=1.0.0
tqdm>=4.28.1
nltk>=3.4.1
prettytable>=0.7.2
requests
spacy
Expand Down

0 comments on commit 4c457e9

Please sign in to comment.