Skip to content

Commit

Permalink
Change default argument save_as_state_dict in ModelCheckpoint to True (
Browse files Browse the repository at this point in the history
…pytorch#416)

* Keyword arg save_as_state_dict now defaults to True

* Adapt tests to reflect the new default
  • Loading branch information
Fabian Schilling authored and vfdev-5 committed Jan 26, 2019
1 parent 96c9e6d commit 79e5233
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
2 changes: 1 addition & 1 deletion ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(self, dirname, filename_prefix,
n_saved=1,
atomic=True, require_empty=True,
create_dir=True,
save_as_state_dict=False):
save_as_state_dict=True):

self._dirname = os.path.expanduser(dirname)
self._fname_prefix = filename_prefix
Expand Down
23 changes: 15 additions & 8 deletions tests/ignite/handlers/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def test_args_validation(dirname):


def test_simple_recovery(dirname):
h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, save_interval=1)
h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, save_interval=1,
save_as_state_dict=False)
h(None, {'obj': 42})

fname = os.path.join(dirname, '{}_{}_{}.pth'.format(_PREFIX, 'obj', 1))
Expand All @@ -70,7 +71,8 @@ def test_simple_recovery_from_existing_non_empty(dirname):
with open(previous_fname, 'w') as f:
f.write("test")

h = ModelCheckpoint(dirname, _PREFIX, create_dir=True, require_empty=False, save_interval=1)
h = ModelCheckpoint(dirname, _PREFIX, create_dir=True, require_empty=False,
save_interval=1, save_as_state_dict=False)
h(None, {'obj': 42})

fname = os.path.join(dirname, '{}_{}_{}.pth'.format(_PREFIX, 'obj', 1))
Expand All @@ -87,11 +89,12 @@ def _test_existance(atomic, name, obj, expected):
atomic=atomic,
create_dir=False,
require_empty=False,
save_interval=1)
save_interval=1,
save_as_state_dict=False)

try:
h(None, {name: obj})
except:
except Exception:
pass

fname = os.path.join(dirname, '{}_{}_{}.pth'.format(_PREFIX, name, 1))
Expand All @@ -105,7 +108,8 @@ def _test_existance(atomic, name, obj, expected):


def test_last_k(dirname):
h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2, save_interval=2)
h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2,
save_interval=2, save_as_state_dict=False)
to_save = {'name': 42}

for _ in range(8):
Expand All @@ -124,7 +128,8 @@ def score_function(engine):
return next(scores)

h = ModelCheckpoint(dirname, _PREFIX, create_dir=False,
n_saved=2, score_function=score_function)
n_saved=2, score_function=score_function,
save_as_state_dict=False)

to_save = {'name': 42}
for _ in range(4):
Expand All @@ -144,7 +149,8 @@ def score_function(engine):
return next(scores_iter)

h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2,
score_function=score_function, score_name="val_loss")
score_function=score_function, score_name="val_loss",
save_as_state_dict=False)

to_save = {'name': 42}
for _ in range(4):
Expand All @@ -164,7 +170,8 @@ def update_fn(engine, batch):
name = 'model'
engine = Engine(update_fn)
handler = ModelCheckpoint(dirname, _PREFIX, create_dir=False,
n_saved=2, save_interval=1)
n_saved=2, save_interval=1,
save_as_state_dict=False)

engine.add_event_handler(Events.EPOCH_COMPLETED, handler, {name: 42})
engine.run([0], max_epochs=4)
Expand Down

0 comments on commit 79e5233

Please sign in to comment.