Skip to content

Commit

Permalink
change NN
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Jul 9, 2024
1 parent 68802d8 commit df36145
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions NN/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ def Face2StepModel(pointsN, eyeSize, latentSize, embeddingsSize):

def Step2LatentModel(latentSize, embeddingsSize):
latents = L.Input((None, latentSize))
embeddings = L.Input((None, embeddingsSize))
embeddingsInput = L.Input((None, embeddingsSize))
T = L.Input((None, 1))
embeddings = embeddingsInput[..., :1] * 0.0

stepsData = latents
intermediate = {}
Expand All @@ -115,14 +116,14 @@ def Step2LatentModel(latentSize, embeddingsSize):
continue
# # # # # # # # # # # # # # # # # # # # # # # # # # # # #
latent = sMLP(sizes=[latentSize] * 1, activation='relu')(
L.Concatenate(-1)([stepsData, temporal, encodedT, encodedT])
L.Concatenate(-1)([stepsData, temporal, encodedT, embeddings])
)
latent = CFusingBlock()([stepsData, latent])
return tf.keras.Model(
inputs={
'latent': latents,
'time': T,
'embeddings': embeddings,
'embeddings': embeddingsInput,
},
outputs={
'latent': latent,
Expand Down Expand Up @@ -195,9 +196,7 @@ def Face2LatentModel(
}
res['result'] = IntermediatePredictor(
shift=0.0 if diffusion else 0.5 # shift points to the center, if not using diffusion
)(
L.Concatenate(-1)([res['latent'], emb])
)
)(res['latent'])

if diffusion:
inputs['diffusionT'] = diffusionT
Expand Down

0 comments on commit df36145

Please sign in to comment.