diff --git a/benchmarks/bench_plot_approximate_neighbors.py b/benchmarks/bench_plot_approximate_neighbors.py index 3354c38170e9d..fc8d394b5307c 100644 --- a/benchmarks/bench_plot_approximate_neighbors.py +++ b/benchmarks/bench_plot_approximate_neighbors.py @@ -125,20 +125,16 @@ def calc_accuracy(X, queries, n_queries, n_neighbors, exact_neighbors, # Set labels for LSHForest parameters colors = ['c', 'm', 'y'] - p1 = plt.Rectangle((0, 0), 0.1, 0.1, fc=colors[0]) - p2 = plt.Rectangle((0, 0), 0.1, 0.1, fc=colors[1]) - p3 = plt.Rectangle((0, 0), 0.1, 0.1, fc=colors[2]) + legend_rects = [plt.Rectangle((0, 0), 0.1, 0.1, fc=color) + for color in colors] - labels = ['n_estimators=' + str(params_list[0]['n_estimators']) + - ', n_candidates=' + str(params_list[0]['n_candidates']), - 'n_estimators=' + str(params_list[1]['n_estimators']) + - ', n_candidates=' + str(params_list[1]['n_candidates']), - 'n_estimators=' + str(params_list[2]['n_estimators']) + - ', n_candidates=' + str(params_list[2]['n_candidates'])] + legend_labels = ['n_estimators={n_estimators}, ' + 'n_candidates={n_candidates}'.format(**p) + for p in params_list] # Plot precision plt.figure() - plt.legend((p1, p2, p3), (labels[0], labels[1], labels[2]), + plt.legend(legend_rects, legend_labels, loc='upper left') for i in range(len(params_list)): @@ -154,7 +150,7 @@ def calc_accuracy(X, queries, n_queries, n_neighbors, exact_neighbors, # Plot speed up plt.figure() - plt.legend((p1, p2, p3), (labels[0], labels[1], labels[2]), + plt.legend(legend_rects, legend_labels, loc='upper left') for i in range(len(params_list)):