-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDecoder.py
61 lines (52 loc) · 1.85 KB
/
Decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import tensorflow as tf
import tensorflow.keras.layers as L
from NN.utils import sMLP
'''
Simple MLP decoder that takes condition, coords, timestep and V as input and returns corresponding V of specified 'pixels'
'''
class MLPDecoder(tf.keras.Model):
def __init__(self, channels, blocks, residual, **kwargs):
super().__init__(**kwargs)
self._residual = residual
self._channels = channels
self._blocks = blocks(self.name)
return
def call(self, condition, coords, timestep, V):
B = tf.shape(V)[0]
res = tf.zeros((B, self._channels), dtype=V.dtype) # initial residual by 0.0
initState = tf.concat([condition, coords, timestep, V], axis=-1)
for block in self._blocks:
state = tf.concat([initState, res], axis=-1)
curValue = block(state)
if self._residual:
res = res + curValue
else:
res = curValue
continue
tf.assert_equal(tf.shape(res), (B, self._channels))
return res
def _mlp_from_config(config, channels):
def _createMlp(name):
mlp = sMLP(config['sizes'], activation=config['activation'], name='%s/mlp' % name)
return tf.keras.Sequential([
mlp,
L.Dense(channels, activation=config.get('final activation', 'linear'))
], name=name)
if not config['shared']: return _createMlp
# shared mlp, create it once and reuse
shared = [None]
def _createSharedMlp(name):
if shared[0] is None:
shared[0] = _createMlp(name)
pass
return shared[0]
return _createSharedMlp
def decoder_from_config(config, channels):
if 'mlp' == config['name']:
mlpF = _mlp_from_config(config['mlp'], channels)
return MLPDecoder(
channels=channels,
blocks=lambda name: [mlpF('%s/MLP-%d' % (name, i)) for i in range(config['mlp blocks'])],
residual=config['residual'],
)
raise ValueError(f"Unknown decoder name: {config['name']}")