Skip to content

Commit

Permalink
updated gridsearch
Browse files Browse the repository at this point in the history
  • Loading branch information
Axel Montout committed Mar 28, 2024
1 parent b17b89a commit 2156ca3
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 39 deletions.
5 changes: 2 additions & 3 deletions bluepebble.sh

Large diffs are not rendered by default.

47 changes: 24 additions & 23 deletions eval_regularisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,30 @@


def plot_heatmap(df, col, out_dir, title=""):
scores = df[col].values
scores = np.array(scores).reshape(len(df["C"].unique()), len(df["gamma"].unique()))
#plt.figure(figsize=(8, 6))
#plt.subplots_adjust(left=.2, right=0.95, bottom=0.15, top=0.95)
fig, ax = plt.subplots()
im = ax.imshow(scores[::-1, :], interpolation='nearest')
# im = ax.imshow(scores, interpolation='nearest',
# norm=MidpointNormalize(vmin=-.2, midpoint=0.5))
ax.set_xlabel('gamma')
ax.set_ylabel('C')
fig.colorbar(im)
ax.set_xticks(np.arange(len(df["gamma"].unique())),
[np.format_float_scientific(i, 1) for i in df["gamma"].unique()], rotation=45)
ax.set_yticks(np.arange(len(df["C"].unique()))[::-1],
[np.format_float_scientific(i, ) for i in df["C"].unique()])
ax.set_title(f'Regularisation AUC\n{title}')
fig.tight_layout()
fig.show()
out_dir.mkdir(parents=True, exist_ok=True)
filename = f"heatmap_{col}_{title}.png".replace(":", "_").replace(" ", "_")
filepath = out_dir / filename
print(filepath)
fig.savefig(filepath)
df = df.fillna("linear")
for g in df["gamma"].unique():
df_ = df[df["gamma"] == g]
scores = df_[col].values
scores = np.array(scores).reshape(len(df_["C"].unique()), len(df_["gamma"].unique()))
#plt.figure(figsize=(8, 6))
#plt.subplots_adjust(left=.2, right=0.95, bottom=0.15, top=0.95)
fig, ax = plt.subplots()
im = ax.imshow(scores[::-1, :], interpolation='nearest')
# im = ax.imshow(scores, interpolation='nearest',
# norm=MidpointNormalize(vmin=-.2, midpoint=0.5))
# ax.set_xlabel('gamma')
ax.set_ylabel('C')
fig.colorbar(im)
ax.set_xticks(np.arange(len(df_["gamma"].unique())), [i for i in df_["gamma"].unique()], rotation=45)
ax.set_yticks(np.arange(len(df_["C"].unique()))[::-1],
[np.format_float_scientific(i, ) for i in df_["C"].unique()])
ax.set_title(f'Regularisation AUC\n{title}')
fig.tight_layout()
fig.show()
out_dir.mkdir(parents=True, exist_ok=True)
filename = f"heatmap_{g}_{col}_{title}.png".replace(":", "_").replace(" ", "_")
filepath = out_dir / filename
print(filepath)


def plot_fig(df, col, out_dir, title=""):
Expand Down
12 changes: 6 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ def main(
exp_temporal: bool = False,
exp_cross_farm: bool = False,
weather_exp: bool = False,
regularisation_exp: bool = False,
regularisation_exp: bool = True,
output_dir: Path = Path("output"),
delmas_dir_mrnn: Path = Path("datasets/delmas_dataset4_mrnn_7day"),
cedara_dir_mrnn: Path = Path("datasets/cedara_datasetmrnn7_23"),
n_job: int = 5,
enable_regularisation: bool = False,
export_hpc_string: bool = False,
enable_regularisation: bool = True,
export_hpc_string: bool = True,
plot_2d_space: bool = False
):
"""Thesis script runs all key experiments for data exploration chapter
Expand All @@ -33,8 +33,8 @@ def main(
for steps in steps_list:
slug = "_".join(steps)
for clf in ["rbf"]:
for i_day in [6]:
for a_day in [7]:
for i_day in [1,2,3,4,5,6]:
for a_day in [1,2,3,4,5,6,7]:
if i_day >= a_day:
continue
for cv in ["RepeatedKFold"]:
Expand Down Expand Up @@ -63,7 +63,7 @@ def main(
],
study_id=farm_id,
export_fig_as_pdf=False,
plot_2d_space=True,
plot_2d_space=plot_2d_space,
pre_visu=False,
export_hpc_string=export_hpc_string,
skip=False,
Expand Down
6 changes: 3 additions & 3 deletions model/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,9 +916,9 @@ def cross_validate_svm_fast(
if kernel in ["linear", "rbf"]:
if enable_regularisation:
print("Grid search...")
svc = SVC(kernel=kernel, probability=True)
parameters["kernel"] = [kernel]
clf = GridSearchCV(svc, parameters, return_train_score=True, cv=10, scoring='roc_auc', n_jobs=-1)
svc = SVC(probability=True)
# parameters["kernel"] = [kernel]
clf = GridSearchCV(svc, parameters, return_train_score=True, cv=5, scoring='roc_auc', n_jobs=-1)
else:
if C is not None and gamma is not None:
clf = SVC(kernel=kernel, probability=True, C=C, gamma=gamma)
Expand Down
12 changes: 8 additions & 4 deletions var.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,15 @@
# ],
# }

parameters = {
"C": [1e-100, 1.7782794100389228e-75, 3.1622776601683795e-50, 5.623413251903491e-25, 10.0],
"gamma": [1e-100, 1e-25, 1e+50, 1e+125, 1e+200]
}
parameters = [
{'C': [0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000], 'kernel': ['linear']},
{'C': [0.000001, 0.00001, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100, 1000], 'gamma': ['scale'], 'kernel': ['rbf']},
]

# parameters = [
# {'C': [0.0001, 100], 'kernel': ['linear']},
# {'C': [0.0001, 100], 'gamma': ['scale'], 'kernel': ['rbf']},
# ]

transponders_delmas = [
"40101310316",
Expand Down

0 comments on commit 2156ca3

Please sign in to comment.