Skip to content

Commit

Permalink
feat: add Hook Level Lineage support for GCSHook (apache#42507)
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Muda <[email protected]>
  • Loading branch information
kacpermuda authored Oct 23, 2024
1 parent 69af185 commit cc76229
Show file tree
Hide file tree
Showing 9 changed files with 372 additions and 7 deletions.
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@
"google": {
"deps": [
"PyOpenSSL>=23.0.0",
"apache-airflow-providers-common-compat>=1.1.0",
"apache-airflow-providers-common-compat>=1.2.1",
"apache-airflow-providers-common-sql>=1.7.2",
"apache-airflow>=2.8.0",
"asgiref>=3.5.2",
Expand Down
45 changes: 45 additions & 0 deletions providers/src/airflow/providers/google/assets/gcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from airflow.providers.common.compat.assets import Asset
from airflow.providers.google.cloud.hooks.gcs import _parse_gcs_url

if TYPE_CHECKING:
from urllib.parse import SplitResult

from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset


def create_asset(*, bucket: str, key: str, extra: dict | None = None) -> Asset:
return Asset(uri=f"gs://{bucket}/{key}", extra=extra)


def sanitize_uri(uri: SplitResult) -> SplitResult:
if not uri.netloc:
raise ValueError("URI format gs:// must contain a bucket name")
return uri


def convert_asset_to_openlineage(asset: Asset, lineage_context) -> OpenLineageDataset:
"""Translate Asset with valid AIP-60 uri to OpenLineage with assistance from the hook."""
from airflow.providers.common.compat.openlineage.facet import Dataset as OpenLineageDataset

bucket, key = _parse_gcs_url(asset.uri)
return OpenLineageDataset(namespace=f"gs://{bucket}", name=key if key else "/")
52 changes: 50 additions & 2 deletions providers/src/airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from requests import Session

from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector
from airflow.providers.google.cloud.utils.helpers import normalize_directory_path
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import (
Expand Down Expand Up @@ -214,6 +215,16 @@ def copy(
destination_object = source_bucket.copy_blob( # type: ignore[attr-defined]
blob=source_object, destination_bucket=destination_bucket, new_name=destination_object
)
get_hook_lineage_collector().add_input_asset(
context=self,
scheme="gs",
asset_kwargs={"bucket": source_bucket.name, "key": source_object.name}, # type: ignore[attr-defined]
)
get_hook_lineage_collector().add_output_asset(
context=self,
scheme="gs",
asset_kwargs={"bucket": destination_bucket.name, "key": destination_object.name}, # type: ignore[union-attr]
)

self.log.info(
"Object %s in bucket %s copied to object %s in bucket %s",
Expand Down Expand Up @@ -267,6 +278,16 @@ def rewrite(
).rewrite(source=source_object, token=token)

self.log.info("Total Bytes: %s | Bytes Written: %s", total_bytes, bytes_rewritten)
get_hook_lineage_collector().add_input_asset(
context=self,
scheme="gs",
asset_kwargs={"bucket": source_bucket.name, "key": source_object.name}, # type: ignore[attr-defined]
)
get_hook_lineage_collector().add_output_asset(
context=self,
scheme="gs",
asset_kwargs={"bucket": destination_bucket.name, "key": destination_object}, # type: ignore[attr-defined]
)
self.log.info(
"Object %s in bucket %s rewritten to object %s in bucket %s",
source_object.name, # type: ignore[attr-defined]
Expand Down Expand Up @@ -345,9 +366,18 @@ def download(

if filename:
blob.download_to_filename(filename, timeout=timeout)
get_hook_lineage_collector().add_input_asset(
context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name}
)
get_hook_lineage_collector().add_output_asset(
context=self, scheme="file", asset_kwargs={"path": filename}
)
self.log.info("File downloaded to %s", filename)
return filename
else:
get_hook_lineage_collector().add_input_asset(
context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name}
)
return blob.download_as_bytes()

except GoogleCloudError:
Expand Down Expand Up @@ -555,6 +585,9 @@ def _call_with_retry(f: Callable[[], None]) -> None:
_call_with_retry(
partial(blob.upload_from_filename, filename=filename, content_type=mime_type, timeout=timeout)
)
get_hook_lineage_collector().add_input_asset(
context=self, scheme="file", asset_kwargs={"path": filename}
)

if gzip:
os.remove(filename)
Expand All @@ -576,6 +609,10 @@ def _call_with_retry(f: Callable[[], None]) -> None:
else:
raise ValueError("'filename' and 'data' parameter missing. One is required to upload to gcs.")

get_hook_lineage_collector().add_output_asset(
context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name}
)

def exists(self, bucket_name: str, object_name: str, retry: Retry = DEFAULT_RETRY) -> bool:
"""
Check for the existence of a file in Google Cloud Storage.
Expand Down Expand Up @@ -691,6 +728,9 @@ def delete(self, bucket_name: str, object_name: str) -> None:
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_name=object_name)
blob.delete()
get_hook_lineage_collector().add_input_asset(
context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": blob.name}
)

self.log.info("Blob %s deleted.", object_name)

Expand Down Expand Up @@ -1198,9 +1238,17 @@ def compose(self, bucket_name: str, source_objects: List[str], destination_objec
client = self.get_conn()
bucket = client.bucket(bucket_name)
destination_blob = bucket.blob(destination_object)
destination_blob.compose(
sources=[bucket.blob(blob_name=source_object) for source_object in source_objects]
source_blobs = [bucket.blob(blob_name=source_object) for source_object in source_objects]
destination_blob.compose(sources=source_blobs)
get_hook_lineage_collector().add_output_asset(
context=self, scheme="gs", asset_kwargs={"bucket": bucket.name, "key": destination_blob.name}
)
for single_source_blob in source_blobs:
get_hook_lineage_collector().add_input_asset(
context=self,
scheme="gs",
asset_kwargs={"bucket": bucket.name, "key": single_source_blob.name},
)

self.log.info("Completed successfully.")

Expand Down
14 changes: 11 additions & 3 deletions providers/src/airflow/providers/google/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ versions:

dependencies:
- apache-airflow>=2.8.0
- apache-airflow-providers-common-compat>=1.1.0
- apache-airflow-providers-common-compat>=1.2.1
- apache-airflow-providers-common-sql>=1.7.2
- asgiref>=3.5.2
- dill>=0.2.3
Expand Down Expand Up @@ -777,15 +777,23 @@ asset-uris:
- schemes: [gcp]
handler: null
- schemes: [bigquery]
handler: airflow.providers.google.datasets.bigquery.sanitize_uri
handler: airflow.providers.google.assets.bigquery.sanitize_uri
- schemes: [gs]
handler: airflow.providers.google.assets.gcs.sanitize_uri
factory: airflow.providers.google.assets.gcs.create_asset
to_openlineage_converter: airflow.providers.google.assets.gcs.convert_asset_to_openlineage

# dataset has been renamed to asset in Airflow 3.0
# This is kept for backward compatibility.
dataset-uris:
- schemes: [gcp]
handler: null
- schemes: [bigquery]
handler: airflow.providers.google.datasets.bigquery.sanitize_uri
handler: airflow.providers.google.assets.bigquery.sanitize_uri
- schemes: [gs]
handler: airflow.providers.google.assets.gcs.sanitize_uri
factory: airflow.providers.google.assets.gcs.create_asset
to_openlineage_converter: airflow.providers.google.assets.gcs.convert_asset_to_openlineage

hooks:
- integration-name: Google Ads
Expand Down
2 changes: 1 addition & 1 deletion providers/tests/google/assets/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import pytest

from airflow.providers.google.datasets.bigquery import sanitize_uri
from airflow.providers.google.assets.bigquery import sanitize_uri


def test_sanitize_uri_pass() -> None:
Expand Down
74 changes: 74 additions & 0 deletions providers/tests/google/assets/test_gcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import urllib.parse

import pytest

from airflow.providers.common.compat.assets import Asset
from airflow.providers.google.assets.gcs import convert_asset_to_openlineage, create_asset, sanitize_uri


def test_sanitize_uri():
uri = sanitize_uri(urllib.parse.urlsplit("gs://bucket/dir/file.txt"))
result = sanitize_uri(uri)
assert result.scheme == "gs"
assert result.netloc == "bucket"
assert result.path == "/dir/file.txt"


def test_sanitize_uri_no_netloc():
with pytest.raises(ValueError):
sanitize_uri(urllib.parse.urlsplit("gs://"))


def test_sanitize_uri_no_path():
uri = sanitize_uri(urllib.parse.urlsplit("gs://bucket"))
result = sanitize_uri(uri)
assert result.scheme == "gs"
assert result.netloc == "bucket"
assert result.path == ""


def test_create_asset():
assert create_asset(bucket="test-bucket", key="test-path") == Asset(uri="gs://test-bucket/test-path")
assert create_asset(bucket="test-bucket", key="test-dir/test-path") == Asset(
uri="gs://test-bucket/test-dir/test-path"
)


def test_sanitize_uri_trailing_slash():
uri = sanitize_uri(urllib.parse.urlsplit("gs://bucket/"))
result = sanitize_uri(uri)
assert result.scheme == "gs"
assert result.netloc == "bucket"
assert result.path == "/"


def test_convert_asset_to_openlineage_valid():
uri = "gs://bucket/dir/file.txt"
ol_dataset = convert_asset_to_openlineage(asset=Asset(uri=uri), lineage_context=None)
assert ol_dataset.namespace == "gs://bucket"
assert ol_dataset.name == "dir/file.txt"


@pytest.mark.parametrize("uri", ("gs://bucket", "gs://bucket/"))
def test_convert_asset_to_openlineage_no_path(uri):
ol_dataset = convert_asset_to_openlineage(asset=Asset(uri=uri), lineage_context=None)
assert ol_dataset.namespace == "gs://bucket"
assert ol_dataset.name == "/"
Loading

0 comments on commit cc76229

Please sign in to comment.