Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Update setup_env and config to use PathManager
Browse files Browse the repository at this point in the history
Summary: Enables use of PathManager in pycls and updates config to make use of it

Reviewed By: vaibhava0

Differential Revision: D24163964

fbshipit-source-id: b9075e5eb4a167065351d348696bf923e65b1c92
  • Loading branch information
theschnitz authored and facebook-github-bot committed Jan 13, 2021
1 parent ca89a79 commit b5063a6
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 37 deletions.
13 changes: 10 additions & 3 deletions pycls/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import sys

from iopath.common.file_io import g_pathmgr
from pycls.core.io import cache_url
from yacs.config import CfgNode as CfgNode

Expand Down Expand Up @@ -377,17 +378,23 @@ def cache_cfg_urls():
_C.TEST.WEIGHTS = cache_url(_C.TEST.WEIGHTS, _C.DOWNLOAD_CACHE)


def merge_from_file(cfg_file):
with g_pathmgr.open(cfg_file, "r") as f:
cfg = _C.load_cfg(f)
_C.merge_from_other_cfg(cfg)


def dump_cfg():
"""Dumps the config to the output directory."""
cfg_file = os.path.join(_C.OUT_DIR, _C.CFG_DEST)
with open(cfg_file, "w") as f:
with g_pathmgr.open(cfg_file, "w") as f:
_C.dump(stream=f)


def load_cfg(out_dir, cfg_dest="config.yaml"):
"""Loads config from specified output directory."""
cfg_file = os.path.join(out_dir, cfg_dest)
_C.merge_from_file(cfg_file)
merge_from_file(cfg_file)


def reset_cfg():
Expand All @@ -406,5 +413,5 @@ def load_cfg_fom_args(description="Config file options."):
parser.print_help()
sys.exit(1)
args = parser.parse_args()
_C.merge_from_file(args.cfg_file)
merge_from_file(args.cfg_file)
_C.merge_from_list(args.opts)
35 changes: 35 additions & 0 deletions pycls/core/env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import random

import numpy as np
import pycls.core.config as config
import pycls.core.distributed as dist
import pycls.core.logging as logging
import torch
from iopath.common.file_io import g_pathmgr
from pycls.core.config import cfg


logger = logging.get_logger(__name__)


def setup_env():
"""Sets up environment for training or testing."""
if dist.is_master_proc():
# Ensure that the output dir exists
g_pathmgr.mkdirs(cfg.OUT_DIR)
# Save the config
config.dump_cfg()
# Setup logging
logging.setup_logging()
# Log torch, cuda, and cudnn versions
version = [torch.__version__, torch.version.cuda, torch.backends.cudnn.version()]
logger.info("PyTorch Version: torch={}, cuda={}, cudnn={}".format(*version))
# Log the config as both human readable and as a json
logger.info("Config:\n{}".format(cfg)) if cfg.VERBOSE else ()
logger.info(logging.dump_log_data(cfg, "cfg", None))
# Fix the RNG seeds (see RNG comment in core/config.py for discussion)
np.random.seed(cfg.RNG_SEED)
torch.manual_seed(cfg.RNG_SEED)
random.seed(cfg.RNG_SEED)
# Configure the CUDNN backend
torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK
37 changes: 5 additions & 32 deletions pycls/core/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,12 @@
# LICENSE file in the root directory of this source tree.

"""Tools for training and testing a model."""

import os
import random

import numpy as np
import pycls.core.benchmark as benchmark
import pycls.core.builders as builders
import pycls.core.checkpoint as cp
import pycls.core.config as config
import pycls.core.distributed as dist
import pycls.core.env as env
import pycls.core.logging as logging
import pycls.core.meters as meters
import pycls.core.net as net
Expand All @@ -29,29 +25,6 @@
logger = logging.get_logger(__name__)


def setup_env():
"""Sets up environment for training or testing."""
if dist.is_master_proc():
# Ensure that the output dir exists
os.makedirs(cfg.OUT_DIR, exist_ok=True)
# Save the config
config.dump_cfg()
# Setup logging
logging.setup_logging()
# Log torch, cuda, and cudnn versions
version = [torch.__version__, torch.version.cuda, torch.backends.cudnn.version()]
logger.info("PyTorch Version: torch={}, cuda={}, cudnn={}".format(*version))
# Log the config as both human readable and as a json
logger.info("Config:\n{}".format(cfg)) if cfg.VERBOSE else ()
logger.info(logging.dump_log_data(cfg, "cfg", None))
# Fix the RNG seeds (see RNG comment in core/config.py for discussion)
np.random.seed(cfg.RNG_SEED)
torch.manual_seed(cfg.RNG_SEED)
random.seed(cfg.RNG_SEED)
# Configure the CUDNN backend
torch.backends.cudnn.benchmark = cfg.CUDNN.BENCHMARK


def setup_model():
"""Sets up a model for training or testing and log the results."""
# Build the model
Expand Down Expand Up @@ -145,7 +118,7 @@ def test_epoch(loader, model, meter, cur_epoch):
def train_model():
"""Trains the model."""
# Setup training/testing environment
setup_env()
env.setup_env()
# Construct the model, loss_fun, and optimizer
model = setup_model()
loss_fun = builders.build_loss_fun().cuda()
Expand Down Expand Up @@ -194,7 +167,7 @@ def train_model():
def test_model():
"""Evaluates a trained model."""
# Setup training/testing environment
setup_env()
env.setup_env()
# Construct the model
model = setup_model()
# Load model weights
Expand All @@ -210,7 +183,7 @@ def test_model():
def time_model():
"""Times model."""
# Setup training/testing environment
setup_env()
env.setup_env()
# Construct the model and loss_fun
model = setup_model()
loss_fun = builders.build_loss_fun().cuda()
Expand All @@ -221,7 +194,7 @@ def time_model():
def time_model_and_loader():
"""Times model and data loader."""
# Setup training/testing environment
setup_env()
env.setup_env()
# Construct the model and loss_fun
model = setup_model()
loss_fun = builders.build_loss_fun().cuda()
Expand Down
5 changes: 3 additions & 2 deletions pycls/datasets/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(self, data_path, split):
splits = ["train", "test"]
assert split in splits, "Split '{}' not supported for cifar".format(split)
logger.info("Constructing CIFAR-10 {}...".format(split))
self._im_size = cfg.TRAIN.IM_SIZE
self._data_path, self._split = data_path, split
self._inputs, self._labels = self._load_data()

Expand All @@ -52,7 +53,7 @@ def _load_data(self):
labels += data[b"labels"]
# Combine and reshape the inputs
inputs = np.vstack(inputs).astype(np.float32)
inputs = inputs.reshape((-1, 3, cfg.TRAIN.IM_SIZE, cfg.TRAIN.IM_SIZE))
inputs = inputs.reshape((-1, 3, self._im_size, self._im_size))
return inputs, labels

def _prepare_im(self, im):
Expand All @@ -62,7 +63,7 @@ def _prepare_im(self, im):
im[i] = (im[i] - _MEAN[i]) / _STD[i]
if self._split == "train":
# Randomly flip and crop center patch from CHW image
size = cfg.TRAIN.IM_SIZE
size = self._im_size
im = im[:, :, ::-1] if np.random.uniform() < 0.5 else im
im = np.pad(im, ((0, 0), (4, 4), (4, 4)), mode="constant")
y = np.random.randint(0, im.shape[1] - size)
Expand Down

0 comments on commit b5063a6

Please sign in to comment.