Skip to content

Commit

Permalink
cpu support
Browse files Browse the repository at this point in the history
  • Loading branch information
yysijie committed Jun 27, 2018
1 parent 1238694 commit 22eb7cb
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
18 changes: 11 additions & 7 deletions processor/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,13 @@ def init_environment(self):
self.io.save_arg(self.arg)

# gpu
gpus = torchlight.visible_gpu(self.arg.device)
torchlight.occupy_gpu(gpus)
self.gpus = gpus
self.dev = "cuda:0"
if self.arg.use_gpu:
gpus = torchlight.visible_gpu(self.arg.device)
torchlight.occupy_gpu(gpus)
self.gpus = gpus
self.dev = "cuda:0"
else:
self.dev = "cpu"

def load_model(self):
self.model = self.io.load_model(self.arg.model,
Expand All @@ -73,14 +76,14 @@ def load_weights(self):

def gpu(self):
# move modules to gpu
self.model = self.model.cuda()
self.model = self.model.to(self.dev)
for name, value in vars(self).items():
cls_name = str(value.__class__)
if cls_name.find('torch.nn.modules') != -1:
setattr(self, name, value.cuda())
setattr(self, name, value.to(self.dev))

# model parallel
if len(self.gpus) > 1:
if self.arg.use_gpu and len(self.gpus) > 1:
self.model = nn.DataParallel(self.model, device_ids=self.gpus)

def start(self):
Expand All @@ -97,6 +100,7 @@ def get_parser(add_help=False):
parser.add_argument('-c', '--config', default=None, help='path to the configuration file')

# processor
parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not')
parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing')

# visulize and debug
Expand Down
1 change: 1 addition & 0 deletions processor/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def get_parser(add_help=False):
parser.add_argument('--save_result', type=str2bool, default=False, help='if ture, the output of the model will be stored')
parser.add_argument('--start_epoch', type=int, default=0, help='start training from which epoch')
parser.add_argument('--num_epoch', type=int, default=80, help='stop training in which epoch')
parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not')
parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing')

# visulize and debug
Expand Down

0 comments on commit 22eb7cb

Please sign in to comment.