Skip to content

Commit 5760ed3

Browse files
mufaddal-rohawalaXinlu Tuxinlutu2
authored andcommitted
feature: Support Amazon SageMaker AutoMLStep
Co-authored-by: Xinlu Tu <[email protected]> Co-authored-by: Xinlu Tu <[email protected]> @xinlutu2 feat: Close feature gaps between Python SageMaker SDK and CreateAutoMLJob API includes ENSEMBLING mode @xinlutu2 feature: add AutoMLStep for SageMaker Pipelines Workflows @xinlutu2 feature: add AutoMLStep integration test
1 parent a35a093 commit 5760ed3

File tree

14 files changed

+1458
-103
lines changed

14 files changed

+1458
-103
lines changed

doc/workflows/pipelines/sagemaker.workflow.pipelines.rst

+2
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,5 @@ Steps
168168
.. autoclass:: sagemaker.workflow.clarify_check_step.ClarifyCheckStep
169169

170170
.. autoclass:: sagemaker.workflow.fail_step.FailStep
171+
172+
.. autoclass:: sagemaker.workflow.automl_step.AutoMLStep

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def read_requirements(filename):
4848
# Declare minimal set for installation
4949
required_packages = [
5050
"attrs>=20.3.0,<23",
51-
"boto3>=1.20.21,<2.0",
51+
"boto3>=1.24.69,<2.0",
5252
"google-pasta",
5353
"numpy>=1.9.0,<2.0",
5454
"protobuf>=3.1,<4.0",

src/sagemaker/automl/automl.py

+258-68
Large diffs are not rendered by default.

src/sagemaker/exceptions.py

+14
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,17 @@ class UnexpectedClientError(AsyncInferenceError):
6363

6464
def __init__(self, message):
6565
super().__init__(message=message)
66+
67+
68+
class AutoMLStepInvalidModeError(Exception):
69+
"""Raised when the automl mode passed into AutoMLStep in invalid"""
70+
71+
fmt = (
72+
"Mode in AutoMLJobConfig must be defined for AutoMLStep. "
73+
"AutoMLStep currently only supports ENSEMBLING mode"
74+
)
75+
76+
def __init__(self, **kwargs):
77+
msg = self.fmt.format(**kwargs)
78+
Exception.__init__(self, msg)
79+
self.kwargs = kwargs

src/sagemaker/session.py

+69-3
Original file line numberDiff line numberDiff line change
@@ -1646,6 +1646,7 @@ def auto_ml(
16461646
job_objective=None,
16471647
generate_candidate_definitions_only=False,
16481648
tags=None,
1649+
model_deploy_config=None,
16491650
):
16501651
"""Create an Amazon SageMaker AutoML job.
16511652
@@ -1669,6 +1670,71 @@ def auto_ml(
16691670
definitions. If True, AutoML.list_candidates() cannot be called. Default: False.
16701671
tags ([dict[str,str]]): A list of dictionaries containing key-value
16711672
pairs.
1673+
model_deploy_config (dict): Specifies how to generate the endpoint name
1674+
for an automatic one-click Autopilot model deployment.
1675+
Contains "AutoGenerateEndpointName" and "EndpointName"
1676+
"""
1677+
auto_ml_job_request = self._get_auto_ml_request(
1678+
input_config=input_config,
1679+
output_config=output_config,
1680+
auto_ml_job_config=auto_ml_job_config,
1681+
role=role,
1682+
job_name=job_name,
1683+
problem_type=problem_type,
1684+
job_objective=job_objective,
1685+
generate_candidate_definitions_only=generate_candidate_definitions_only,
1686+
tags=tags,
1687+
model_deploy_config=model_deploy_config,
1688+
)
1689+
1690+
def submit(request):
1691+
LOGGER.info("Creating auto-ml-job with name: %s", job_name)
1692+
LOGGER.debug("auto ml request: %s", json.dumps(request), indent=4)
1693+
self.sagemaker_client.create_auto_ml_job(**request)
1694+
1695+
self._intercept_create_request(auto_ml_job_request, submit, self.auto_ml.__name__)
1696+
1697+
def _get_auto_ml_request(
1698+
self,
1699+
input_config,
1700+
output_config,
1701+
auto_ml_job_config,
1702+
role,
1703+
job_name,
1704+
problem_type=None,
1705+
job_objective=None,
1706+
generate_candidate_definitions_only=False,
1707+
tags=None,
1708+
model_deploy_config=None,
1709+
):
1710+
"""Constructs a request compatible for creating an Amazon SageMaker AutoML job.
1711+
1712+
Args:
1713+
input_config (list[dict]): A list of Channel objects. Each channel contains "DataSource"
1714+
and "TargetAttributeName", "CompressionType" is an optional field.
1715+
output_config (dict): The S3 URI where you want to store the training results and
1716+
optional KMS key ID.
1717+
auto_ml_job_config (dict): A dict of AutoMLJob config, containing "StoppingCondition",
1718+
"SecurityConfig", optionally contains "VolumeKmsKeyId".
1719+
role (str): The Amazon Resource Name (ARN) of an IAM role that
1720+
Amazon SageMaker can assume to perform tasks on your behalf.
1721+
job_name (str): A string that can be used to identify an AutoMLJob. Each AutoMLJob
1722+
should have a unique job name.
1723+
problem_type (str): The type of problem of this AutoMLJob. Valid values are
1724+
"Regression", "BinaryClassification", "MultiClassClassification". If None,
1725+
SageMaker AutoMLJob will infer the problem type automatically.
1726+
job_objective (dict): AutoMLJob objective, contains "AutoMLJobObjectiveType" (optional),
1727+
"MetricName" and "Value".
1728+
generate_candidate_definitions_only (bool): Indicates whether to only generate candidate
1729+
definitions. If True, AutoML.list_candidates() cannot be called. Default: False.
1730+
tags ([dict[str,str]]): A list of dictionaries containing key-value
1731+
pairs.
1732+
model_deploy_config (dict): Specifies how to generate the endpoint name
1733+
for an automatic one-click Autopilot model deployment.
1734+
Contains "AutoGenerateEndpointName" and "EndpointName"
1735+
1736+
Returns:
1737+
Dict: a automl request dict
16721738
"""
16731739
auto_ml_job_request = {
16741740
"AutoMLJobName": job_name,
@@ -1678,6 +1744,8 @@ def auto_ml(
16781744
"RoleArn": role,
16791745
"GenerateCandidateDefinitionsOnly": generate_candidate_definitions_only,
16801746
}
1747+
if model_deploy_config is not None:
1748+
auto_ml_job_request["ModelDeployConfig"] = model_deploy_config
16811749

16821750
if job_objective is not None:
16831751
auto_ml_job_request["AutoMLJobObjective"] = job_objective
@@ -1688,9 +1756,7 @@ def auto_ml(
16881756
if tags is not None:
16891757
auto_ml_job_request["Tags"] = tags
16901758

1691-
LOGGER.info("Creating auto-ml-job with name: %s", job_name)
1692-
LOGGER.debug("auto ml request: %s", json.dumps(auto_ml_job_request, indent=4))
1693-
self.sagemaker_client.create_auto_ml_job(**auto_ml_job_request)
1759+
return auto_ml_job_request
16941760

16951761
def describe_auto_ml_job(self, job_name):
16961762
"""Calls the DescribeAutoMLJob API for the given job name and returns the response.

src/sagemaker/workflow/automl_step.py

+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""The `AutoMLStep` definition for SageMaker Pipelines Workflows"""
14+
from __future__ import absolute_import
15+
16+
from typing import Union, Optional, List
17+
18+
from sagemaker import Session, Model
19+
from sagemaker.exceptions import AutoMLStepInvalidModeError
20+
from sagemaker.workflow.entities import RequestType
21+
22+
from sagemaker.workflow.pipeline_context import _JobStepArguments
23+
from sagemaker.workflow.properties import Properties
24+
from sagemaker.workflow.retry import RetryPolicy
25+
from sagemaker.workflow.steps import ConfigurableRetryStep, CacheConfig, Step, StepTypeEnum
26+
from sagemaker.workflow.utilities import validate_step_args_input
27+
from sagemaker.workflow.step_collections import StepCollection
28+
29+
30+
class AutoMLStep(ConfigurableRetryStep):
31+
"""`AutoMLStep` for SageMaker Pipelines Workflows."""
32+
33+
def __init__(
34+
self,
35+
name: str,
36+
step_args: _JobStepArguments,
37+
display_name: str = None,
38+
description: str = None,
39+
cache_config: CacheConfig = None,
40+
depends_on: Optional[List[Union[str, Step, StepCollection]]] = None,
41+
retry_policies: List[RetryPolicy] = None,
42+
):
43+
"""Construct a `AutoMLStep`, given a `AutoML` instance.
44+
45+
In addition to the `AutoML` instance, the other arguments are those
46+
that are supplied to the `fit` method of the `sagemaker.automl.automl.AutoML`.
47+
48+
Args:
49+
name (str): The name of the `AutoMLStep`.
50+
step_args (_JobStepArguments): The arguments for the `AutoMLStep` definition.
51+
display_name (str): The display name of the `AutoMLStep`.
52+
description (str): The description of the `AutoMLStep`.
53+
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
54+
depends_on (List[Union[str, Step, StepCollection]]): A list of `Step`/`StepCollection`
55+
names or `Step` instances or `StepCollection` instances that this `AutoMLStep`
56+
depends on.
57+
retry_policies (List[RetryPolicy]): A list of retry policies.
58+
"""
59+
super(AutoMLStep, self).__init__(
60+
name, StepTypeEnum.AUTOML, display_name, description, depends_on, retry_policies
61+
)
62+
63+
validate_step_args_input(
64+
step_args=step_args,
65+
expected_caller={Session.auto_ml.__name__},
66+
error_message="The step_args of AutoMLStep must be obtained " "from automl.fit().",
67+
)
68+
69+
self.step_args = step_args.args
70+
self.cache_config = cache_config
71+
72+
root_property = Properties(step_name=name, shape_name="DescribeAutoMLJobResponse")
73+
74+
best_candidate_properties = Properties(step_name=name, path="bestCandidateProperties")
75+
best_candidate_properties.__dict__["modelInsightsJsonReportPath"] = Properties(
76+
step_name=name, path="bestCandidateProperties.modelInsightsJsonReportPath"
77+
)
78+
best_candidate_properties.__dict__["explainabilityJsonReportPath"] = Properties(
79+
step_name=name, path="bestCandidateProperties.explainabilityJsonReportPath"
80+
)
81+
82+
root_property.__dict__["bestCandidateProperties"] = best_candidate_properties
83+
self._properties = root_property
84+
85+
@property
86+
def arguments(self) -> RequestType:
87+
"""The arguments dictionary that is used to call `create_auto_ml_job`.
88+
89+
NOTE: The `CreateAutoMLJob` request is not quite the
90+
args list that workflow needs.
91+
92+
The `AutoMLJobName`, `ModelDeployConfig` and `GenerateCandidateDefinitionsOnly`
93+
attribute cannot be included.
94+
"""
95+
request_dict = self.step_args
96+
if "AutoMLJobConfig" not in request_dict:
97+
raise AutoMLStepInvalidModeError()
98+
if (
99+
"Mode" not in request_dict["AutoMLJobConfig"]
100+
or request_dict["AutoMLJobConfig"]["Mode"] != "ENSEMBLING"
101+
):
102+
raise AutoMLStepInvalidModeError()
103+
104+
if "ModelDeployConfig" in request_dict:
105+
request_dict.pop("ModelDeployConfig", None)
106+
if "GenerateCandidateDefinitionsOnly" in request_dict:
107+
request_dict.pop("GenerateCandidateDefinitionsOnly", None)
108+
request_dict.pop("AutoMLJobName", None)
109+
return request_dict
110+
111+
@property
112+
def properties(self):
113+
"""A `Properties` object representing the `DescribeAutoMLJobResponse` data model."""
114+
return self._properties
115+
116+
def to_request(self) -> RequestType:
117+
"""Updates the dictionary with cache configuration."""
118+
request_dict = super().to_request()
119+
if self.cache_config:
120+
request_dict.update(self.cache_config.config)
121+
122+
return request_dict
123+
124+
def get_best_auto_ml_model(self, role, sagemaker_session=None):
125+
"""Get the best candidate model artifacts, image uri and env variables for the best model.
126+
127+
Args:
128+
role (str): An AWS IAM role (either name or full ARN). The Amazon
129+
SageMaker AutoML jobs and APIs that create Amazon SageMaker
130+
endpoints use this role to access training data and model
131+
artifacts.
132+
sagemaker_session (sagemaker.session.Session): A SageMaker Session
133+
object, used for SageMaker interactions.
134+
If the best model will be used as part of ModelStep, then sagemaker_session
135+
should be class:`~sagemaker.workflow.pipeline_context.PipelineSession`. Example::
136+
model = Model(sagemaker_session=PipelineSession())
137+
model_step = ModelStep(step_args=model.register())
138+
"""
139+
inference_container = self.properties.BestCandidate.InferenceContainers[0]
140+
inference_container_environment = inference_container.Environment
141+
image = inference_container.Image
142+
model_data = inference_container.ModelDataUrl
143+
model = Model(
144+
image_uri=image,
145+
model_data=model_data,
146+
env={
147+
"MODEL_NAME": inference_container_environment["MODEL_NAME"],
148+
"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": inference_container_environment[
149+
"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT"
150+
],
151+
"SAGEMAKER_SUBMIT_DIRECTORY": inference_container_environment[
152+
"SAGEMAKER_SUBMIT_DIRECTORY"
153+
],
154+
"SAGEMAKER_INFERENCE_SUPPORTED": inference_container_environment[
155+
"SAGEMAKER_INFERENCE_SUPPORTED"
156+
],
157+
"SAGEMAKER_INFERENCE_OUTPUT": inference_container_environment[
158+
"SAGEMAKER_INFERENCE_OUTPUT"
159+
],
160+
"SAGEMAKER_PROGRAM": inference_container_environment["SAGEMAKER_PROGRAM"],
161+
},
162+
sagemaker_session=sagemaker_session,
163+
role=role,
164+
)
165+
166+
return model

src/sagemaker/workflow/steps.py

+1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class StepTypeEnum(Enum, metaclass=DefaultEnumMeta):
7171
CLARIFY_CHECK = "ClarifyCheck"
7272
EMR = "EMR"
7373
FAIL = "Fail"
74+
AUTOML = "AutoML"
7475

7576

7677
@attr.s
+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
1080,4,setosa,versicolor,virginica
2+
5.9,3.0,4.2,1.5,1
3+
6.9,3.1,5.4,2.1,2
4+
5.1,3.3,1.7,0.5,0
5+
6.0,3.4,4.5,1.6,1
6+
5.5,2.5,4.0,1.3,1
7+
6.2,2.9,4.3,1.3,1
8+
5.5,4.2,1.4,0.2,0
9+
6.3,2.8,5.1,1.5,2
10+
5.6,3.0,4.1,1.3,1
11+
6.7,2.5,5.8,1.8,2
12+
7.1,3.0,5.9,2.1,2
13+
4.3,3.0,1.1,0.1,0
14+
5.6,2.8,4.9,2.0,2
15+
5.5,2.3,4.0,1.3,1
16+
6.0,2.2,4.0,1.0,1
17+
5.1,3.5,1.4,0.2,0
18+
5.7,2.6,3.5,1.0,1
19+
4.8,3.4,1.9,0.2,0
20+
5.1,3.4,1.5,0.2,0
21+
5.7,2.5,5.0,2.0,2
22+
5.4,3.4,1.7,0.2,0
23+
5.6,3.0,4.5,1.5,1
24+
6.3,2.9,5.6,1.8,2
25+
6.3,2.5,4.9,1.5,1
26+
5.8,2.7,3.9,1.2,1
27+
6.1,3.0,4.6,1.4,1
28+
5.2,4.1,1.5,0.1,0
29+
6.7,3.1,4.7,1.5,1
30+
6.7,3.3,5.7,2.5,2
31+
6.4,2.9,4.3,1.3,1

0 commit comments

Comments
 (0)