Skip to content

Commit

Permalink
Introduce custom quick sort in src/Runtime/OMSort.inc (No test execut…
Browse files Browse the repository at this point in the history
…ed on zOS!!) (onnx#1944)

* Fix zos issue (not supporting qsort_r) by using custom qsort function.
Tests and all the model zoo models pass with this PR on z/OS.

Signed-off-by: Yasushi Negishi <[email protected]>
Co-authored-by: Charles Volzka <[email protected]>
  • Loading branch information
negiyas and cjvolzka authored Jan 10, 2023
1 parent c2e20f1 commit 72600f6
Showing 1 changed file with 117 additions and 1 deletion.
118 changes: 117 additions & 1 deletion src/Runtime/OMSort.inc
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,126 @@ declare_compare_function(Double, double, Descending, <)
size_t dataNum, size_t dataSize,
compareFunctionType compFunc, void *dataPtr));
#endif

//
// Custom quick sort function for environments not suppirting qsort_r (e.g. zos)
//
#define SWAP_INDEX(a, b) \
do { \
uint64_t tmp = (a); \
(a) = (b); \
(b) = tmp; \
} while (0)
// stack for index (uint64_t)
typedef struct indexStack {
uint64_t *stackData;
int64_t stackSize;
int64_t stackTop;
} indexStack;

#define STACK_INIT(stack, stackSize) \
do { \
assert(stackSize > 0); \
stack.stackData = (uint64_t *)alloca(stackSize * sizeof(uint64_t)); \
assert(stack.stackData != NULL); \
stack.stackSize = stackSize; \
stack.stackTop = 0; \
} while (0)
#define STACK_ISEMPTY(stack) (stack.stackTop == 0)
#define STACK_PUSH(stack, begin, end) \
do { \
assert(stack.stackTop <= stack.stackSize - 2); \
stack.stackData[(stack.stackTop)++] = begin; \
stack.stackData[(stack.stackTop)++] = end; \
} while (0)
#define STACK_POP(stack, begin, end) \
do { \
assert(stack.stackTop >= 2); \
end = stack.stackData[--(stack.stackTop)]; \
begin = stack.stackData[--(stack.stackTop)]; \
} while (0)
#define STACK_PRINT(stack) \
do { \
assert(stack.stackTop >= 0); \
fprintf(stderr, "Stack: ["); \
for (int64_t i = 0; (i + 1) < stack.stackTop; i += 2) { \
fprintf( \
stderr, "<%ld:%ld>, ", stack.stackData[i], stack.stackData[i + 1]); \
} \
fprintf( \
stderr, "] (Top=%ld,Size=%ld)\n", stack.stackTop, stack.stackSize); \
fflush(stderr); \
} while (0)

static int64_t log2u(uint64_t n) {
assert(n > 0);
int64_t b = 0;
for (; n > 0; b++)
n = n >> 2;
return b;
}

// Quick sort patition function
static int64_t quick_sort_partition(void *dataPtr, uint64_t *idx,
compareFunctionType compFunc, int64_t begin, int64_t end) {
int64_t i = begin;
for (int64_t j = begin; j < end; j++) {
#if defined(__APPLE__)
if (compFunc(dataPtr, idx + j, idx + end) <= 0)
#elif defined(_MSC_VER)
if (compFunc(dataPtr, idx + j, idx + end) <= 0)
#else
if (compFunc(idx + j, idx + end, dataPtr) <= 0)
#endif
{
SWAP_INDEX(idx[i], idx[j]);
i++;
}
}
SWAP_INDEX(idx[end], idx[i]);
return i;
}

// Quick sort main function (custom version)
#ifdef __APPLE__
void quick_sort_custom(void *base, size_t dataNum, size_t dataSize,
void *dataPtr, compareFunctionType compFunc) {
#else
void quick_sort_custom(void *base, size_t dataNum, size_t dataSize,
compareFunctionType compFunc, void *dataPtr) {
#endif
uint64_t *idx = (uint64_t *)base;
// Calculate the theoritical maximum stack size for index
int64_t stackSize = (log2u(dataNum + 1) + 2) * 2;
indexStack stack;
STACK_INIT(stack, stackSize);
int64_t begin = 0;
int64_t end = dataNum - 1;

// push current job to the stack
STACK_PUSH(stack, begin, end);

while (!STACK_ISEMPTY(stack)) {
// pop current job from the stack
STACK_POP(stack, begin, end);
if (begin < end) {
int64_t pivotIdx =
quick_sort_partition(dataPtr, idx, compFunc, begin, end);
// To limit the stack size, push larger partion at first
if ((pivotIdx - begin) > (end - pivotIdx)) {
if (begin < pivotIdx - 1)
STACK_PUSH(stack, begin, pivotIdx - 1);
if (pivotIdx + 1 < end)
STACK_PUSH(stack, pivotIdx + 1, end);
} else {
if (pivotIdx + 1 < end)
STACK_PUSH(stack, pivotIdx + 1, end);
if (begin < pivotIdx - 1)
STACK_PUSH(stack, begin, pivotIdx - 1);
}
}
}
}

compareFunctionType *getCompareFunction(
uint64_t ascending, OM_DATA_TYPE dataType) {
Expand Down Expand Up @@ -228,9 +342,11 @@ void omTensorSort(OMTensor *orderTensor, const OMTensor *inputTensor,
// Windows supports qsort_s
sortFunctionType *sortFunc = qsort_s;
#pragma warning(pop)
#else
#elif defined(__linux) || defined(__linux__) || defined(linux)
// Use standard quick sort in libc
sortFunctionType *sortFunc = qsort_r;
#else // for environments not supporting quick sort
sortFunctionType *sortFunc = quick_sort_custom; // custum quick sort
#endif
// To support input Tensor with various ranks in a uniform way.
// If the input rank < 6, upgrade the rank to 6 virtually without changing
Expand Down

0 comments on commit 72600f6

Please sign in to comment.