Skip to content

Commit

Permalink
Pass conn ID to ObjectStoragePath via URI (apache#35913)
Browse files Browse the repository at this point in the history
This enables an alternative ObjectStoragePath init syntax, using the
auth section in the URI to supply conn ID instead of a separate keyword
argument. The explicit keyword argument is honored if supplied.
  • Loading branch information
uranusjr authored Dec 1, 2023
1 parent 9f212d4 commit ab87cd0
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 20 deletions.
2 changes: 1 addition & 1 deletion airflow/example_dags/tutorial_objectstorage.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
}

# [START create_object_storage_path]
base = ObjectStoragePath("s3://airflow-tutorial-data/", conn_id="aws_default")
base = ObjectStoragePath("s3://aws_default@airflow-tutorial-data/")
# [END create_object_storage_path]


Expand Down
10 changes: 9 additions & 1 deletion airflow/io/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def __new__(
cls: type[PT],
*args: str | os.PathLike,
scheme: str | None = None,
conn_id: str | None = None,
**kwargs: typing.Any,
) -> PT:
args_list = list(args)
Expand Down Expand Up @@ -137,7 +138,14 @@ def __new__(
else:
args_list.insert(0, parsed_url.path)

return cls._from_parts(args_list, url=parsed_url, **kwargs) # type: ignore
# This matches the parsing logic in urllib.parse; see:
# https://github.com/python/cpython/blob/46adf6b701c440e047abf925df9a75a/Lib/urllib/parse.py#L194-L203
userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@")
if have_info:
conn_id = conn_id or userinfo or None
parsed_url = parsed_url._replace(netloc=hostinfo)

return cls._from_parts(args_list, url=parsed_url, conn_id=conn_id, **kwargs) # type: ignore

@functools.lru_cache
def __hash__(self) -> int:
Expand Down
2 changes: 1 addition & 1 deletion airflow/io/store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def attach(

if not alias:
alias = f"{protocol}-{conn_id}" if conn_id else protocol
if store := _STORE_CACHE.get(alias, None):
if store := _STORE_CACHE.get(alias):
return store

_STORE_CACHE[alias] = store = ObjectStore(protocol=protocol, conn_id=conn_id, fs=fs)
Expand Down
16 changes: 9 additions & 7 deletions docs/apache-airflow/core-concepts/objectstorage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,22 @@ object you want to interact with. For example, to point to a bucket in s3, you w
from airflow.io.path import ObjectStoragePath
base = ObjectStoragePath("s3://my-bucket/", conn_id="aws_default") # conn_id is optional
base = ObjectStoragePath("s3://aws_default@my-bucket/")
The username part of the URI is optional. It can alternatively be passed in as a separate keyword argument:

.. code-block:: python
# Equivalent to the previous example.
base = ObjectStoragePath("s3://my-bucket/", conn_id="aws_default")
Listing file-objects:

.. code-block:: python
@task
def list_files() -> list(ObjectStoragePath):
files = []
for f in base.iterdir():
if f.is_file():
files.append(f)
def list_files() -> list[ObjectStoragePath]:
files = [f for f in base.iterdir() if f.is_file()]
return files
Expand Down
18 changes: 14 additions & 4 deletions docs/apache-airflow/tutorial/objectstorage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ analytical database. You can do this by running ``pip install duckdb``. The tuto
makes use of S3 Object Storage. This requires that the amazon provider is installed
including ``s3fs`` by running ``pip install apache-airflow-providers-amazon[s3fs]``.
If you would like to use a different storage provider, you can do so by changing the
url in the ``create_object_storage_path`` function to the appropriate url for your
URL in the ``create_object_storage_path`` function to the appropriate URL for your
provider, for example by replacing ``s3://`` with ``gs://`` for Google Cloud Storage.
You will also need the right provider to be installed then. Finally, you will need
``pandas``, which can be installed by running ``pip install pandas``.
Expand All @@ -49,9 +49,19 @@ It is the fundamental building block of the Object Storage API.
:start-after: [START create_object_storage_path]
:end-before: [END create_object_storage_path]

The ObjectStoragePath constructor can take an optional connection id. If supplied
it will use the connection to obtain the right credentials to access the backend.
Otherwise it will revert to the default for that backend.
The username part of the URL given to ObjectStoragePath should be a connection ID.
The specified connection will be used to obtain the right credentials to access
the backend. If it is omitted, the default connection for the backend will be used.

The connection ID can alternatively be passed in with a keyword argument:

.. code-block:: python
ObjectStoragePath("s3://airflow-tutorial-data/", conn_id="aws_default")
This is useful when reusing a URL defined for another purpose (e.g. Dataset),
which generally does not contain a username part. The explicit keyword argument
takes precedence over the URL's username value if both are specified.

It is safe to instantiate an ObjectStoragePath at the root of your DAG. Connections
will not be created until the path is used. This means that you can create the
Expand Down
28 changes: 22 additions & 6 deletions tests/io/test_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ def _strip_protocol(cls, path) -> str:


class TestFs:
def setup_class(self):
self._store_cache = _STORE_CACHE.copy()

def teardown(self):
_STORE_CACHE.clear()
_STORE_CACHE.update(self._store_cache)

def test_alias(self):
store = attach("file", alias="local")
assert isinstance(store.fs, LocalFileSystem)
Expand Down Expand Up @@ -100,6 +107,19 @@ def test_ls(self):

assert not o.exists()

@pytest.fixture()
def fake_fs(self):
fs = mock.Mock()
fs._strip_protocol.return_value = "/"
fs.conn_id = "fake"
return fs

def test_objectstoragepath_init_conn_id_in_uri(self, fake_fs):
fake_fs.stat.return_value = {"stat": "result"}
attach(protocol="fake", conn_id="fake", fs=fake_fs)
p = ObjectStoragePath("fake://fake@bucket/path")
assert p.stat() == {"stat": "result", "conn_id": "fake", "protocol": "fake"}

@pytest.mark.parametrize(
"fn, args, fn2, path, expected_args, expected_kwargs",
[
Expand All @@ -124,12 +144,8 @@ def test_ls(self):
),
],
)
def test_standard_extended_api(self, monkeypatch, fn, args, fn2, path, expected_args, expected_kwargs):
_fs = mock.Mock()
_fs._strip_protocol.return_value = "/"
_fs.conn_id = "fake"

store = attach(protocol="file", conn_id="fake", fs=_fs)
def test_standard_extended_api(self, fake_fs, fn, args, fn2, path, expected_args, expected_kwargs):
store = attach(protocol="file", conn_id="fake", fs=fake_fs)
o = ObjectStoragePath(path, conn_id="fake")

getattr(o, fn)(**args)
Expand Down

0 comments on commit ab87cd0

Please sign in to comment.