Skip to content

Commit

Permalink
[Datasets] Add missing passthrough args to read_images() (ray-project…
Browse files Browse the repository at this point in the history
…#32942)

This PR adds missing passthrough args to read_images(), such as ray_remote_args, arrow_open_file_args, and other misc. args popped in the base FileBasedDatasource such as compression. This PR also adds a **read_args catch-all to ImageDatasource._read_file(), which should add support for using the local:// protocol.
  • Loading branch information
clarkzinzow authored Mar 29, 2023
1 parent dc0cee4 commit 047abf5
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 5 deletions.
8 changes: 4 additions & 4 deletions python/ray/data/datasource/image_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def create_reader(
size: Optional[Tuple[int, int]] = None,
mode: Optional[str] = None,
include_paths: bool = False,
**kwargs,
**reader_args,
) -> "Reader[T]":
if size is not None and len(size) != 2:
raise ValueError(
Expand All @@ -63,7 +63,7 @@ def create_reader(
_check_import(self, module="PIL", package="Pillow")

return _ImageDatasourceReader(
self, size=size, mode=mode, include_paths=include_paths, **kwargs
self, size=size, mode=mode, include_paths=include_paths, **reader_args
)

def _convert_block_to_tabular_block(
Expand All @@ -82,10 +82,11 @@ def _read_file(
size: Optional[Tuple[int, int]],
mode: Optional[str],
include_paths: bool,
**reader_args,
) -> "pyarrow.Table":
from PIL import Image

records = super()._read_file(f, path, include_paths=True)
records = super()._read_file(f, path, include_paths=True, **reader_args)
assert len(records) == 1
path, data = records[0]

Expand Down Expand Up @@ -145,7 +146,6 @@ def __init__(
paths=paths,
filesystem=filesystem,
schema=None,
open_stream_args=None,
meta_provider=meta_provider,
partition_filter=partition_filter,
partitioning=partitioning,
Expand Down
7 changes: 7 additions & 0 deletions python/ray/data/read_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,8 @@ def read_images(
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
parallelism: int = -1,
meta_provider: BaseFileMetadataProvider = _ImageFileMetadataProvider(),
ray_remote_args: Dict[str, Any] = None,
arrow_open_file_args: Optional[Dict[str, Any]] = None,
partition_filter: Optional[
PathPartitionFilter
] = ImageDatasource.file_extension_filter(),
Expand Down Expand Up @@ -630,6 +632,9 @@ def read_images(
limited by the number of files of the dataset.
meta_provider: File metadata provider. Custom metadata providers may
be able to resolve file metadata more quickly and/or accurately.
ray_remote_args: kwargs passed to ray.remote in the read tasks.
arrow_open_file_args: kwargs passed to
``pyarrow.fs.FileSystem.open_input_file``.
partition_filter: Path-based partition filter, if any. Can be used
with a custom callback to read only selected partitions of a dataset.
By default, this filters out any file paths whose file extension does not
Expand Down Expand Up @@ -662,6 +667,8 @@ def read_images(
filesystem=filesystem,
parallelism=parallelism,
meta_provider=meta_provider,
ray_remote_args=ray_remote_args,
open_stream_args=arrow_open_file_args,
partition_filter=partition_filter,
partitioning=partitioning,
size=size,
Expand Down
24 changes: 23 additions & 1 deletion python/ray/data/tests/test_dataset_image.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import Dict
from unittest.mock import patch, ANY

import numpy as np
import pyarrow as pa
Expand All @@ -8,7 +9,7 @@
from fsspec.implementations.local import LocalFileSystem

import ray
from ray.data.datasource import Partitioning
from ray.data.datasource import Partitioning, PathPartitionFilter
from ray.data.datasource.file_meta_provider import FastFileMetadataProvider
from ray.data.datasource.image_datasource import (
_ImageDatasourceReader,
Expand Down Expand Up @@ -235,6 +236,27 @@ def test_dynamic_block_split(ray_start_regular_shared):
ctx.target_max_block_size = target_max_block_size
ctx.block_splitting_enabled = block_splitting_enabled

def test_args_passthrough(ray_start_regular_shared):
kwargs = {
"paths": "foo",
"filesystem": pa.fs.LocalFileSystem(),
"parallelism": 20,
"meta_provider": FastFileMetadataProvider(),
"ray_remote_args": {"resources": {"bar": 1}},
"arrow_open_file_args": {"foo": "bar"},
"partition_filter": PathPartitionFilter.of(lambda x: True),
"partitioning": Partitioning("hive"),
"size": (2, 2),
"mode": "foo",
"include_paths": True,
"ignore_missing_paths": True,
}
with patch("ray.data.read_api.read_datasource") as mock:
ray.data.read_images(**kwargs)
kwargs["open_stream_args"] = kwargs.pop("arrow_open_file_args")
mock.assert_called_once_with(ANY, **kwargs)
assert isinstance(mock.call_args[0][0], ImageDatasource)


if __name__ == "__main__":
import sys
Expand Down

0 comments on commit 047abf5

Please sign in to comment.