-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Encoder.py
89 lines (75 loc) · 3.39 KB
/
Encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import tensorflow as tf
class CEncoder(tf.keras.Model):
def __init__(self, imgWidth, channels, head, extractor, dropoutRate, blurRadiusEncoder=None, **kwargs):
super().__init__(**kwargs)
self._imgWidth = imgWidth
self._channels = channels
self._srcBN = tf.keras.layers.BatchNormalization(name=self.name + '/SrcBN')
self._encoderHead = head(self.name + '/EncoderHead')
self._extractor = extractor(self.name + '/Extractor')
# spatial 2d dropout for the local context
self._dropoutRate = dropoutRate
self._blurRadiusEncoder = blurRadiusEncoder
print('[CEncoder] Blur: ', blurRadiusEncoder is not None)
return
def _addR(self, src, R, training):
if self._blurRadiusEncoder is None: return src
B, H, W, C = [tf.shape(src)[i] for i in range(4)]
tf.assert_equal(tf.shape(R), (B, 1))
# R is (B, 1), we need to repeat it to (B, H, W, 1)
R = self._blurRadiusEncoder(R, training=training)
R = tf.reshape(R, [B, 1, 1, tf.shape(R)[-1]])
R = tf.tile(R, [1, H, W, 1])
src = tf.concat([src, R], axis=-1)
return src
def call(self, src, training=None, params=None, R=0.0):
src = self._srcBN(src, training=training)
src = self._addR(src, R, training=training)
res = self._encoderHead(src, training=training)
# ablation study of intermediate representations
if not(params is None):
def applyIntermediateMask(i, x):
if params.get('no intermediate {}'.format(i + 1), False): return tf.zeros_like(x)
return x
res['intermediate'] = [applyIntermediateMask(i, x) for i, x in enumerate(res['intermediate'])]
pass
return res
def latentAt(self,
encoded, pos, training=None,
params=None # parameters for ablation study
):
if training is None:
training = tf.keras.backend.learning_phase()
B = tf.shape(pos)[0]
N = tf.shape(pos)[1]
tf.assert_equal(tf.shape(pos), (B, N, 2))
# global context is always present, even if it's a dummy one
context = encoded['context']
context = tf.repeat(context, N, axis=0)
tf.assert_equal(tf.shape(context)[:-1], (B * N,))
if self._extractor is None: return context # local context is disabled
# local context could be absent
localCtx = self._extractor(encoded['intermediate'], pos, training=training)
tf.assert_equal(tf.shape(localCtx)[:-1], (B * N,))
# ablation study
if not(params is None):
noLocalCtx = params.get('no local context', False)
noGlobalCtx = params.get('no global context', False)
assert not(noLocalCtx and noGlobalCtx), 'can\'t drop both local and global context at the same time'
if noLocalCtx: localCtx = tf.zeros_like(localCtx)
if noGlobalCtx: context = tf.zeros_like(context)
pass
tf.assert_equal(tf.shape(context)[:-1], (B * N,))
tf.assert_equal(tf.shape(localCtx)[:-1], (B * N,))
if training and (0.0 < self._dropoutRate):
msk = tf.random.uniform([B * N, 1], dtype=context.dtype)
localCtx = tf.where(msk < self._dropoutRate, tf.zeros_like(localCtx), localCtx)
pass
res = tf.concat([context, localCtx], axis=-1)
tf.assert_equal(tf.shape(res)[:-1], (B * N,))
# just to make sure that the shape is correctly inferred
res = tf.ensure_shape(res, (None, context.shape[-1] + localCtx.shape[-1]))
return res
def get_input_shape(self):
return (None, self._imgWidth, self._imgWidth, self._channels)
# End of CEncoder