Skip to content

Commit

Permalink
Add parent_model param in UploadModelOperator (apache#42091)
Browse files Browse the repository at this point in the history
  • Loading branch information
jx2lee authored Sep 9, 2024
1 parent 4950c62 commit 96640a7
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 4 deletions.
15 changes: 11 additions & 4 deletions airflow/providers/google/cloud/hooks/vertex_ai/model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def upload_model(
project_id: str,
region: str,
model: Model | dict,
parent_model: str | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
Expand All @@ -218,18 +219,24 @@ def upload_model(
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param region: Required. The ID of the Google Cloud region that the service belongs to.
:param model: Required. The Model to create.
:param parent_model: The name of the parent model to create a new version under.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
"""
client = self.get_model_service_client(region)
parent = client.common_location_path(project_id, region)

request = {
"parent": parent,
"model": model,
}

if parent_model:
request["parent_model"] = parent_model

result = client.upload_model(
request={
"parent": parent,
"model": model,
},
request=request,
retry=retry,
timeout=timeout,
metadata=metadata,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ class UploadModelOperator(GoogleCloudBaseOperator):
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
:param region: Required. The ID of the Google Cloud region that the service belongs to.
:param model: Required. The Model to create.
:param parent_model: The name of the parent model to create a new version under.
:param retry: Designation of what errors, if any, should be retried.
:param timeout: The timeout for this request.
:param metadata: Strings which should be sent along with the request as metadata.
Expand All @@ -385,6 +386,7 @@ def __init__(
project_id: str,
region: str,
model: Model | dict,
parent_model: str | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
Expand All @@ -396,6 +398,7 @@ def __init__(
self.project_id = project_id
self.region = region
self.model = model
self.parent_model = parent_model
self.retry = retry
self.timeout = timeout
self.metadata = metadata
Expand All @@ -412,6 +415,7 @@ def execute(self, context: Context):
project_id=self.project_id,
region=self.region,
model=self.model,
parent_model=self.parent_model,
retry=self.retry,
timeout=self.timeout,
metadata=self.metadata,
Expand Down
37 changes: 37 additions & 0 deletions tests/providers/google/cloud/hooks/vertex_ai/test_model_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TEST_REGION: str = "test-region"
TEST_PROJECT_ID: str = "test-project-id"
TEST_MODEL = None
TEST_PARENT_MODEL = "test-parent-model"
TEST_MODEL_NAME: str = "test_model_name"
TEST_OUTPUT_CONFIG: dict = {}

Expand Down Expand Up @@ -136,6 +137,24 @@ def test_upload_model(self, mock_client) -> None:
)
mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)

@mock.patch(MODEL_SERVICE_STRING.format("ModelServiceHook.get_model_service_client"))
def test_upload_model_with_parent_model(self, mock_client) -> None:
self.hook.upload_model(
project_id=TEST_PROJECT_ID, region=TEST_REGION, model=TEST_MODEL, parent_model=TEST_PARENT_MODEL
)
mock_client.assert_called_once_with(TEST_REGION)
mock_client.return_value.upload_model.assert_called_once_with(
request=dict(
parent=mock_client.return_value.common_location_path.return_value,
model=TEST_MODEL,
parent_model=TEST_PARENT_MODEL,
),
metadata=(),
retry=DEFAULT,
timeout=None,
)
mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)

@mock.patch(MODEL_SERVICE_STRING.format("ModelServiceHook.get_model_service_client"))
def test_list_model_versions(self, mock_client) -> None:
self.hook.list_model_versions(
Expand Down Expand Up @@ -322,6 +341,24 @@ def test_upload_model(self, mock_client) -> None:
)
mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)

@mock.patch(MODEL_SERVICE_STRING.format("ModelServiceHook.get_model_service_client"))
def test_upload_model_with_parent_model(self, mock_client) -> None:
self.hook.upload_model(
project_id=TEST_PROJECT_ID, region=TEST_REGION, model=TEST_MODEL, parent_model=TEST_PARENT_MODEL
)
mock_client.assert_called_once_with(TEST_REGION)
mock_client.return_value.upload_model.assert_called_once_with(
request=dict(
parent=mock_client.return_value.common_location_path.return_value,
model=TEST_MODEL,
parent_model=TEST_PARENT_MODEL,
),
metadata=(),
retry=DEFAULT,
timeout=None,
)
mock_client.return_value.common_location_path.assert_called_once_with(TEST_PROJECT_ID, TEST_REGION)

@mock.patch(MODEL_SERVICE_STRING.format("ModelServiceHook.get_model_service_client"))
def test_list_model_versions(self, mock_client) -> None:
self.hook.list_model_versions(
Expand Down
28 changes: 28 additions & 0 deletions tests/providers/google/cloud/operators/test_vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2849,6 +2849,34 @@ def test_execute(self, mock_hook, to_dict_mock):
region=GCP_LOCATION,
project_id=GCP_PROJECT,
model=TEST_MODEL_OBJ,
parent_model=None,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)

@mock.patch(VERTEX_AI_PATH.format("model_service.model_service.UploadModelResponse.to_dict"))
@mock.patch(VERTEX_AI_PATH.format("model_service.ModelServiceHook"))
def test_execute_with_parent_model(self, mock_hook, to_dict_mock):
op = UploadModelOperator(
task_id=TASK_ID,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
region=GCP_LOCATION,
project_id=GCP_PROJECT,
model=TEST_MODEL_OBJ,
parent_model=TEST_PARENT_MODEL,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN)
mock_hook.return_value.upload_model.assert_called_once_with(
region=GCP_LOCATION,
project_id=GCP_PROJECT,
model=TEST_MODEL_OBJ,
parent_model=TEST_PARENT_MODEL,
retry=RETRY,
timeout=TIMEOUT,
metadata=METADATA,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,13 @@
model=MODEL_OBJ,
)
# [END how_to_cloud_vertex_ai_upload_model_operator]
upload_model_with_parent_model = UploadModelOperator(
task_id="upload_model_with_parent_model",
region=REGION,
project_id=PROJECT_ID,
model=MODEL_OBJ,
parent_model=MODEL_DISPLAY_NAME,
)

# [START how_to_cloud_vertex_ai_export_model_operator]
export_model = ExportModelOperator(
Expand All @@ -251,6 +258,13 @@
trigger_rule=TriggerRule.ALL_DONE,
)
# [END how_to_cloud_vertex_ai_delete_model_operator]
delete_model_with_parent_model = DeleteModelOperator(
task_id="delete_model_with_parent_model",
project_id=PROJECT_ID,
region=REGION,
model_id=upload_model_with_parent_model.output["model_id"],
trigger_rule=TriggerRule.ALL_DONE,
)

# [START how_to_cloud_vertex_ai_list_models_operator]
list_models = ListModelsOperator(
Expand Down Expand Up @@ -317,8 +331,10 @@
>> set_default_version
>> add_version_alias
>> upload_model
>> upload_model_with_parent_model
>> export_model
>> delete_model
>> delete_model_with_parent_model
>> list_models
# TEST TEARDOWN
>> delete_version_alias
Expand Down

0 comments on commit 96640a7

Please sign in to comment.