Skip to content

Commit

Permalink
adaptive gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
yysijie committed May 24, 2020
1 parent ce0e28a commit 1d53c5b
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 12 deletions.
8 changes: 6 additions & 2 deletions configs/recognition/st_gcn_aaai18/kinetics-skeleton/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ argparse_cfg:
help: number of gpus
batch_size:
bind_to: processor_cfg.batch_size
type: int
gpu_batch_size:
bind_to: processor_cfg.gpu_batch_size
checkpoint:
bind_to: processor_cfg.checkpoint
help: the checkpoint file to load from
Expand Down Expand Up @@ -31,5 +34,6 @@ processor_cfg:
# debug: true

# dataloader setting
batch_size: 64
gpus: 1
batch_size: null
gpu_batch_size: 64
gpus: -1
8 changes: 6 additions & 2 deletions configs/recognition/st_gcn_aaai18/ntu-rgbd-xsub/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ argparse_cfg:
help: number of gpus
batch_size:
bind_to: processor_cfg.batch_size
type: int
gpu_batch_size:
bind_to: processor_cfg.gpu_batch_size
checkpoint:
bind_to: processor_cfg.checkpoint
help: the checkpoint file to load from
Expand Down Expand Up @@ -32,5 +35,6 @@ processor_cfg:
# debug: true

# dataloader setting
batch_size: 64
gpus: 1
batch_size: null
gpu_batch_size: 64
gpus: -1
8 changes: 6 additions & 2 deletions configs/recognition/st_gcn_aaai18/ntu-rgbd-xview/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ argparse_cfg:
help: number of gpus
batch_size:
bind_to: processor_cfg.batch_size
type: int
gpu_batch_size:
bind_to: processor_cfg.gpu_batch_size
checkpoint:
bind_to: processor_cfg.checkpoint
help: the checkpoint file to load from
Expand Down Expand Up @@ -32,5 +35,6 @@ processor_cfg:
# debug: true

# dataloader setting
batch_size: 64
gpus: 1
batch_size: null
gpu_batch_size: 64
gpus: -1
32 changes: 26 additions & 6 deletions mmskeleton/processor/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,21 @@
from mmcv.parallel import MMDataParallel


def test(model_cfg, dataset_cfg, checkpoint, batch_size=64, gpus=1, workers=4):
def test(model_cfg,
dataset_cfg,
checkpoint,
batch_size=None,
gpu_batch_size=None,
gpus=-1,
workers=4):

# calculate batch size
if gpus < 0:
gpus = torch.cuda.device_count()
if (batch_size is None) and (gpu_batch_size is not None):
batch_size = gpu_batch_size * gpus
assert batch_size is not None, 'Please appoint batch_size or gpu_batch_size.'

dataset = call_obj(**dataset_cfg)
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
Expand Down Expand Up @@ -50,17 +64,25 @@ def train(
loss_cfg,
dataset_cfg,
optimizer_cfg,
batch_size,
total_epochs,
training_hooks,
batch_size=None,
gpu_batch_size=None,
workflow=[('train', 1)],
gpus=1,
gpus=-1,
log_level=0,
workers=4,
resume_from=None,
load_from=None,
):

# calculate batch size
if gpus < 0:
gpus = torch.cuda.device_count()
if (batch_size is None) and (gpu_batch_size is not None):
batch_size = gpu_batch_size * gpus
assert batch_size is not None, 'Please appoint batch_size or gpu_batch_size.'

# prepare data loaders
if isinstance(dataset_cfg, dict):
dataset_cfg = [dataset_cfg]
Expand All @@ -79,9 +101,8 @@ def train(
else:
model = call_obj(**model_cfg)
model.apply(weights_init)
print(111, len(model.edge_importance))

model = MMDataParallel(model, device_ids=range(gpus)).cuda()
print(222, len(model.module.edge_importance))
loss = call_obj(**loss_cfg)

# build runner
Expand All @@ -96,7 +117,6 @@ def train(

# run
workflow = [tuple(w) for w in workflow]
print(222, len(model.module.edge_importance))
runner.run(data_loaders, workflow, total_epochs, loss=loss)


Expand Down
2 changes: 2 additions & 0 deletions mmskl.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def parse_cfg():
if 'type' not in info:
if default is not None:
info['type'] = type(default)
else:
info['type'] = eval(info['type'])
kwargs = dict(default=default)
kwargs.update({k: v for k, v in info.items() if k != 'bind_to'})
parser.add_argument('--' + key, **kwargs)
Expand Down

0 comments on commit 1d53c5b

Please sign in to comment.