Skip to content

Commit

Permalink
Add Supervised Fine Tuning Train Operator, Hook, Tests, Docs (apache#…
Browse files Browse the repository at this point in the history
…41807)

* add supervised_fine_tuning

* build fix

* build,test fix

* unit test build fix

* xcom fix

* refactor supervised tuning into generative_model module, PR feedback, tests

* minor system test fix

* update provider.yaml

* doc fix

* Update Vertex AI Documentation
  • Loading branch information
CYarros10 authored Aug 30, 2024
1 parent 3f0b3d7 commit 35ce2f1
Show file tree
Hide file tree
Showing 8 changed files with 293 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@

from __future__ import annotations

from typing import Sequence
import time
from typing import TYPE_CHECKING, Sequence

import vertexai
from deprecated import deprecated
from vertexai.generative_models import GenerativeModel, Part
from vertexai.language_models import TextEmbeddingModel, TextGenerationModel
from vertexai.preview.tuning import sft

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID, GoogleBaseHook

if TYPE_CHECKING:
from google.cloud.aiplatform_v1 import types


class GenerativeModelHook(GoogleBaseHook):
"""Hook for Google Cloud Vertex AI Generative Model APIs."""
Expand Down Expand Up @@ -348,3 +353,55 @@ def generative_model_generate_content(
)

return response.text

@GoogleBaseHook.fallback_to_default_project_id
def supervised_fine_tuning_train(
self,
source_model: str,
train_dataset: str,
location: str,
tuned_model_display_name: str | None = None,
validation_dataset: str | None = None,
epochs: int | None = None,
adapter_size: int | None = None,
learning_rate_multiplier: float | None = None,
project_id: str = PROVIDE_PROJECT_ID,
) -> types.TuningJob:
"""
Use the Supervised Fine Tuning API to create a tuning job.
:param source_model: Required. A pre-trained model optimized for performing natural
language tasks such as classification, summarization, extraction, content
creation, and ideation.
:param train_dataset: Required. Cloud Storage URI of your training dataset. The dataset
must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param tuned_model_display_name: Optional. Display name of the TunedModel. The name can be up
to 128 characters long and can consist of any UTF-8 characters.
:param validation_dataset: Optional. Cloud Storage URI of your training dataset. The dataset must be
formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
:param epochs: Optional. To optimize performance on a specific dataset, try using a higher
epoch value. Increasing the number of epochs might improve results. However, be cautious
about over-fitting, especially when dealing with small datasets. If over-fitting occurs,
consider lowering the epoch number.
:param adapter_size: Optional. Adapter size for tuning.
:param learning_rate_multiplier: Optional. Multiplier for adjusting the default learning rate.
"""
vertexai.init(project=project_id, location=location, credentials=self.get_credentials())

sft_tuning_job = sft.train(
source_model=source_model,
train_dataset=train_dataset,
validation_dataset=validation_dataset,
epochs=epochs,
adapter_size=adapter_size,
learning_rate_multiplier=learning_rate_multiplier,
tuned_model_display_name=tuned_model_display_name,
)

# Polling for job completion
while not sft_tuning_job.has_ended:
time.sleep(60)
sft_tuning_job.refresh()

return sft_tuning_job
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import TYPE_CHECKING, Sequence

from deprecated import deprecated
from google.cloud.aiplatform_v1 import types

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.vertex_ai.generative_model import GenerativeModelHook
Expand Down Expand Up @@ -525,7 +526,7 @@ class GenerativeModelGenerateContentOperator(GoogleCloudBaseOperator):
account from the list granting this role to the originating account (templated).
"""

template_fields = ("location", "project_id", "impersonation_chain", "contents")
template_fields = ("location", "project_id", "impersonation_chain", "contents", "pretrained_model")

def __init__(
self,
Expand Down Expand Up @@ -571,3 +572,93 @@ def execute(self, context: Context):
self.xcom_push(context, key="model_response", value=response)

return response


class SupervisedFineTuningTrainOperator(GoogleCloudBaseOperator):
"""
Use the Supervised Fine Tuning API to create a tuning job.
:param source_model: Required. A pre-trained model optimized for performing natural
language tasks such as classification, summarization, extraction, content
creation, and ideation.
:param train_dataset: Required. Cloud Storage URI of your training dataset. The dataset
must be formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
:param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
:param location: Required. The ID of the Google Cloud location that the service belongs to.
:param tuned_model_display_name: Optional. Display name of the TunedModel. The name can be up
to 128 characters long and can consist of any UTF-8 characters.
:param validation_dataset: Optional. Cloud Storage URI of your training dataset. The dataset must be
formatted as a JSONL file. For best results, provide at least 100 to 500 examples.
:param epochs: Optional. To optimize performance on a specific dataset, try using a higher
epoch value. Increasing the number of epochs might improve results. However, be cautious
about over-fitting, especially when dealing with small datasets. If over-fitting occurs,
consider lowering the epoch number.
:param adapter_size: Optional. Adapter size for tuning.
:param learning_multiplier_rate: Optional. Multiplier for adjusting the default learning rate.
:param gcp_conn_id: The connection ID to use connecting to Google Cloud.
:param impersonation_chain: Optional service account to impersonate using short-term
credentials, or chained list of accounts required to get the access_token
of the last account in the list, which will be impersonated in the request.
If set as a string, the account must grant the originating account
the Service Account Token Creator IAM role.
If set as a sequence, the identities from the list must grant
Service Account Token Creator IAM role to the directly preceding identity, with first
account from the list granting this role to the originating account (templated).
"""

template_fields = ("location", "project_id", "impersonation_chain", "train_dataset", "validation_dataset")

def __init__(
self,
*,
source_model: str,
train_dataset: str,
project_id: str,
location: str,
tuned_model_display_name: str | None = None,
validation_dataset: str | None = None,
epochs: int | None = None,
adapter_size: int | None = None,
learning_rate_multiplier: float | None = None,
gcp_conn_id: str = "google_cloud_default",
impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.source_model = source_model
self.train_dataset = train_dataset
self.tuned_model_display_name = tuned_model_display_name
self.validation_dataset = validation_dataset
self.epochs = epochs
self.adapter_size = adapter_size
self.learning_rate_multiplier = learning_rate_multiplier
self.project_id = project_id
self.location = location
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain

def execute(self, context: Context):
self.hook = GenerativeModelHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
response = self.hook.supervised_fine_tuning_train(
source_model=self.source_model,
train_dataset=self.train_dataset,
project_id=self.project_id,
location=self.location,
validation_dataset=self.validation_dataset,
epochs=self.epochs,
adapter_size=self.adapter_size,
learning_rate_multiplier=self.learning_rate_multiplier,
tuned_model_display_name=self.tuned_model_display_name,
)

self.log.info("Tuned Model Name: %s", response.tuned_model_name)
self.log.info("Tuned Model Endpoint Name: %s", response.tuned_model_endpoint_name)

self.xcom_push(context, key="tuned_model_name", value=response.tuned_model_name)
self.xcom_push(context, key="tuned_model_endpoint_name", value=response.tuned_model_endpoint_name)

return types.TuningJob.to_dict(response)
2 changes: 1 addition & 1 deletion airflow/providers/google/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ dependencies:
- google-api-python-client>=2.0.2
- google-auth>=2.29.0
- google-auth-httplib2>=0.0.1
- google-cloud-aiplatform>=1.57.0
- google-cloud-aiplatform>=1.63.0
- google-cloud-automl>=2.12.0
# Excluded versions contain bug https://github.com/apache/airflow/issues/39541 which is resolved in 3.24.0
- google-cloud-bigquery>=3.4.0,!=3.21.*,!=3.22.0,!=3.23.*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ To get a pipeline job list you can use
:start-after: [START how_to_cloud_vertex_ai_list_pipeline_job_operator]
:end-before: [END how_to_cloud_vertex_ai_list_pipeline_job_operator]

Interacting with a Generative Model
Interacting with Generative AI
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

To generate a prediction via language model you can use
Expand Down Expand Up @@ -615,6 +615,16 @@ The operator returns the model's response in :ref:`XCom <concepts:xcom>` under `
:start-after: [START how_to_cloud_vertex_ai_generative_model_generate_content_operator]
:end-before: [END how_to_cloud_vertex_ai_generative_model_generate_content_operator]

To run a supervised fine tuning job you can use
:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.SupervisedFineTuningTrainOperator`.
The operator returns the tuned model's endpoint name in :ref:`XCom <concepts:xcom>` under ``tuned_model_endpoint_name`` key.

.. exampleinclude:: /../../tests/system/providers/google/cloud/vertex_ai/example_vertex_ai_generative_model_tuning.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_vertex_ai_supervised_fine_tuning_train_operator]
:end-before: [END how_to_cloud_vertex_ai_supervised_fine_tuning_train_operator]

Reference
^^^^^^^^^

Expand Down
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@
"google-api-python-client>=2.0.2",
"google-auth-httplib2>=0.0.1",
"google-auth>=2.29.0",
"google-cloud-aiplatform>=1.57.0",
"google-cloud-aiplatform>=1.63.0",
"google-cloud-automl>=2.12.0",
"google-cloud-batch>=0.13.0",
"google-cloud-bigquery-datatransfer>=3.13.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@
TEST_MEDIA_GCS_PATH = "gs://download.tensorflow.org/example_images/320px-Felis_catus-cat_on_snow.jpg"
TEST_MIME_TYPE = "image/jpeg"

SOURCE_MODEL = "gemini-1.0-pro-002"
TRAIN_DATASET = "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl"

BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
GENERATIVE_MODEL_STRING = "airflow.providers.google.cloud.hooks.vertex_ai.generative_model.{}"

Expand Down Expand Up @@ -194,3 +197,23 @@ def test_generative_model_generate_content(self, mock_model) -> None:
generation_config=TEST_GENERATION_CONFIG,
safety_settings=TEST_SAFETY_SETTINGS,
)

@mock.patch("vertexai.preview.tuning.sft.train")
def test_supervised_fine_tuning_train(self, mock_sft_train) -> None:
self.hook.supervised_fine_tuning_train(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
source_model=SOURCE_MODEL,
train_dataset=TRAIN_DATASET,
)

# Assertions
mock_sft_train.assert_called_once_with(
source_model=SOURCE_MODEL,
train_dataset=TRAIN_DATASET,
validation_dataset=None,
epochs=None,
adapter_size=None,
learning_rate_multiplier=None,
tuned_model_display_name=None,
)
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
PromptLanguageModelOperator,
PromptMultimodalModelOperator,
PromptMultimodalModelWithMediaOperator,
SupervisedFineTuningTrainOperator,
TextEmbeddingModelGetEmbeddingsOperator,
TextGenerationModelPredictOperator,
)
Expand Down Expand Up @@ -390,3 +391,41 @@ def test_execute(self, mock_hook):
safety_settings=safety_settings,
pretrained_model=pretrained_model,
)


class TestVertexAISupervisedFineTuningTrainOperator:
@mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook"))
@mock.patch("google.cloud.aiplatform_v1.types.TuningJob.to_dict")
def test_execute(
self,
to_dict_mock,
mock_hook,
):
source_model = "gemini-1.0-pro-002"
train_dataset = "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl"

op = SupervisedFineTuningTrainOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
source_model=source_model,
train_dataset=train_dataset,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
mock_hook.return_value.supervised_fine_tuning_train.assert_called_once_with(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
source_model=source_model,
train_dataset=train_dataset,
adapter_size=None,
epochs=None,
learning_rate_multiplier=None,
tuned_model_display_name=None,
validation_dataset=None,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Example Airflow DAG for Google Vertex AI Generative Model Tuning Tasks.
"""

from __future__ import annotations

import os
from datetime import datetime

from airflow.models.dag import DAG
from airflow.providers.google.cloud.operators.vertex_ai.generative_model import (
SupervisedFineTuningTrainOperator,
)

PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
DAG_ID = "vertex_ai_generative_model_tuning_dag"
REGION = "us-central1"
SOURCE_MODEL = "gemini-1.0-pro-002"
TRAIN_DATASET = "gs://cloud-samples-data/ai-platform/generative_ai/sft_train_data.jsonl"
TUNED_MODEL_DISPLAY_NAME = "my_tuned_gemini_model"

with DAG(
dag_id=DAG_ID,
description="Sample DAG with generative model tuning tasks.",
schedule="@once",
start_date=datetime(2024, 1, 1),
catchup=False,
tags=["example", "vertex_ai", "generative_model"],
) as dag:
# [START how_to_cloud_vertex_ai_supervised_fine_tuning_train_operator]
sft_train_task = SupervisedFineTuningTrainOperator(
task_id="sft_train_task",
project_id=PROJECT_ID,
location=REGION,
source_model=SOURCE_MODEL,
train_dataset=TRAIN_DATASET,
tuned_model_display_name=TUNED_MODEL_DISPLAY_NAME,
)
# [END how_to_cloud_vertex_ai_supervised_fine_tuning_train_operator]

from tests.system.utils.watcher import watcher

# This test needs watcher in order to properly mark success/failure
# when "tearDown" task with trigger rule is part of the DAG
list(dag.tasks) >> watcher()

from tests.system.utils import get_test_run # noqa: E402

# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest)
test_run = get_test_run(dag)

0 comments on commit 35ce2f1

Please sign in to comment.