Type safe cuDNN wrapper for the Rust programming language.
The current version of cuDNN targeted by this wrapper is the 8.3.2. You can refer to the official release notes and to the support matrix by NVIDIA.
The legacy API is somewhat complete and it is usable but the backend API is still to be considered a work in progress and its usage is therefore much discouraged. Both APIs are still being developed so expect bugs and reasonable breaking changes whilst using this crate.
The project is part of the Rust CUDA ecosystem and is actively maintained by frjnn.
Here follows a list of useful concepts that should be taken as a handbook for the users of the crate. This is not intended to be the full documentation, as each wrapped struct, enum and function has its own docs, but rather a quick sum up of the key points of the API. As a matter of fact, for a deeper view, you should refer both to the docs of each item and to the official ones by NVIDIA. Furthermore, if you are new to cuDNN we strongly suggest reading the official developer guide.
This crate is built around cust
which is the core wrapper for interfacing with the CUDA driver API of our choice.
All cuDNN library functions return their status. This crate uses Result
to achieve a leaner, idiomatic and easier to manage API.
The main entry point of the cuDNN library is the CudnnContext
struct. This handle is tied to a device and it is explicitly passed to every subsequent library function that operates on GPU data. It manages resources allocations both on the host and the device and takes care of the synchronization of all the the cuDNN primitives.
The handles, and the other cuDNN structs wrapped by this crate, are implementors of the Drop
trait which implicitly calls their destructors on the cuDNN side when they go out of scope.
cuDNN contexts can be created as shown in the following snippet:
use cudnn::CudnnContext;
let ctx = CudnnContext::new().unwrap();
In order to enforce type safety as much as possible at compile time, we shifted away from the original cuDNN enumerated data types and instead opted to leverage Rust's generics. In practice, this means that specifying the data type of a cuDNN tensor descriptor is done as follows:
use cudnn::{CudnnContext, TensorDescriptor};
let ctx = CudnnContext::new().unwrap();
let shape = &[5, 5, 10, 25];
let strides = &[1250, 250, 25, 1];
// f32 tensor
let desc = TensorDescriptor::<f32>::new_strides(shape, strides).unwrap();
This API also allows for using Rust own types as cuDNN data types, which we see as a desirable property.
Safely manipulating cuDNN data types that do not have any such direct match, such as vectorized ones, whilst still performing compile time compatibility checks can be done as follows:
use cudnn::{CudnnContext, TensorDescriptor, Vec4};
let ctx = CudnnContext::new().unwrap();
let shape = &[4, 32, 32, 32];
// in cuDNN this is equal to the INT8x4 data type and CUDNN_TENSOR_NCHW_VECT_C format
let desc = TensorDescriptor::<i8>::new_vectorized::<Vec4>(shape).unwrap();
The previous tensor descriptor can be used together with a i8
device buffer and cuDNN will see it as being a tensor of CUDNN_TENSOR_NCHW_VECT_C
format and CUDNN_DATA_INT8x4
data type.
Currently this crate does not support f16
and bf16
data types.
We decided not to check tensor format configurations at compile time, since it is too strong of a requirement. As a consequence, should you mess up, the program will fail at run-time. A proper understanding of the cuDNN API mechanics is thus fundamental to properly use this crate.
You can refer to this extract from the cuDNN developer guide to learn more about tensor formats.
We split the original cuDNN tensor format enum, which counts 3 variants, in 2 parts: the ScalarC
enum and the TensorFormat::NchwVectC
enum variant. The former stands for "scalar channel" and it encapsulates the Nchw
and Nhwc
formats. Scalar channel formats can be both converted to the TensorFormat
enum with .into()
.
use cudnn::{ScalarC, TensorFormat};
let sc_fmt = ScalarC::Nchw;
let vc_fmt = TensorFormat::NchwVectC;
let sc_to_tf: TensorFormat = sc_fmt.into();