Skip to content

Commit

Permalink
fix bugs in visualization script
Browse files Browse the repository at this point in the history
  • Loading branch information
Wuziyi616 committed Feb 18, 2023
1 parent bad9621 commit 920162c
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions scripts/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@
from tqdm import tqdm

import torch
import torch.nn as nn

from multi_part_assembly.datasets import build_dataloader
from multi_part_assembly.models import build_model
from multi_part_assembly.utils import trans_rmat_to_pmat, trans_quat_to_pmat, \
quaternion_to_rmat, save_pc
quaternion_to_rmat, save_pc, Rotation3D


@torch.no_grad()
Expand All @@ -23,7 +22,7 @@ def visualize(cfg):
model = build_model(cfg)
ckp = torch.load(cfg.exp.weight_file, map_location='cpu')
model.load_state_dict(ckp['state_dict'])
model = nn.DataParallel(model).cuda().eval()
model = model.cuda().eval()

# Initialize dataloaders
_, val_loader = build_dataloader(cfg)
Expand All @@ -34,7 +33,10 @@ def visualize(cfg):
for batch in tqdm(val_loader):
batch = {k: v.float().cuda() for k, v in batch.items()}
out_dict = model(batch) # trans/rot: [B, P, 3/4/(3, 3)]
loss_dict, _ = model.module._calc_loss(out_dict, batch) # loss is [B]
# compute loss to measure the quality of the predictions
batch['part_rot'] = Rotation3D(
batch['part_quat'], rot_type='quat').convert(model.rot_type)
loss_dict, _ = model._calc_loss(out_dict, batch) # loss is [B]
# the criterion to cherry-pick examples
loss = loss_dict['rot_pt_l2_loss'] + loss_dict['trans_mae']
# convert all the rotations to quaternion for simplicity
Expand Down

0 comments on commit 920162c

Please sign in to comment.