Skip to content

Commit

Permalink
Add plot and function for test loss and accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
steber97 committed Jun 12, 2020
1 parent 1843bac commit 8984d3a
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 11 deletions.
63 changes: 52 additions & 11 deletions notebooks/visu_stat.ipynb

Large diffs are not rendered by default.

75 changes: 75 additions & 0 deletions visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,23 @@ def filter_df(df):
Args:
df (pd.DataFrame): info issue from the json file
"""
df = df[df['test_losses'].isnull()]
drop_indexes = df[df.val_losses.apply(lambda x:any([None in y for y in x]))].index
return df.drop(drop_indexes)


def filter_df_test(df):
"""
return the same df while dropping any row whose validation loss contains Nan values
Args:
df (pd.DataFrame): info issue from the json file
"""
df = df[~df['test_losses'].isnull()]
drop_indexes = df[df.test_losses.apply(lambda x:any([None in y for y in x]))].index
return df.drop(drop_indexes)


def plot_losses_fits(losses,
ax,
colors,
Expand Down Expand Up @@ -148,6 +161,21 @@ def get_concat_losses(df,train : bool):
concat_losses = concat_losses + losses
return concat_losses


def get_concat_test_losses(df,train : bool):
"""
concat the losses for the given df into a single list
Args:
df (pd.DataFrame): info issue from the json file
train (Bool): train or validation set
"""
concat_losses = []
tot_losses = df.train_losses if train else df.test_losses
for losses in tot_losses.values:
concat_losses = concat_losses + losses
return concat_losses

def plot_grid_search(df,
ax,
plot_SGD=True,
Expand Down Expand Up @@ -192,3 +220,50 @@ def plot_grid_search(df,
bottom,top = ax.get_ylim()
ax.set_yticks(np.linspace(bottom,top,5))
ax.get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())


def plot_grid_search_test(df,
ax,
plot_SGD=True,
plot_Adam=True,
plot_AdamW=True,
plot_runs=True,
plot_fit=False,
plot_mean=False,
fit_type="log",
train=True):
"""
plot the runs of the computed grid search according to many display parameters
Args:
df (pd.DataFrame): info issue from the json file
ax (plt.ax): matplotlib ax
plot_XXX (bool): whether to plot the XXX optimizer
plot_runs (bool): plot the runs
plot_fit (bool): plot the fitted curve
plot_mean (bool): plot the mean of the runs
fit_type (bool): whether to fit the regression in a log-log or lin-log space
train (Bool): train or validation set
"""


args = [plot_runs,plot_fit,plot_mean,fit_type]
if plot_SGD:
sgd_df = df[df.optimizer == 'SGD']
sgd_concat_losses = get_concat_test_losses(sgd_df,train)
col1 = (colSGD,"#00ff3c")
reg1 = plot_losses_fits(sgd_concat_losses, ax, col1,*args,label="SGD")
if plot_Adam:
adam_df = df[df.optimizer == 'Adam']
adam_concat_losses = get_concat_test_losses(adam_df,train)
col2 = (coladam,"#08f0fc")
reg2 = plot_losses_fits(adam_concat_losses, ax,col2,*args,label="Adam")
if plot_AdamW:
adamW_df = df[df.optimizer == 'AdamW']
adamW_concat_losses = get_concat_test_losses(adamW_df,train)
col3 = (coladamW, "yellow")
reg3 = plot_losses_fits(adamW_concat_losses, ax,col3,*args,label="AdamW")
bottom,top = ax.get_ylim()
ax.set_yticks(np.linspace(bottom,top,5))
ax.get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())

0 comments on commit 8984d3a

Please sign in to comment.