Skip to content

Commit

Permalink
fix ddp test(almost same with trainer.py)
Browse files Browse the repository at this point in the history
  • Loading branch information
wZuck committed Nov 6, 2021
1 parent 146a8e0 commit fb5a741
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 30 deletions.
87 changes: 65 additions & 22 deletions core/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@ class Test(object):
Build a tester from config dict, set up model from a saved checkpoint, etc. Test and log.
"""

def __init__(self, config, result_path=None):
def __init__(self, rank, config, result_path=None):
self.rank = rank
self.config = config
self.config["rank"] = rank
self.result_path = result_path
self.distribute = self.config["n_gpu"] > 1
self.viz_path, self.state_dict_path = self._init_files(config)
self.logger = self._init_logger()
self.device, self.list_ids = self._init_device(config)
self.device, self.list_ids = self._init_device(rank, config)
self.writer = TensorboardWriter(self.viz_path)
self.test_meter = self._init_meter()
print(config)
Expand Down Expand Up @@ -75,7 +78,10 @@ def _validate(self, epoch_idx):
"""
# switch to evaluate mode
self.model.eval()
self.model.reverse_setting_info()
if self.distribute:
self.model.module.reverse_setting_info()
else:
self.model.reverse_setting_info()
meter = self.test_meter
meter.reset()
episode_size = self.config["episode_size"]
Expand Down Expand Up @@ -123,7 +129,10 @@ def _validate(self, epoch_idx):
)
)
print(info_str)
self.model.reverse_setting_info()
if self.distribute:
self.model.module.reverse_setting_info()
else:
self.model.reverse_setting_info()
return meter.avg("acc"), accuracies

def _init_files(self, config):
Expand Down Expand Up @@ -190,10 +199,28 @@ def _init_dataloader(self, config):
Returns:
tuple: A tuple of (train_loader, val_loader and test_loader).
"""
test_loader = get_dataloader(config, "test", self.model_type)
self._check_data_config()
distribute = self.distribute
test_loader = get_dataloader(config, "test", self.model_type, distribute)

return test_loader

def _check_data_config(self):
"""
Check the config params.
"""
# check: episode_size >= n_gpu and episode_size != 0
assert (
self.config["episode_size"] >= self.config["n_gpu"] and self.config["episode_size"] != 0
)

# check: episode_size % n_gpu == 0
assert self.config["episode_size"] % self.config["n_gpu"] == 0

# check: episode_num % episode_size == 0
assert self.config["train_episode"] % self.config["episode_size"] == 0
assert self.config["test_episode"] % self.config["episode_size"] == 0

def _init_model(self, config):
"""
Init model(backbone+classifier) from the config dict and load the best checkpoint, then parallel if necessary .
Expand Down Expand Up @@ -224,24 +251,30 @@ def _init_model(self, config):
state_dict = torch.load(self.state_dict_path, map_location="cpu")
model.load_state_dict(state_dict)

model = model.to(self.device)
if len(self.list_ids) > 1:
parallel_list = self.config["parallel_part"]
if parallel_list is not None:
for parallel_part in parallel_list:
if hasattr(model, parallel_part):
setattr(
model,
parallel_part,
nn.DataParallel(
getattr(model, parallel_part),
device_ids=self.list_ids,
),
)
if self.distribute:
# higher order grad of BN in multi gpu will conflict with syncBN
# FIXME MAML with multi GPU is conflict with syncBN
if not (self.config["classifier"]["name"] in ["MAML"] and self.config["n_gpu"] > 1):
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
else:
print(
"{} with multi GPU will conflict with syncBN".format(
self.config["classifier"]["name"]
),
level="warn",
)
model = model.to(self.rank)
model = nn.parallel.DistributedDataParallel(
model, device_ids=[self.rank], output_device=self.rank, find_unused_parameters=True
)

return model, model.model_type
return model, model.module.model_type
else:
model = model.to(self.device)

def _init_device(self, config):
return model, model.model_type

def _init_device(self, rank, config):
"""
Init the devices from the config file.
Expand All @@ -252,7 +285,17 @@ def _init_device(self, config):
tuple: A tuple of deviceand list_ids.
"""
init_seed(config["seed"], config["deterministic"])
device, list_ids = prepare_device(config["device_ids"], config["n_gpu"])
device, list_ids = prepare_device(
rank,
config["device_ids"],
config["n_gpu"],
backend="nccl" if "dist_backend" not in self.config else self.config["dist_backend"],
dist_url="tcp://127.0.0.1:32512"
if "dist_url" not in self.config
else self.config["dist_url"],
)
torch.cuda.set_device(self.rank)

return device, list_ids

def _init_meter(self):
Expand Down
25 changes: 17 additions & 8 deletions run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,30 @@
sys.dont_write_bytecode = True

import os

import torch
from core.config import Config
from core import Test

PATH = "./results/RelationNet-miniImageNet--ravi-resnet18-5-1-Aug-22-2021_21-22-02"

PATH = "./results/ckpt"
VAR_DICT = {
"test_epoch": 5,
"device_ids": "2",
"n_gpu": 1,
"device_ids": "1,2",
"n_gpu": 2,
"test_episode": 600,
"episode_size": 1,
"test_way": 6,
"episode_size": 2,
"test_way": 5,
}


def main(rank, config):
test = Test(rank, config, PATH)
test.test_loop()


if __name__ == "__main__":
config = Config(os.path.join(PATH, "config.yaml"), VAR_DICT).get_config_dict()
test = Test(config, PATH)
test.test_loop()
if config["n_gpu"] > 1:
torch.multiprocessing.spawn(main, nprocs=config["n_gpu"], args=(config,))
else:
main(0, config)

0 comments on commit fb5a741

Please sign in to comment.