Skip to content

Commit

Permalink
修复save_path在dev为空的时候的bug
Browse files Browse the repository at this point in the history
  • Loading branch information
yhcc committed Sep 19, 2020
1 parent 720d7b8 commit 40bec21
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 14 deletions.
6 changes: 3 additions & 3 deletions fastNLP/core/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,16 @@ class LossFunc(LossBase):
r"""
提供给用户使用自定义损失函数的类
:param func: 用户自行定义的损失函数,应当为一个函数或者callable(func)为True的ojbect
:param func: 用户自行定义的损失函数,应当为一个函数。
:param dict key_map: 参数映射表。键为Model/DataSet参数名,值为损失函数参数名。
fastNLP的trainer将在训练时从模型返回值或者训练数据DataSet的target=True的field中
找到相对应的参数名为value的参数,并传入func中作为参数名为key的参数
:param kwargs: 除了参数映射表以外可以用key word args的方式设置参数映射关系
使用方法::
func = torch.nn.CrossEntropyLoss()
loss_func = LossFunc(func, input="pred", target="label")
import torch.nn.functional as F
loss_func = LossFunc(F.cross_entropy, input="pred", target="label")
# 这表示构建了一个损失函数类,由func计算损失函数,其中将从模型返回值或者DataSet的target=True的field
# 当中找到一个参数名为`pred`的参数传入func一个参数名为`input`的参数;找到一个参数名为`label`的参数
# 传入func作为一个名为`target`的参数
Expand Down
5 changes: 5 additions & 0 deletions fastNLP/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,11 @@ def train(self, load_best_model=True, on_exception='auto'):
self.logger.info("Reloaded the best model.")
else:
self.logger.info("Fail to reload best model.")

if self.dev_data is None and self.save_path is not None:
model_name = "_".join([self.model.__class__.__name__, self.start_time])
self._save_model(self.model, model_name)

finally:
if self.dev_data is not None and self.best_dev_perf is not None:
self.logger.info(
Expand Down
4 changes: 2 additions & 2 deletions fastNLP/embeddings/bert_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased'
word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS]
来进行分类的任务将auto_truncate置为True。
:param kwargs:
int min_freq: 小于该次数的词会被unk代替
int min_freq: 小于该次数的词会被unk代替, 默认为1
"""
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)

Expand All @@ -110,7 +110,7 @@ def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased'
if '[CLS]' in vocab:
self._word_cls_index = vocab['CLS']

min_freq = kwargs.get('min_freq', 2)
min_freq = kwargs.get('min_freq', 1)
self._min_freq = min_freq
self.model = _BertWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
pool_method=pool_method, include_cls_sep=include_cls_sep,
Expand Down
4 changes: 2 additions & 2 deletions fastNLP/embeddings/gpt2_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', layers: str

only_use_pretrain_bpe = kwargs.get('only_use_pretrain_bpe', False)
truncate_embed = kwargs.get('truncate_embed', True)
min_freq = kwargs.get('min_freq', 2)
min_freq = kwargs.get('min_freq', 1)

self.lm_loss =language_model
self.model = _GPT2Model(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
Expand Down Expand Up @@ -315,7 +315,7 @@ def get_lm_loss(self, release=True):

class _GPT2Model(nn.Module):
def __init__(self, model_dir_or_name, vocab, layers, pool_method='first', auto_truncate=True, language_model=False,
only_use_pretrain_bpe=False, min_freq=2, truncate_embed=False):
only_use_pretrain_bpe=False, min_freq=1, truncate_embed=False):
super().__init__()

self.tokenzier = GPT2Tokenizer.from_pretrained(model_dir_or_name)
Expand Down
6 changes: 3 additions & 3 deletions fastNLP/embeddings/roberta_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', layers: str
word pieces后的内容,并将第512个word piece置为</s>。超过长度的部分的encode结果直接全部置零。一般仅有只使用<s>
来进行分类的任务将auto_truncate置为True。
:param kwargs:
int min_freq: 小于该次数的词会被unk代替
int min_freq: 小于该次数的词会被unk代替, 默认为1
"""
super().__init__(vocab, word_dropout=word_dropout, dropout=dropout)

Expand All @@ -93,7 +93,7 @@ def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', layers: str
if '<s>' in vocab:
self._word_cls_index = vocab['<s>']

min_freq = kwargs.get('min_freq', 2)
min_freq = kwargs.get('min_freq', 1)
self._min_freq = min_freq

self.model = _RobertaWordModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers,
Expand Down Expand Up @@ -464,7 +464,7 @@ def save(self, folder):

os.makedirs(os.path.join(folder, ROBERTA_ENCODER_FOLDER), exist_ok=True)
self.model.save(os.path.join(folder, ROBERTA_ENCODER_FOLDER))
logger.debug(f"BertWordPieceEncoder has been saved in {folder}")
logger.debug(f"RobertaWordPieceEncoder has been saved in {folder}")

@classmethod
def load(cls, folder):
Expand Down
4 changes: 2 additions & 2 deletions fastNLP/embeddings/static_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def __init__(self, vocab: Vocabulary, model_dir_or_name: Union[str, None] = 'en'
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。
:param dict kwargs:
bool only_train_min_freq: 仅对train中的词语使用min_freq筛选;
bool only_norm_found_vector: 是否仅对在预训练中找到的词语使用normalize;
bool only_use_pretrain_word: 仅使用出现在pretrain词表中的词,如果该词没有在预训练的词表中出现则为unk。如果embedding不需要更新建议设置为True。
bool only_norm_found_vector: 默认为False, 是否仅对在预训练中找到的词语使用normalize;
bool only_use_pretrain_word: 默认为False, 仅使用出现在pretrain词表中的词,如果该词没有在预训练的词表中出现则为unk。如果embedding不需要更新建议设置为True。
"""
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
if embedding_dim > 0:
Expand Down
2 changes: 1 addition & 1 deletion fastNLP/modules/generator/seq2seq_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_
max_len_eos_mask = max_lengths.eq(cur_len+1)
eos_scores = scores[:, _eos_token_id]
# 如果已经达到最大长度,就把eos的分数加大
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+100, eos_scores)
scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e12, eos_scores)

if do_sample:
if temperature > 0 and temperature != 1:
Expand Down
13 changes: 12 additions & 1 deletion test/core/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,19 @@ def test_save_path(self):
use_tqdm=True, check_code_level=2)
trainer.train()
import os
import shutil
self.assertTrue(os.path.exists(save_path))
if os.path.exists(save_path):
shutil.rmtree(save_path)

# 无dev_data的训练
trainer = Trainer(train_set, model, optimizer=SGD(lr=0.1), loss=BCELoss(pred="predict", target="y"),
batch_size=32, n_epochs=10, print_every=50, dev_data=None,
metrics=None, validate_every=-1, save_path=save_path,
use_tqdm=True, check_code_level=2)
trainer.train()
self.assertTrue(os.path.exists(save_path))
if os.path.exists(save_path):
import shutil
shutil.rmtree(save_path)

def test_trainer_suggestion1(self):
Expand Down

0 comments on commit 40bec21

Please sign in to comment.