Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shearlets #40

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
bd4ea22
Update unet.py with exact reconstruction
kevinmicha Apr 26, 2021
baa8d4f
Back to previous version
kevinmicha Apr 26, 2021
4499a95
Merge branch 'zaccharieramzi:master' into master
kevinmicha May 12, 2021
9d88a3a
modified analysis class to allow fixed filters
kevinmicha Jun 14, 2021
fadd9af
including shearlets and exact_recon parameter
kevinmicha Jun 14, 2021
774e972
Added extra dot
kevinmicha Jun 17, 2021
ee1c145
adapted tiling part
kevinmicha Jun 17, 2021
42b7674
fixed indentation for tiling func
kevinmicha Jun 17, 2021
3b8dc98
Update learning_wavelets/training_scripts/learnlet_training.py
kevinmicha Jun 17, 2021
3e5652d
Update learning_wavelets/models/learnlet_layers.py
kevinmicha Jun 17, 2021
e8f1649
changed layer name
kevinmicha Jun 17, 2021
678b328
back to original version
kevinmicha Jul 23, 2021
0e0eb29
including synthesis filters part
kevinmicha Jul 23, 2021
f7bc4a0
optional import: cadmos
kevinmicha Jul 23, 2021
beedb0b
removed wrong indent
kevinmicha Jul 23, 2021
6f249cf
used kernel's keywords
kevinmicha Jul 23, 2021
a0589a7
added spaces back
kevinmicha Jul 23, 2021
3139c2b
added untrainable layers
kevinmicha Jul 23, 2021
9a923d3
fixed white space
kevinmicha Jul 23, 2021
8ff2877
not prepared for exact recon True
kevinmicha Jul 23, 2021
6f8c3f0
changed n_tiling. Doesn't make sense 3 as default
kevinmicha Jul 23, 2021
5af5fd2
removed last change
kevinmicha Jul 23, 2021
3001d00
changed some default values for testing matters
kevinmicha Jul 23, 2021
a61f1ef
more def values changed
kevinmicha Jul 23, 2021
07ee9b1
removing exact recon test from this branch. No sense and tests fail
kevinmicha Jul 23, 2021
5d928be
exact recon unet test is back
kevinmicha Jul 23, 2021
19b021a
Removed the part that tests exact recon. No sense
kevinmicha Jul 23, 2021
7e061ee
added fixed keyword in names
kevinmicha Jul 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
including shearlets and exact_recon parameter
  • Loading branch information
kevinmicha authored Jun 14, 2021
commit fadd9afd7b63199157c79fb448cb436983f4a446
86 changes: 74 additions & 12 deletions learning_wavelets/training_scripts/learnlet_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
import os.path as op
import time

import cadmos_lib as cl
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this import optional

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

import click
import numpy as np
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, LearningRateScheduler
from tensorflow.keras.optimizers import Adam
import tensorflow as tf

from learning_wavelets.config import LOGS_DIR, CHECKPOINTS_DIR
from learning_wavelets.data.datasets import im_dataset_div2k, im_dataset_bsd500
from learning_wavelets.evaluate import keras_psnr, keras_ssim, center_keras_psnr
from learning_wavelets.keras_utils.normalisation import NormalisationAdjustment
from learning_wavelets.models.learned_wavelet import learnlet

Expand Down Expand Up @@ -72,13 +76,45 @@
type=int,
help='The number of filters in the learnlets. Defaults to 256.',
)
@click.option(
'kernel_sizes',
'--k-sizes',
nargs=2,
default=(11, 13),
type=int,
help='The analysis and synthesis kernel sizes. Defaults to [11, 13]',
)
@click.option(
'decreasing_noise_level',
'--decr-n-lvl',
is_flag=True,
help='Set if you want the noise level distribution to be non uniform, skewed towards low value.',
)
def train_learnlet(noise_std_train, noise_std_val, n_samples, source, cuda_visible_devices, denoising_activation, n_filters, decreasing_noise_level):
@click.option(
'exact_reco',
'-e',
is_flag=True,
help='Set if you want the learnlets to have exact reconstruction.',
)
@click.option(
'n_reweights',
'-nr',
default=1,
help='The number of reweights. Defaults to 1.',
)
def train_learnlet(
noise_std_train,
noise_std_val,
n_samples,
source,
cuda_visible_devices,
denoising_activation,
n_filters,
kernel_sizes,
decreasing_noise_level,
exact_reco,
n_reweights,
):
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices)
# data preparation
batch_size = 8
Expand Down Expand Up @@ -109,19 +145,23 @@ def train_learnlet(noise_std_train, noise_std_val, n_samples, source, cuda_visib
'n_tiling': n_filters,
'mixing_details': False,
'skip_connection': True,
'kernel_size': 11,
'kernel_size': analysis_kernel_size,
},
'learnlet_synthesis_kwargs': {
'res': True,
'kernel_size': 13,
'kernel_size': synthesis_kernel_size,
},
'threshold_kwargs':{
'noise_std_norm': True,
},
'wav_type': 'starlet',
'n_scales': 5,
'n_reweights_learn': n_reweights,
'exact_reconstruction': exact_reco,
'clip': False,
}

n_epochs = 500
run_id = f'learnlet_dynamic_{n_filters}_{denoising_activation}_{source}_{noise_std_train[0]}_{noise_std_train[1]}_{n_samples}_{int(time.time())}'
chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
run_id = f'learnlet_exact_recon_{exact_reco}_{n_filters}_{denoising_activation}_{source}_{noise_std_train[0]}_{noise_std_train[1]}_{n_samples}_{int(time.time())}' chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}.hdf5'
print(run_id)


Expand All @@ -147,21 +187,43 @@ def l_rate_schedule(epoch):
norm_cback.on_train_batch_end = norm_cback.on_batch_end


n_channels = 1
# run distributed
# run distributed
mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
model = learnlet(input_size=(None, None, n_channels), lr=1e-3, **run_params)
print(model.summary(line_length=114))

model = Learnlet(**run_params)
model.compile(
optimizer=Adam(lr=1e-3),
loss='mse',
metrics=[keras_psnr, keras_ssim],
)
inputs = [tf.zeros((1, 32, 32, 1)), tf.zeros((1, 1))]
model(inputs)

shearlets, _ = cl.get_shearlets(512, 512, 6)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 6 ?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put the answer in comment if it makes sense

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


n_shearlets = np.shape(shearlets)[0]
total_size = np.shape(shearlets)[1]
half_filters = n_filters//2
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait I don't get it...
Does it mean that you do not always consider all the shearlet filters?
I think this is a mistake since you can not just take the shearlet filters you want, you need them all.
Surely there is a way to require less shearlet filters though

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

filter_size = analysis_kernel_size
crop_min = total_size//2 - filter_size//2
crop_max = total_size//2 + filter_size//2 + 1
resized_shearlets = np.zeros((5, filter_size, filter_size, 1, half_filters))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surely here it's not 5 but n_scales or sthg similar.
Same for below the 6,11 shouldn't be hardcoded

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

for i in range(half_filters):
resized_shearlets[:,:,:,0,i] = shearlets[i % n_shearlets, crop_min:crop_max, crop_min:crop_max]

filters = np.copy(model.layers[0].get_weights())
for i in range(6, 11):
filters[i] = resized_shearlets[i-6,:,:,:,:]

model.layers[0].set_weights(filters)

model.fit(
im_ds_train,
steps_per_epoch=200,
epochs=n_epochs,
validation_data=im_ds_val,
validation_steps=1,
verbose=0,
verbose=1,
kevinmicha marked this conversation as resolved.
Show resolved Hide resolved
callbacks=[tboard_cback, chkpt_cback, norm_cback, lrate_cback],
shuffle=False,
)
Expand Down