Skip to content

Commit

Permalink
gt values for debugging purposes
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed May 18, 2024
1 parent b1d2319 commit 8fe566b
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions NN/Nerf2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,12 @@ def test_step(self, images):
src = ensure4d(src)
dest = ensure4d(dest)
# call the model itself to obtain the reconstructed image in the proper format
reconstructed = self(src, size=tf.shape(dest)[1], training=False)
# add original data for debugging purposes
B = tf.shape(src)[0]
coords = generateSquareGrid(tf.shape(dest)[1], 1.0, 0.0)
coords = tf.tile(coords[None], [B, 1, 1])
values = extractInterpolated(dest, coords)
reconstructed = self(src, size=tf.shape(dest)[1], training=False, GT=values)
return self._testMetrics(dest, reconstructed)

def _createAlgorithmInterceptor(self, interceptor, image, pos):
Expand Down Expand Up @@ -165,6 +170,7 @@ def _withBlur(self, reverseArgs, B):
def _inference(
self, src, pos,
batchSize, reverseArgs, initialValues, encoderParams,
GT=None
):
B = tf.shape(src)[0]
N = tf.shape(pos)[0]
Expand Down Expand Up @@ -206,8 +212,14 @@ def getChunk(ind, sz):
residual = self._withResidual(src, points=posCB)
residual = tf.reshape(residual, (-1, 3))
tf.assert_equal(tf.shape(residual), (flatB, 3))

reverseArgsNew = reverseArgs
if GT is not None:
# add ground truth values to copy of reverseArgs
reverseArgsNew = {**reverseArgs, 'GT': tf.reshape(GT[:, ind:ind+sz], (flatB, 3))}
pass
return dict(
latents=latents, pos=posC, reverseArgs=reverseArgs, value=value,
latents=latents, pos=posC, reverseArgs=reverseArgsNew, value=value,
residual=residual
)

Expand All @@ -226,6 +238,7 @@ def call(self,
batchSize=None, # renderers batch size
initialValues=None, # initial values for the restoration process
reverseArgs=None,
GT=None, # ground truth values for debugging purposes
):
src = ensure4d(src)
B = tf.shape(src)[0]
Expand Down Expand Up @@ -258,7 +271,8 @@ def call(self,
batchSize=batchSize,
reverseArgs=reverseArgs,
encoderParams=encoderParams,
initialValues=initialValues
initialValues=initialValues,
GT=GT
)
probes = tf.reshape(probes, sampleShape)
return probes
Expand Down

0 comments on commit 8fe566b

Please sign in to comment.