-
Notifications
You must be signed in to change notification settings - Fork 32
/
MLProject
42 lines (38 loc) · 1.86 KB
/
MLProject
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
name: hp_search
python_env: python_env.yaml
entry_points:
# Use Hyperopt to optimize hyperparams of the train entry_point.
search_params:
parameters:
training_data: {type: string, default: "hemanthsai7/loandefault"}
max_runs: {type: int, default: 10}
model_type: {type: str, default: "hgbt"}
command: "python -O search_params.py {training_data} --max-runs {max_runs} --model-type {model_type}"
# train Random Forest model with default HPs
train_rf:
parameters:
dset_name: {type: string, default: "sgpjesus/bank-account-fraud-dataset-neurips-2022"}
max_depth: {type: int, default: 5}
max_features: {type: float, default: 0.1}
class_weight: {type: str, default: "balanced"}
min_samples_leaf: {type: int, default: 10}
command: "python train_rf.py {dset_name}
--max-depth {max_depth}
--max-features {max_features}
--class-weight {class_weight}
--min-samples-leaf {min_samples_leaf}"
# train HistGradientBoosted model with default parameters
train_hgbt:
parameters:
dset_name: {type: string, default: "sgpjesus/bank-account-fraud-dataset-neurips-2022"}
max_depth: {type: int, default: 20}
learning_rate: {type: float, default: 0.1}
class_weight: {type: str, default: "balanced"}
max_leaf_nodes: {type: int, default: 31}
l2_regularization: {type: int, default: 1.}
command: "python train_hgbt.py {dset_name}
--max-depth {max_depth}
--learning-rate {learning_rate}
--class-weight {class_weight}
--max-leaf-nodes {max_leaf_nodes}
--l2-regularization {l2_regularization}"