-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
__init__.py
48 lines (42 loc) · 1.92 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from .CRestorationModel import CRestorationModel
from .CSequentialRestorator import CSequentialRestorator
from .CRepeatedRestorator import CRepeatedRestorator
from NN.Decoder import decoder_from_config
from NN.encoding import encoding_from_config
from NN.restorators import restorator_from_config
from .Embeddings import CEncodedEmbeddings, CNumberEmbeddings
from NN.encoding import encoding_from_config
def embeddings_from_config(config, N):
assert isinstance(config, dict)
name = config['name'].lower()
if 'embeddings' == name:
return CNumberEmbeddings(N=N, D=config['dim'])
name = config['name'].lower()
if 'encoded' == name:
return CEncodedEmbeddings(N=N, encoding=encoding_from_config(config['encoding']))
raise ValueError(f"Unknown embeddings name: {config['name']}")
def restorationModel_from_config(config):
name = config['name'].lower()
if 'sequential' == name:
restorators = [restorationModel_from_config(subConfig) for subConfig in config['restorators']]
return CSequentialRestorator(restorators)
if 'basic' == name:
restorator = restorator_from_config(config['restorator'])
blurRadiusEncoder = None
if 'blur radius encoding' in config:
blurRadiusEncoder = encoding_from_config(config['blur radius encoding'])
return CRestorationModel(
decoder=decoder_from_config(config['decoder'], channels=restorator.channels),
restorator=restorator,
posEncoder=encoding_from_config(config['position encoding']),
timeEncoder=encoding_from_config(config['time encoding']),
blurRadiusEncoder=blurRadiusEncoder,
residualCondition=config.get('residual condition', False),
)
if 'repeated' == name:
return CRepeatedRestorator(
restorator=restorationModel_from_config(config['restorator']),
IDs=embeddings_from_config(config['IDs'], N=config['N']),
N=config['N'],
)
raise ValueError(f"Unknown restoration model name: {config['name']}")