Skip to content

Commit

Permalink
新增了 DataBundle 的 apply_more 和 apply_field_more 方法。需要进一步试用和测试。
Browse files Browse the repository at this point in the history
  • Loading branch information
WillQvQ committed Jun 6, 2020
1 parent 1f27d00 commit 8732dfd
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 12 deletions.
9 changes: 4 additions & 5 deletions fastNLP/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,10 +861,9 @@ def apply_field(self, func, field_name, new_field_name=None, **kwargs):
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
:return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度
"""
assert len(self) != 0, "Null DataSet cannot use apply_field()."
if field_name not in self:
if not self.has_field(field_name=field_name):
raise KeyError("DataSet has no field named `{}`.".format(field_name))
return self.apply(func, new_field_name, _apply_field=field_name, **kwargs)

Expand All @@ -888,10 +887,10 @@ def apply_field_more(self, func, field_name, modify_fields=True, **kwargs):
3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型
:return Dict[int:Field]: 返回一个字典
:return Dict[str:Field]: 返回一个字典
"""
assert len(self) != 0, "Null DataSet cannot use apply_field()."
if field_name not in self:
if not self.has_field(field_name=field_name):
raise KeyError("DataSet has no field named `{}`.".format(field_name))
return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs)

Expand Down Expand Up @@ -950,7 +949,7 @@ def apply_more(self, func, modify_fields=True, **kwargs):
3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型
:return Dict[int:Field]: 返回一个字典
:return Dict[str:Field]: 返回一个字典
"""
# 返回 dict , 检查是否一直相同
assert callable(func), "The func you provide is not callable."
Expand Down
70 changes: 63 additions & 7 deletions fastNLP/io/data_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, i
dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type)
return self

def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True):
def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True):
r"""
将DataBundle中所有的DataSet中名为field_name的Field的padding值设置为pad_val.
Expand Down Expand Up @@ -282,7 +282,7 @@ def get_dataset_names(self) -> List[str]:
"""
return list(self.datasets.keys())

def get_vocab_names(self)->List[str]:
def get_vocab_names(self) -> List[str]:
r"""
返回DataBundle中Vocabulary的名称
Expand All @@ -304,9 +304,9 @@ def iter_vocabs(self) -> Union[str, Vocabulary]:
for field_name, vocab in self.vocabs.items():
yield field_name, vocab

def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs):
def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs):
r"""
对DataBundle中所有的dataset使用apply_field方法
对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :meth:`~fastNLP.DataSet.apply_field` 方法
:param callable func: input是instance中名为 `field_name` 的field的内容。
:param str field_name: 传入func的是哪个field。
Expand All @@ -329,8 +329,41 @@ def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_da
raise KeyError(f"{field_name} not found DataSet:{name}.")
return self

def apply(self, func, new_field_name:str, **kwargs):
def apply_field_more(self, func, field_name, modify_fields=True, ignore_miss_dataset=True, **kwargs):
r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法
.. note::
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply`` 区别的介绍。
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param str field_name: 传入func的是哪个field。
:param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
如果为False,则报错
:param optional kwargs: 支持输入is_input, is_target, ignore_type
1. is_input: bool, 如果为True则将被修改的field设置为input
2. is_target: bool, 如果为True则将被修改的field设置为target
3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字
"""
res = {}
for name, dataset in self.datasets.items():
if dataset.has_field(field_name=field_name):
res[name] = dataset.apply_field_more(func=func, field_name=field_name, modify_fields=modify_fields, **kwargs)
elif not ignore_miss_dataset:
raise KeyError(f"{field_name} not found DataSet:{name} .")
return res

def apply(self, func, new_field_name: str, **kwargs):
r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法
对DataBundle中所有的dataset使用apply方法
:param callable func: input是instance中名为 `field_name` 的field的内容。
Expand All @@ -348,6 +381,31 @@ def apply(self, func, new_field_name:str, **kwargs):
dataset.apply(func, new_field_name=new_field_name, **kwargs)
return self

def apply_more(self, func, modify_fields=True, **kwargs):
r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法
.. note::
``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply`` 区别的介绍。
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True
:param optional kwargs: 支持输入is_input,is_target,ignore_type
1. is_input: bool, 如果为True则将被修改的的field设置为input
2. is_target: bool, 如果为True则将被修改的的field设置为target
3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字
"""
res = {}
for name, dataset in self.datasets.items():
res[name] = dataset.apply_more(func, modify_fields=modify_fields, **kwargs)
return res

def add_collate_fn(self, fn, name=None):
r"""
向所有DataSet增加collate_fn, collate_fn详见 :class:`~fastNLP.DataSet` 中相关说明.
Expand Down Expand Up @@ -380,5 +438,3 @@ def __repr__(self):
for name, vocab in self.vocabs.items():
_str += '\t{} has {} entries.\n'.format(name, len(vocab))
return _str


0 comments on commit 8732dfd

Please sign in to comment.