-
Notifications
You must be signed in to change notification settings - Fork 96
/
Copy pathgenerate_plots.py
110 lines (83 loc) · 3.48 KB
/
generate_plots.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import argparse
def make_plots_lr(filenames, title="Dummy Title"):
datas = {}
x = [0.5, 1.0, 2.0, 5.0, 10.0, 20.0, 50.0, 100.0]
for idx, filename in enumerate(filenames):
if '/LR/' in filename:
skip = 2
else:
skip = 0
try:
datas[idx] = pd.read_csv(filename, skiprows=skip, header=0)
except ValueError:
print "Failed to load %s" % filename
rows = ['loss', 'loss_u', 'probas', 'probas_u']
row_names = [r'$L_{test}$', r'$L_{unif}$',
r'probas $L_{test}$', r'probas $L_{unif}$']
methods = [
('base', 'passive'),
('base', 'adapt-local'),
('base', 'adapt-oracle'),
('base', 'lowd-meek'),
('extr', 'passive')]
method_names = ['Agnostic Learning',
'Active Learning',
'Boundary Finding',
'Lowd - Meek',
'Equation Solving (with conf. scores)']
plt.rc('text', usetex=True)
plt.rc('font', family='serif')
fig, axs = plt.subplots(nrows=len(rows), ncols=1, sharex=True,
figsize=(25, 15))
for ax, row in zip(axs, row_names):
ax.set_ylabel(row, size='large')
for row, target in enumerate(rows):
ax = axs[row]
ax.yaxis.grid(True)
for (mode, method), m_name in zip(methods, method_names):
if mode in datas.values()[0]['mode'].values \
and method in datas.values()[0]['method'].values:
ys = []
for data in datas.values():
xy = data[(data['mode'] == mode) &
(data['method'] == method)].sort('budget')
if 'binary' in data['method'].values and mode == 'extr':
xy = data[(data['mode'] == mode) &
(data['method'] == 'binary')]
y = xy[target].values
if len(y) < len(x):
y = np.hstack(([1]*(len(x)-len(y)), y))
if len(y) > len(x):
y = y[-len(x):]
if 'loss' in target:
y = np.maximum(y, 1e-6)
y = np.hstack((y[0],
[min(y[:i+1]) for i in range(1, len(y))]))
ys.append(y)
mean_y = np.mean(np.array(ys), axis=0)
err_y = np.std(np.array(ys), axis=0)
ax.errorbar(x, mean_y, err_y, label=m_name)
#if 'accuracy' in target:
ax.set_yscale('log')
ax.set_ylim(0, None, auto=True)
ax.set_xlim(0, 105)
if row == 0:
ax.legend(loc='upper right')
ax.set_title(title)
if row == len(rows) - 1:
ax.set_xlabel(r'Budget ($\alpha \cdot \texttt{num\_unknowns}$)')
plt.show()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('title', type=str, help='figure title')
parser.add_argument('filenames', nargs='+', type=str,
help='a list of filenames')
args = parser.parse_args()
title = args.title
filenames = args.filenames
make_plots_lr(filenames, title=title)
if __name__ == "__main__":
main()