Skip to content

Commit 967ca07

Browse files
committed
[master] add grad_clip
1 parent 159413f commit 967ca07

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

examples/torch-starter/trainval.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -112,17 +112,6 @@ def main():
112112
else:
113113
args.tb_dir = None
114114

115-
mldash.init(
116-
desc_name=args.series_name + '/' + args.desc_name,
117-
expr_name=args.expr,
118-
run_name=args.run_name,
119-
args=args,
120-
highlight_args=parser,
121-
configs=configs,
122-
)
123-
124-
mldash.update(metainfo_file=args.meta_file, log_file=args.log_file, meter_file=args.meter_file, tb_dir=args.tb_dir)
125-
126115
if not args.debug:
127116
logger.critical('Writing logs to file: "{}".'.format(args.log_file))
128117
set_output_file(args.log_file)
@@ -194,8 +183,22 @@ def main():
194183
meters = GroupMeters()
195184

196185
if not args.debug:
186+
logger.critical('Writing metainfo to file: "{}".'.format(args.meta_file))
187+
with open(args.meta_file, 'w') as f:
188+
f.write(dump_metainfo(args=args.__dict__, configs=configs))
197189
logger.critical('Writing meter logs to file: "{}".'.format(args.meter_file))
198190

191+
logger.critical('Initializing MLDash.')
192+
mldash.init(
193+
desc_name=args.series_name + '/' + args.desc_name,
194+
expr_name=args.expr,
195+
run_name=args.run_name,
196+
args=args,
197+
highlight_args=parser,
198+
configs=configs,
199+
)
200+
mldash.update(metainfo_file=args.meta_file, log_file=args.log_file, meter_file=args.meter_file, tb_dir=args.tb_dir)
201+
199202
if args.embed:
200203
from IPython import embed; embed()
201204

jactorch/train/env.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def decay_learning_rate(self, decay):
116116
for param_group in self._optimizer.param_groups:
117117
param_group['lr'] *= decay
118118

119-
def step(self, feed_dict, reduce_func=default_reduce_func, cast_tensor=False, measure_time=False):
119+
def step(self, feed_dict, grad_clip=0., reduce_func=default_reduce_func, cast_tensor=False, measure_time=False):
120120
if hasattr(self.model, 'train_step'):
121121
return self.model.train_step(self.optimizer, feed_dict)
122122

@@ -153,6 +153,9 @@ def step(self, feed_dict, reduce_func=default_reduce_func, cast_tensor=False, me
153153
self.trigger_event('backward:before', self, feed_dict, loss, monitors, output_dict)
154154
if loss.requires_grad:
155155
loss.backward()
156+
if grad_clip > 0:
157+
from torch.nn.utils.clip_grad import clip_grad_norm_
158+
clip_grad_norm_(self.model.parameters(), grad_clip)
156159

157160
if measure_time:
158161
extra['time/backward'] = cuda_time() - end_time

0 commit comments

Comments
 (0)