Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Apr 3, 2024
1 parent 0bf7637 commit 6ea65fb
Showing 1 changed file with 8 additions and 22 deletions.
30 changes: 8 additions & 22 deletions NN/Nerf2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,12 @@ def test_step(self, images):

def _createAlgorithmInterceptor(self, interceptor, image, pos):
from NN.restorators.samplers.CWatcherWithExtras import CWatcherWithExtras
def interceptorFactory(algorithm):
res = interceptor(algorithm)
res = CWatcherWithExtras(
watcher=res,
converter=self._converter,
residuals=None # residuals applied in the renderer
)
return res
return interceptorFactory
res = CWatcherWithExtras(
watcher=interceptor,
converter=self._converter,
residuals=None # residuals applied in the renderer
)
return res.interceptor()
#####################################################
@tf.function
def _inference(
Expand All @@ -156,15 +153,6 @@ def _inference(
tf.assert_equal(tf.shape(initialValues)[:1], (B, ))
initialValues = tf.reshape(initialValues, (B, N, C))

if 'algorithmInterceptor' in reverseArgs: # update algorithm interceptor if provided
newParams = {k: v for k, v in encoderParams.items() if k != 'algorithmInterceptor'}
newParams['algorithmInterceptor'] = self._createAlgorithmInterceptor(
interceptor=reverseArgs['algorithmInterceptor'],
image=src, pos=tf.tile(pos[None], [B, 1, 1])
)
encoderParams = newParams
pass

encoded = self._encoder(src, training=False, params=encoderParams)
def getChunk(ind, sz):
posC = pos[ind:ind+sz]
Expand Down Expand Up @@ -233,14 +221,12 @@ def call(self,
reverseArgs = {k: v for k, v in reverseArgs.items() if k != 'encoder'}
# add interceptors if needed
if 'algorithmInterceptor' in reverseArgs:
newParams = {k: v for k, v in encoderParams.items()}
newParams = {k: v for k, v in reverseArgs.items()}
newParams['algorithmInterceptor'] = self._createAlgorithmInterceptor(
interceptor=reverseArgs['algorithmInterceptor'],
image=src, pos=tf.tile(pos[None], [B, 1, 1])
)
encoderParams = newParams
# remove the interceptor from the reverseArgs
reverseArgs = {k: v for k, v in reverseArgs.items() if k != 'algorithmInterceptor'}
reverseArgs = newParams
pass

probes = self._inference(
Expand Down

0 comments on commit 6ea65fb

Please sign in to comment.