From 2d0f1c46d0db798339b46adc4e9154a04fabdd65 Mon Sep 17 00:00:00 2001 From: andreh7 Date: Thu, 13 Jul 2017 07:27:04 +0200 Subject: [PATCH] added a function makedirs() which works both for python 2 and 3 (#176) --- snli/train.py | 6 +++--- snli/util.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/snli/train.py b/snli/train.py index 60ec0e75ad..e64eb540f5 100644 --- a/snli/train.py +++ b/snli/train.py @@ -10,7 +10,7 @@ from torchtext import datasets from model import SNLIClassifier -from util import get_args +from util import get_args, makedirs args = get_args() @@ -27,7 +27,7 @@ inputs.vocab.vectors = torch.load(args.vector_cache) else: inputs.vocab.load_vectors(wv_dir=args.data_cache, wv_type=args.word_vectors, wv_dim=args.d_embed) - os.makedirs(os.path.dirname(args.vector_cache), exist_ok=True) + makedirs(os.path.dirname(args.vector_cache)) torch.save(inputs.vocab.vectors, args.vector_cache) answers.build_vocab(train) @@ -61,7 +61,7 @@ header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy' dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(',')) log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(',')) -os.makedirs(args.save_path, exist_ok=True) +makedirs(args.save_path) print(header) for epoch in range(args.epochs): diff --git a/snli/util.py b/snli/util.py index 1ef2133de6..509fdbdf1f 100644 --- a/snli/util.py +++ b/snli/util.py @@ -1,6 +1,23 @@ import os from argparse import ArgumentParser +def makedirs(name): + """helper function for python 2 and 3 to call os.makedirs() + avoiding an error if the directory to be created already exists""" + + import os, errno + + try: + os.makedirs(name) + except OSError as ex: + if ex.errno == errno.EEXIST and os.path.isdir(name): + # ignore existing directory + pass + else: + # a different error happened + raise + + def get_args(): parser = ArgumentParser(description='PyTorch/torchtext SNLI example') parser.add_argument('--epochs', type=int, default=50)