forked from open-mmlab/mmskeleton
-
Notifications
You must be signed in to change notification settings - Fork 0
/
io.py
117 lines (93 loc) · 3.86 KB
/
io.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#!/usr/bin/env python
# pylint: disable=W0201
import sys
import argparse
import yaml
import numpy as np
# torch
import torch
import torch.nn as nn
# torchlight
import torchlight
from torchlight import str2bool
from torchlight import DictAction
from torchlight import import_class
class IO():
"""
IO Processor
"""
def __init__(self, argv=None):
self.load_arg(argv)
self.init_environment()
self.load_model()
self.load_weights()
self.gpu()
def load_arg(self, argv=None):
parser = self.get_parser()
# load arg form config file
p = parser.parse_args(argv)
if p.config is not None:
# load config file
with open(p.config, 'r') as f:
default_arg = yaml.load(f)
# update parser from config file
key = vars(p).keys()
for k in default_arg.keys():
if k not in key:
print('Unknown Arguments: {}'.format(k))
assert k in key
parser.set_defaults(**default_arg)
self.arg = parser.parse_args(argv)
def init_environment(self):
self.io = torchlight.IO(
self.arg.work_dir,
save_log=self.arg.save_log,
print_log=self.arg.print_log)
self.io.save_arg(self.arg)
# gpu
if self.arg.use_gpu:
gpus = torchlight.visible_gpu(self.arg.device)
torchlight.occupy_gpu(gpus)
self.gpus = gpus
self.dev = "cuda:0"
else:
self.dev = "cpu"
def load_model(self):
self.model = self.io.load_model(self.arg.model,
**(self.arg.model_args))
def load_weights(self):
if self.arg.weights:
self.model = self.io.load_weights(self.model, self.arg.weights,
self.arg.ignore_weights)
def gpu(self):
# move modules to gpu
self.model = self.model.to(self.dev)
for name, value in vars(self).items():
cls_name = str(value.__class__)
if cls_name.find('torch.nn.modules') != -1:
setattr(self, name, value.to(self.dev))
# model parallel
if self.arg.use_gpu and len(self.gpus) > 1:
self.model = nn.DataParallel(self.model, device_ids=self.gpus)
def start(self):
self.io.print_log('Parameters:\n{}\n'.format(str(vars(self.arg))))
@staticmethod
def get_parser(add_help=False):
#region arguments yapf: disable
# parameter priority: command line > config > default
parser = argparse.ArgumentParser( add_help=add_help, description='IO Processor')
parser.add_argument('-w', '--work_dir', default='./work_dir/tmp', help='the work folder for storing results')
parser.add_argument('-c', '--config', default=None, help='path to the configuration file')
# processor
parser.add_argument('--use_gpu', type=str2bool, default=True, help='use GPUs or not')
parser.add_argument('--device', type=int, default=0, nargs='+', help='the indexes of GPUs for training or testing')
# visulize and debug
parser.add_argument('--print_log', type=str2bool, default=True, help='print logging or not')
parser.add_argument('--save_log', type=str2bool, default=True, help='save logging or not')
# model
parser.add_argument('--model', default=None, help='the model will be used')
parser.add_argument('--model_args', action=DictAction, default=dict(), help='the arguments of model')
parser.add_argument('--weights', default=None, help='the weights for network initialization')
parser.add_argument('--ignore_weights', type=str, default=[], nargs='+', help='the name of weights which will be ignored in the initialization')
#endregion yapf: enable
return parser