Skip to content

Commit

Permalink
residual conditional
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Apr 8, 2024
1 parent 4d942fd commit aefa449
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 0 deletions.
1 change: 1 addition & 0 deletions NN/RestorationModel/CRepeatedRestorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def train_step(self, x0, latents, positions, params, xT=None):
xT = tf.stop_gradient(trainStep['value'])

for i in range(1, self._N):
params['residual'] = xT
trainStep = self._restorator.train_step(
x0=x0,
xT=xT,
Expand Down
9 changes: 9 additions & 0 deletions NN/RestorationModel/CRestorationModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class CRestorationModel(tf.keras.Model):
def __init__(self,
decoder, restorator,
posEncoder, timeEncoder,
residualCondition=False,
**kwargs
):
assert posEncoder is not None, "posEncoder is not provided"
Expand All @@ -18,6 +19,7 @@ def __init__(self,
self._restorator = restorator
self._posEncoder = posEncoder
self._timeEncoder = timeEncoder
self._residualCondition = residualCondition
return

def _encodeTime(self, t, training):
Expand Down Expand Up @@ -60,6 +62,13 @@ def denoiser(x, t, mask=None):
args = {k: masked(v, mask) for k, v in args.items()}
residuals = masked(residual, mask)

if self._residualCondition: # add residuals to the condition
B = tf.shape(residuals)[0]
args['condition'] = tf.reshape( # ensure that we know the shape of the condition
tf.concat([args['condition'], residuals], axis=-1),
[B, -1]
)

res = self._decoder(**args, training=training)
return self._withResidual(res, residuals)

Expand Down
1 change: 1 addition & 0 deletions NN/RestorationModel/CSequentialRestorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def train_step(self, x0, latents, positions, params, xT=None):
xT = tf.stop_gradient(trainStep['value'])
# the rest of the restorators are trained sequentially
for restorator in self._restorators:
params['residual'] = xT
trainStep = restorator.train_step(
x0=x0,
xT=xT,
Expand Down
1 change: 1 addition & 0 deletions NN/RestorationModel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def restorationModel_from_config(config):
restorator=restorator,
posEncoder=encoding_from_config(config['position encoding']),
timeEncoder=encoding_from_config(config['time encoding']),
residualCondition=config.get('residual condition', False),
)

if 'repeated' == name:
Expand Down

0 comments on commit aefa449

Please sign in to comment.