Skip to content

Commit

Permalink
save vocabulary file after training
Browse files Browse the repository at this point in the history
  • Loading branch information
markus-eberts committed Feb 18, 2020
1 parent 8585782 commit bfa93b9
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
2 changes: 1 addition & 1 deletion spert/spert_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def train(self, train_path: str, valid_path: str, types_path: str, input_reader_
# save final model
extra = dict(epoch=args.epochs, updates_epoch=updates_epoch, epoch_iteration=0)
global_iteration = args.epochs * updates_epoch
self._save_model(self._save_path, model, global_iteration,
self._save_model(self._save_path, model, self._tokenizer, global_iteration,
optimizer=optimizer if self.args.save_optimizer else None, extra=extra,
include_iteration=False, name='final_model')

Expand Down
17 changes: 13 additions & 4 deletions spert/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.nn import DataParallel
from torch.optim import Optimizer
from transformers import PreTrainedModel
from transformers import PreTrainedTokenizer

from spert import util
from spert.opt import tensorboardX
Expand Down Expand Up @@ -91,16 +92,18 @@ def _log_csv(self, dataset_label: str, data_label: str, *data: Tuple[object]):
logs = self._log_paths[dataset_label]
util.append_csv(logs[data_label], *data)

def _save_best(self, model, optimizer, accuracy, iteration, label, extra=None):
def _save_best(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, optimizer: Optimizer,
accuracy: float, iteration: int, label: str, extra=None):
if accuracy > self._best_results[label]:
self._logger.info("[%s] Best model in iteration %s: %s%% accuracy" % (label, iteration, accuracy))
self._save_model(self._save_path, model, iteration,
self._save_model(self._save_path, model, tokenizer, iteration,
optimizer=optimizer if self.args.save_optimizer else None,
save_as_best=True, name='model_%s' % label, extra=extra)
self._best_results[label] = accuracy

def _save_model(self, save_path: str, model: PreTrainedModel, iteration: int, optimizer: Optimizer = None,
save_as_best: bool = False, extra: dict = None, include_iteration: int = True, name: str = 'model'):
def _save_model(self, save_path: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizer,
iteration: int, optimizer: Optimizer = None, save_as_best: bool = False,
extra: dict = None, include_iteration: int = True, name: str = 'model'):
extra_state = dict(iteration=iteration)

if optimizer:
Expand All @@ -117,10 +120,16 @@ def _save_model(self, save_path: str, model: PreTrainedModel, iteration: int, op

util.create_directories_dir(dir_path)

# save model
if isinstance(model, DataParallel):
model.module.save_pretrained(dir_path)
else:
model.save_pretrained(dir_path)

# save vocabulary
tokenizer.save_pretrained(dir_path)

# save extra
state_path = os.path.join(dir_path, 'extra.state')
torch.save(extra_state, state_path)

Expand Down

0 comments on commit bfa93b9

Please sign in to comment.