forked from google/compare_gan
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
2 changed files
with
448 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,297 @@ | ||
# coding=utf-8 | ||
# Copyright 2018 Google LLC & Hwalsuk Lee. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Precision and recall computation based on samples from two distributions. | ||
Given a sample from the true and the fake distribution embedded in some feature | ||
space (say, Inception), it computes the precision and recall via the algorithm | ||
presented in [arxiv.org/abs/1806.00035]. Finally, one can plot the resulting | ||
curves for different models. | ||
Typical usage example: | ||
import prd | ||
prd_data_1 = prd.compute_prd_from_embedding(eval_feats_1, ref_feats_1) | ||
prd_data_2 = prd.compute_prd_from_embedding(eval_feats_2, ref_feats_2) | ||
prd.plot([prd_data_1, prd_data_2], ['GAN_1', 'GAN_2']) | ||
""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from matplotlib import pyplot as plt | ||
import numpy as np | ||
import sklearn.cluster | ||
|
||
|
||
def compute_prd(eval_dist, ref_dist, num_angles=1001, epsilon=1e-10): | ||
"""Computes the PRD curve for discrete distributions. | ||
This function computes the PRD curve for the discrete distribution eval_dist | ||
with respect to the reference distribution ref_dist. This implements the | ||
algorithm in [arxiv.org/abs/1806.2281349]. The PRD will be computed for an | ||
equiangular grid of num_angles values between [0, pi/2]. | ||
Args: | ||
eval_dist: 1D NumPy array or list of floats with the probabilities of the | ||
different states under the distribution to be evaluated. | ||
ref_dist: 1D NumPy array or list of floats with the probabilities of the | ||
different states under the reference distribution. | ||
num_angles: Number of angles for which to compute PRD. Must be in [3, 1e6]. | ||
The default value is 1001. | ||
epsilon: Angle for PRD computation in the edge cases 0 and pi/2. The PRD | ||
will be computes for epsilon and pi/2-epsilon, respectively. | ||
The default value is 1e-10. | ||
Returns: | ||
precision: NumPy array of shape [num_angles] with the precision for the | ||
different ratios. | ||
recall: NumPy array of shape [num_angles] with the recall for the different | ||
ratios. | ||
Raises: | ||
ValueError: If not 0 < epsilon <= 0.1. | ||
ValueError: If num_angles < 3. | ||
""" | ||
|
||
if not (epsilon > 0 and epsilon < 0.1): | ||
raise ValueError('epsilon must be in (0, 0.1] but is %s.' % str(epsilon)) | ||
if not (num_angles >= 3 and num_angles <= 1e6): | ||
raise ValueError('num_angles must be in [3, 1e6] but is %d.' % num_angles) | ||
|
||
# Compute slopes for linearly spaced angles between [0, pi/2] | ||
angles = np.linspace(epsilon, np.pi/2 - epsilon, num=num_angles) | ||
slopes = np.tan(angles) | ||
|
||
# Broadcast slopes so that second dimension will be states of the distribution | ||
slopes_2d = np.expand_dims(slopes, 1) | ||
|
||
# Broadcast distributions so that first dimension represents the angles | ||
ref_dist_2d = np.expand_dims(ref_dist, 0) | ||
eval_dist_2d = np.expand_dims(eval_dist, 0) | ||
|
||
# Compute precision and recall for all angles in one step via broadcasting | ||
precision = np.minimum(ref_dist_2d*slopes_2d, eval_dist_2d).sum(axis=1) | ||
recall = precision / slopes | ||
return precision, recall | ||
|
||
|
||
def _cluster_into_bins(eval_data, ref_data, num_clusters): | ||
"""Clusters the union of the data points and returns the cluster distribution. | ||
Clusters the union of eval_data and ref_data into num_clusters using minibatch | ||
k-means. Then, for each cluster, it computes the number of points from | ||
eval_data and ref_data. | ||
Args: | ||
eval_data: NumPy array of data points from the distribution to be evaluated. | ||
ref_data: NumPy array of data points from the reference distribution. | ||
num_clusters: Number of cluster centers to fit. | ||
Returns: | ||
Two NumPy arrays, each of size num_clusters, where i-th entry represents the | ||
number of points assigned to the i-th cluster. | ||
""" | ||
|
||
cluster_data = np.vstack([eval_data, ref_data]) | ||
kmeans = sklearn.cluster.MiniBatchKMeans(n_clusters=num_clusters, n_init=10) | ||
labels = kmeans.fit(cluster_data).labels_ | ||
|
||
eval_labels = labels[:len(eval_data)] | ||
ref_labels = labels[len(eval_data):] | ||
|
||
eval_bins = np.histogram(eval_labels, bins=num_clusters, | ||
range=[0, num_clusters], density=True)[0] | ||
ref_bins = np.histogram(ref_labels, bins=num_clusters, | ||
range=[0, num_clusters], density=True)[0] | ||
return eval_bins, ref_bins | ||
|
||
|
||
def compute_prd_from_embedding(eval_data, ref_data, num_clusters=20, | ||
num_angles=1001, num_runs=10, | ||
enforce_balance=True): | ||
"""Computes PRD data from sample embeddings. | ||
The points from both distributions are mixed and then clustered. This leads | ||
to a pair of histograms of discrete distributions over the cluster centers | ||
on which the PRD algorithm is executed. | ||
The number of points in eval_data and ref_data must be equal since | ||
unbalanced distributions bias the clustering towards the larger dataset. The | ||
check can be disabled by setting the enforce_balance flag to False (not | ||
recommended). | ||
Args: | ||
eval_data: NumPy array of data points from the distribution to be evaluated. | ||
ref_data: NumPy array of data points from the reference distribution. | ||
num_clusters: Number of cluster centers to fit. The default value is 20. | ||
num_angles: Number of angles for which to compute PRD. Must be in [3, 1e6]. | ||
The default value is 1001. | ||
num_runs: Number of independent runs over which to average the PRD data. | ||
enforce_balance: If enabled, throws exception if eval_data and ref_data do | ||
not have the same length. The default value is True. | ||
Returns: | ||
precision: NumPy array of shape [num_angles] with the precision for the | ||
different ratios. | ||
recall: NumPy array of shape [num_angles] with the recall for the different | ||
ratios. | ||
Raises: | ||
ValueError: If len(eval_data) != len(ref_data) and enforce_balance is set to | ||
True. | ||
""" | ||
|
||
if enforce_balance and len(eval_data) != len(ref_data): | ||
raise ValueError( | ||
'The number of points in eval_data %d is not equal to the number of ' | ||
'points in ref_data %d. To disable this exception, set enforce_balance ' | ||
'to False (not recommended).' % (len(eval_data), len(ref_data))) | ||
|
||
eval_data = np.array(eval_data, dtype=np.float64) | ||
ref_data = np.array(ref_data, dtype=np.float64) | ||
precisions = [] | ||
recalls = [] | ||
for _ in range(num_runs): | ||
eval_dist, ref_dist = _cluster_into_bins(eval_data, ref_data, num_clusters) | ||
precision, recall = compute_prd(eval_dist, ref_dist, num_angles) | ||
precisions.append(precision) | ||
recalls.append(recall) | ||
precision = np.mean(precisions, axis=0) | ||
recall = np.mean(recalls, axis=0) | ||
return precision, recall | ||
|
||
|
||
def _prd_to_f_beta(precision, recall, beta=1, epsilon=1e-10): | ||
"""Computes F_beta scores for the given precision/recall values. | ||
The F_beta scores for all precision/recall pairs will be computed and | ||
returned. | ||
For precision p and recall r, the F_beta score is defined as: | ||
F_beta = (1 + beta^2) * (p * r) / ((beta^2 * p) + r) | ||
Args: | ||
precision: 1D NumPy array of precision values in [0, 1]. | ||
recall: 1D NumPy array of precision values in [0, 1]. | ||
beta: Beta parameter. Must be positive. The default value is 1. | ||
epsilon: Small constant to avoid numerical instability caused by division | ||
by 0 when precision and recall are close to zero. | ||
Returns: | ||
NumPy array of same shape as precision and recall with the F_beta scores for | ||
each pair of precision/recall. | ||
Raises: | ||
ValueError: If any value in precision or recall is outside of [0, 1]. | ||
ValueError: If beta is not positive. | ||
""" | ||
|
||
if not ((precision >= 0).all() and (precision <= 1).all()): | ||
raise ValueError('All values in precision must be in [0, 1].') | ||
if not ((recall >= 0).all() and (recall <= 1).all()): | ||
raise ValueError('All values in recall must be in [0, 1].') | ||
if beta <= 0: | ||
raise ValueError('Given parameter beta %s must be positive.' % str(beta)) | ||
|
||
return (1 + beta**2) * (precision * recall) / ( | ||
(beta**2 * precision) + recall + epsilon) | ||
|
||
|
||
def prd_to_max_f_beta_pair(precision, recall, beta=8): | ||
"""Computes max. F_beta and max. F_{1/beta} for precision/recall pairs. | ||
Computes the maximum F_beta and maximum F_{1/beta} score over all pairs of | ||
precision/recall values. This is useful to compress a PRD plot into a single | ||
pair of values which correlate with precision and recall. | ||
For precision p and recall r, the F_beta score is defined as: | ||
F_beta = (1 + beta^2) * (p * r) / ((beta^2 * p) + r) | ||
Args: | ||
precision: 1D NumPy array or list of precision values in [0, 1]. | ||
recall: 1D NumPy array or list of precision values in [0, 1]. | ||
beta: Beta parameter. Must be positive. The default value is 8. | ||
Returns: | ||
f_beta: Maximum F_beta score. | ||
f_beta_inv: Maximum F_{1/beta} score. | ||
Raises: | ||
ValueError: If beta is not positive. | ||
""" | ||
|
||
if not ((precision >= 0).all() and (precision <= 1).all()): | ||
raise ValueError('All values in precision must be in [0, 1].') | ||
if not ((recall >= 0).all() and (recall <= 1).all()): | ||
raise ValueError('All values in recall must be in [0, 1].') | ||
if beta <= 0: | ||
raise ValueError('Given parameter beta %s must be positive.' % str(beta)) | ||
|
||
f_beta = np.max(_prd_to_f_beta(precision, recall, beta)) | ||
f_beta_inv = np.max(_prd_to_f_beta(precision, recall, 1/beta)) | ||
return f_beta, f_beta_inv | ||
|
||
|
||
def plot(precision_recall_pairs, labels=None, out_path=None, | ||
legend_loc='lower left', dpi=150): | ||
"""Plots precision recall curves for distributions. | ||
Creates the PRD plot for the given data and stores the plot in a given path. | ||
Args: | ||
precision_recall_pairs: List of prd_data to plot. Each item in this list is | ||
a 2D array of precision and recall values for the | ||
same number of ratios. | ||
labels: Optional list of labels of same length as list_of_prd_data. The | ||
default value is None. | ||
out_path: Output path for the resulting plot. If None, the plot will be | ||
opened via plt.show(). The default value is None. | ||
legend_loc: Location of the legend. The default value is 'lower left'. | ||
dpi: Dots per inch (DPI) for the figure. The default value is 150. | ||
Raises: | ||
ValueError: If labels is a list of different length than list_of_prd_data. | ||
""" | ||
|
||
if labels is not None and len(labels) != len(precision_recall_pairs): | ||
raise ValueError( | ||
'Length of labels %d must be identical to length of ' | ||
'precision_recall_pairs %d.' | ||
% (len(labels), len(precision_recall_pairs))) | ||
|
||
fig = plt.figure(figsize=(3.5, 3.5), dpi=dpi) | ||
plot_handle = fig.add_subplot(111) | ||
plot_handle.tick_params(axis='both', which='major', labelsize=12) | ||
|
||
for i in range(len(precision_recall_pairs)): | ||
precision, recall = precision_recall_pairs[i] | ||
label = labels[i] if labels is not None else None | ||
plt.plot(recall, precision, label=label, alpha=0.5, linewidth=3) | ||
|
||
if labels is not None: | ||
plt.legend(loc=legend_loc) | ||
|
||
plt.xlim([0, 1]) | ||
plt.ylim([0, 1]) | ||
plt.xlabel('Recall', fontsize=12) | ||
plt.ylabel('Precision', fontsize=12) | ||
plt.tight_layout() | ||
if out_path is None: | ||
plt.show() | ||
else: | ||
plt.savefig(out_path, bbox_inches='tight', dpi=dpi) | ||
plt.close() |
Oops, something went wrong.