Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Wuziyi616 committed Jul 1, 2022
1 parent 377641a commit 0d297b6
Show file tree
Hide file tree
Showing 10 changed files with 34 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_C = CN()
_C.noise_dim = 0 # no stochastic
_C.num_rot = 1 # rotate GT to match the predictions

_C.trans_loss_w = 1.
_C.rot_pt_cd_loss_w = 10.
_C.transform_pt_cd_loss_w = 10.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
_C = CN()
_C.noise_dim = 32 # stochastic PoseRegressor
_C.sample_iter = 5 # MoN loss sampling

_C.trans_loss_w = 1.
_C.rot_pt_cd_loss_w = 10.
_C.transform_pt_cd_loss_w = 10.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

_C.exp = CN()
_C.exp.num_epochs = 400
_C.exp.batch_size = 8 # GPU memory limit on RTX6000 with 24GB memory
_C.exp.num_workers = 4

_C.optimizer = CN()
_C.optimizer.warmup_ratio = 0.05
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

_C.exp = CN()
_C.exp.num_epochs = 400
_C.exp.batch_size = 8 # GPU memory limit on RTX6000 with 24GB memory
_C.exp.num_workers = 4

_C.optimizer = CN()
_C.optimizer.warmup_ratio = 0.05
Expand Down
15 changes: 12 additions & 3 deletions multi_part_assembly/models/modules/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def validation_step(self, data_dict, batch_idx):
def validation_epoch_end(self, outputs):
# avg_loss among all data
# we need to consider different batch_size
batch_sizes = torch.tensor([
func = torch.tensor if \
isinstance(outputs[0]['batch_size'], int) else torch.stack
batch_sizes = func([
output.pop('batch_size') for output in outputs
]).type_as(outputs[0]['loss']) # [num_batches]
losses = {
Expand All @@ -91,7 +93,9 @@ def test_step(self, data_dict, batch_idx):
def test_epoch_end(self, outputs):
# avg_loss among all data
# we need to consider different batch_size
batch_sizes = torch.tensor([
func = torch.tensor if \
isinstance(outputs[0]['batch_size'], int) else torch.stack
batch_sizes = func([
output.pop('batch_size') for output in outputs
]).type_as(outputs[0]['loss']) # [num_batches]
losses = {
Expand Down Expand Up @@ -247,11 +251,16 @@ def _match_rotation(self, pred_trans, pred_rot, gt_trans, gt_rot, valids):
Returns:
GT poses after rearrangement
"""
if self.num_rot == 1:
return gt_trans.detach().clone(), gt_rot.detach().clone()
P = pred_trans.shape[1]
# uniform rotation along z-axis
if not hasattr(self, '_uniform_z_rot'):
z_angles = 360. / self.num_rot * np.arange(self.num_rot)
z_rot = [R.from_euler('z', a, degrees=True) for a in z_angles]
z_rot = [
R.from_euler('z', angle, degrees=True).as_matrix()
for angle in z_angles
]
self._uniform_z_rot = torch.from_numpy(np.stack(z_rot, 0))[None]
z_rot = self._uniform_z_rot.type_as(gt_trans) # [1, n, 3, 3]
# rotate `gt_trans`, [B, n, P, 3]
Expand Down
10 changes: 5 additions & 5 deletions multi_part_assembly/models/modules/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ def __init__(self, feat_dim, rot_type='rmat'):

# for rotation
self.vn_fc_layers = nn.Sequential(
VNLinear(feat_dim, 256),
VNLeakyReLU(256, negative_slope=0.2),
VNLinear(256, 128),
VNLeakyReLU(128, negative_slope=0.2),
VNLinear(feat_dim, 256, dim=3),
VNLeakyReLU(256, dim=3, negative_slope=0.2),
VNLinear(256, 128, dim=3),
VNLeakyReLU(128, dim=3, negative_slope=0.2),
)

# Rotation prediction head
# we use the 6D representation from the CVPR'19 paper
self.rot_head = VNLinear(128, 2) # [2, 3] --> 6
self.rot_head = VNLinear(128, 2, dim=3) # [2, 3] --> 6

# for translation
self.in_feats = VNInFeature(feat_dim, dim=3)
Expand Down
2 changes: 1 addition & 1 deletion multi_part_assembly/models/pn_transformer/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __init__(
relu=relu,
dropout=dropout,
)
self.out_fc = VNLinear(d_model, out_dim) if \
self.out_fc = VNLinear(d_model, out_dim, dim=4) if \
out_dim is not None else nn.Identity()

def forward(self, tokens, valid_masks):
Expand Down
6 changes: 2 additions & 4 deletions multi_part_assembly/utils/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,10 @@ def _min_error(errors, min_idx=None):
"""Take the canonical pose with the min error."""
# errors: [B*P], should mask out invalid parts
errors = errors.reshape(B, P)
if min_idx is not None:
min_errors = torch.gather(errors, dim=1, index=min_idx)
else:
if min_idx is None:
shift_errors = errors + 1e9 * (1. - valids)
min_idx = shift_errors.argmin(dim=1, keepdim=True) # [B, 1]
min_errors = torch.gather(errors, dim=1, index=min_idx)
min_errors = torch.gather(errors, dim=1, index=min_idx)
return min_errors, min_idx

metric_dict = {}
Expand Down
2 changes: 1 addition & 1 deletion script/dup_run_sbatch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ do
cfg="${CFG:0:(-3)}-dup${repeat_idx}.py"
cp $CFG $cfg
job_name="${JOB_NAME}-dup${repeat_idx}"
cmd="./script/sbatch_run.sh $PARTITION $job_name $PY_FILE --cfg_file $CFG $PY_ARGS"
cmd="./script/sbatch_run.sh $PARTITION $job_name $PY_FILE --cfg_file $cfg $PY_ARGS"
echo $cmd
eval $cmd
done
11 changes: 7 additions & 4 deletions script/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,7 @@ def main(cfg):
trainer = pl.Trainer(
logger=logger,
gpus=all_gpus,
# TODO: very strange, I still cannot train DDP on Vector...
# TODO: modify this line if you can run DDP on the cluster
# strategy='ddp' if len(all_gpus) > 1 else None,
strategy='dp' if len(all_gpus) > 1 else None,
strategy=parallel_strategy if len(all_gpus) > 1 else None,
max_epochs=cfg.exp.num_epochs,
callbacks=callbacks,
precision=16 if args.fp16 else 32, # FP16 training
Expand Down Expand Up @@ -116,7 +113,13 @@ def main(cfg):
cfg = importlib.import_module(os.path.basename(args.cfg_file)[:-3])
cfg = cfg.get_cfg_defaults()

# TODO: very strange, I still cannot train DDP on Vector...
# TODO: modify this line if you can run DDP on the cluster
parallel_strategy = 'dp' # 'ddp'
cfg.exp.gpus = args.gpus
if len(cfg.exp.gpus) > 1 and parallel_strategy == 'dp':
cfg.exp.batch_size *= len(cfg.exp.gpus)
cfg.exp.num_workers *= len(cfg.exp.gpus)
if args.category:
cfg.data.category = args.category
if args.weight:
Expand Down

0 comments on commit 0d297b6

Please sign in to comment.