Skip to content

Commit

Permalink
added a function makedirs() which works both for python 2 and 3 (pyto…
Browse files Browse the repository at this point in the history
  • Loading branch information
andreh7 authored and soumith committed Jul 13, 2017
1 parent 1b26501 commit 2d0f1c4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
6 changes: 3 additions & 3 deletions snli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torchtext import datasets

from model import SNLIClassifier
from util import get_args
from util import get_args, makedirs


args = get_args()
Expand All @@ -27,7 +27,7 @@
inputs.vocab.vectors = torch.load(args.vector_cache)
else:
inputs.vocab.load_vectors(wv_dir=args.data_cache, wv_type=args.word_vectors, wv_dim=args.d_embed)
os.makedirs(os.path.dirname(args.vector_cache), exist_ok=True)
makedirs(os.path.dirname(args.vector_cache))
torch.save(inputs.vocab.vectors, args.vector_cache)
answers.build_vocab(train)

Expand Down Expand Up @@ -61,7 +61,7 @@
header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy'
dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(','))
log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(','))
os.makedirs(args.save_path, exist_ok=True)
makedirs(args.save_path)
print(header)

for epoch in range(args.epochs):
Expand Down
17 changes: 17 additions & 0 deletions snli/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,23 @@
import os
from argparse import ArgumentParser

def makedirs(name):
"""helper function for python 2 and 3 to call os.makedirs()
avoiding an error if the directory to be created already exists"""

import os, errno

try:
os.makedirs(name)
except OSError as ex:
if ex.errno == errno.EEXIST and os.path.isdir(name):
# ignore existing directory
pass
else:
# a different error happened
raise


def get_args():
parser = ArgumentParser(description='PyTorch/torchtext SNLI example')
parser.add_argument('--epochs', type=int, default=50)
Expand Down

0 comments on commit 2d0f1c4

Please sign in to comment.