Skip to content

Commit

Permalink
add PySparkOvertimeMonitor to avoid exceeding time budget (microsoft#923
Browse files Browse the repository at this point in the history
)

* merging

* clean commit

* Delete mylearner.py

This file is not needed.

* fix py4j import error

* more tolerant cancelling time

* fix problems following suggestions

* Update flaml/tune/spark/utils.py

Co-authored-by: Li Jiang <[email protected]>

* remove redundant model

* Update test/spark/custom_mylearner.py

Co-authored-by: Chi Wang <[email protected]>

* add docstr

* reverse change in gitignore

* Update test/spark/custom_mylearner.py

Co-authored-by: Chi Wang <[email protected]>

---------

Co-authored-by: Li Jiang <[email protected]>
Co-authored-by: Chi Wang <[email protected]>
  • Loading branch information
3 people authored Feb 24, 2023
1 parent 4118c8e commit c6a2440
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Project
# Project
/.vs
.vscode

Expand Down
10 changes: 10 additions & 0 deletions flaml/automl/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,8 @@ def custom_metric(
on disk when deleting automl. By default the checkpoint is preserved.
early_stop: boolean, default=False | Whether to stop early if the
search is considered to converge.
force_cancel: boolean, default=False | Whether to forcely cancel Spark jobs if the
search time exceeded the time budget.
append_log: boolean, default=False | Whetehr to directly append the log
records to the input log file if it exists.
auto_augment: boolean, default=True | Whether to automatically
Expand Down Expand Up @@ -785,6 +787,7 @@ def custom_metric(
settings["keep_search_state"] = settings.get("keep_search_state", False)
settings["preserve_checkpoint"] = settings.get("preserve_checkpoint", True)
settings["early_stop"] = settings.get("early_stop", False)
settings["force_cancel"] = settings.get("force_cancel", False)
settings["append_log"] = settings.get("append_log", False)
settings["min_sample_size"] = settings.get("min_sample_size", MIN_SAMPLE_TRAIN)
settings["use_ray"] = settings.get("use_ray", False)
Expand Down Expand Up @@ -2207,6 +2210,7 @@ def fit(
keep_search_state=None,
preserve_checkpoint=True,
early_stop=None,
force_cancel=None,
append_log=None,
auto_augment=None,
min_sample_size=None,
Expand Down Expand Up @@ -2396,6 +2400,7 @@ def custom_metric(
on disk when deleting automl. By default the checkpoint is preserved.
early_stop: boolean, default=False | Whether to stop early if the
search is considered to converge.
force_cancel: boolean, default=False | Whether to forcely cancel the PySpark job if overtime.
append_log: boolean, default=False | Whetehr to directly append the log
records to the input log file if it exists.
auto_augment: boolean, default=True | Whether to automatically
Expand Down Expand Up @@ -2598,6 +2603,9 @@ def cv_score_agg_func(val_loss_folds, log_metrics_folds):
early_stop = (
self._settings.get("early_stop") if early_stop is None else early_stop
)
force_cancel = (
self._settings.get("force_cancel") if force_cancel is None else force_cancel
)
# no search budget is provided?
no_budget = time_budget < 0 and max_iter is None and not early_stop
append_log = (
Expand Down Expand Up @@ -2648,6 +2656,7 @@ def cv_score_agg_func(val_loss_folds, log_metrics_folds):
self._n_concurrent_trials = n_concurrent_trials
self._early_stop = early_stop
self._use_spark = use_spark
self._force_cancel = force_cancel
self._use_ray = use_ray
# use the following condition if we have an estimation of average_trial_time and average_trial_overhead
# self._use_ray = use_ray or n_concurrent_trials > ( average_trial_time + average_trial_overhead) / (average_trial_time)
Expand Down Expand Up @@ -3174,6 +3183,7 @@ def _search_parallel(self):
verbose=max(self.verbose - 2, 0),
use_ray=False,
use_spark=True,
force_cancel=self._force_cancel,
# raise_on_failed_trial=False,
# keep_checkpoints_num=1,
# checkpoint_score_attr="min-val_loss",
Expand Down
124 changes: 121 additions & 3 deletions flaml/tune/spark/utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
import os
import logging
from functools import partial, lru_cache
import os
import textwrap
import threading
import time
from functools import lru_cache, partial


logger = logging.getLogger(__name__)
logger_formatter = logging.Formatter(
"[%(name)s: %(asctime)s] {%(lineno)d} %(levelname)s - %(message)s", "%m-%d %H:%M:%S"
)

try:
import pyspark
from pyspark.sql import SparkSession
from pyspark.util import VersionUtils
import pyspark
import py4j

_have_spark = True
_spark_major_minor_version = VersionUtils.majorMinorVersion(pyspark.__version__)
except ImportError as e:
logger.debug("Could not import pyspark: %s", e)
_have_spark = False
py4j = None
_spark_major_minor_version = (0, 0)


Expand Down Expand Up @@ -187,3 +192,116 @@ def get_broadcast_data(broadcast_data):
if _have_spark and isinstance(broadcast_data, pyspark.broadcast.Broadcast):
broadcast_data = broadcast_data.value
return broadcast_data


class PySparkOvertimeMonitor:
"""A context manager class to monitor if the PySpark job is overtime.
Example:
```python
with PySparkOvertimeMonitor(time_start, time_budget_s, force_cancel, parallel=parallel):
results = parallel(
delayed(evaluation_function)(trial_to_run.config)
for trial_to_run in trials_to_run
)
```
"""

def __init__(
self,
start_time,
time_budget_s,
force_cancel=False,
cancel_func=None,
parallel=None,
sc=None,
):
"""Constructor.
Specify the time budget and start time of the PySpark job, and specify how to cancel them.
Args:
Args relate to monitoring:
start_time: float | The start time of the PySpark job.
time_budget_s: float | The time budget of the PySpark job in seconds.
force_cancel: boolean, default=False | Whether to forcely cancel the PySpark job if overtime.
Args relate to how to cancel the PySpark job:
(Only one of the following args will work. Priorities from top to bottom)
cancel_func: function | A function to cancel the PySpark job.
parallel: joblib.parallel.Parallel | Specify this if using joblib_spark as a parallel backend. It will call parallel._backend.terminate() to cancel the jobs.
sc: pyspark.SparkContext object | You can pass a specific SparkContext.
If all three args is None, the monitor will call pyspark.SparkContext.getOrCreate().cancelAllJobs() to cancel the jobs.
"""
self._time_budget_s = time_budget_s
self._start_time = start_time
self._force_cancel = force_cancel
# TODO: add support for non-spark scenario
if self._force_cancel and _have_spark:
self._monitor_daemon = None
self._finished_flag = False
self._cancel_flag = False
self.sc = None
if cancel_func:
self.__cancel_func = cancel_func
elif parallel:
self.__cancel_func = parallel._backend.terminate
elif sc:
self.sc = sc
self.__cancel_func = self.sc.cancelAllJobs
else:
self.__cancel_func = pyspark.SparkContext.getOrCreate().cancelAllJobs
# logger.info(self.__cancel_func)

def _monitor_overtime(self):
"""The lifecycle function for monitor thread."""
if self._time_budget_s is None:
self.__cancel_func()
self._cancel_flag = True
return
while time.time() - self._start_time <= self._time_budget_s:
time.sleep(0.01)
if self._finished_flag:
return
self.__cancel_func()
self._cancel_flag = True
return

def _setLogLevel(self, level):
"""Set the log level of the spark context.
Set the level to OFF could block the warning message of Spark."""
if self.sc:
self.sc.setLogLevel(level)
else:
pyspark.SparkContext.getOrCreate().setLogLevel(level)

def __enter__(self):
"""Enter the context manager.
This will start a monitor thread if spark is available and force_cancel is True."""
if self._force_cancel and _have_spark:
self._monitor_daemon = threading.Thread(target=self._monitor_overtime)
# logger.setLevel("INFO")
logger.info("monitor started")
self._setLogLevel("OFF")
self._monitor_daemon.start()

def __exit__(self, exc_type, exc_value, exc_traceback):
"""Exit the context manager.
This will wait for the monitor thread to nicely exit."""
if self._force_cancel and _have_spark:
self._finished_flag = True
self._monitor_daemon.join()
if self._cancel_flag:
print()
logger.warning("Time exceeded, canceled jobs")
# self._setLogLevel("WARN")
if not exc_type:
return True
elif exc_type == py4j.protocol.Py4JJavaError:
return True
else:
return False
18 changes: 13 additions & 5 deletions flaml/tune/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .trial import Trial
from .result import DEFAULT_METRIC
import logging
from flaml.tune.spark.utils import PySparkOvertimeMonitor

logger = logging.getLogger(__name__)
logger.propagate = False
Expand Down Expand Up @@ -246,6 +247,7 @@ def run(
use_incumbent_result_in_evaluation: Optional[bool] = None,
log_file_name: Optional[str] = None,
lexico_objectives: Optional[dict] = None,
force_cancel: Optional[bool] = False,
**ray_args,
):
"""The trigger for HPO.
Expand Down Expand Up @@ -730,10 +732,14 @@ def easy_objective(config):
logger.debug(
f"Configs of Trials to run: {[trial_to_run.config for trial_to_run in trials_to_run]}"
)
results = parallel(
delayed(evaluation_function)(trial_to_run.config)
for trial_to_run in trials_to_run
)
results = None
with PySparkOvertimeMonitor(
time_start, time_budget_s, force_cancel, parallel=parallel
):
results = parallel(
delayed(evaluation_function)(trial_to_run.config)
for trial_to_run in trials_to_run
)
# results = [evaluation_function(trial_to_run.config) for trial_to_run in trials_to_run]
while results:
result = results.pop(0)
Expand Down Expand Up @@ -803,7 +809,9 @@ def easy_objective(config):
num_trials += 1
if verbose:
logger.info(f"trial {num_trials} config: {trial_to_run.config}")
result = evaluation_function(trial_to_run.config)
result = None
with PySparkOvertimeMonitor(time_start, time_budget_s, force_cancel):
result = evaluation_function(trial_to_run.config)
if result is not None:
if isinstance(result, dict):
if result:
Expand Down
31 changes: 31 additions & 0 deletions test/spark/custom_mylearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

custom_code = """
from flaml import tune
import time
from flaml.automl.model import LGBMEstimator, XGBoostSklearnEstimator, SKLearnEstimator
from flaml.automl.data import CLASSIFICATION, get_output_from_log
Expand Down Expand Up @@ -91,6 +92,7 @@ def search_space(cls, **params):
}
def custom_metric(
X_val,
y_val,
Expand Down Expand Up @@ -119,6 +121,35 @@ def custom_metric(
"train_loss": train_loss,
"pred_time": pred_time,
}
def lazy_metric(
X_val,
y_val,
estimator,
labels,
X_train,
y_train,
weight_val=None,
weight_train=None,
config=None,
groups_val=None,
groups_train=None,
):
from sklearn.metrics import log_loss
time.sleep(2)
start = time.time()
y_pred = estimator.predict_proba(X_val)
pred_time = (time.time() - start) / len(X_val)
val_loss = log_loss(y_val, y_pred, labels=labels, sample_weight=weight_val)
y_pred = estimator.predict_proba(X_train)
train_loss = log_loss(y_train, y_pred, labels=labels, sample_weight=weight_train)
alpha = 0.5
return val_loss * (1 + alpha) - alpha * train_loss, {
"val_loss": val_loss,
"train_loss": train_loss,
"pred_time": pred_time,
}
"""

_ = broadcast_code(custom_code=custom_code)
70 changes: 70 additions & 0 deletions test/spark/test_overtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os
import time

import numpy as np
import pyspark
import pytest
from sklearn.datasets import load_iris

from flaml import AutoML
from flaml.tune.spark.utils import check_spark

try:
from test.spark.custom_mylearner import *
except ImportError:
from custom_mylearner import *

from flaml.tune.spark.mylearner import lazy_metric

os.environ["FLAML_MAX_CONCURRENT"] = "10"

spark = pyspark.sql.SparkSession.builder.appName("App4OvertimeTest").getOrCreate()
spark_available, _ = check_spark()
skip_spark = not spark_available

pytestmark = pytest.mark.skipif(
skip_spark, reason="Spark is not installed. Skip all spark tests."
)


def test_overtime():
time_budget = 15
df, y = load_iris(return_X_y=True, as_frame=True)
df["label"] = y
automl_experiment = AutoML()
automl_settings = {
"dataframe": df,
"label": "label",
"time_budget": time_budget,
"eval_method": "cv",
"metric": lazy_metric,
"task": "classification",
"log_file_name": "test/iris_custom.log",
"log_training_metric": True,
"log_type": "all",
"n_jobs": 1,
"model_history": True,
"sample_weight": np.ones(len(y)),
"pred_time_limit": 1e-5,
"estimator_list": ["lgbm"],
"n_concurrent_trials": 2,
"use_spark": True,
"force_cancel": True,
}
start_time = time.time()
automl_experiment.fit(**automl_settings)
elapsed_time = time.time() - start_time
print(
"time budget: {:.2f}s, actual elapsed time: {:.2f}s".format(
time_budget, elapsed_time
)
)
assert abs(elapsed_time - time_budget) < 2
print(automl_experiment.predict(df))
print(automl_experiment.model)
print(automl_experiment.best_iteration)
print(automl_experiment.best_estimator)


if __name__ == "__main__":
test_overtime()

0 comments on commit c6a2440

Please sign in to comment.