forked from ShenghaiRong/BECO
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodules.py
25 lines (22 loc) · 790 Bytes
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch.nn as nn
def freeze(network: nn.Module) -> None:
"""
Freeze all params in a model
"""
for param in network.parameters():
param.requires_grad = False
network.eval()
def unfreeze(network: nn.Module) -> None:
"""
Unfreeze all params in a model
"""
for param in network.parameters():
param.requires_grad = True
network.train()
def init_weight(module: nn.Module, a=0, mode='fan_in', nonlinearity='relu'):
for m in module.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, a=a, mode=mode, nonlinearity=nonlinearity)
elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)