forked from mila-iqia/babyai
-
Notifications
You must be signed in to change notification settings - Fork 0
/
rl_dataeff.py
73 lines (61 loc) · 2.35 KB
/
rl_dataeff.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/usr/bin/env python3
import os
import re
import babyai.plotting as bp
import pandas
import argparse
import json
import numpy as np
from scipy import stats
parser = argparse.ArgumentParser("Analyze data efficiency of reinforcement learning")
parser.add_argument("--path", default='.',
help="path to model logs")
parser.add_argument("--regex", default='.*',
help="filter out some logs")
parser.add_argument("--other", default=None,
help="path to model logs for ttest comparison")
parser.add_argument("--other_regex", default='.*',
help="filter out some logs from comparison")
parser.add_argument("--window", type=int, default=100,
help="size of sliding window average, 10 for GoToRedBallGrey, 100 otherwise")
args = parser.parse_args()
def dataeff(df_model, window):
smoothed_sr = df_model['success_rate'].rolling(window, center=True).mean()
if smoothed_sr.max() < 0.99:
print('not done, success rate is only {}% so far'.format(100 * smoothed_sr.max()))
return int(1e9)
return df_model[smoothed_sr >= 0.99].iloc[0].episodes
def get_data(path, regex):
print(path)
print(regex)
df = pandas.concat(bp.load_logs(path), sort=True)
fps = bp.get_fps(df)
models = df['model'].unique()
models = [model for model in df['model'].unique() if re.match(regex, model)]
data = []
for model in models:
x = df[df['model'] == model]
eff = float(dataeff(x, args.window))
print(model, eff)
if eff != 1e9:
data.append(eff)
return np.array(data), fps
if args.other is not None:
print("is this architecture better")
Z = 2.576
data, fps = get_data(args.path, args.regex)
result = {'samples': len(data), 'mean': data.mean(), 'std': data.std(),
'min': data.mean() - Z * data.std(), 'max': data.mean() + Z * data.std(),
'fps_mean': fps.mean(), 'fps_std': fps.std()}
print(result)
if args.other is not None:
print("\nthan this one")
data_ttest, fps = get_data(args.other, args.other_regex)
result = {'samples': len(data_ttest),
'mean': data_ttest.mean(), 'std': data_ttest.std(),
'min': data_ttest.mean() - Z * data_ttest.std(),
'max': data_ttest.mean() + Z * data_ttest.std(),
'fps_mean': fps.mean(), 'fps_std': fps.std()}
print(result)
ttest = stats.ttest_ind(data, data_ttest, equal_var=False)
print(f"\n{ttest}")