Skip to content

Commit

Permalink
misc
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Apr 12, 2024
1 parent c97ee86 commit 3a4c3e1
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions NN/Nerf2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,18 @@ def _createAlgorithmInterceptor(self, interceptor, image, pos):
)
return res.interceptor()
#####################################################
def _withBlur(self, reverseArgs, B):
R = None
if 'blurRadius' in reverseArgs:
R = reverseArgs['blurRadius']
if not tf.is_tensor(R):
R = tf.convert_to_tensor(R, dtype=tf.float32)
R = tf.reshape(R, [1, 1])
R = tf.tile(R, [B, 1])
else:
R = tf.zeros((B, 1), dtype=tf.float32) # let encoder to decide needed it or not
return R

@tf.function
def _inference(
self, src, pos,
Expand All @@ -163,16 +175,7 @@ def _inference(
tf.assert_equal(tf.shape(initialValues)[:1], (B, ))
initialValues = tf.reshape(initialValues, (B, N, C))

R = None
if 'blurRadius' in reverseArgs:
R = reverseArgs['blurRadius']
if not tf.is_tensor(R):
R = tf.convert_to_tensor(R, dtype=tf.float32)
R = tf.reshape(R, [1, 1])
R = tf.tile(R, [B, 1])
else:
R = tf.zeros((B, 1), dtype=tf.float32)
pass
R = self._withBlur(reverseArgs, B)
encoded = self._encoder(src, training=False, params=encoderParams, R=R)
def getChunk(ind, sz):
posC = pos[ind:ind+sz]
Expand Down

0 comments on commit 3a4c3e1

Please sign in to comment.