Skip to content

Commit

Permalink
merge eval results in all processes. (dmlc#1160)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da authored Jan 2, 2020
1 parent dfb10db commit 7451bb2
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 9 deletions.
15 changes: 14 additions & 1 deletion apps/kg/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,26 @@ def main(args):
args.step = 0
args.max_step = 0
if args.num_proc > 1:
queue = mp.Queue(args.num_proc)
procs = []
for i in range(args.num_proc):
proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]]))
proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]],
'Test', queue))
procs.append(proc)
proc.start()
for proc in procs:
proc.join()

total_metrics = {}
for i in range(args.num_proc):
metrics = queue.get()
for k, v in metrics.items():
if i == 0:
total_metrics[k] = v / args.num_proc
else:
total_metrics[k] += v / args.num_proc
for k, v in metrics.items():
print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))
else:
test(args, model, [test_sampler_head, test_sampler_tail])

Expand Down
18 changes: 17 additions & 1 deletion apps/kg/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,16 +263,32 @@ def run(args, logger):

# test
if args.test:
start = time.time()
if args.num_proc > 1:
queue = mp.Queue(args.num_proc)
procs = []
for i in range(args.num_proc):
proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]]))
proc = mp.Process(target=test, args=(args, model, [test_sampler_heads[i], test_sampler_tails[i]],
'Test', queue))
procs.append(proc)
proc.start()

total_metrics = {}
for i in range(args.num_proc):
metrics = queue.get()
for k, v in metrics.items():
if i == 0:
total_metrics[k] = v / args.num_proc
else:
total_metrics[k] += v / args.num_proc
for k, v in metrics.items():
print('Test average {} at [{}/{}]: {}'.format(k, args.step, args.max_step, v))

for proc in procs:
proc.join()
else:
test(args, model, [test_sampler_head, test_sampler_tail])
print('test:', time.time() - start)

if __name__ == '__main__':
args = ArgParser().parse_args()
Expand Down
2 changes: 1 addition & 1 deletion apps/kg/train_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def train(args, model, train_sampler, valid_samplers=None):
# clear cache
logs = []

def test(args, model, test_samplers, mode='Test'):
def test(args, model, test_samplers, mode='Test', queue=None):
logs = []

for sampler in test_samplers:
Expand Down
12 changes: 6 additions & 6 deletions apps/kg/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,9 @@ def train(args, model, train_sampler, valid_samplers=None):
test(args, model, valid_samplers, mode='Valid')
print('test:', time.time() - start)

def test(args, model, test_samplers, mode='Test'):
def test(args, model, test_samplers, mode='Test', queue=None):
if args.num_proc > 1:
th.set_num_threads(1)
start = time.time()
with th.no_grad():
logs = []
for sampler in test_samplers:
Expand All @@ -96,9 +95,10 @@ def test(args, model, test_samplers, mode='Test'):
if len(logs) > 0:
for metric in logs[0].keys():
metrics[metric] = sum([log[metric] for log in logs]) / len(logs)

for k, v in metrics.items():
print('{} average {} at [{}/{}]: {}'.format(mode, k, args.step, args.max_step, v))
print('test:', time.time() - start)
if queue is not None:
queue.put(metrics)
else:
for k, v in metrics.items():
print('{} average {} at [{}/{}]: {}'.format(mode, k, args.step, args.max_step, v))
test_samplers[0] = test_samplers[0].reset()
test_samplers[1] = test_samplers[1].reset()

0 comments on commit 7451bb2

Please sign in to comment.