-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathfigure_3.py
328 lines (265 loc) · 9.86 KB
/
figure_3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
# %%
# Imports
# --------------------------------------------------------------------
from pathlib import Path
from typing import OrderedDict
import torch
import torchvision
import torch.nn as nn
import matplotlib.pyplot as plt
import spyrit.core.meas as meas
import spyrit.core.noise as noise
import spyrit.core.prep as prep
import spyrit.core.recon as recon
import spyrit.core.nnet as nnet
import spyrit.core.train as train
import spyrit.external.drunet as drunet
import utility_dpgd as dpgd
# %%
# General
# --------------------------------------------------------------------
# Experimental data
image_folder = "data/images/" # images for simulated measurements
model_folder = "model/" # reconstruction models
stat_folder = "stat/" # statistics
recon_folder = "recon/figure_3/" # reconstructed images
# Full paths
image_folder_full = Path.cwd() / Path(image_folder)
model_folder_full = Path.cwd() / Path(model_folder)
stat_folder_full = Path.cwd() / Path(stat_folder)
recon_folder_full = Path.cwd() / Path(recon_folder)
recon_folder_full.mkdir(parents=True, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# %%
# Load images
# --------------------------------------------------------------------
img_size = 128 # image size
print("Loading image...")
image_path = image_folder_full / "cropped/ILSVRC2012_val_00000003_crop.JPEG"
x = torchvision.io.read_image(image_path, torchvision.io.ImageReadMode.GRAY)
# Resize image
x = torchvision.transforms.functional.resize(x, (img_size, img_size)).reshape(
1, 1, img_size, img_size
)
# Select image
x = x.detach().clone()
x = x / 255
b, c, h, w = x.shape
print(f"Shape of input image: {x.shape}")
# save image as original
plt.imsave(recon_folder_full / "original.png", x[0, 0, :, :], cmap="gray")
# %%
# Simulate measurements for three image intensities
# --------------------------------------------------------------------
# Measurement parameters
alpha_list = [2, 10, 50] # Poisson law parameter for noisy image acquisitions
n_alpha = len(alpha_list)
M = 128 * 128 // 4 # Number of measurements (here, 1/4 of the pixels)
# Measurement and noise operators
Ord_rec = torch.ones(img_size, img_size)
Ord_rec[:, img_size // 2 :] = 0
Ord_rec[img_size // 2 :, :] = 0
# Send to GPU if available
noise_model = noise.Poisson(alpha_list[0])
meas_op = meas.HadamSplit2d(h, M, Ord_rec, noise_model=noise_model, device=device)
prep_op = prep.UnsplitRescale(alpha_list[0])
rerange = prep.Rerange((0, 1), (-1, 1))
x = x.to(device)
# Measurement vectors
y_shape = torch.Size([n_alpha]) + meas_op(x).shape
y = torch.zeros(y_shape, device=device)
for ii, alpha in enumerate(alpha_list):
torch.manual_seed(0) # for reproducibility
noise_model.alpha = alpha
y[ii, ...] = meas_op(x)
reconstruct_size = torch.Size([n_alpha]) + x.shape
# %%
# Pinv
# ====================================================================
# Init
pinv = recon.PinvNet(meas_op, prep_op, device=device)
# Use GPU if available
# pinv = pinv.to(device)
# Reconstruct
x_pinv = torch.zeros(reconstruct_size, device=device)
with torch.no_grad():
for ii, alpha in enumerate(alpha_list):
pinv.prep.alpha = alpha
x_pinv[ii] = pinv.reconstruct(y[ii, ...])
filename = f"pinv_alpha_{alpha:02}.png"
full_path = recon_folder_full / filename
plt.imsave(
full_path, x_pinv[ii, 0, 0, :, :].cpu().detach().numpy(), cmap="gray"
)
# %%
# Pinv-Net
# ====================================================================
model_name = "pinv-net_unet_imagenet_N0_10_m_hadam-split_N_128_M_4096_epo_30_lr_0.001_sss_10_sdr_0.5_bs_512_reg_1e-07_retrained_light.pth"
denoiser = OrderedDict(
{"rerange": rerange, "denoi": nnet.Unet(), "rerange_inv": rerange.inverse()}
)
denoiser = nn.Sequential(denoiser)
# this function loads the model into the '.denoi' key present in the second
# argument. It fails if it does not find the '.denoi' key.
train.load_net(model_folder_full / model_name, denoiser, device, False)
# Init
pinvnet = recon.PinvNet(meas_op, prep_op, denoiser, device=device)
pinvnet.eval()
# Reconstruct
x_pinvnet = torch.zeros(reconstruct_size, device=device)
with torch.no_grad():
for ii, alpha in enumerate(alpha_list):
pinvnet.prep.alpha = alpha
x_pinvnet[ii, ...] = pinvnet.reconstruct(y[ii, ...])
filename = f"pinvnet_alpha_{alpha:02}.png"
full_path = recon_folder_full / filename
plt.imsave(
full_path, x_pinvnet[ii, 0, 0, :, :].cpu().detach().numpy(), cmap="gray"
)
del denoiser
del pinvnet
torch.cuda.empty_cache()
# %%
# LPGD
# ====================================================================
model_name = "lpgd_unet_imagenet_N0_10_m_hadam-split_N_128_M_4096_epo_30_lr_0.001_sss_10_sdr_0.5_bs_128_reg_1e-07_uit_3_sdec0-9_light.pth"
denoiser = OrderedDict(
{"rerange": rerange, "denoi": nnet.Unet(), "rerange_inv": rerange.inverse()}
)
denoiser = nn.Sequential(denoiser)
# this function loads the model into the '.denoi' key present in the second
# argument. It fails if it does not find the '.denoi' key.
train.load_net(model_folder_full / model_name, denoiser, device, False)
# Initialize network
lpgd = recon.LearnedPGD(meas_op, prep_op, denoiser, step_decay=0.9)
lpgd.eval()
# load net and use GPU if available
lpgd = lpgd.to(device)
# Reconstruct and save
x_lpgd = torch.zeros(reconstruct_size, device=device)
with torch.no_grad():
for ii, alpha in enumerate(alpha_list):
lpgd.prep.alpha = alpha
x_lpgd[ii, ...] = lpgd.reconstruct(y[ii, ...])
# save
filename = f"lpgd_alpha_{alpha:02}.png"
full_path = recon_folder_full / filename
plt.imsave(
full_path, x_lpgd[ii, 0, 0, :, :].cpu().detach().numpy(), cmap="gray"
)
del denoiser
del lpgd
torch.cuda.empty_cache()
# %%
# DC-Net
# ====================================================================
model_name = "dc-net_unet_imagenet_rect_N0_10_N_128_M_4096_epo_30_lr_0.001_sss_10_sdr_0.5_bs_256_reg_1e-07_light.pth"
denoiser = OrderedDict(
{"rerange": rerange, "denoi": nnet.Unet(), "rerange_inv": rerange.inverse()}
)
denoiser = nn.Sequential(denoiser)
# this function loads the model into the '.denoi' key present in the second
# argument. It fails if it does not find the '.denoi' key.
train.load_net(model_folder_full / model_name, denoiser, device, False)
# Load covariance prior
cov_name = stat_folder_full / "Cov_8_{}x{}.pt".format(img_size, img_size)
Cov = torch.load(cov_name, weights_only=True).to(device)
# divide by 4 because the measurement covariance has been computed on images
# with values in [-1, 1] (total span 2) whereas our image is in [0, 1] (total
# span 1). The covariance is thus 2^2 = 4 times larger than expected.
Cov /= 4
# Init
#prep_op = prep.UnsplitRescaleEstim(meas_op, use_fast_pinv=True)
dcnet = recon.DCNet(meas_op, prep_op, Cov, denoiser, device=device)
dcnet.eval()
# Reconstruct
x_dcnet = torch.zeros(reconstruct_size, device=device)
with torch.no_grad():
for ii, alpha in enumerate(alpha_list):
dcnet.prep.alpha = alpha
x_dcnet[ii, ...] = dcnet.reconstruct(y[ii, ...])
filename = f"dcnet_alpha_{alpha:02}.png"
full_path = recon_folder_full / filename
plt.imsave(
full_path, x_dcnet[ii, 0, 0, :, :].cpu().detach().numpy(), cmap="gray"
)
del denoiser
del dcnet
torch.cuda.empty_cache()
# %%
# Pinv - PnP
# ====================================================================
model_name = "drunet_gray.pth"
noise_levels = [130, 50, 20] # noise levels from 0 to 255 for each alpha
denoiser = OrderedDict(
{
# No rerange() needed with normalize=False
#"rerange": rerange,
"denoi": drunet.DRUNet(normalize=False),
# No rerange.inverse() here as DRUNet works for images in [0,1]
#"rerange_inv": rerange.inverse(),
}
)
denoiser = nn.Sequential(denoiser)
# Initialize network
pinvpnp = recon.PinvNet(meas_op, prep_op, denoiser, device=device)
pinvpnp.denoi.denoi.load_state_dict(
torch.load(model_folder_full / model_name, weights_only=True), strict=False
)
pinvpnp.eval()
# Reconstruct and save
x_pinvpnp = torch.zeros(reconstruct_size, device=device)
with torch.no_grad():
for ii, alpha in enumerate(alpha_list):
# set noise level for measurement operator and PnP denoiser
pinvpnp.prep.alpha = alpha
nu = noise_levels[ii]
pinvpnp.denoi.denoi.set_noise_level(nu)
x_pinvpnp[ii, ...] = pinvpnp.reconstruct(y[ii, ...])
# save
filename = f"pinv_pnp_alpha_{alpha:02}_nu_{nu:03}.png"
full_path = recon_folder_full / filename
plt.imsave(
full_path, x_pinvpnp[ii, 0, 0, :, :].cpu().detach().numpy(), cmap="gray"
)
del denoiser
del pinvpnp
torch.cuda.empty_cache()
# %%
# DPGD-PnP
# ====================================================================
# load denoiser
n_channel, n_feature, n_layer = 1, 100, 20
model_name = "DFBNet_l1_patchsize=50_varnoise0.1_feat_100_layers_20.pth"
denoi = dpgd.load_model(
pth=(model_folder_full / model_name).as_posix(),
n_ch=n_channel,
features=n_feature,
num_of_layers=n_layer,
)
denoi.module.update_lip((1, 50, 50))
denoi.eval()
# Reconstruction hyperparameters
gamma = 1 / img_size**2
max_iter = 101
mu_list = [6000, 3500, 1200]
crit_norm = 1e-4
# Init
dpgdnet = dpgd.DualPGD(meas_op, prep_op, denoi, gamma, mu_list[0], max_iter, crit_norm)
dpgdnet = dpgdnet.to(device)
x_dpgd = torch.zeros(reconstruct_size, device=device)
with torch.no_grad():
for ii, alpha in enumerate(alpha_list):
dpgdnet.prep.alpha = alpha
dpgdnet.mu = mu_list[ii]
x_dpgd[ii, ...] = dpgdnet.reconstruct(y[ii, ...])
# save
filename = f"dpgd_alpha_{alpha:02}.png"
full_path = recon_folder_full / filename
plt.imsave(
full_path, x_dpgd[ii, 0, 0, :, :].cpu().detach().numpy(), cmap="gray"
)
del denoi
del dpgdnet
torch.cuda.empty_cache()