Skip to content

Commit

Permalink
simplify the network and add a transformation encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Jun 1, 2024
1 parent 3787ab3 commit c4501c0
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 72 deletions.
4 changes: 2 additions & 2 deletions NN/FaceMeshEncoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import tensorflow as tf
from NN.Utils import sMLP, CRMLBlock
from NN.Utils import sMLP, CFusingBlock
from NN.CCoordsEncodingLayer import CCoordsEncodingLayer
from Core.Utils import FACE_MESH_INVALID_VALUE

Expand All @@ -15,7 +15,7 @@ def __init__(self, latentSize, **kwargs):
self._sMLP2 = sMLP(sizes=[latentSize], activation='relu', name='FaceMeshEncoder/sMLP-2')

self._RML = [
CRMLBlock(
CFusingBlock(
mlp=sMLP(
sizes=[latentSize * 2] * 3,
activation='relu', name=f'FaceMeshEncoder/RML-{i}/mlp'
Expand Down
64 changes: 5 additions & 59 deletions NN/Utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,72 +227,17 @@ def call(self, x, training=None):
quantized = 0.5 + tf.clip_by_value(quantized, self._minValue, self._maxValue)
return x + tf.stop_gradient(quantized - x)
####################################
class CResidualMultiplicativeLayer(tf.keras.layers.Layer):
def __init__(self, eps=1e-8, headsN=1, **kwargs):
super().__init__(**kwargs)
self._eps = eps
self._scale = tf.Variable(
initial_value=tf.random.normal((1, ), mean=0.0, stddev=0.1),
trainable=True, dtype=tf.float32,
name=self.name + '/_scale'
)
self._headsN = headsN
self._normalization = None
return

@property
def scale(self): return tf.nn.sigmoid(self._scale) * (1.0 - 2.0 * self._eps) + self._eps # [eps, 1 - eps]

def _SMNormalization(self, xhat):
xhat = tf.nn.softmax(xhat, axis=-1)
xhat = xhat - tf.reduce_mean(xhat, axis=-1, keepdims=True)
rng = tf.reduce_max(tf.abs(xhat), axis=-1, keepdims=True)
return 1.0 + tf.math.divide_no_nan(xhat, rng * self.scale) # [1 - scale, 1 + scale]

def _HeadwiseNormalizationNoPadding(self, xhat):
shape = tf.shape(xhat)
# reshape [B, ..., N * headsN] -> [B, ..., headsN, N], apply normalization, reshape back
xhat = tf.reshape(xhat, tf.concat([shape[:-1], [self._headsN, shape[-1] // self._headsN]], axis=-1))
xhat = self._SMNormalization(xhat)
xhat = tf.reshape(xhat, shape)
return xhat

def _HeadwiseNormalizationPadded(self, lastChunk):
def F(xhat):
mainPart = self._HeadwiseNormalizationNoPadding(xhat[..., :-lastChunk])
tailPart = self._SMNormalization(xhat[..., -lastChunk:])
return tf.concat([mainPart, tailPart], axis=-1)
return F

def build(self, input_shapes):
_, xhatShape = input_shapes
self._normalization = self._SMNormalization
if 1 < self._headsN:
assert 1 < (xhatShape[-1] // self._headsN), "too few channels for headsN"

lastChunk = xhatShape[-1] % self._headsN
self._normalization = self._HeadwiseNormalizationPadded(lastChunk) if 0 < lastChunk else self._HeadwiseNormalizationNoPadding
pass
return super().build(input_shapes)

def call(self, x):
x, xhat = x
# return (tf.nn.relu(x) + self._eps) * (self._normalization(xhat) + self._eps) # more general/stable version
# with SM normalization, relu and addition are redundant
return x * self._normalization(xhat)
####################################
class CRMLBlock(tf.keras.Model):
def __init__(self, mlp=None, RML=None, **kwargs):
class CFusingBlock(tf.keras.Model):
def __init__(self, mlp=None, **kwargs):
super().__init__(**kwargs)
if mlp is None: mlp = lambda x: x
self._mlp = mlp
if RML is None: RML = CResidualMultiplicativeLayer()
self._RML = RML
return

def build(self, input_shapes):
xShape = input_shapes[0]
self._lastDense = L.Dense(xShape[-1], activation='relu', name='%s/LastDense' % self.name)
self._combiner = L.Dense(xShape[-1], activation='relu', name='%s/Combiner' % self.name)
return super().build(input_shapes)

def call(self, x):
Expand All @@ -301,7 +246,8 @@ def call(self, x):
xhat = self._mlp(xhat)
xhat = self._lastDense(xhat)
x0 = x[0]
return self._RML([x0, xhat])
x = tf.concat([x0, xhat], axis=-1)
return self._combiner(x)
####################################
# Hacky way to provide same optimizer for all models
def createOptimizer(config=None):
Expand Down
44 changes: 33 additions & 11 deletions NN/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def Face2StepModel(pointsN, eyeSize, latentSize, embeddingsSize):
# we need to combine them together and with the encodedP
combined = encodedP # start with the face features
for i, EFeat in enumerate(encodedEFList):
combined = CResidualMultiplicativeLayer(name='F2S/ResMul-%d' % i)([
combined = CFusingBlock(name='F2S/ResMul-%d' % i)([
combined,
sMLP(sizes=[latentSize] * 1, activation='relu', name='F2S/MLP-%d' % i)(
L.Concatenate(-1)([combined, encodedP, EFeat, embeddings])
Expand Down Expand Up @@ -94,7 +94,7 @@ def Step2LatentModel(latentSize, embeddingsSize):
temporal = sMLP(sizes=[latentSize] * 1, activation='relu')(
L.Concatenate(-1)([stepsData, encodedT, embeddings])
)
temporal = CResidualMultiplicativeLayer()([stepsData, temporal])
temporal = CFusingBlock()([stepsData, temporal])
intermediate['S2L/enc0'] = temporal
# # # # # # # # # # # # # # # # # # # # # # # # # # # # #
for blockId in range(3):
Expand All @@ -104,14 +104,14 @@ def Step2LatentModel(latentSize, embeddingsSize):
temp = sMLP(sizes=[latentSize] * 1, activation='relu')(
L.Concatenate(-1)([temporal, temp])
)
temporal = CResidualMultiplicativeLayer()([temporal, temp])
temporal = CFusingBlock()([temporal, temp])
intermediate['S2L/ResLSTM-%d' % blockId] = temporal
continue
# # # # # # # # # # # # # # # # # # # # # # # # # # # # #
latent = sMLP(sizes=[latentSize] * 1, activation='relu')(
L.Concatenate(-1)([stepsData, temporal, encodedT, encodedT])
)
latent = CResidualMultiplicativeLayer()([stepsData, latent])
latent = CFusingBlock()([stepsData, latent])
return tf.keras.Model(
inputs={
'latent': latents,
Expand Down Expand Up @@ -185,6 +185,35 @@ def Face2LatentModel(
IP = lambda x: IntermediatePredictor()(x) # own IntermediatePredictor for each output
res['intermediate'] = {k: IP(x) for k, x in intermediate.items()}
res['result'] = IP(res['latent'])
###################################
# TODO: figure out is this helpful or not
# branch for global coordinates transformation
# predict shift, rotation, scale
emb = L.Concatenate(-1)([userIdEmb, placeIdEmb, screenIdEmb])
emb = sMLP(sizes=[64, 64, 64, 64, 32], activation='relu')(emb[:, 0])
shift = L.Dense(2, name='GlobalShift')(emb)[:, None]
rotation = L.Dense(1, name='GlobalRotation', activation='sigmoid')(emb)[:, None] * np.pi
scale = L.Dense(2, name='GlobalScale')(emb)[:, None]

shifted = res['result'] + shift - 0.5 # [0.5, 0.5] -> [0, 0]
# Rotation matrix components
cos_rotation = L.Lambda(lambda x: tf.cos(x))(rotation)
sin_rotation = L.Lambda(lambda x: tf.sin(x))(rotation)
rotation_matrix = L.Lambda(lambda x: tf.stack([x[0], x[1]], axis=-1))([cos_rotation, sin_rotation])

# Apply rotation
rotated = L.Lambda(
lambda x: tf.einsum('isj,iomj->isj', x[0], x[1])
)([shifted, rotation_matrix]) + 0.5 # [0, 0] -> [0.5, 0.5] back

# Apply scale
scaled = rotated * scale
def clipWithGradient(x):
res = tf.clip_by_value(x, 0.0, 1.0)
return x + tf.stop_gradient(res - x)

res['result'] = L.Lambda(clipWithGradient)(scaled)
###################################

main = tf.keras.Model(inputs=inputs, outputs=res)
return {
Expand All @@ -195,13 +224,6 @@ def Face2LatentModel(
}

if __name__ == '__main__':
# autoencoder = FaceAutoencoderModel(latentSize=64, means={
# 'points': np.zeros((478, 2), np.float32),
# 'left eye': np.zeros((32, 32), np.float32),
# 'right eye': np.zeros((32, 32), np.float32),
# })['main']
# autoencoder.summary(expand_nested=True)

X = Face2LatentModel(steps=5, latentSize=64,
embeddings={
'userId': 1, 'placeId': 1, 'screenId': 1, 'size': 64
Expand Down

0 comments on commit c4501c0

Please sign in to comment.