Skip to content

Commit

Permalink
Add support for BatchInfo in experimental TF DALI Dataset (NVIDIA#3468)
Browse files Browse the repository at this point in the history
Propagate the batch_info option in SourceDescription
(only when it matters), use it when converting the callback
for TF from_generator dataset.
Add test coverage.

Signed-off-by: Krzysztof Lecki <[email protected]>
  • Loading branch information
klecki authored Nov 8, 2021
1 parent 6a282ad commit 3490831
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 39 deletions.
30 changes: 19 additions & 11 deletions dali/python/nvidia/dali/_utils/external_source_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ class SourceKind(Enum):
class SourceDescription:
"""Keep the metadata about the source parameter that was originally passed
"""
def __init__(self, source, kind: SourceKind, has_inputs: bool, cycle: str):
def __init__(self, source, kind: SourceKind, has_inputs: bool, cycle: str, batch_info = False):
self.source = source
self.kind = kind
self.has_inputs = has_inputs
self.cycle = cycle
self.batch_info = batch_info

def __str__(self) -> str:
if self.kind == SourceKind.CALLABLE:
Expand Down Expand Up @@ -225,9 +226,11 @@ def accepted_arg_count(callable):
return callable.__code__.co_argcount - implicit_args


def get_callback_from_source(source, cycle):
def get_callback_from_source(source, cycle, batch_info=False):
"""Repack the source into a unified callback function. Additionally prepare the SourceDescription.
`batch_info` is usable only with callables.
Returns
-------
callback, SourceDescription
Expand Down Expand Up @@ -275,7 +278,7 @@ def get_callback_from_source(source, cycle):
raise TypeError("Source must be callable, iterable or a parameterless generator function")
# We got a callable
desc = SourceDescription(source, SourceKind.CALLABLE,
accepted_arg_count(source) > 0, cycle)
accepted_arg_count(source) > 0, cycle, batch_info)
callback = source
else:
desc = None
Expand All @@ -300,11 +303,10 @@ def _inspect_data(data, is_batched):
return as_numpy.dtype, (None,) * as_numpy.ndim


def get_batch_iterable_from_callback(source_desc):
def get_batch_iterable_from_callback(source_desc: SourceDescription):
"""Transform batch callback accepting one argument into an Iterable
"""

first = source_desc.source(0)
first = source_desc.source(types.BatchInfo(0, 0) if source_desc.batch_info else 0)
dtype, shape = _inspect_data(first, True)

class CallableBatchIterator:
Expand All @@ -323,14 +325,20 @@ def __next__(self):
result = CallableBatchIterator.first_value
CallableBatchIterator.first_value = None
else:
result = self.source(self.iteration)
if source_desc.batch_info:
# There is no notion of epochs when iterating over DALI Dataset
# as the "raise" policy is not supported, so we use epoch 0 only.
argument = types.BatchInfo(self.iteration, 0)
else:
argument = self.iteration
result = self.source(argument)
self.iteration += 1
return batch_to_numpy(result, _tf_batch_error_msg, non_uniform_str=_tf_uniform_error_msg)

return CallableBatchIterator, dtype, shape


def get_sample_iterable_from_callback(source_desc, batch_size):
def get_sample_iterable_from_callback(source_desc: SourceDescription, batch_size):
"""Transform sample callback accepting one argument into an Iterable
"""
first = source_desc.source(types.SampleInfo(0, 0, 0, 0))
Expand Down Expand Up @@ -369,7 +377,7 @@ def __next__(self):
return CallableSampleIterator, dtype, shape


def get_iterable_from_callback(source_desc, is_batched):
def get_iterable_from_callback(source_desc: SourceDescription, is_batched):
"""Transform callback that doesn't accept arguments into iterable
"""
print("get_iterable_from_callback")
Expand Down Expand Up @@ -398,7 +406,7 @@ def __next__(self):
return CallableIterator, dtype, shape


def get_iterable_from_iterable_or_generator(source_desc, is_batched):
def get_iterable_from_iterable_or_generator(source_desc: SourceDescription, is_batched):
"""Wrap iterable or generator function into another iterable while peeking the first element
If the source is generator function it must be called first.
Expand Down Expand Up @@ -441,7 +449,7 @@ def __next__(self):
return PeekFirstGenerator, dtype, shape


def _get_generator_from_source_desc(source_desc, batch_size, is_batched):
def _get_generator_from_source_desc(source_desc: SourceDescription, batch_size, is_batched):
"""Based on DALI source description create a generator function, type and shape specification
compatible with TF Generator Dataset.
Expand Down
16 changes: 8 additions & 8 deletions dali/python/nvidia/dali/external_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def __init__(
import nvidia.dali.ops
kwargs, self._call_args = nvidia.dali.ops._separate_kwargs(kwargs)

callback, source_desc = _get_callback_from_source(source, cycle)
callback, source_desc = _get_callback_from_source(source, cycle, batch_info or False)

if name is not None and num_outputs is not None:
raise ValueError("`num_outputs` is not compatible with named `ExternalSource`")
Expand Down Expand Up @@ -440,6 +440,12 @@ def __call__(
""
from nvidia.dali.ops import _OperatorInstance

if batch_info is None:
batch_info = self._batch_info or False
elif self._batch_info is not None:
raise ValueError(
"The argument ``batch_info`` already specified in constructor.")

if source is None:
if cycle is not None:
if self._callback:
Expand All @@ -454,7 +460,7 @@ def __call__(
else:
if self._callback is not None:
raise RuntimeError("``source`` already specified in constructor.")
callback, source_desc = _get_callback_from_source(source, cycle)
callback, source_desc = _get_callback_from_source(source, cycle, self._batch_info)

# Keep the metadata for Pipeline inspection
self._source_desc = source_desc
Expand All @@ -479,12 +485,6 @@ def __call__(
raise ValueError(
"The argument ``prefetch_queue_depth`` already specified in constructor.")

if batch_info is None:
batch_info = self._batch_info or False
elif self._batch_info is not None:
raise ValueError(
"The argument ``batch_info`` already specified in constructor.")

if no_copy is None:
no_copy = self._no_copy
elif self._no_copy is not None:
Expand Down
1 change: 0 additions & 1 deletion dali/python/nvidia/dali/plugin/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,6 @@ def _input_lists_from_source(self, callbacked_es_map):
with tf.device('/cpu:0'):
tf_gen, dtype, shape = _get_generator_from_source_desc(
source_desc, self._batch_size, external_source._batch)
# dataset = tf.data.Dataset.from_generator(tf_gen, output_types=dtype)
signature = _get_signature(dtype, shape)
dataset = tf.data.Dataset.from_generator(tf_gen, output_signature=signature)
if _cycle_enabled(source_desc.cycle):
Expand Down
53 changes: 34 additions & 19 deletions dali/test/python/test_dali_tf_es_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,19 @@ def callback(x):
return np.stack(result) if dense else result
return callback

def get_batch_one_arg_callback_with_batch_info(dtype, iter_limit=1000, batch_size=None, dense=True):
def callback(x):
if x.iteration > iter_limit:
raise StopIteration()
size = (x.iteration % 16 + 4,)
result = [np.full(size, x.iteration, dtype=dtype)] * batch_size
for i, elem in enumerate(result):
elem[0] = i
elem[1] = x.iteration
elem[2] = x.epoch_idx
return np.stack(result) if dense else result
return callback

def get_no_arg_callback(dtype, iter_limit=1000, batch_size=None, dense=True):
class Callable:
def __init__(self):
Expand Down Expand Up @@ -132,25 +145,26 @@ def generator():
return generator


# generator, is_batched, cycle
# generator, is_batched, cycle, batch_info
# TODO(klecki): cycle='raise' is currently not supported, and probably never will be
es_configurations = [
(get_sample_one_arg_callback, False, None),
(get_batch_one_arg_callback, True, None),
(get_no_arg_callback, False, None),
(get_no_arg_callback, True, None),
(get_iterable, False, False),
(get_iterable, False, True),
# (get_iterable, False, "raise"),
(get_iterable, True, False),
(get_iterable, True, True),
# (get_iterable, True, "raise"),
(get_iterable_generator, False, False),
(get_iterable_generator, False, True),
# (get_iterable_generator, False, "raise"),
(get_iterable_generator, True, False),
(get_iterable_generator, True, True),
# (get_iterable_generator, True, "raise"),
(get_sample_one_arg_callback, False, None, False),
(get_batch_one_arg_callback, True, None, False),
(get_batch_one_arg_callback_with_batch_info, True, None, True),
(get_no_arg_callback, False, None, False),
(get_no_arg_callback, True, None, False),
(get_iterable, False, False, False),
(get_iterable, False, True, False),
# (get_iterable, False, "raise", False),
(get_iterable, True, False, False),
(get_iterable, True, True, False),
# (get_iterable, True, "raise", False),
(get_iterable_generator, False, False, False),
(get_iterable_generator, False, True, False),
# (get_iterable_generator, False, "raise", False),
(get_iterable_generator, True, False, False),
(get_iterable_generator, True, True, False),
# (get_iterable_generator, True, "raise", False),
]

def get_external_source_pipe(es_args, dtype, es_device):
Expand Down Expand Up @@ -191,12 +205,13 @@ def get_dense_options(is_batched):

def gen_tf_with_dali_external_source(test_run):
for dtype in [np.uint8, np.int32, np.float32]:
for get_callback, is_batched, cycle in es_configurations:
for get_callback, is_batched, cycle, batch_info in es_configurations:
for dense in get_dense_options(is_batched):
for dev, es_dev in [("cpu", "cpu"), ("gpu", "cpu"), ("gpu", "gpu")]:
for iter_limit in [3, 9, 10, 11, 100]:
bs = 12 if is_batched else None
es_args = {'source': get_callback(dtype, iter_limit, bs, dense),
'batch': is_batched,
'cycle': cycle}
'cycle': cycle,
'batch_info': batch_info}
yield test_run, dev, es_args, es_dev, tf.dtypes.as_dtype(dtype), iter_limit, dense

0 comments on commit 3490831

Please sign in to comment.