diff --git a/NN/Nerf2D.py b/NN/Nerf2D.py index fd6be73..082e0d5 100644 --- a/NN/Nerf2D.py +++ b/NN/Nerf2D.py @@ -132,15 +132,12 @@ def test_step(self, images): def _createAlgorithmInterceptor(self, interceptor, image, pos): from NN.restorators.samplers.CWatcherWithExtras import CWatcherWithExtras - def interceptorFactory(algorithm): - res = interceptor(algorithm) - res = CWatcherWithExtras( - watcher=res, - converter=self._converter, - residuals=None # residuals applied in the renderer - ) - return res - return interceptorFactory + res = CWatcherWithExtras( + watcher=interceptor, + converter=self._converter, + residuals=None # residuals applied in the renderer + ) + return res.interceptor() ##################################################### @tf.function def _inference( @@ -156,15 +153,6 @@ def _inference( tf.assert_equal(tf.shape(initialValues)[:1], (B, )) initialValues = tf.reshape(initialValues, (B, N, C)) - if 'algorithmInterceptor' in reverseArgs: # update algorithm interceptor if provided - newParams = {k: v for k, v in encoderParams.items() if k != 'algorithmInterceptor'} - newParams['algorithmInterceptor'] = self._createAlgorithmInterceptor( - interceptor=reverseArgs['algorithmInterceptor'], - image=src, pos=tf.tile(pos[None], [B, 1, 1]) - ) - encoderParams = newParams - pass - encoded = self._encoder(src, training=False, params=encoderParams) def getChunk(ind, sz): posC = pos[ind:ind+sz] @@ -233,14 +221,12 @@ def call(self, reverseArgs = {k: v for k, v in reverseArgs.items() if k != 'encoder'} # add interceptors if needed if 'algorithmInterceptor' in reverseArgs: - newParams = {k: v for k, v in encoderParams.items()} + newParams = {k: v for k, v in reverseArgs.items()} newParams['algorithmInterceptor'] = self._createAlgorithmInterceptor( interceptor=reverseArgs['algorithmInterceptor'], image=src, pos=tf.tile(pos[None], [B, 1, 1]) ) - encoderParams = newParams - # remove the interceptor from the reverseArgs - reverseArgs = {k: v for k, v in reverseArgs.items() if k != 'algorithmInterceptor'} + reverseArgs = newParams pass probes = self._inference(