-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtable_S2.py
154 lines (120 loc) · 5.03 KB
/
table_S2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# -*- coding: utf-8 -*-
"""
Created on Fri Jan 31 12:10:10 2025
@author: ducros
"""
# %%
# Imports
# --------------------------------------------------------------------
from pathlib import Path
from typing import OrderedDict
import torch
import torch.nn as nn
import spyrit.core.meas as meas
import spyrit.core.prep as prep
import spyrit.core.noise as noise
import spyrit.core.recon as recon
import spyrit.misc.statistics as stats
import spyrit.external.drunet as drunet
import utility_dpgd as dpgd
# %%
# General Parameters
# --------------------------------------------------------------------
img_size = 128 # image size
batch_size = 128
n_batches = 3 # iterate over how many batches to get statistical data
# Experimental
val_folder = "data/ILSVRC2012/val/" # used for statistical analysis
model_folder = "model/" # reconstruction models
stat_folder = "stat/" # statistics
recon_folder = "recon/table/" # table output
# Full paths
val_folder_full = Path.cwd() / Path(val_folder)
model_folder_full = Path.cwd() / Path(model_folder)
stat_folder_full = Path.cwd() / Path(stat_folder)
recon_folder_full = Path.cwd() / Path(recon_folder)
val_folder_full.mkdir(parents=True, exist_ok=True)
recon_folder_full.mkdir(parents=True, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# %%
# Load images
# --------------------------------------------------------------------
# /!\ spyrit v3 works with images in [0,1]. Therefore, we use normalize=False
dataloader = stats.data_loaders_ImageNet(
val_folder_full, val_folder_full, img_size, batch_size, normalize=False
)["val"]
# %%
# Simulate measurements for three image intensities
# --------------------------------------------------------------------
# Measurement parameters
alpha_list = [2, 10, 50] # Poisson law parameter for noisy image acquisitions
n_alpha = len(alpha_list)
M = 128 * 128 // 4 # Number of measurements (here, 1/4 of the pixels)
# Measurement and noise operators
Ord_rec = torch.ones(img_size, img_size)
Ord_rec[:, img_size // 2 :] = 0
Ord_rec[img_size // 2 :, :] = 0
noise_op = noise.Poisson(alpha_list[0])
meas_op = meas.HadamSplit2d(img_size, M, Ord_rec, noise_model=noise_op, device=device)
prep_op = prep.UnsplitRescale(alpha_list[0])
rerange = prep.Rerange((0, 1), (-1, 1))
# %% DPGD-PnP
# load denoiser
n_channel, n_feature, n_layer = 1, 100, 20
model_name = "DFBNet_l1_patchsize=50_varnoise0.1_feat_100_layers_20.pth"
denoi = dpgd.load_model(
pth=(model_folder_full / model_name).as_posix(),
n_ch=n_channel,
features=n_feature,
num_of_layers=n_layer,
)
denoi.module.update_lip((1, 50, 50))
denoi.eval()
# Reconstruction hyperparameters
gamma = 1 / img_size**2
max_iter = 101
crit_norm = 1e-4
mu_list = [
[2000, 4000, 4500, 5000, 5500, 6000, 8000, 10000], # 2 photons
[1000, 2000, 2500, 3000, 3500, 4000, 5000], # 10 photons
[500, 1000, 1200, 1500, 1800, 2000, 3000, 4000], # 50 photons
]
# Init
dpgdnet = dpgd.DualPGD(meas_op, prep_op, denoi, gamma, mu_list[0], max_iter, crit_norm)
dpgdnet = dpgdnet.to(device)
print("\nDPGD-PnP reconstruction metrics")
with torch.no_grad():
for ii, alpha in enumerate(alpha_list):
metric_file = recon_folder_full / 'table_S2.tex'
with open(metric_file, 'a') as f:
f.write('\n')
f.write(f'$\\alpha={alpha}$ & & \\\\ \n')
# Set alpha for the simulation of the measurements
# dpgdnet.acqu_modules.acqu.noise_model.alpha = alpha # works too
dpgdnet.acqu.noise_model.alpha = alpha
# Set alpha for reconstruction
# dpgdnet.recon_modules.prep.alpha = alpha # works too
dpgdnet.prep.alpha = alpha
for mu in mu_list[ii]:
torch.manual_seed(0) # for reproducibility
dpgdnet.mu = mu
print(f"For alpha={alpha} and mu={mu}")
# PSNR
mean_psnr, var_psnr = stats.stat_psnr(dpgdnet, dataloader,
device, num_batchs=n_batches,
img_dyn=1.0)
mean_psnr = mean_psnr.cpu().numpy()
std_psnr = torch.sqrt(var_psnr).cpu().numpy()
print(f"psnr = {mean_psnr:.2f} +/- {std_psnr:.2f} dB")
# SSIM
mean_ssim, var_ssim = stats.stat_ssim(dpgdnet, dataloader,
device, num_batchs=n_batches,
img_dyn=1.0)
mean_ssim = mean_ssim.cpu().numpy()
std_ssim = torch.sqrt(var_ssim).cpu().numpy()
print(f"ssim = {mean_ssim:.3f} +/- {std_ssim:.3f}")
# sample list
f.write(f'$\\mu={mu}$ & {mean_psnr:.2f} ({std_psnr:.2f}) & {mean_ssim:.3f} ({std_ssim:.3f}) \\\\ \n')
del dpgdnet
del denoi