Skip to content

Latest commit

 

History

History

cudnn

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

cudnn

Type safe cuDNN wrapper for the Rust programming language.

Project status

The current version of cuDNN targeted by this wrapper is the 8.3.2. You can refer to the official release notes and to the support matrix by NVIDIA.

The legacy API is somewhat complete and it is usable but the backend API is still to be considered a work in progress and its usage is therefore much discouraged. Both APIs are still being developed so expect bugs and reasonable breaking changes whilst using this crate.

The project is part of the Rust CUDA ecosystem and is actively maintained by frjnn.

Primer

Here follows a list of useful concepts that should be taken as a handbook for the users of the crate. This is not intended to be the full documentation, as each wrapped struct, enum and function has its own docs, but rather a quick sum up of the key points of the API. As a matter of fact, for a deeper view, you should refer both to the docs of each item and to the official ones by NVIDIA. Furthermore, if you are new to cuDNN we strongly suggest reading the official developer guide.

Device buffers

This crate is built around cust which is the core wrapper for interfacing with the CUDA driver API of our choice.

cuDNN statuses and Result

All cuDNN library functions return their status. This crate uses Result to achieve a leaner, idiomatic and easier to manage API.

cuDNN handles and RAII

The main entry point of the cuDNN library is the CudnnContext struct. This handle is tied to a device and it is explicitly passed to every subsequent library function that operates on GPU data. It manages resources allocations both on the host and the device and takes care of the synchronization of all the the cuDNN primitives.

The handles, and the other cuDNN structs wrapped by this crate, are implementors of the Drop trait which implicitly calls their destructors on the cuDNN side when they go out of scope.

cuDNN contexts can be created as shown in the following snippet:

use cudnn::CudnnContext;

let ctx = CudnnContext::new().unwrap();

cuDNN data types

In order to enforce type safety as much as possible at compile time, we shifted away from the original cuDNN enumerated data types and instead opted to leverage Rust's generics. In practice, this means that specifying the data type of a cuDNN tensor descriptor is done as follows:

use cudnn::{CudnnContext, TensorDescriptor};

let ctx = CudnnContext::new().unwrap();

let shape = &[5, 5, 10, 25];
let strides = &[1250, 250, 25, 1];

// f32 tensor
let desc = TensorDescriptor::<f32>::new_strides(shape, strides).unwrap();

This API also allows for using Rust own types as cuDNN data types, which we see as a desirable property.

Safely manipulating cuDNN data types that do not have any such direct match, such as vectorized ones, whilst still performing compile time compatibility checks can be done as follows:

use cudnn::{CudnnContext, TensorDescriptor, Vec4};

let ctx = CudnnContext::new().unwrap();

let shape = &[4, 32, 32, 32];

// in cuDNN this is equal to the INT8x4 data type and CUDNN_TENSOR_NCHW_VECT_C format
let desc = TensorDescriptor::<i8>::new_vectorized::<Vec4>(shape).unwrap();

The previous tensor descriptor can be used together with a i8 device buffer and cuDNN will see it as being a tensor of CUDNN_TENSOR_NCHW_VECT_C format and CUDNN_DATA_INT8x4 data type.

Currently this crate does not support f16 and bf16 data types.

cuDNN tensor formats

We decided not to check tensor format configurations at compile time, since it is too strong of a requirement. As a consequence, should you mess up, the program will fail at run-time. A proper understanding of the cuDNN API mechanics is thus fundamental to properly use this crate.

You can refer to this extract from the cuDNN developer guide to learn more about tensor formats.

We split the original cuDNN tensor format enum, which counts 3 variants, in 2 parts: the ScalarC enum and the TensorFormat::NchwVectC enum variant. The former stands for "scalar channel" and it encapsulates the Nchw and Nhwc formats. Scalar channel formats can be both converted to the TensorFormat enum with .into().

use cudnn::{ScalarC, TensorFormat};

let sc_fmt = ScalarC::Nchw;

let vc_fmt = TensorFormat::NchwVectC;

let sc_to_tf: TensorFormat = sc_fmt.into();