Skip to content

Commit

Permalink
Added saving args for easier model loading.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaletap committed Oct 16, 2020
1 parent f1a8ab5 commit 5017aa6
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def main(args):
min_occ=args.min_occ
)

model = SentenceVAE(
params = dict(
vocab_size=datasets['train'].vocab_size,
sos_idx=datasets['train'].sos_idx,
eos_idx=datasets['train'].eos_idx,
Expand All @@ -44,7 +44,8 @@ def main(args):
latent_size=args.latent_size,
num_layers=args.num_layers,
bidirectional=args.bidirectional
)
)
model = SentenceVAE(**params)

if torch.cuda.is_available():
model = model.cuda()
Expand All @@ -60,6 +61,9 @@ def main(args):
save_model_path = os.path.join(args.save_model_path, ts)
os.makedirs(save_model_path)

with open(os.path.join(save_model_path, 'model_params.json'), 'w') as f:
json.dump(params, f, indent=4)

def kl_anneal_function(anneal_function, step, k, x0):
if anneal_function == 'logistic':
return float(1/(1+np.exp(-k*(step-x0))))
Expand All @@ -72,7 +76,7 @@ def loss_fn(logp, target, length, mean, logv, anneal_function, step, k, x0):
# cut-off unnecessary padding from target, and flatten
target = target[:, :torch.max(length).item()].contiguous().view(-1)
logp = logp.view(-1, logp.size(2))

# Negative Log Likelihood
NLL_loss = NLL(logp, target)

Expand Down

0 comments on commit 5017aa6

Please sign in to comment.