-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathhypernetwork.py
106 lines (82 loc) · 3.47 KB
/
hypernetwork.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# copy from https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/c344ba3b325459abbf9b0df2c1b18f7bf99805b2/modules/hypernetworks/hypernetwork.py
import os
import torch
from torch.nn.init import normal_
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
def __init__(self, dim, state_dict=None):
super().__init__()
linears = []
# linears.append(torch.nn.Linear(dim, dim * 2))
# linears.append(torch.nn.Linear(dim * 2, dim))
# linears.append(torch.nn.ReLU())
linears.append(torch.nn.Linear(dim, dim * 4))
linears.append(torch.nn.LeakyReLU())
linears.append(torch.nn.Linear(dim * 4, dim * 4))
linears.append(torch.nn.LeakyReLU())
linears.append(torch.nn.Linear(dim * 4, dim))
self.linear = torch.nn.Sequential(*linears)
if state_dict is not None:
self.load_state_dict(state_dict, strict=True)
else:
for layer in self.linear:
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
w, b = layer.weight.data, layer.bias.data
normal_(w, mean=0.0, std=0.01)
normal_(b, mean=0.0, std=0)
def forward(self, x):
return x + self.linear(x) * self.multiplier
def trainables(self):
layer_structure = []
for layer in self.linear:
if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
layer_structure += [layer.weight, layer.bias]
return layer_structure
class Hypernetwork:
filename = None
name = None
def __init__(self, name=None, enable_sizes=None):
self.filename = None
self.name = name
self.layers = {}
self.step = 0
self.sd_checkpoint = None
self.sd_checkpoint_name = None
for key, size in enable_sizes or []:
self.layers[f"{key}_{size}"] = (HypernetworkModule(size), HypernetworkModule(size))
def to(self, device):
for k, layers in self.layers.items():
for layer in layers:
layer.to(device)
def train(self):
for k, layers in self.layers.items():
for layer in layers:
layer.train()
def weights(self):
res = []
for k, layers in self.layers.items():
for layer in layers:
layer.train()
res += layer.trainables()
return res
def save(self, filename):
state_dict = {}
for k, v in self.layers.items():
state_dict[k] = (v[0].state_dict(), v[1].state_dict())
state_dict['step'] = self.step
state_dict['name'] = self.name
state_dict['sd_checkpoint'] = self.sd_checkpoint
state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
torch.save(state_dict, filename)
def load(self, filename):
self.filename = filename
if self.name is None:
self.name = os.path.splitext(os.path.basename(filename))[0]
state_dict = torch.load(filename, map_location='cpu')
for size, sd in state_dict.items():
if type(size) == int:
self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))
self.name = state_dict.get('name', self.name)
self.step = state_dict.get('step', 0)
self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)