-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathdistributed.py
35 lines (27 loc) · 1.06 KB
/
distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
""" Distributed training/validation utils
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from torch import distributed as dist
from models.model_ema import ModelEma
def unwrap_model(model):
if isinstance(model, ModelEma):
return unwrap_model(model.ema)
else:
return model.module if hasattr(model, 'module') else model
def reduce_tensor(tensor, n):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= n
return rt
def distribute_bn(model, world_size, reduce=False):
# ensure every node has the same running bn stats
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
if ('running_mean' in bn_name) or ('running_var' in bn_name):
if reduce:
# average bn stats across whole group
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
bn_buf /= float(world_size)
else:
# broadcast bn stats from rank 0 to whole group
torch.distributed.broadcast(bn_buf, 0)