From 8732dfd979aaff0960f063d1a9005f95130313d4 Mon Sep 17 00:00:00 2001 From: ChenXin Date: Sat, 6 Jun 2020 11:06:17 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E4=BA=86=20DataBundle=20?= =?UTF-8?q?=E7=9A=84=20apply=5Fmore=20=E5=92=8C=20apply=5Ffield=5Fmore=20?= =?UTF-8?q?=E6=96=B9=E6=B3=95=E3=80=82=E9=9C=80=E8=A6=81=E8=BF=9B=E4=B8=80?= =?UTF-8?q?=E6=AD=A5=E8=AF=95=E7=94=A8=E5=92=8C=E6=B5=8B=E8=AF=95=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/dataset.py | 9 +++-- fastNLP/io/data_bundle.py | 70 +++++++++++++++++++++++++++++++++++---- 2 files changed, 67 insertions(+), 12 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 464a6446..5e80a6fb 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -861,10 +861,9 @@ def apply_field(self, func, field_name, new_field_name=None, **kwargs): 3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度 - """ assert len(self) != 0, "Null DataSet cannot use apply_field()." - if field_name not in self: + if not self.has_field(field_name=field_name): raise KeyError("DataSet has no field named `{}`.".format(field_name)) return self.apply(func, new_field_name, _apply_field=field_name, **kwargs) @@ -888,10 +887,10 @@ def apply_field_more(self, func, field_name, modify_fields=True, **kwargs): 3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 - :return Dict[int:Field]: 返回一个字典 + :return Dict[str:Field]: 返回一个字典 """ assert len(self) != 0, "Null DataSet cannot use apply_field()." - if field_name not in self: + if not self.has_field(field_name=field_name): raise KeyError("DataSet has no field named `{}`.".format(field_name)) return self.apply_more(func, modify_fields, _apply_field=field_name, **kwargs) @@ -950,7 +949,7 @@ def apply_more(self, func, modify_fields=True, **kwargs): 3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 - :return Dict[int:Field]: 返回一个字典 + :return Dict[str:Field]: 返回一个字典 """ # 返回 dict , 检查是否一直相同 assert callable(func), "The func you provide is not callable." diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index bcb8a211..e911a26f 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -166,7 +166,7 @@ def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, i dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type) return self - def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True): + def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True): r""" 将DataBundle中所有的DataSet中名为field_name的Field的padding值设置为pad_val. @@ -282,7 +282,7 @@ def get_dataset_names(self) -> List[str]: """ return list(self.datasets.keys()) - def get_vocab_names(self)->List[str]: + def get_vocab_names(self) -> List[str]: r""" 返回DataBundle中Vocabulary的名称 @@ -304,9 +304,9 @@ def iter_vocabs(self) -> Union[str, Vocabulary]: for field_name, vocab in self.vocabs.items(): yield field_name, vocab - def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs): + def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_dataset=True, **kwargs): r""" - 对DataBundle中所有的dataset使用apply_field方法 + 对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :meth:`~fastNLP.DataSet.apply_field` 方法 :param callable func: input是instance中名为 `field_name` 的field的内容。 :param str field_name: 传入func的是哪个field。 @@ -329,8 +329,41 @@ def apply_field(self, func, field_name: str, new_field_name: str, ignore_miss_da raise KeyError(f"{field_name} not found DataSet:{name}.") return self - def apply(self, func, new_field_name:str, **kwargs): + def apply_field_more(self, func, field_name, modify_fields=True, ignore_miss_dataset=True, **kwargs): r""" + 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 + + .. note:: + ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 + ``apply`` 区别的介绍。 + + :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 + :param str field_name: 传入func的是哪个field。 + :param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True + :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; + 如果为False,则报错 + :param optional kwargs: 支持输入is_input, is_target, ignore_type + + 1. is_input: bool, 如果为True则将被修改的field设置为input + + 2. is_target: bool, 如果为True则将被修改的field设置为target + + 3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型 + + :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 + """ + res = {} + for name, dataset in self.datasets.items(): + if dataset.has_field(field_name=field_name): + res[name] = dataset.apply_field_more(func=func, field_name=field_name, modify_fields=modify_fields, **kwargs) + elif not ignore_miss_dataset: + raise KeyError(f"{field_name} not found DataSet:{name} .") + return res + + def apply(self, func, new_field_name: str, **kwargs): + r""" + 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 + 对DataBundle中所有的dataset使用apply方法 :param callable func: input是instance中名为 `field_name` 的field的内容。 @@ -348,6 +381,31 @@ def apply(self, func, new_field_name:str, **kwargs): dataset.apply(func, new_field_name=new_field_name, **kwargs) return self + def apply_more(self, func, modify_fields=True, **kwargs): + r""" + 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 + + .. note:: + ``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 + ``apply`` 区别的介绍。 + + :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 + :param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True + :param optional kwargs: 支持输入is_input,is_target,ignore_type + + 1. is_input: bool, 如果为True则将被修改的的field设置为input + + 2. is_target: bool, 如果为True则将被修改的的field设置为target + + 3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型 + + :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 + """ + res = {} + for name, dataset in self.datasets.items(): + res[name] = dataset.apply_more(func, modify_fields=modify_fields, **kwargs) + return res + def add_collate_fn(self, fn, name=None): r""" 向所有DataSet增加collate_fn, collate_fn详见 :class:`~fastNLP.DataSet` 中相关说明. @@ -380,5 +438,3 @@ def __repr__(self): for name, vocab in self.vocabs.items(): _str += '\t{} has {} entries.\n'.format(name, len(vocab)) return _str - -