Skip to content

Commit

Permalink
Support specifying per-frame positional arguments in sequence process…
Browse files Browse the repository at this point in the history
…ing test utility (NVIDIA#3901)

* Support testing of positional arguments in per-frame test utility
* Provide better context to the parameter callbacks

Signed-off-by: Kamil Tokarski <[email protected]>
  • Loading branch information
stiepan authored May 20, 2022
1 parent 9d790c7 commit 756a706
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 124 deletions.
203 changes: 138 additions & 65 deletions dali/test/python/sequences_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import os
import random
import numpy as np
from typing import List, Union, Callable
from dataclasses import dataclass

from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
Expand All @@ -28,13 +30,55 @@
vid_file = os.path.join(data_root, 'db', 'video',
'sintel', 'sintel_trailer-720p.mp4')


@dataclass
class SampleDesc:
"""Context that the argument provider callback receives when prompted for parameter"""
rng: random.Random
frame_idx: int
sample_idx: int
batch_idx: int
sample: np.ndarray


@dataclass
class ArgDesc:
name: Union[str, int]
is_per_frame: bool
dest_device: str

def __post_init__(self):
assert self.is_positional_arg or self.dest_device == "cpu", "Named arguments on GPU are not supported"

@property
def is_positional_arg(self):
return isinstance(self.name, int)


class ArgCb:
def __init__(self, name: Union[str, int], cb: Callable[[SampleDesc], np.ndarray], is_per_frame: bool, dest_device: str = "cpu"):
self.desc = ArgDesc(name, is_per_frame, dest_device)
self.cb = cb

def __repr__(self):
return "ArgCb{}".format((self.cb, self.desc))


@dataclass
class ArgData:
desc: ArgDesc
data: List[List[np.ndarray]]


class ParamsProviderBase:
"""
Computes data to be passed as argument inputs in sequence processing tests, the `compute_params` params
should return tuple of lists with data for respectively per-sample and per-frame arguments.
The `expand_params` should return corressponding unfolded/expanded arguments to be used in the
should return a lists of ArgData; number of dimensions in the `arg_desc.data` must reflect
`arg_data.desc.is_per_frame` argument.
The `expand_params` should return corresponding unfolded/expanded ArgData to be used in the
baseline pipeline.
"""

def __init__(self):
self.input_data = None
self.input_layout = None
Expand All @@ -56,43 +100,47 @@ def setup_expand(self, num_expand, unfolded_input, unfolded_input_layout):
self.unfolded_input = unfolded_input
self.unfolded_input_layout = unfolded_input_layout

def compute_params(self):
def compute_params(self) -> List[ArgData]:
raise NotImplementedError

def expand_params(self):
def expand_params(self) -> List[ArgData]:
raise NotImplementedError


class ParamsProvider(ParamsProviderBase):

def __init__(self, input_params):
"""
`input_params` : List[Tuple[str, rng -> np.array, bool]]
List describing tensor input arguments of the form: [(tensor_arg_name, single_arg_cb, is_per_frame)])]
The `single_arg_cb` should be a function that takes Python's random number generator and returns
an argument for a single sample or frame, depending on the `is_per_frame` flag."""
def __init__(self, input_params: List[ArgCb]):
super().__init__()
self.input_params = input_params
self.per_sample_params_data = None
self.per_frame_params_data = None
self.arg_input_data = None
self.expanded_params_data = None

def compute_params(self):
self.per_sample_params_data, self.per_frame_params_data = get_input_params_data(
self.input_data, self.input_layout, self.input_params, self.rng)
return self.per_sample_params_data, self.per_frame_params_data

def expand_params(self):
expanded_per_sample_params = [
(arg_name, expand_arg_input(self.input_data, self.input_layout, self.num_expand, arg_data, False))
for arg_name, arg_data in self.per_sample_params_data]
expanded_per_frame_params = [
(arg_name, expand_arg_input(self.input_data, self.input_layout, self.num_expand, arg_data, True))
for arg_name, arg_data in self.per_frame_params_data]
self.expanded_params_data = expanded_per_sample_params + expanded_per_frame_params
def compute_params(self) -> List[ArgData]:
self.arg_input_data = compute_input_params_data(
self.input_data, self.input_layout, self.rng, self.input_params)
return self.arg_input_data

def expand_params(self) -> List[ArgData]:
self.expanded_params_data = [
ArgData(
ArgDesc(arg_data.desc.name, False, arg_data.desc.dest_device),
expand_arg_input(
self.input_data, self.input_layout, self.num_expand,
arg_data.data, arg_data.desc.is_per_frame))
for arg_data in self.arg_input_data
]
return self.expanded_params_data


def arg_data_node(arg_data: ArgData):
node = fn.external_source(dummy_source(arg_data.data))
if arg_data.desc.dest_device == "gpu":
node = node.gpu()
if arg_data.desc.is_per_frame:
node = fn.per_frame(node)
return node


def as_batch(tensor):
if isinstance(tensor, _Tensors.TensorListGPU):
tensor = tensor.as_cpu()
Expand Down Expand Up @@ -166,40 +214,50 @@ def _test_seq_input(device, num_iters, expandable_extents, operator_fn, fixed_pa
input_layout, input_data, rng):

@pipeline_def
def pipeline(input_data, input_layout, per_sample_params_input, per_frame_params_input):
def pipeline(input_data, input_layout, args_data: List[ArgData]):
input = fn.external_source(
source=dummy_source(input_data), layout=input_layout)
if device == "gpu":
input = input.gpu()
pos_args = [
arg_data for arg_data in args_data if arg_data.desc.is_positional_arg]
pos_nodes = [None] * (len(pos_args) + 1)
for arg_data in pos_args:
assert 0 <= arg_data.desc.name < len(pos_nodes)
assert pos_nodes[arg_data.desc.name] is None
pos_nodes[arg_data.desc.name] = arg_data_node(arg_data)
[input_idx] = [i for i, pos_input in enumerate(
pos_nodes) if pos_input is None] # there should be exactly one
pos_nodes[input_idx] = input
named_args = [
arg_data for arg_data in args_data if not arg_data.desc.is_positional_arg]
arg_nodes = {
arg_name: fn.external_source(source=dummy_source(arg_data))
for arg_name, arg_data in per_sample_params_input + per_frame_params_input}
for arg_name, _ in per_frame_params_input:
arg_nodes[arg_name] = fn.per_frame(arg_nodes[arg_name])
output = operator_fn(input, **fixed_params, **arg_nodes)
arg_data.desc.name: arg_data_node(arg_data)
for arg_data in named_args}
output = operator_fn(*pos_nodes, **fixed_params, **arg_nodes)
return output

max_batch_size = max(len(batch) for batch in input_data)

params_provider = input_params if isinstance(input_params, ParamsProviderBase) else ParamsProvider(input_params)
params_provider = input_params if isinstance(
input_params, ParamsProviderBase) else ParamsProvider(input_params)
params_provider.setup(input_data, input_layout, fixed_params, rng)
per_sample_params_input, per_frame_params_input = params_provider.compute_params()
args_data = params_provider.compute_params()
seq_pipe = pipeline(input_data=input_data, input_layout=input_layout,
per_sample_params_input=per_sample_params_input,
per_frame_params_input=per_frame_params_input,
args_data=args_data,
batch_size=max_batch_size, num_threads=4,
device_id=0)

num_expand = get_layout_prefix_len(input_layout, expandable_extents)
unfolded_input = unfold_batches(input_data, num_expand)
unfolded_input_layout = input_layout[num_expand:]
params_provider.setup_expand(num_expand, unfolded_input, unfolded_input_layout)
expanded_params_data = params_provider.expand_params()
params_provider.setup_expand(
num_expand, unfolded_input, unfolded_input_layout)
expanded_args_data = params_provider.expand_params()
max_uf_batch_size = max(len(batch) for batch in unfolded_input)
baseline_pipe = pipeline(input_data=unfolded_input,
input_layout=unfolded_input_layout,
per_sample_params_input=expanded_params_data,
per_frame_params_input=[],
args_data=expanded_args_data,
batch_size=max_uf_batch_size, num_threads=4,
device_id=0)
seq_pipe.build()
Expand All @@ -216,37 +274,52 @@ def pipeline(input_data, input_layout, per_sample_params_input, per_frame_params


def get_input_arg_per_sample(input_data, param_cb, rng):
return [[param_cb(rng) for _ in batch] for batch in input_data]
return [[
param_cb(SampleDesc(rng, None, sample_idx, batch_idx, sample))
for sample_idx, sample in enumerate(batch)]
for batch_idx, batch in enumerate(input_data)]


def get_input_arg_per_frame(input_data, input_layout, param_cb, rng):
def arg_for_sample(num_frames):
if rng.randint(1, 4) == 1:
return np.array([param_cb(rng)])
return np.array([param_cb(rng) for _ in range(num_frames)])
def get_input_arg_per_frame(input_data, input_layout, param_cb, rng, check_broadcasting):
frame_idx = input_layout.find("F")
return [[arg_for_sample(sample.shape[frame_idx])
for sample in batch] for batch in input_data]

def arg_for_sample(sample_idx, batch_idx, sample):
if check_broadcasting and rng.randint(1, 4) == 1:
return np.array([param_cb(SampleDesc(rng, 0, sample_idx, batch_idx, sample))])
num_frames = sample.shape[frame_idx]
return np.array([
param_cb(SampleDesc(rng, frame_idx, sample_idx, batch_idx, sample))
for frame_idx in range(num_frames)])

return [[
arg_for_sample(sample_idx, batch_idx, sample)
for sample_idx, sample in enumerate(batch)]
for batch_idx, batch in enumerate(input_data)]

def get_input_params_data(input_data, input_layout, input_params, rng):
per_sample_args = [
(param_name, get_input_arg_per_sample(input_data, param_cb, rng))
for param_name, param_cb, is_per_frame in input_params if not is_per_frame]
per_frame_args = [
(param_name, get_input_arg_per_frame(
input_data, input_layout, param_cb, rng))
for param_name, param_cb, is_per_frame in input_params if is_per_frame]

return per_sample_args, per_frame_args
def compute_input_params_data(input_data, input_layout, rng, input_params: List[ArgCb]):
def input_param_data(arg_cb):
if arg_cb.desc.is_per_frame:
return get_input_arg_per_frame(
input_data, input_layout, arg_cb.cb, rng, not arg_cb.desc.is_positional_arg)
return get_input_arg_per_sample(input_data, arg_cb.cb, rng)
return [ArgData(arg_cb.desc, input_param_data(arg_cb)) for arg_cb in input_params]


def sequence_suite_helper(rng, expandable_extents, input_cases, ops_test_cases, num_iters=4):
for operator_fn, fixed_params, input_params in ops_test_cases:
for device in ["cpu", "gpu"]:
class OpTestCase:
def __init__(self, operator_fn, fixed_params, input_params, devices=None):
self.operator_fn = operator_fn
self.fixed_params = fixed_params
self.input_params = input_params
self.devices = ["cpu", "gpu"] if devices is None else devices
for test_case_args in ops_test_cases:
test_case = OpTestCase(*test_case_args)
for device in test_case.devices:
for (input_layout, input_data) in input_cases:
yield _test_seq_input, device, num_iters, expandable_extents, operator_fn, fixed_params, \
input_params, input_layout, input_data, rng
yield _test_seq_input, device, num_iters, expandable_extents, \
test_case.operator_fn, test_case.fixed_params, \
test_case.input_params, input_layout, input_data, rng


def get_video_input_cases(seq_layout, rng, larger_shape=(512, 288), smaller_shape=(384, 216)):
Expand Down Expand Up @@ -312,7 +385,7 @@ def vid_source(batch_size, num_batches, num_frames, width, height, seq_layout):
return batches


def video_suite_helper(ops_test_cases, test_channel_first=True, expand_channels=False):
def video_suite_helper(ops_test_cases, test_channel_first=True, expand_channels=False, rng=None):
"""
Generates suite of video test cases for a sequence processing operator.
The operator should meet the SequenceOperator assumptions, i.e.
Expand All @@ -327,20 +400,20 @@ def video_suite_helper(ops_test_cases, test_channel_first=True, expand_channels=
For testing operator with different input than the video, consider using `sequence_suite_helper` directly.
----------
`ops_test_cases` : List[Tuple[Operator, Dict[str, Any], ParamProviderBase|List[Tuple[str, rng -> np.array, bool]]]]
`ops_test_cases` : List[Tuple[Operator, Dict[str, Any], ParamProviderBase|List[ArgCb]]]
List of operators and their parameters that should be tested.
Each element is expected to be a triple of the form:
[(fn.operator, {fixed_param_name: fixed_param_value}, [(tensor_arg_name, single_arg_cb, is_per_frame)])]
[(fn.operator, {fixed_param_name: fixed_param_value}, [ArgCb(tensor_arg_name, single_arg_cb, is_per_frame, dest_device)])]
where the first element is ``fn.operator``, the second one is a dictionary of fixed arguments that should
be passed to the operator and the last one is a list of tuples describing tensor input arguments
be passed to the operator and the last one is a list of ArgCb instances describing tensor input arguments
(see `ParamsProvider` argument description) or custom params provider instance.
`test_channel_first` : bool
If True, the "FCHW" layout is tested.
`expand_channels` : bool
If True, for the "FCHW" layout the first two (and not just one) dims are expanded, and "CFHW" layout is tested.
Requires `test_channel_first` to be True.
"""
rng = random.Random(42)
rng = rng or random.Random(42)
expandable_extents = "FC" if expand_channels else "F"
layouts = ["FHWC"]
if not test_channel_first:
Expand Down
32 changes: 16 additions & 16 deletions dali/test/python/test_operator_gaussian_blur.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from nose_utils import assert_raises
from nose.plugins.attrib import attr

from sequences_test_utils import video_suite_helper
from sequences_test_utils import video_suite_helper, ArgCb
from test_utils import get_dali_extra_path, check_batch, compare_pipelines, RandomlyShapedDataIterator, dali_type

data_root = get_dali_extra_path()
Expand Down Expand Up @@ -340,29 +340,29 @@ def test_fail_gaussian_blur():


def test_per_frame():
def window_size(rng):
return np.array(2 * rng.randint(1, 15) + 1, dtype=np.int32)
def window_size(sample_desc):
return np.array(2 * sample_desc.rng.randint(1, 15) + 1, dtype=np.int32)

def per_axis_window_size(rng):
return np.array([window_size(rng) for _ in range(2)])
def per_axis_window_size(sample_desc):
return np.array([window_size(sample_desc) for _ in range(2)])

def sigma(rng):
return np.array((rng.random() + 1) * 3., dtype=np.float32)
def sigma(sample_desc):
return np.array((sample_desc.rng.random() + 1) * 3., dtype=np.float32)

def per_axis_sigma(rng):
return np.array([sigma(rng) for _ in range(2)])
def per_axis_sigma(sample_desc):
return np.array([sigma(sample_desc) for _ in range(2)])

video_test_cases = [
(fn.gaussian_blur, {'window_size': 3}, []),
(fn.gaussian_blur, {}, [("window_size", window_size, True)]),
(fn.gaussian_blur, {}, [("window_size", per_axis_window_size, True)]),
(fn.gaussian_blur, {}, [("sigma", sigma, True)]),
(fn.gaussian_blur, {}, [ArgCb("window_size", window_size, True)]),
(fn.gaussian_blur, {}, [ArgCb("window_size", per_axis_window_size, True)]),
(fn.gaussian_blur, {}, [ArgCb("sigma", sigma, True)]),
(fn.gaussian_blur, {}, [
("window_size", per_axis_window_size, True),
("sigma", per_axis_sigma, True)]),
ArgCb("window_size", per_axis_window_size, True),
ArgCb("sigma", per_axis_sigma, True)]),
(fn.gaussian_blur, {'dtype': types.FLOAT}, [
("window_size", per_axis_window_size, False),
("sigma", per_axis_sigma, True)]),
ArgCb("window_size", per_axis_window_size, False),
ArgCb("sigma", per_axis_sigma, True)]),
]

yield from video_suite_helper(video_test_cases, expand_channels=True)
Loading

0 comments on commit 756a706

Please sign in to comment.