Skip to content

Commit

Permalink
Merge pull request jcjohnson#73 from ChrisCummins/wip/checkpoints
Browse files Browse the repository at this point in the history
(optionally) Save and restore the checkpoint iteration counter
  • Loading branch information
jcjohnson committed Apr 24, 2016
2 parents f8f4d94 + 6990602 commit fa666c3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
1 change: 1 addition & 0 deletions doc/flags.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The training script `train.lua` accepts the following command-line flags:

**Model options**:
- `-init_from`: Path to a checkpoint file from a previous run of `train.lua`. Use this to continue training from an existing checkpoint; if this flag is passed then the other flags in this section will be ignored and the architecture from the existing checkpoint will be used instead.
- `-reset_iterations`: Set this to 0 to restore the iteration counter of a previous run. Default is 1 (do not restore iteration counter). Only applicable if `-init_from` option is used.
- `-model_type`: The type of recurrent network to use; either `lstm` (default) or `rnn`. `lstm` is slower but better.
- `-wordvec_size`: Dimension of learned word vector embeddings; default is 64. You probably won't need to change this.
- `-rnn_size`: The number of hidden units in the RNN; default is 128. Larger values (256 or 512) are commonly used to learn more powerful models and for bigger datasets, but this will significantly slow down computation.
Expand Down
11 changes: 9 additions & 2 deletions train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ cmd:option('-seq_length', 50)

-- Model options
cmd:option('-init_from', '')
cmd:option('-reset_iterations', 1)
cmd:option('-model_type', 'lstm')
cmd:option('-wordvec_size', 64)
cmd:option('-rnn_size', 128)
Expand Down Expand Up @@ -83,9 +84,14 @@ end
local opt_clone = torch.deserialize(torch.serialize(opt))
opt_clone.idx_to_token = idx_to_token
local model = nil
local start_i = 0
if opt.init_from ~= '' then
print('Initializing from ', opt.init_from)
model = torch.load(opt.init_from).model:type(dtype)
local checkpoint = torch.load(opt.init_from)
model = checkpoint.model:type(dtype)
if opt.reset_iterations == 0 then
start_i = checkpoint.i
end
else
model = nn.LanguageModel(opt_clone):type(dtype)
end
Expand Down Expand Up @@ -162,7 +168,7 @@ local optim_config = {learningRate = opt.learning_rate}
local num_train = loader.split_sizes['train']
local num_iterations = opt.max_epochs * num_train
model:training()
for i = 1, num_iterations do
for i = start_i + 1, num_iterations do
local epoch = math.floor(i / num_train) + 1

-- Check if we are at the end of an epoch
Expand Down Expand Up @@ -219,6 +225,7 @@ for i = 1, num_iterations do
val_loss_history_it = val_loss_history_it,
forward_backward_times = forward_backward_times,
memory_usage = memory_usage,
i = i
}
local filename = string.format('%s_%d.json', opt.checkpoint_name, i)
-- Make sure the output directory exists before we try to write it
Expand Down

0 comments on commit fa666c3

Please sign in to comment.