Skip to content

Commit

Permalink
lib/sort: Add priv pointer to swap function
Browse files Browse the repository at this point in the history
Adding support to have priv pointer in swap callback function.

Following the initial change on cmp callback functions [1]
and adding SWAP_WRAPPER macro to identify sort call of sort_r.

Signed-off-by: Jiri Olsa <[email protected]>
Signed-off-by: Alexei Starovoitov <[email protected]>
Reviewed-by: Masami Hiramatsu <[email protected]>
Link: https://lore.kernel.org/bpf/[email protected]

[1] 4333fb9 ("media: lib/sort.c: implement sort() variant taking context argument")
  • Loading branch information
olsajiri authored and Alexei Starovoitov committed Mar 18, 2022
1 parent 245d949 commit a0019cd
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 11 deletions.
2 changes: 1 addition & 1 deletion include/linux/sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

void sort_r(void *base, size_t num, size_t size,
cmp_r_func_t cmp_func,
swap_func_t swap_func,
swap_r_func_t swap_func,
const void *priv);

void sort(void *base, size_t num, size_t size,
Expand Down
1 change: 1 addition & 0 deletions include/linux/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ struct callback_head {
typedef void (*rcu_callback_t)(struct rcu_head *head);
typedef void (*call_rcu_func_t)(struct rcu_head *head, rcu_callback_t func);

typedef void (*swap_r_func_t)(void *a, void *b, int size, const void *priv);
typedef void (*swap_func_t)(void *a, void *b, int size);

typedef int (*cmp_r_func_t)(const void *a, const void *b, const void *priv);
Expand Down
40 changes: 30 additions & 10 deletions lib/sort.c
Original file line number Diff line number Diff line change
Expand Up @@ -122,32 +122,43 @@ static void swap_bytes(void *a, void *b, size_t n)
* a pointer, but small integers make for the smallest compare
* instructions.
*/
#define SWAP_WORDS_64 (swap_func_t)0
#define SWAP_WORDS_32 (swap_func_t)1
#define SWAP_BYTES (swap_func_t)2
#define SWAP_WORDS_64 (swap_r_func_t)0
#define SWAP_WORDS_32 (swap_r_func_t)1
#define SWAP_BYTES (swap_r_func_t)2
#define SWAP_WRAPPER (swap_r_func_t)3

struct wrapper {
cmp_func_t cmp;
swap_func_t swap;
};

/*
* The function pointer is last to make tail calls most efficient if the
* compiler decides not to inline this function.
*/
static void do_swap(void *a, void *b, size_t size, swap_func_t swap_func)
static void do_swap(void *a, void *b, size_t size, swap_r_func_t swap_func, const void *priv)
{
if (swap_func == SWAP_WRAPPER) {
((const struct wrapper *)priv)->swap(a, b, (int)size);
return;
}

if (swap_func == SWAP_WORDS_64)
swap_words_64(a, b, size);
else if (swap_func == SWAP_WORDS_32)
swap_words_32(a, b, size);
else if (swap_func == SWAP_BYTES)
swap_bytes(a, b, size);
else
swap_func(a, b, (int)size);
swap_func(a, b, (int)size, priv);
}

#define _CMP_WRAPPER ((cmp_r_func_t)0L)

static int do_cmp(const void *a, const void *b, cmp_r_func_t cmp, const void *priv)
{
if (cmp == _CMP_WRAPPER)
return ((cmp_func_t)(priv))(a, b);
return ((const struct wrapper *)priv)->cmp(a, b);
return cmp(a, b, priv);
}

Expand Down Expand Up @@ -198,7 +209,7 @@ static size_t parent(size_t i, unsigned int lsbit, size_t size)
*/
void sort_r(void *base, size_t num, size_t size,
cmp_r_func_t cmp_func,
swap_func_t swap_func,
swap_r_func_t swap_func,
const void *priv)
{
/* pre-scale counters for performance */
Expand All @@ -208,6 +219,10 @@ void sort_r(void *base, size_t num, size_t size,
if (!a) /* num < 2 || size == 0 */
return;

/* called from 'sort' without swap function, let's pick the default */
if (swap_func == SWAP_WRAPPER && !((struct wrapper *)priv)->swap)
swap_func = NULL;

if (!swap_func) {
if (is_aligned(base, size, 8))
swap_func = SWAP_WORDS_64;
Expand All @@ -230,7 +245,7 @@ void sort_r(void *base, size_t num, size_t size,
if (a) /* Building heap: sift down --a */
a -= size;
else if (n -= size) /* Sorting: Extract root to --n */
do_swap(base, base + n, size, swap_func);
do_swap(base, base + n, size, swap_func, priv);
else /* Sort complete */
break;

Expand All @@ -257,7 +272,7 @@ void sort_r(void *base, size_t num, size_t size,
c = b; /* Where "a" belongs */
while (b != a) { /* Shift it into place */
b = parent(b, lsbit, size);
do_swap(base + b, base + c, size, swap_func);
do_swap(base + b, base + c, size, swap_func, priv);
}
}
}
Expand All @@ -267,6 +282,11 @@ void sort(void *base, size_t num, size_t size,
cmp_func_t cmp_func,
swap_func_t swap_func)
{
return sort_r(base, num, size, _CMP_WRAPPER, swap_func, cmp_func);
struct wrapper w = {
.cmp = cmp_func,
.swap = swap_func,
};

return sort_r(base, num, size, _CMP_WRAPPER, SWAP_WRAPPER, &w);
}
EXPORT_SYMBOL(sort);

0 comments on commit a0019cd

Please sign in to comment.