-
Notifications
You must be signed in to change notification settings - Fork 149
/
fid.py
165 lines (134 loc) · 7.16 KB
/
fid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
"""Script for calculating Frechet Inception Distance (FID)."""
import os
import click
import tqdm
import pickle
import numpy as np
import scipy.linalg
import torch
import dnnlib
from torch_utils import distributed as dist
from training import dataset
#----------------------------------------------------------------------------
def calculate_inception_stats(
image_path, num_expected=None, seed=0, max_batch_size=64,
num_workers=3, prefetch_factor=2, device=torch.device('cuda'),
):
# Rank 0 goes first.
if dist.get_rank() != 0:
torch.distributed.barrier()
# Load Inception-v3 model.
# This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
dist.print0('Loading Inception-v3 model...')
detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl'
detector_kwargs = dict(return_features=True)
feature_dim = 2048
with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f:
detector_net = pickle.load(f).to(device)
# List images.
dist.print0(f'Loading images from "{image_path}"...')
dataset_obj = dataset.ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed)
if num_expected is not None and len(dataset_obj) < num_expected:
raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}')
if len(dataset_obj) < 2:
raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics')
# Other ranks follow.
if dist.get_rank() == 0:
torch.distributed.barrier()
# Divide images into batches.
num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches)
rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]
data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor)
# Accumulate statistics.
dist.print0(f'Calculating statistics for {len(dataset_obj)} images...')
mu = torch.zeros([feature_dim], dtype=torch.float64, device=device)
sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device)
for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)):
torch.distributed.barrier()
if images.shape[0] == 0:
continue
if images.shape[1] == 1:
images = images.repeat([1, 3, 1, 1])
features = detector_net(images.to(device), **detector_kwargs).to(torch.float64)
mu += features.sum(0)
sigma += features.T @ features
# Calculate grand totals.
torch.distributed.all_reduce(mu)
torch.distributed.all_reduce(sigma)
mu /= len(dataset_obj)
sigma -= mu.ger(mu) * len(dataset_obj)
sigma /= len(dataset_obj) - 1
return mu.cpu().numpy(), sigma.cpu().numpy()
#----------------------------------------------------------------------------
def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref):
m = np.square(mu - mu_ref).sum()
s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False)
fid = m + np.trace(sigma + sigma_ref - s * 2)
return float(np.real(fid))
#----------------------------------------------------------------------------
@click.group()
def main():
"""Calculate Frechet Inception Distance (FID).
Examples:
\b
# Generate 50000 images and save them as fid-tmp/*/*.png
torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \\
--network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl
\b
# Calculate FID
torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \\
--ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz
\b
# Compute dataset reference statistics
python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz
"""
#----------------------------------------------------------------------------
@main.command()
@click.option('--images', 'image_path', help='Path to the images', metavar='PATH|ZIP', type=str, required=True)
@click.option('--ref', 'ref_path', help='Dataset reference statistics ', metavar='NPZ|URL', type=str, required=True)
@click.option('--num', 'num_expected', help='Number of images to use', metavar='INT', type=click.IntRange(min=2), default=50000, show_default=True)
@click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=int, default=0, show_default=True)
@click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
def calc(image_path, ref_path, num_expected, seed, batch):
"""Calculate FID for a given set of images."""
torch.multiprocessing.set_start_method('spawn')
dist.init()
dist.print0(f'Loading dataset reference statistics from "{ref_path}"...')
ref = None
if dist.get_rank() == 0:
with dnnlib.util.open_url(ref_path) as f:
ref = dict(np.load(f))
mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch)
dist.print0('Calculating FID...')
if dist.get_rank() == 0:
fid = calculate_fid_from_inception_stats(mu, sigma, ref['mu'], ref['sigma'])
print(f'{fid:g}')
torch.distributed.barrier()
#----------------------------------------------------------------------------
@main.command()
@click.option('--data', 'dataset_path', help='Path to the dataset', metavar='PATH|ZIP', type=str, required=True)
@click.option('--dest', 'dest_path', help='Destination .npz file', metavar='NPZ', type=str, required=True)
@click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True)
def ref(dataset_path, dest_path, batch):
"""Calculate dataset reference statistics needed by 'calc'."""
torch.multiprocessing.set_start_method('spawn')
dist.init()
mu, sigma = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch)
dist.print0(f'Saving dataset reference statistics to "{dest_path}"...')
if dist.get_rank() == 0:
if os.path.dirname(dest_path):
os.makedirs(os.path.dirname(dest_path), exist_ok=True)
np.savez(dest_path, mu=mu, sigma=sigma)
torch.distributed.barrier()
dist.print0('Done.')
#----------------------------------------------------------------------------
if __name__ == "__main__":
main()
#----------------------------------------------------------------------------