diff --git a/configs/recognition/st_gcn_aaai18/kinetics-skeleton/test.yaml b/configs/recognition/st_gcn_aaai18/kinetics-skeleton/test.yaml index e84435019..0a310c4b1 100644 --- a/configs/recognition/st_gcn_aaai18/kinetics-skeleton/test.yaml +++ b/configs/recognition/st_gcn_aaai18/kinetics-skeleton/test.yaml @@ -4,6 +4,9 @@ argparse_cfg: help: number of gpus batch_size: bind_to: processor_cfg.batch_size + type: int + gpu_batch_size: + bind_to: processor_cfg.gpu_batch_size checkpoint: bind_to: processor_cfg.checkpoint help: the checkpoint file to load from @@ -31,5 +34,6 @@ processor_cfg: # debug: true # dataloader setting - batch_size: 64 - gpus: 1 + batch_size: null + gpu_batch_size: 64 + gpus: -1 diff --git a/configs/recognition/st_gcn_aaai18/ntu-rgbd-xsub/test.yaml b/configs/recognition/st_gcn_aaai18/ntu-rgbd-xsub/test.yaml index b3a9d410e..8ced91486 100644 --- a/configs/recognition/st_gcn_aaai18/ntu-rgbd-xsub/test.yaml +++ b/configs/recognition/st_gcn_aaai18/ntu-rgbd-xsub/test.yaml @@ -4,6 +4,9 @@ argparse_cfg: help: number of gpus batch_size: bind_to: processor_cfg.batch_size + type: int + gpu_batch_size: + bind_to: processor_cfg.gpu_batch_size checkpoint: bind_to: processor_cfg.checkpoint help: the checkpoint file to load from @@ -32,5 +35,6 @@ processor_cfg: # debug: true # dataloader setting - batch_size: 64 - gpus: 1 + batch_size: null + gpu_batch_size: 64 + gpus: -1 diff --git a/configs/recognition/st_gcn_aaai18/ntu-rgbd-xview/test.yaml b/configs/recognition/st_gcn_aaai18/ntu-rgbd-xview/test.yaml index 87d45779f..0a90c55e2 100644 --- a/configs/recognition/st_gcn_aaai18/ntu-rgbd-xview/test.yaml +++ b/configs/recognition/st_gcn_aaai18/ntu-rgbd-xview/test.yaml @@ -4,6 +4,9 @@ argparse_cfg: help: number of gpus batch_size: bind_to: processor_cfg.batch_size + type: int + gpu_batch_size: + bind_to: processor_cfg.gpu_batch_size checkpoint: bind_to: processor_cfg.checkpoint help: the checkpoint file to load from @@ -32,5 +35,6 @@ processor_cfg: # debug: true # dataloader setting - batch_size: 64 - gpus: 1 + batch_size: null + gpu_batch_size: 64 + gpus: -1 diff --git a/mmskeleton/processor/recognition.py b/mmskeleton/processor/recognition.py index e489f070a..a8b167b93 100644 --- a/mmskeleton/processor/recognition.py +++ b/mmskeleton/processor/recognition.py @@ -8,7 +8,21 @@ from mmcv.parallel import MMDataParallel -def test(model_cfg, dataset_cfg, checkpoint, batch_size=64, gpus=1, workers=4): +def test(model_cfg, + dataset_cfg, + checkpoint, + batch_size=None, + gpu_batch_size=None, + gpus=-1, + workers=4): + + # calculate batch size + if gpus < 0: + gpus = torch.cuda.device_count() + if (batch_size is None) and (gpu_batch_size is not None): + batch_size = gpu_batch_size * gpus + assert batch_size is not None, 'Please appoint batch_size or gpu_batch_size.' + dataset = call_obj(**dataset_cfg) data_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, @@ -50,17 +64,25 @@ def train( loss_cfg, dataset_cfg, optimizer_cfg, - batch_size, total_epochs, training_hooks, + batch_size=None, + gpu_batch_size=None, workflow=[('train', 1)], - gpus=1, + gpus=-1, log_level=0, workers=4, resume_from=None, load_from=None, ): + # calculate batch size + if gpus < 0: + gpus = torch.cuda.device_count() + if (batch_size is None) and (gpu_batch_size is not None): + batch_size = gpu_batch_size * gpus + assert batch_size is not None, 'Please appoint batch_size or gpu_batch_size.' + # prepare data loaders if isinstance(dataset_cfg, dict): dataset_cfg = [dataset_cfg] @@ -79,9 +101,8 @@ def train( else: model = call_obj(**model_cfg) model.apply(weights_init) - print(111, len(model.edge_importance)) + model = MMDataParallel(model, device_ids=range(gpus)).cuda() - print(222, len(model.module.edge_importance)) loss = call_obj(**loss_cfg) # build runner @@ -96,7 +117,6 @@ def train( # run workflow = [tuple(w) for w in workflow] - print(222, len(model.module.edge_importance)) runner.run(data_loaders, workflow, total_epochs, loss=loss) diff --git a/mmskl.py b/mmskl.py index da25c3454..33a6a9269 100644 --- a/mmskl.py +++ b/mmskl.py @@ -77,6 +77,8 @@ def parse_cfg(): if 'type' not in info: if default is not None: info['type'] = type(default) + else: + info['type'] = eval(info['type']) kwargs = dict(default=default) kwargs.update({k: v for k, v in info.items() if k != 'bind_to'}) parser.add_argument('--' + key, **kwargs)