Skip to content

Commit

Permalink
GPU PythonFunction operator (#1655)
Browse files Browse the repository at this point in the history
Extend PythonFunctionOperator to suppport GPU
Signed-off-by: Rafal <[email protected]>
  • Loading branch information
banasraf authored Jan 15, 2020
1 parent 3884d77 commit 94a3146
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 9 deletions.
10 changes: 8 additions & 2 deletions dali/operators/python_function/python_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,19 @@ namespace dali {

DALI_SCHEMA(PythonFunctionBase)
.AddArg("function",
R"code(Function object consuming and producing numpy arrays.)code",
R"code(Function object.)code",
DALI_PYTHON_OBJECT)
.AddOptionalArg("num_outputs", R"code(Number of outputs)code", 1)
.MakeInternal();

DALI_SCHEMA(PythonFunction)
.DocStr("Executes a python function.")
.DocStr("Executes a python function. \n"
"The operator can be used to execute custom python code within the DALI pipeline. "
"The called function will get tensors' data as numpy arrays for CPU operators"
" or as cupy arrays for GPU operators and should return results in the same format."
"For now, this operator can be used only in pipelines with "
"`exec_async=False` and `exec_pipelined=False` specified. Due to "
"inferior performance, it is intended mostly for prototyping and debugging.")
.NumInput(0, 256)
.AllowSequences()
.SupportVolumetric()
Expand Down
47 changes: 40 additions & 7 deletions dali/python/nvidia/dali/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@
import nvidia.dali.libpython_function_plugin


cupy = None
def _setup_cupy():
global cupy
if cupy is None:
import cupy as cupy


class _EdgeReference(object):
def __init__(self, name, device="cpu", source=None):
self.name = name
Expand Down Expand Up @@ -746,7 +753,9 @@ def _dlpack_from_array(array):

class PythonFunction(PythonFunctionBase):
global _cpu_ops
global _gpu_ops
_cpu_ops = _cpu_ops.union({'PythonFunction'})
_gpu_ops = _gpu_ops.union({'PythonFunction'})

@staticmethod
def current_stream():
Expand Down Expand Up @@ -779,22 +788,46 @@ def function_wrapper_batch(function, from_dlpack, to_dlpack, *dlpack_inputs):
def _function_wrapper_cpu(batch_processing, function, *dlpack_inputs):
if batch_processing:
return PythonFunction.function_wrapper_batch(function, _dlpack_to_array,
_dlpack_from_array, *dlpack_inputs)
_dlpack_from_array, *dlpack_inputs)
else:
return PythonFunction.function_wrapper_per_sample(function, _dlpack_to_array,
_dlpack_from_array,
*dlpack_inputs)
_dlpack_from_array,
*dlpack_inputs)

@staticmethod
def _cupy_stream_wrapper(function, *inputs):
stream = cupy.cuda.Stream(null=True)
stream.ptr = PythonFunction.current_stream().ptr
with stream:
out = function(*inputs)
stream.ptr = 0
return out

@staticmethod
def _function_wrapper_gpu(batch_processing, function, *dlpack_inputs):
def wrapped_func(*inputs):
return PythonFunction._cupy_stream_wrapper(function, *inputs)
if batch_processing:
return PythonFunction.function_wrapper_batch(wrapped_func, cupy.fromDlpack,
lambda t: t.toDlpack(), *dlpack_inputs)
else:
return PythonFunction.function_wrapper_per_sample(wrapped_func, cupy.fromDlpack,
lambda t: t.toDlpack(),
*dlpack_inputs)

def __init__(self, function, num_outputs=1, device='cpu', batch_processing=False, **kwargs):
if device == 'gpu':
_setup_cupy()
func = (lambda *ts: PythonFunction._function_wrapper_cpu(batch_processing, function, *ts))\
if device == 'cpu' else \
(lambda *ts: PythonFunction._function_wrapper_gpu(batch_processing, function, *ts))
super(PythonFunction, self).__init__(impl_name="DLTensorPythonFunctionImpl",
function=lambda *ins:
PythonFunction._function_wrapper_cpu(batch_processing,
function, *ins),
function=func,
num_outputs=num_outputs, device=device,
synchronize_stream=False,
batch_processing=batch_processing, **kwargs)



class DLTensorPythonFunction(PythonFunctionBase):
global _cpu_ops
_cpu_ops = _cpu_ops.union({'DLTensorPythonFunction'})
Expand Down
77 changes: 77 additions & 0 deletions dali/test/python/test_gpu_python_function_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import cupy
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import nvidia.dali.types as types
import numpy as np
import test_utils
import random
import os


def random_seed():
return int(random.random() * (1 << 32))


test_data_root = os.environ['DALI_EXTRA_PATH']
images_dir = os.path.join(test_data_root, 'db', 'single', 'jpeg')


DEVICE_ID = 0
BATCH_SIZE = 8
ITERS = 32
SEED = random_seed()
NUM_WORKERS = 6


class PythonFunctionPipeline(Pipeline):
def __init__(self, function, device, num_outputs=1):
super(PythonFunctionPipeline, self).__init__(BATCH_SIZE, NUM_WORKERS, DEVICE_ID,
seed=SEED,
exec_async=False, exec_pipelined=False)
self.device = device
self.reader = ops.FileReader(file_root=images_dir)
self.decode = ops.ImageDecoder(device='cpu',
output_type=types.RGB)
self.norm = ops.CropMirrorNormalize(std=255., mean=0., device=device, output_layout="HWC")
self.func = ops.PythonFunction(device=device, function=function, num_outputs=num_outputs)

def define_graph(self):
jpegs, labels = self.reader()
decoded = self.decode(jpegs)
images = decoded if self.device == 'cpu' else decoded.gpu()
normalized = self.norm(images)
return self.func(normalized, normalized)


def validate_cpu_vs_gpu(gpu_fun, cpu_fun, num_outputs=1):
gpu_pipe = PythonFunctionPipeline(gpu_fun, 'gpu', num_outputs)
cpu_pipe = PythonFunctionPipeline(cpu_fun, 'cpu', num_outputs)
test_utils.compare_pipelines(gpu_pipe, cpu_pipe, BATCH_SIZE, ITERS)


def arrays_arithmetic(in1, in2):
return in1 + in2, in1 - in2 / 2.


def test_simple_arithm():
validate_cpu_vs_gpu(arrays_arithmetic, arrays_arithmetic, num_outputs=2)


square_diff_kernel = cupy.ElementwiseKernel(
'T x, T y',
'T z',
'z = x*x - y*y',
'square_diff'
)


def square_diff(in1, in2):
return in1*in1 - in2*in2


def test_cupy_kernel():
validate_cpu_vs_gpu(square_diff_kernel, square_diff)


def test_builtin_func():
validate_cpu_vs_gpu(cupy.logaddexp, np.logaddexp)
1 change: 1 addition & 0 deletions qa/TL0_python_self_test_frameworks/test_cupy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ target_dir=./dali/test/python

test_body() {
nosetests --verbose -m '(?:^|[\b_\./-])[Tt]est.*cupy' test_dltensor_operator.py
nosetests --verbose test_gpu_python_function_operator.py
}

pushd ../..
Expand Down

0 comments on commit 94a3146

Please sign in to comment.