Skip to content

Commit

Permalink
update depth testing file
Browse files Browse the repository at this point in the history
  • Loading branch information
akar43 committed Oct 31, 2017
1 parent 442c09e commit 1c4c673
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
28 changes: 20 additions & 8 deletions depth/test_dlsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ops import conv_rnns
from skimage.io import imsave
from shapenet import ShapeNet
from evaluate import eval_l1_err, print_depth_stats
from utils import init_logging, process_args, get_session_config, mkdir_p


Expand Down Expand Up @@ -74,6 +75,9 @@ def run(args):
qsize=32,
nthreads=args.prefetch_threads)

# Init stats
l1_err = []

# Testing loop
pbar = tqdm(desc='Testing', total=len(mids))
deq_mids, deq_sids, deq_view_idx = [], [], []
Expand All @@ -94,9 +98,12 @@ def run(args):
}

pred = sess.run(net.depth_out, feed_dict=feed_dict)
vis_depth(pred, batch_data['depth'], batch_data['shape_id'],
batch_data['model_id'], batch_data['view_idx'])
# Update iou dict
batch_err = eval_l1_err(pred[:num_batch_items],
batch_data['depth'][:num_batch_items])
if args.vis:
vis_depth(pred, batch_data['depth'], batch_data['shape_id'],
batch_data['model_id'], batch_data['view_idx'])
l1_err.extend(batch_err)
pbar.update(num_batch_items)
except Exception, e:
logger.error(repr(e))
Expand All @@ -113,12 +120,16 @@ def run(args):
sort_idx = np.argsort(deq_mids)
deq_mids = deq_mids[sort_idx]
deq_sids = deq_sids[sort_idx]
l1_err = l1_err[sort_idx]
deq_view_idx = deq_view_idx[sort_idx, :]
with open(args.test_set_file, 'w') as f:
for ix in range(len(deq_mids)):
f.write(deq_sids[ix] + '\t' + deq_mids[ix] + '\t' +
' '.join(map(str, deq_view_idx[ix].tolist())) + '\n')
print('Test set file: {:s}'.format(args.test_set_file))
stats, stats_table = print_depth_stats(zip(deq_sids, deq_mids), l1_err)
print(stats_table)
if args.test_set_file is not None:
with open(args.test_set_file, 'w') as f:
for ix in range(len(deq_mids)):
f.write(deq_sids[ix] + '\t' + deq_mids[ix] + '\t' +
' '.join(map(str, deq_view_idx[ix].tolist())) + '\n')
print('Test set file: {:s}'.format(args.test_set_file))


def parse_args():
Expand All @@ -134,6 +145,7 @@ def parse_args():
'--test_split_file', type=str, default='data/splits.json')
parser.add_argument('--prefetch_threads', type=int, default=2)
parser.add_argument('--savedir', type=str, default=None)
parser.add_argument('--vis', action="store_true")
args = process_args(parser)
return args

Expand Down
1 change: 0 additions & 1 deletion depth/val_dlsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def validate(args, checkpoint):
batch_err = eval_l1_err(pred[:num_batch_items],
batch_data['depth'][:num_batch_items])

# Update iou dict
l1_err.extend(batch_err)
pbar.update(num_batch_items)
except Exception, e:
Expand Down

0 comments on commit 1c4c673

Please sign in to comment.