Skip to content

Commit

Permalink
vary parameters of the flows
Browse files Browse the repository at this point in the history
  • Loading branch information
AgatheSenellart committed Jan 12, 2023
1 parent 83be83f commit bee9c24
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/bivae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
# pretrained_joint_path = '../experiments/clean_mnist_svhn/2022-06-29/2022-06-29T11:41:41.132687__5qri92/'
# pretrained_joint_path = '../experiments/jmvae/2022-06-28/2022-06-28T17:25:01.03903846svjh2d/'
# pretrained_joint_path = '../experiments/celeba/2022-10-13/2022-10-13T13:54:42.595068mmpybk9u/'
pretrained_joint_path = '../experiments/joint_encoders/'+ args.experiment.split('/')[1] + '/'
pretrained_joint_path = '../experiments/joint_encoders/'+ args.experiment.split('/')[-1] + '/'

min_epoch = 1

Expand Down
16 changes: 11 additions & 5 deletions src/bivae/models/jmvae_nf/jmvae_nf_mnist_svhn_dcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,21 @@ def __init__(self, params):
hidden_dim = 512
pre_configs = [VAEConfig((1, 28, 28), 20), VAEConfig((3, 32, 32), 20)]
joint_encoder = DoubleHeadJoint(hidden_dim, pre_configs[0], pre_configs[1],Encoder_VAE_MLP, Encoder_VAE_SVHN,params)
# joint_encoder = MultipleHeadJoint(hidden_dim,pre_configs,
# [Encoder_VAE_MLP ,
# Encoder_VAE_SVHN],
# params)



# Define the unimodal encoders config
vae_config1 = vae_config((1, 28, 28), params.latent_dim)
vae_config2 = vae_config((3, 32, 32), params.latent_dim)

if hasattr(params,'n_made_blocks'):
vae_config1.n_made_blocks = params.n_made_blocks
vae_config2.n_made_blocks = params.n_made_blocks
print(f'Using {params.n_made_blocks} in the flows')

wandb.config.update(vae_config1.to_dict())
wandb.config.update(vae_config2.to_dict())



if params.dcca :
# First load the DCCA encoders
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"comment":"The JMVAE model with added flows but no dcca pre-embeddings. No reconstruction and maf flows",


"experiment": "flow_tuning/jmvae_nf_1/mnist_svhn",
"model": "jnf_mnist_svhn_dcca",
"obj": "jmvae_nf",


"K": 1,
"recon_losses" : ["normal", "normal"],
"looser": false,
"llik_scaling": 0,
"batch_size": 128, "learning_rate" : 1e-3,
"epochs": 200,
"latent_dim": 20,
"num_hidden_layers": 1,
"use_pretrain": "",
"learn_prior": false,
"logp": false,
"print_freq": 0,
"no_analytics": false,
"no_cuda": false,
"seed": 1,
"dist": "normal",
"data_path": "../data/",
"skip_warmup": true,
"warmup": 100,
"no_nf": false,
"beta_prior": 1,
"beta_kl": 1,
"decrease_beta_kl": 1,
"fix_decoders": true,
"fix_jencoder": true,
"no_recon": true,
"freq_analytics": 5,
"dcca": false,
"device": "cuda",
"wandb_experiment": "mnist_svhn",
"wandb_mode" : "offline",
"flow" : "maf",
"use_gen" : false,
"save_joint":false,
"linear_warmup":false,


"n_made_blocks":1


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"comment":"The JMVAE model with added flows but no dcca pre-embeddings. No reconstruction and maf flows",


"experiment": "flow_tuning/jmvae_nf_2/mnist_svhn",
"model": "jnf_mnist_svhn_dcca",
"obj": "jmvae_nf",


"K": 1,
"recon_losses" : ["normal", "normal"],
"looser": false,
"llik_scaling": 0,
"batch_size": 128, "learning_rate" : 1e-3,
"epochs": 200,
"latent_dim": 20,
"num_hidden_layers": 1,
"use_pretrain": "",
"learn_prior": false,
"logp": false,
"print_freq": 0,
"no_analytics": false,
"no_cuda": false,
"seed": 1,
"dist": "normal",
"data_path": "../data/",
"skip_warmup": true,
"warmup": 100,
"no_nf": false,
"beta_prior": 1,
"beta_kl": 1,
"decrease_beta_kl": 1,
"fix_decoders": true,
"fix_jencoder": true,
"no_recon": true,
"freq_analytics": 5,
"dcca": false,
"device": "cuda",
"wandb_experiment": "mnist_svhn",
"wandb_mode" : "offline",
"flow" : "maf",
"use_gen" : false,
"save_joint":false,
"linear_warmup":false,


"n_made_blocks":2


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"comment":"The JMVAE model with added flows but no dcca pre-embeddings. No reconstruction and maf flows",


"experiment": "flow_tuning/jmvae_nf_3/mnist_svhn",
"model": "jnf_mnist_svhn_dcca",
"obj": "jmvae_nf",


"K": 1,
"recon_losses" : ["normal", "normal"],
"looser": false,
"llik_scaling": 0,
"batch_size": 128, "learning_rate" : 1e-3,
"epochs": 200,
"latent_dim": 20,
"num_hidden_layers": 1,
"use_pretrain": "",
"learn_prior": false,
"logp": false,
"print_freq": 0,
"no_analytics": false,
"no_cuda": false,
"seed": 1,
"dist": "normal",
"data_path": "../data/",
"skip_warmup": true,
"warmup": 100,
"no_nf": false,
"beta_prior": 1,
"beta_kl": 1,
"decrease_beta_kl": 1,
"fix_decoders": true,
"fix_jencoder": true,
"no_recon": true,
"freq_analytics": 5,
"dcca": false,
"device": "cuda",
"wandb_experiment": "mnist_svhn",
"wandb_mode" : "offline",
"flow" : "maf",
"use_gen" : false,
"save_joint":false,
"linear_warmup":false,


"n_made_blocks":3


}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
{
"comment":"The JMVAE model with added flows but no dcca pre-embeddings. No reconstruction and maf flows",


"experiment": "flow_tuning/jmvae_nf_5/mnist_svhn",
"model": "jnf_mnist_svhn_dcca",
"obj": "jmvae_nf",


"K": 1,
"recon_losses" : ["normal", "normal"],
"looser": false,
"llik_scaling": 0,
"batch_size": 128, "learning_rate" : 1e-3,
"epochs": 200,
"latent_dim": 20,
"num_hidden_layers": 1,
"use_pretrain": "",
"learn_prior": false,
"logp": false,
"print_freq": 0,
"no_analytics": false,
"no_cuda": false,
"seed": 1,
"dist": "normal",
"data_path": "../data/",
"skip_warmup": true,
"warmup": 100,
"no_nf": false,
"beta_prior": 1,
"beta_kl": 1,
"decrease_beta_kl": 1,
"fix_decoders": true,
"fix_jencoder": true,
"no_recon": true,
"freq_analytics": 5,
"dcca": false,
"device": "cuda",
"wandb_experiment": "mnist_svhn",
"wandb_mode" : "offline",
"flow" : "maf",
"use_gen" : false,
"save_joint":false,
"linear_warmup":false,


"n_made_blocks":5


}

0 comments on commit bee9c24

Please sign in to comment.