Skip to content

Commit

Permalink
py3.6-3.5
Browse files Browse the repository at this point in the history
  • Loading branch information
yysijie committed Jul 6, 2018
1 parent 2136bc8 commit 9cc38aa
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions torchlight/torchlight/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def load_weights(self, model, weights_path, ignore_weights=None):
if isinstance(ignore_weights, str):
ignore_weights = [ignore_weights]

self.print_log(f'Load weights from {weights_path}.')
self.print_log('Load weights from {}.'.format(weights_path))
weights = torch.load(weights_path)
weights = OrderedDict([[k.split('module.')[-1],
v.cpu()] for k, v in weights.items()])
Expand All @@ -73,42 +73,42 @@ def load_weights(self, model, weights_path, ignore_weights=None):
ignore_name.append(w)
for n in ignore_name:
weights.pop(n)
self.print_log(f'Filter [{i}] remove weights [{n}].')
self.print_log('Filter [{}] remove weights [{}].'.format(i,n))

for w in weights:
self.print_log(f'Load weights [{w}].')
self.print_log('Load weights [{}].'.format(w))

try:
model.load_state_dict(weights)
except (KeyError, RuntimeError):
state = model.state_dict()
diff = list(set(state.keys()).difference(set(weights.keys())))
for d in diff:
self.print_log(f'Can not find weights [{d}].')
self.print_log('Can not find weights [{}].'.format(d))
state.update(weights)
model.load_state_dict(state)
return model

def save_pkl(self, result, filename):
with open(f'{self.work_dir}/{filename}', 'wb') as f:
with open('{}/{}'.format(self.work_dir, filename), 'wb') as f:
pickle.dump(result, f)

def save_h5(self, result, filename):
with h5py.File(f'{self.work_dir}/{filename}', 'w') as f:
with h5py.File('{}/{}'.format(self.work_dir, filename), 'w') as f:
for k in result.keys():
f[k] = result[k]

def save_model(self, model, name):
model_path = f'{self.work_dir}/{name}'
model_path = '{}/{}'.format(self.work_dir, name)
state_dict = model.state_dict()
weights = OrderedDict([[''.join(k.split('module.')),
v.cpu()] for k, v in state_dict.items()])
torch.save(weights, model_path)
self.print_log(f'The model has been saved as {model_path}.')
self.print_log('The model has been saved as {}.'.format(model_path))

def save_arg(self, arg):

self.session_file = f'{self.work_dir}/config.yaml'
self.session_file = '{}/config.yaml'.format(self.work_dir)

# save arg
arg_dict = vars(arg)
Expand All @@ -126,7 +126,7 @@ def print_log(self, str, print_time=True):
if self.print_to_screen:
print(str)
if self.save_log:
with open(f'{self.work_dir}/log.txt', 'a') as f:
with open('{}/log.txt'.format(self.work_dir), 'a') as f:
print(str, file=f)

def init_timer(self, *name):
Expand All @@ -147,13 +147,14 @@ def split_time(self):

def print_timer(self):
proportion = {
k: f'{int(round(v * 100 / sum(self.split_timer.values()))):02d}%'
k: '{:02d}%'.format(int(round(v * 100 / sum(self.split_timer.values()))))
for k, v in self.split_timer.items()
}
self.print_log(f'Time consumption:')
self.print_log('Time consumption:')
for k in proportion:
self.print_log(
f'\t[{k}][{proportion[k]}]: {self.split_timer[k]:.4f}')
'\t[{}][{}]: {:.4f}'.format(k, proportion[k],self.split_timer[k])
)


def str2bool(v):
Expand All @@ -166,7 +167,7 @@ def str2bool(v):


def str2dict(v):
return eval(f'dict({v})') #pylint: disable=W0123
return eval('dict({})'.format(v)) #pylint: disable=W0123


def _import_class_0(name):
Expand Down Expand Up @@ -195,7 +196,7 @@ def __init__(self, option_strings, dest, nargs=None, **kwargs):
super(DictAction, self).__init__(option_strings, dest, **kwargs)

def __call__(self, parser, namespace, values, option_string=None):
input_dict = eval(f'dict({values})') #pylint: disable=W0123
input_dict = eval('dict({})'.format(values)) #pylint: disable=W0123
output_dict = getattr(namespace, self.dest)
for k in input_dict:
output_dict[k] = input_dict[k]
Expand Down

0 comments on commit 9cc38aa

Please sign in to comment.