Skip to content

Commit

Permalink
[NNPA] zDNN extension for splitting ztensors (onnx#2688)
Browse files Browse the repository at this point in the history
* zDNN extension for splitting ztensors. Apply to zdnn matmul in this commit.

Signed-off-by: Tung D. Le <[email protected]>

---------

Signed-off-by: Tung D. Le <[email protected]>
  • Loading branch information
tungld authored Jan 26, 2024
1 parent 58f3746 commit 46787f3
Show file tree
Hide file tree
Showing 8 changed files with 632 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ ApiRegistry RegisterAllApis(MLIRContext *context) {
ApiSpec(API::ZDNN_LSTM, "zdnn_lstm", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy, opaquePtrTy, opaquePtrTy, opaquePtrTy, opaquePtrTy, int64Ty, opaquePtrTy, opaquePtrTy, opaquePtrTy}, false),
ApiSpec(API::ZDNN_GRU, "zdnn_gru", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy, opaquePtrTy, opaquePtrTy, opaquePtrTy, int64Ty, opaquePtrTy, opaquePtrTy}, false),
// Other operations
ApiSpec(API::ZDNN_MATMUL_OP, "zdnn_matmul_op", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy, int64Ty, opaquePtrTy}, false),
ApiSpec(API::ZDNN_MATMUL_BCAST_OP, "zdnn_matmul_bcast_op", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy, int64Ty, opaquePtrTy}, false),
ApiSpec(API::ZDNN_MATMUL_OP, "zdnn_matmul_op_ext", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy, int64Ty, opaquePtrTy}, false),
ApiSpec(API::ZDNN_MATMUL_BCAST_OP, "zdnn_matmul_bcast_op_ext", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy, int64Ty, opaquePtrTy}, false),
ApiSpec(API::ZDNN_CONV2D, "zdnn_conv2d", int32Ty, {opaquePtrTy, opaquePtrTy, opaquePtrTy, int64Ty, int64Ty, int64Ty, int64Ty, opaquePtrTy, opaquePtrTy}, false),
ApiSpec(API::ZDNN_AVGPOOL2D, "zdnn_avgpool2d", int32Ty, {opaquePtrTy, int64Ty, int64Ty, int64Ty, int64Ty, int64Ty, opaquePtrTy}, false),
ApiSpec(API::ZDNN_MAXPOOL2D, "zdnn_maxpool2d", int32Ty, {opaquePtrTy, int64Ty, int64Ty, int64Ty, int64Ty, int64Ty, opaquePtrTy}, false),
Expand Down
3 changes: 3 additions & 0 deletions src/Accelerators/NNPA/Runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

add_onnx_mlir_library(RuntimeNNPA STATIC
OMRuntimeNNPA.c
zDNNExtension/zDNNExtension.c
zDNNExtension/MatMul.c

EXCLUDE_FROM_OM_LIBS

Expand All @@ -17,5 +19,6 @@ set_target_properties(RuntimeNNPA
PROPERTIES
LANGUAGE C
POSITION_INDEPENDENT_CODE TRUE
COMPILE_OPTIONS -O3
)

5 changes: 5 additions & 0 deletions src/Accelerators/NNPA/Runtime/OMRuntimeNNPA.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <stdio.h>
#include <stdlib.h>

#include "zDNNExtension/zDNNExtension.h"
#include "zdnn.h"

#ifdef __cplusplus
Expand Down Expand Up @@ -104,6 +105,8 @@ void OMInitAccelNNPA() {
if (!OMIsInitAccelNNPA) {
/* Still uninitialized, actual init. */
zdnn_init();
/* Initialize settings for ztensor splitting. */
zDNNExtensionInit();
/* No need for a fence due to strong consistency. */
OMIsInitAccelNNPA = 1;
} /* Release mutex. */
Expand Down Expand Up @@ -143,6 +146,8 @@ uint64_t OMInitCompatibleAccelNNPA(uint64_t versionNum) {
if (!OMIsInitAccelNNPA) {
/* Still uninitialized, actual init. */
zdnn_init();
/* Initialize settings for ztensor splitting. */
zDNNExtensionInit();
/* Check if version is compatible */
if (zdnn_is_version_runnable((uint32_t)versionNum))
isCompatible = 1;
Expand Down
149 changes: 149 additions & 0 deletions src/Accelerators/NNPA/Runtime/zDNNExtension/MatMul.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* SPDX-License-Identifier: Apache-2.0
*/

//===-------------------------- MatMul.c ----------------------------------===//
//
// Copyright 2024 The IBM Research Authors.
//
// =============================================================================
//
// A wrapper of zdnn_matmul_op for ztensor partition and parallelism.
//
//===----------------------------------------------------------------------===//

// Include pthreads (need special treatment on z/OS).
#ifdef __MVS__
#define _OPEN_THREADS
#endif
#include <pthread.h>

#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/time.h>

#include "zDNNExtension.h"
#include "zdnn.h"

#ifdef __cplusplus
extern "C" {
#endif

static inline zdnn_status call_zdnn_matmul_op(const zdnn_ztensor *inputA,
const zdnn_ztensor *inputB, const zdnn_ztensor *inputC, int opType,
zdnn_ztensor *output, bool isBcast) {
if (isBcast)
return zdnn_matmul_bcast_op(
inputA, inputB, inputC, (zdnn_matmul_bcast_ops)opType, output);
return zdnn_matmul_op(
inputA, inputB, inputC, (zdnn_matmul_ops)opType, output);
}

static zdnn_status zdnn_matmul_op_common(const zdnn_ztensor *inputA,
const zdnn_ztensor *inputB, const zdnn_ztensor *inputC, int opType,
zdnn_ztensor *output, bool isBcast) {
// Verify that e4, e3, e1 do not exceed the maximum dimension size. Thus, we
// will split e2 safely.
OrigShape origShapeOfA;
getOrigShape(inputA, &origShapeOfA);
uint32_t maxDimSize = zdnn_get_nnpa_max_dim_idx_size();
if ((origShapeOfA.e4 > maxDimSize) || (origShapeOfA.e3 > maxDimSize) ||
(origShapeOfA.e1 > maxDimSize)) {
printf("[MatMul] The 1st tensor dimension exceeds maximum dimension index "
"size (MDIS) of %d: e4 = %d, e3 = %d, e1 = %d.\n",
maxDimSize, origShapeOfA.e4, origShapeOfA.e3, origShapeOfA.e1);
return ZDNN_EXCEEDS_MDIS;
}

// For a MatMul of (M,N)*(N,P),
// We split M that is e2 in (e4, e3, e2, e1).
SplitInfo splitInfoA, splitInfoY;
splitInfoA.axis = 2;
splitInfoY.axis = 2;
splitInfoA.chunkSize = OMZTensorSplitSize;
splitInfoY.chunkSize = OMZTensorSplitSize;

// Dim is small or ztensor split is disabled.
if (!OMZTensorSplitEnabled || !initSplitInfo(inputA, &splitInfoA)) {
if (OMZTensorSplitDebug)
printf("[MatMul] Not split zTensor ...\n");
return call_zdnn_matmul_op(inputA, inputB, inputC, opType, output, isBcast);
}

// Split input A.
if (OMZTensorSplitDebug)
printf("[MatMul] Split the 1st ztensor along e2 into %d chunks of %d "
"elements \n",
splitInfoA.numOfChunks, splitInfoA.chunkSize);
initSplitInfo(output, &splitInfoY);

double splitTime = 0.;
double mmTime = 0.;
double mergeTime = 0.;
clock_t start_time, end_time;

// Split input A into chunks.
if (OMZTensorSplitDebug)
start_time = clock();
splitZTensor(inputA, &splitInfoA, /*copyData=*/true);
splitZTensor(output, &splitInfoY, /*copyData=*/false);
if (OMZTensorSplitDebug) {
end_time = clock();
splitTime = ((float)(end_time - start_time) / (float)CLOCKS_PER_SEC) * 1000;
}

// Call zdnn_matmul_op on each chunk.
if (OMZTensorSplitDebug)
start_time = clock();
for (uint32_t i = 0; i < splitInfoA.numOfChunks; ++i) {
zdnn_status status = call_zdnn_matmul_op(splitInfoA.tensors + i, inputB,
inputC, opType, splitInfoY.tensors + i, isBcast);
assert(status == ZDNN_OK);
}
if (OMZTensorSplitDebug) {
end_time = clock();
mmTime = ((float)(end_time - start_time) / (float)CLOCKS_PER_SEC) * 1000;
}

// Merging the chunks into the output.
if (OMZTensorSplitDebug)
start_time = clock();
mergeZTensors(&splitInfoY, output);
if (OMZTensorSplitDebug) {
end_time = clock();
mergeTime = ((float)(end_time - start_time) / (float)CLOCKS_PER_SEC) * 1000;
}

freeSplitInfoBuffer(&splitInfoA);
freeSplitInfoBuffer(&splitInfoY);

if (OMZTensorSplitDebug)
printf("[MatMul] split, %f, mm, %f, merge, %f (milliseconds)\n", splitTime,
mmTime, mergeTime);

return ZDNN_OK;
}

zdnn_status zdnn_matmul_op_ext(const zdnn_ztensor *inputA,
const zdnn_ztensor *inputB, const zdnn_ztensor *inputC, int opType,
zdnn_ztensor *output) {
return zdnn_matmul_op_common(
inputA, inputB, inputC, opType, output, /*isBcast=*/false);
}

zdnn_status zdnn_matmul_bcast_op_ext(const zdnn_ztensor *inputA,
const zdnn_ztensor *inputB, const zdnn_ztensor *inputC, int opType,
zdnn_ztensor *output) {
zdnn_status status = zdnn_matmul_op_common(
inputA, inputB, inputC, opType, output, /*isBcast=*/true);
// Compiler does not check the return result at this moment. Thus, check it
// here.
assert(status == ZDNN_OK && "Failed to execute MatMul on NNPA");
return status;
}

#ifdef __cplusplus
}
#endif
Loading

0 comments on commit 46787f3

Please sign in to comment.