Skip to content

Commit

Permalink
working gaussian processes, il perform+ttest+fps, rl dataeff+ttest+fps
Browse files Browse the repository at this point in the history
  • Loading branch information
dyth committed Jun 26, 2020
1 parent ce722b0 commit c55c574
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 423 deletions.
8 changes: 7 additions & 1 deletion babyai/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ def model_num_samples(model):
return int(re.findall('_([0-9]+)', model)[0])


def get_fps(df):
data = df['FPS']
data = data.tolist()
return np.array(data)


def best_within_normal_time(df, regex, patience, limit='epochs', window=1, normal_time=None, summary_path=None):
"""
Compute the best success rate that is achieved in all runs within the normal time.
Expand Down Expand Up @@ -254,7 +260,7 @@ def estimate_sample_efficiency(df, visualize=False, figure_path=None):
y = (y - success_threshold) * 100

# fit an RBF GP
kernel = 1.0 * RBF() + WhiteKernel(noise_level_bounds=(1e-10, 3))
kernel = 1.0 * RBF() + WhiteKernel(noise_level_bounds=(1e-10, 10))
gp = GaussianProcessRegressor(kernel=kernel, alpha=0, normalize_y=False).fit(x[:, None], y)
print("Kernel:", gp.kernel_)
print("Marginal likelihood:", gp.log_marginal_likelihood_value_)
Expand Down
3 changes: 1 addition & 2 deletions scripts/il_dataeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import pandas
import os
import json
import shutil

from babyai import plotting

Expand All @@ -19,8 +18,8 @@
args = parser.parse_args()

if os.path.exists(args.report):
shutil.rmtree(args.report)
raise ValueError("report directory already exists")
os.mkdir(args.report)

summary_path = os.path.join(args.report, 'summary.csv')
figure_path = os.path.join(args.report, 'visualization.png')
Expand Down
127 changes: 0 additions & 127 deletions scripts/il_fps.py

This file was deleted.

61 changes: 61 additions & 0 deletions scripts/il_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#!/usr/bin/env python3
import argparse
import pandas
import os
import json
import re
import numpy as np
from scipy import stats

from babyai import plotting as bp


parser = argparse.ArgumentParser("Analyze performance of imitation learning")
parser.add_argument("--path", default='.',
help="path to model logs")
parser.add_argument("--regex", default='.*',
help="filter out some logs")
parser.add_argument("--ttest", default=None,
help="path to model logs for comparison")
parser.add_argument("--ttest_regex", default='.*',
help="filter out some logs from comparison")
parser.add_argument("--window", type=int, default=100,
help="size of sliding window average, 10 for GoToRedBallGrey, 100 otherwise")
args = parser.parse_args()


def get_data(path, regex):
df = pandas.concat(bp.load_logs(path), sort=True)
fps = bp.get_fps(df)
models = df['model'].unique()
models = [model for model in df['model'].unique() if re.match(regex, model)]

maxes = []
for model in models:
df_model = df[df['model'] == model]
success_rate = df_model['validation_success_rate']
success_rate = success_rate.rolling(args.window, center=True).mean()
success_rate = max(success_rate[np.logical_not(np.isnan(success_rate))])
print(model, success_rate)
maxes.append(success_rate)
return np.array(maxes), fps



if args.ttest is not None:
print("is this architecture better")
print(args.regex)
maxes, fps = get_data(args.path, args.regex)
result = {'samples': len(maxes), 'mean': maxes.mean(), 'std': maxes.std(),
'fps_mean': fps.mean(), 'fps_std': fps.std()}
print(result)

if args.ttest is not None:
print("\nthan this one")
maxes_ttest, fps = get_data(args.ttest, args.ttest_regex)
result = {'samples': len(maxes_ttest),
'mean': maxes_ttest.mean(), 'std': maxes_ttest.std(),
'fps_mean': fps.mean(), 'fps_std': fps.std()}
print(result)
ttest = stats.ttest_ind(maxes, maxes_ttest, equal_var=False)
print(f"\n{ttest}")
78 changes: 0 additions & 78 deletions scripts/il_performance.py

This file was deleted.

Loading

0 comments on commit c55c574

Please sign in to comment.