-
-
Notifications
You must be signed in to change notification settings - Fork 457
/
Copy pathtest_vae.py
113 lines (88 loc) · 3.76 KB
/
test_vae.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
import os
from PIL import Image
import torch
from torchvision.transforms import Resize, ToTensor
from diffusers import AutoencoderKL
from pytorch_fid import fid_score
from skimage.metrics import peak_signal_noise_ratio as psnr
import lpips
from tqdm import tqdm
from torchvision import transforms
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_images(folder_path):
images = []
for filename in os.listdir(folder_path):
if filename.lower().endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(folder_path, filename)
images.append(img_path)
return images
def paramiter_count(model):
state_dict = model.state_dict()
paramiter_count = 0
for key in state_dict:
paramiter_count += torch.numel(state_dict[key])
return int(paramiter_count)
def calculate_metrics(vae, images, max_imgs=-1):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae = vae.to(device)
lpips_model = lpips.LPIPS(net='alex').to(device)
rfid_scores = []
psnr_scores = []
lpips_scores = []
# transform = transforms.Compose([
# transforms.Resize(256, antialias=True),
# transforms.CenterCrop(256)
# ])
# needs values between -1 and 1
to_tensor = ToTensor()
if max_imgs > 0 and len(images) > max_imgs:
images = images[:max_imgs]
for img_path in tqdm(images):
try:
img = Image.open(img_path).convert('RGB')
# img_tensor = to_tensor(transform(img)).unsqueeze(0).to(device)
img_tensor = to_tensor(img).unsqueeze(0).to(device)
img_tensor = 2 * img_tensor - 1
# if width or height is not divisible by 8, crop it
if img_tensor.shape[2] % 8 != 0 or img_tensor.shape[3] % 8 != 0:
img_tensor = img_tensor[:, :, :img_tensor.shape[2] // 8 * 8, :img_tensor.shape[3] // 8 * 8]
except Exception as e:
print(f"Error processing {img_path}: {e}")
continue
with torch.no_grad():
reconstructed = vae.decode(vae.encode(img_tensor).latent_dist.sample()).sample
# Calculate rFID
# rfid = fid_score.calculate_frechet_distance(vae, img_tensor, reconstructed)
# rfid_scores.append(rfid)
# Calculate PSNR
psnr_val = psnr(img_tensor.cpu().numpy(), reconstructed.cpu().numpy())
psnr_scores.append(psnr_val)
# Calculate LPIPS
lpips_val = lpips_model(img_tensor, reconstructed).item()
lpips_scores.append(lpips_val)
# avg_rfid = sum(rfid_scores) / len(rfid_scores)
avg_rfid = 0
avg_psnr = sum(psnr_scores) / len(psnr_scores)
avg_lpips = sum(lpips_scores) / len(lpips_scores)
return avg_rfid, avg_psnr, avg_lpips
def main():
parser = argparse.ArgumentParser(description="Calculate average rFID, PSNR, and LPIPS for VAE reconstructions")
parser.add_argument("--vae_path", type=str, required=True, help="Path to the VAE model")
parser.add_argument("--image_folder", type=str, required=True, help="Path to the folder containing images")
parser.add_argument("--max_imgs", type=int, default=-1, help="Max num of images. Default is -1 for all images.")
args = parser.parse_args()
if os.path.isfile(args.vae_path):
vae = AutoencoderKL.from_single_file(args.vae_path)
else:
vae = AutoencoderKL.from_pretrained(args.vae_path)
vae.eval()
vae = vae.to(device)
print(f"Model has {paramiter_count(vae)} parameters")
images = load_images(args.image_folder)
avg_rfid, avg_psnr, avg_lpips = calculate_metrics(vae, images, args.max_imgs)
# print(f"Average rFID: {avg_rfid}")
print(f"Average PSNR: {avg_psnr}")
print(f"Average LPIPS: {avg_lpips}")
if __name__ == "__main__":
main()