Skip to content

Commit

Permalink
remove tf.stop_gradient
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Apr 8, 2024
1 parent 74630f9 commit 4d942fd
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
11 changes: 5 additions & 6 deletions NN/RestorationModel/CRepeatedRestorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def call(self, latents, pos, T, V, residual, training=None):

def reverse(self, latents, pos, reverseArgs, training, value, residual, index):
for i in range(self._N):
if tf.is_tensor(value): value = tf.stop_gradient(value)
value = self._restorator.reverse(
latents=self._withID(latents, i, training),
pos=pos, reverseArgs=reverseArgs,
Expand All @@ -50,11 +49,11 @@ def train_step(self, x0, latents, positions, params, xT=None):

for i in range(1, self._N):
trainStep = self._restorator.train_step(
x0=tf.stop_gradient(x0),
xT=tf.stop_gradient(xT),
latents=self._withID(tf.stop_gradient(latents), i, training=True),
positions=tf.stop_gradient(positions),
params={k: tf.stop_gradient(v) if tf.is_tensor(v) else v for k, v in params.items()}
x0=x0,
xT=xT,
latents=self._withID(latents, i, training=True),
positions=positions,
params=params
)
loss += trainStep['loss']
xT = tf.stop_gradient(trainStep['value'])
Expand Down
11 changes: 5 additions & 6 deletions NN/RestorationModel/CSequentialRestorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ def call(self, latents, pos, T, V, residual, training=None):

def reverse(self, latents, pos, reverseArgs, training, value, residual, index):
for restorator in self._restorators:
if tf.is_tensor(value): value = tf.stop_gradient(value)
value = restorator.reverse(
latents=latents, pos=pos, reverseArgs=reverseArgs,
residual=residual,
Expand All @@ -37,11 +36,11 @@ def train_step(self, x0, latents, positions, params, xT=None):
# the rest of the restorators are trained sequentially
for restorator in self._restorators:
trainStep = restorator.train_step(
x0=tf.stop_gradient(x0),
xT=tf.stop_gradient(xT),
latents=tf.stop_gradient(latents),
positions=tf.stop_gradient(positions),
params={k: tf.stop_gradient(v) if isinstance(v, tf.Tensor) else v for k, v in params.items()}
x0=x0,
xT=xT,
latents=latents,
positions=positions,
params=params
)
loss += trainStep['loss']
xT = tf.stop_gradient(trainStep['value'])
Expand Down

0 comments on commit 4d942fd

Please sign in to comment.