-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Shearlets #40
base: master
Are you sure you want to change the base?
Shearlets #40
Changes from 1 commit
bd4ea22
baa8d4f
4499a95
9d88a3a
fadd9af
774e972
ee1c145
42b7674
3b8dc98
3e5652d
e8f1649
678b328
0e0eb29
f7bc4a0
beedb0b
6f249cf
a0589a7
3139c2b
9a923d3
8ff2877
6f8c3f0
5af5fd2
3001d00
a61f1ef
07ee9b1
5d928be
19b021a
7e061ee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,12 +2,16 @@ | |
import os.path as op | ||
import time | ||
|
||
import cadmos_lib as cl | ||
import click | ||
import numpy as np | ||
from tensorflow.keras.callbacks import TensorBoard, ModelCheckpoint, LearningRateScheduler | ||
from tensorflow.keras.optimizers import Adam | ||
import tensorflow as tf | ||
|
||
from learning_wavelets.config import LOGS_DIR, CHECKPOINTS_DIR | ||
from learning_wavelets.data.datasets import im_dataset_div2k, im_dataset_bsd500 | ||
from learning_wavelets.evaluate import keras_psnr, keras_ssim, center_keras_psnr | ||
from learning_wavelets.keras_utils.normalisation import NormalisationAdjustment | ||
from learning_wavelets.models.learned_wavelet import learnlet | ||
|
||
|
@@ -72,13 +76,45 @@ | |
type=int, | ||
help='The number of filters in the learnlets. Defaults to 256.', | ||
) | ||
@click.option( | ||
'kernel_sizes', | ||
'--k-sizes', | ||
nargs=2, | ||
default=(11, 13), | ||
type=int, | ||
help='The analysis and synthesis kernel sizes. Defaults to [11, 13]', | ||
) | ||
@click.option( | ||
'decreasing_noise_level', | ||
'--decr-n-lvl', | ||
is_flag=True, | ||
help='Set if you want the noise level distribution to be non uniform, skewed towards low value.', | ||
) | ||
def train_learnlet(noise_std_train, noise_std_val, n_samples, source, cuda_visible_devices, denoising_activation, n_filters, decreasing_noise_level): | ||
@click.option( | ||
'exact_reco', | ||
'-e', | ||
is_flag=True, | ||
help='Set if you want the learnlets to have exact reconstruction.', | ||
) | ||
@click.option( | ||
'n_reweights', | ||
'-nr', | ||
default=1, | ||
help='The number of reweights. Defaults to 1.', | ||
) | ||
def train_learnlet( | ||
noise_std_train, | ||
noise_std_val, | ||
n_samples, | ||
source, | ||
cuda_visible_devices, | ||
denoising_activation, | ||
n_filters, | ||
kernel_sizes, | ||
decreasing_noise_level, | ||
exact_reco, | ||
n_reweights, | ||
): | ||
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(cuda_visible_devices) | ||
# data preparation | ||
batch_size = 8 | ||
|
@@ -109,19 +145,23 @@ def train_learnlet(noise_std_train, noise_std_val, n_samples, source, cuda_visib | |
'n_tiling': n_filters, | ||
'mixing_details': False, | ||
'skip_connection': True, | ||
'kernel_size': 11, | ||
'kernel_size': analysis_kernel_size, | ||
}, | ||
'learnlet_synthesis_kwargs': { | ||
'res': True, | ||
'kernel_size': 13, | ||
'kernel_size': synthesis_kernel_size, | ||
}, | ||
'threshold_kwargs':{ | ||
'noise_std_norm': True, | ||
}, | ||
'wav_type': 'starlet', | ||
'n_scales': 5, | ||
'n_reweights_learn': n_reweights, | ||
'exact_reconstruction': exact_reco, | ||
'clip': False, | ||
} | ||
|
||
n_epochs = 500 | ||
run_id = f'learnlet_dynamic_{n_filters}_{denoising_activation}_{source}_{noise_std_train[0]}_{noise_std_train[1]}_{n_samples}_{int(time.time())}' | ||
chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}.hdf5' | ||
run_id = f'learnlet_exact_recon_{exact_reco}_{n_filters}_{denoising_activation}_{source}_{noise_std_train[0]}_{noise_std_train[1]}_{n_samples}_{int(time.time())}' chkpt_path = f'{CHECKPOINTS_DIR}checkpoints/{run_id}' + '-{epoch:02d}.hdf5' | ||
print(run_id) | ||
|
||
|
||
|
@@ -147,21 +187,43 @@ def l_rate_schedule(epoch): | |
norm_cback.on_train_batch_end = norm_cback.on_batch_end | ||
|
||
|
||
n_channels = 1 | ||
# run distributed | ||
# run distributed | ||
mirrored_strategy = tf.distribute.MirroredStrategy() | ||
with mirrored_strategy.scope(): | ||
model = learnlet(input_size=(None, None, n_channels), lr=1e-3, **run_params) | ||
print(model.summary(line_length=114)) | ||
|
||
model = Learnlet(**run_params) | ||
model.compile( | ||
optimizer=Adam(lr=1e-3), | ||
loss='mse', | ||
metrics=[keras_psnr, keras_ssim], | ||
) | ||
inputs = [tf.zeros((1, 32, 32, 1)), tf.zeros((1, 1))] | ||
model(inputs) | ||
|
||
shearlets, _ = cl.get_shearlets(512, 512, 6) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Put the answer in comment if it makes sense There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
|
||
n_shearlets = np.shape(shearlets)[0] | ||
total_size = np.shape(shearlets)[1] | ||
half_filters = n_filters//2 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait I don't get it... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
filter_size = analysis_kernel_size | ||
crop_min = total_size//2 - filter_size//2 | ||
crop_max = total_size//2 + filter_size//2 + 1 | ||
resized_shearlets = np.zeros((5, filter_size, filter_size, 1, half_filters)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Surely here it's not There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed |
||
for i in range(half_filters): | ||
resized_shearlets[:,:,:,0,i] = shearlets[i % n_shearlets, crop_min:crop_max, crop_min:crop_max] | ||
|
||
filters = np.copy(model.layers[0].get_weights()) | ||
for i in range(6, 11): | ||
filters[i] = resized_shearlets[i-6,:,:,:,:] | ||
|
||
model.layers[0].set_weights(filters) | ||
|
||
model.fit( | ||
im_ds_train, | ||
steps_per_epoch=200, | ||
epochs=n_epochs, | ||
validation_data=im_ds_val, | ||
validation_steps=1, | ||
verbose=0, | ||
verbose=1, | ||
kevinmicha marked this conversation as resolved.
Show resolved
Hide resolved
|
||
callbacks=[tboard_cback, chkpt_cback, norm_cback, lrate_cback], | ||
shuffle=False, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this import optional
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done