Skip to content

Commit

Permalink
Merge pull request #4 from saipraveenb25/add-plain-vector-matrix-types
Browse files Browse the repository at this point in the history
Add logic for plain vector/matrix inputs to CUDA kernels
  • Loading branch information
saipraveenb25 authored Apr 30, 2024
2 parents 869a3ed + 8110624 commit e9d0c02
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 1 deletion.
69 changes: 68 additions & 1 deletion slangtorch/util/builtin_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import namedtuple
import re
import torch

DiffTensorView = namedtuple('DiffTensorView', ['value', 'grad'], defaults=[None])
Expand Down Expand Up @@ -72,7 +73,73 @@ def accept_array(inp):

return tuple, accept_array


def make_vector_wrapper(module, typename, wrappedTypeMap, makeTypeWrapper):
typeInfoFnName = f"__typeinfo__{typename}"
if hasattr(module, typeInfoFnName):
typeInfoFn = getattr(module, typeInfoFnName)
(fieldnames, fieldtypenames) = typeInfoFn()

# Vector types get converted into 'VectorStorage' types with an embedded array called "data".
# Our strategy here is to use the array wrapper to parse the data and pack it into a singleton tuple.
#

assert len(fieldnames) == 1
assert "data" in fieldnames

elementType = fieldtypenames[0]

# Get the wrapper for the element type
(_, innerArrayConvertFn) = makeTypeWrapper(module, elementType, wrappedTypeMap)

def accept_vector(inp):
return tuple([innerArrayConvertFn(inp)])

wrappedTypeMap[typename] = (tuple, accept_vector)

return tuple, accept_vector


def make_matrix_wrapper(module, typename, wrappedTypeMap, makeTypeWrapper):
typeInfoFnName = f"__typeinfo__{typename}"
if hasattr(module, typeInfoFnName):
typeInfoFn = getattr(module, typeInfoFnName)
(fieldnames, fieldtypenames) = typeInfoFn()

# Matrix types get converted into 'MatrixStorage' types with an embedded array called "data".
# Our strategy here is to use the array wrapper to parse the data and pack it into a singleton tuple.
#

assert len(fieldnames) == 1
assert "data" in fieldnames

# Parse matrix type name to get the element type and dimensions
# Find the first two numbers in the typename of the form 'NxM'
m = re.search(r'\d+x\d+', typename)
if m is None:
raise ValueError(f"Could not parse matrix typename {typename}")
dimensions = m.group(0).split('x')
assert len(dimensions) == 2

def accept_matrix(inp):
# Check that the input is a tuple of tuples
if not isinstance(inp, tuple):
raise ValueError(f"Expected tuple, got {type(inp)}")
if not all(isinstance(x, tuple) for x in inp):
raise ValueError(f"Expected tuple of tuples, got {inp}")
if not all(len(x) == int(dimensions[1]) for x in inp):
raise ValueError(f"Expected tuple of tuples of length {dimensions[1]}, got {inp}")

# Flatten the input into a single tuple and nest it in another tuple
return (sum(inp, ()),)

wrappedTypeMap[typename] = (tuple, accept_matrix)

return tuple, accept_matrix

wrappers = {
'DiffTensorView': make_diff_tensor_view_wrapper,
'Array_*': make_array_wrapper
'Array_*': make_array_wrapper,
'_VectorStorage_*': make_vector_wrapper,
'_MatrixStorage_*': make_matrix_wrapper,
}
37 changes: 37 additions & 0 deletions tests/builtin-type-input.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
[AutoPyBindCUDA]
[CUDAKernel]
void plain_copy_float3(float3 input, TensorView<float> output)
{
// Get the 'global' index of this thread.
uint3 dispatchIdx = cudaThreadIdx() + cudaBlockIdx() * cudaBlockDim();

// If the thread index is beyond the input size, exit early.
if (dispatchIdx.x >= 1)
return;

output[0] = input.x;
output[1] = input.y;
output[2] = input.z;
}

[AutoPyBindCUDA]
[CUDAKernel]
void plain_copy_float3x3(float3x3 input, TensorView<float> output)
{
// Get the 'global' index of this thread.
uint3 dispatchIdx = cudaThreadIdx() + cudaBlockIdx() * cudaBlockDim();

// If the thread index is beyond the input size, exit early.
if (dispatchIdx.x >= 1)
return;

output[0] = input[0][0];
output[1] = input[0][1];
output[2] = input[0][2];
output[3] = input[1][0];
output[4] = input[1][1];
output[5] = input[1][2];
output[6] = input[2][0];
output[7] = input[2][1];
output[8] = input[2][2];
}
24 changes: 24 additions & 0 deletions tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,30 @@ def test_struct_failed_input(self):
with self.assertRaises(TypeError):
self.module.multiply(foo={'A': A, 'Ba': B}, result=Y).launchRaw(blockSize=(32, 32, 1), gridSize=(1, 1, 1))

class TestBuiltinTypeInputs(unittest.TestCase):
def setUp(self) -> None:
super().setUp()
test_dir = os.path.dirname(os.path.abspath(__file__))
slangModuleSourceFile = os.path.join(test_dir, 'builtin-type-input.slang')

module = slangtorch.loadModule(slangModuleSourceFile)
self.module = module

def test_plain_vector_input(self):
Y = torch.tensor([0., 0., 0.]).cuda()

self.module.plain_copy_float3(input=(1.0, 2.0, 3.0), output=Y).launchRaw(blockSize=(32, 1, 1), gridSize=(1, 1, 1))
expected1 = torch.tensor([1., 2., 3.]).cpu()

assert(torch.all(torch.eq(Y.cpu(), expected1)))

def test_plain_matrix_input(self):
Y = torch.tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.]).cuda()

self.module.plain_copy_float3x3(input=((1.0, 2.0, 3.0), (4.0, 5.0, 6.0), (7.0, 8.0, 9.0)), output=Y).launchRaw(blockSize=(32, 1, 1), gridSize=(1, 1, 1))
expected1 = torch.tensor([1., 2., 3., 4., 5., 6., 7., 8., 9.]).cpu()

assert(torch.all(torch.eq(Y.cpu(), expected1)))

class TestEmptyTensor(unittest.TestCase):
def setUp(self) -> None:
Expand Down

0 comments on commit e9d0c02

Please sign in to comment.