Skip to content

Commit

Permalink
fix re-using scheduler b/t batches, prep eval
Browse files Browse the repository at this point in the history
  • Loading branch information
EllingtonKirby committed Jun 12, 2024
1 parent 0a4f4e3 commit 8229b85
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
experiment:
id: car_gen_eval_cyclical_embeddings_1
id: car_gen_eval_positional_embeddings_1

##Data
data:
Expand Down
25 changes: 22 additions & 3 deletions lidiff/models/models_objects_full_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from lidiff.utils.metrics import ChamferDistance, PrecisionRecall
from lidiff.utils.three_d_helpers import build_two_point_clouds
from diffusers import DPMSolverMultistepScheduler
from random import shuffle

class DiffusionPoints(LightningModule):
def __init__(self, hparams:dict, data_module: LightningDataModule = None):
Expand Down Expand Up @@ -287,6 +288,16 @@ def valid_paths(self, filenames):
return np.all(skip), output_paths

def test_step(self, batch:dict, batch_idx):
self.dpm_scheduler = DPMSolverMultistepScheduler(
num_train_timesteps=self.t_steps,
beta_start=self.hparams['diff']['beta_start'],
beta_end=self.hparams['diff']['beta_end'],
beta_schedule='linear',
algorithm_type='sde-dpmsolver++',
solver_order=2,
)
self.dpm_scheduler.set_timesteps(self.s_steps)
self.scheduler_to_cuda()
self.model.eval()

viz_pcd = o3d.geometry.PointCloud()
Expand All @@ -307,7 +318,8 @@ def test_step(self, batch:dict, batch_idx):
x_uncond = torch.zeros_like(x_cond)

x_gen_evals = []
for i in tqdm(range(self.hparams['diff']['num_val_samples'])):
num_val_samples = self.hparams['diff']['num_val_samples']
for i in range(num_val_samples):
np.random.seed(i)
torch.manual_seed(i)
torch.cuda.manual_seed(i)
Expand Down Expand Up @@ -341,8 +353,15 @@ def test_step(self, batch:dict, batch_idx):
box = batch['size'][pcd_index].cpu()
cd_mean_as_pct_of_box.append((last_cd / box.mean())*100.)
curr_index = max_index
visualization = self.visualize_step_t(genrtd_pcd, object_pcd, viz_pcd)
o3d.io.write_point_cloud(f'{self.logger.log_dir}/generated_pcd/visualizations/batch_{batch_idx}_object_{pcd_index}_seed_{best_index}.ply', visualization)
if pcd_index == 0:
visualization_1 = self.visualize_step_t(genrtd_pcd, object_pcd, viz_pcd)
o3d.io.write_point_cloud(f'{self.logger.log_dir}/generated_pcd/visualizations/batch_{batch_idx}_object_{pcd_index}_seed_{best_index}_best.ply', visualization_1)
random_choices = [i for i in range(num_val_samples) if i != best_index]
shuffle(random_choices)
for i in random_choices[0:2]:
genrtd_pcd_2 = x_gen_evals[i][curr_index:max_index]
visualization_2 = self.visualize_step_t(genrtd_pcd_2, object_pcd, viz_pcd)
o3d.io.write_point_cloud(f'{self.logger.log_dir}/generated_pcd/visualizations/batch_{batch_idx}_object_{pcd_index}_seed_{i}.ply', visualization_2)

cd_mean, cd_std = self.chamfer_distance.compute()
cd_mean_as_pct_of_box = np.mean(cd_mean_as_pct_of_box)
Expand Down
4 changes: 3 additions & 1 deletion lidiff/train_objects_full_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,9 @@ def main(config, weights, checkpoint, test):
max_epochs= cfg['train']['max_epoch'],
callbacks=[lr_monitor, checkpoint_saver],
check_val_every_n_epoch=10,
num_sanity_val_steps=0
num_sanity_val_steps=0,
limit_test_batches=10,
limit_val_batches=10,
)


Expand Down

0 comments on commit 8229b85

Please sign in to comment.