Skip to content

Commit

Permalink
Fix GPU sort for large arrays (ml-explore#1285)
Browse files Browse the repository at this point in the history
* Fix GPU sort for large arrays
  • Loading branch information
jagrit06 authored Jul 24, 2024
1 parent ebd7135 commit 7f91436
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 25 deletions.
36 changes: 21 additions & 15 deletions mlx/backend/metal/kernels/sort.h
Original file line number Diff line number Diff line change
Expand Up @@ -522,13 +522,13 @@ template <
bool ARG_SORT,
short BLOCK_THREADS,
short N_PER_THREAD>
[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void
mb_block_partition(
[[kernel]] void mb_block_partition(
device idx_t* block_partitions [[buffer(0)]],
const device val_t* dev_vals [[buffer(1)]],
const device idx_t* dev_idxs [[buffer(2)]],
const constant int& size_sorted_axis [[buffer(3)]],
const constant int& merge_tiles [[buffer(4)]],
const constant int& n_blocks [[buffer(5)]],
uint3 tid [[threadgroup_position_in_grid]],
uint3 lid [[thread_position_in_threadgroup]],
uint3 tgp_dims [[threads_per_threadgroup]]) {
Expand All @@ -543,23 +543,29 @@ mb_block_partition(
dev_vals += tid.y * size_sorted_axis;
dev_idxs += tid.y * size_sorted_axis;

// Find location in merge step
int merge_group = lid.x / merge_tiles;
int merge_lane = lid.x % merge_tiles;
for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) {
// Find location in merge step
int merge_group = i / merge_tiles;
int merge_lane = i % merge_tiles;

int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;

int A_st = min(size_sorted_axis, sort_st);
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
int B_st = A_ed;
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
int A_st = min(size_sorted_axis, sort_st);
int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
int B_st = A_ed;
int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);

int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
int partition = sort_kernel::merge_partition(
dev_vals + A_st, dev_vals + B_st, A_ed - A_st, B_ed - B_st, partition_at);
int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
int partition = sort_kernel::merge_partition(
dev_vals + A_st,
dev_vals + B_st,
A_ed - A_st,
B_ed - B_st,
partition_at);

block_partitions[lid.x] = A_st + partition;
block_partitions[i] = A_st + partition;
}
}

template <
Expand Down
5 changes: 4 additions & 1 deletion mlx/backend/metal/sort.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ void multi_block_sort(
array dev_vals_out = dev_vals_1;
array dev_idxs_out = dev_idxs_1;

int n_thr_per_group = (n_blocks + 1) < 1024 ? (n_blocks + 1) : 1024;

for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks; merge_tiles *= 2) {
dev_vals_in = ping ? dev_vals_1 : dev_vals_0;
dev_idxs_in = ping ? dev_idxs_1 : dev_idxs_0;
Expand All @@ -199,8 +201,9 @@ void multi_block_sort(
compute_encoder.set_input_array(dev_idxs_in, 2);
compute_encoder->setBytes(&size_sorted_axis, sizeof(int), 3);
compute_encoder->setBytes(&merge_tiles, sizeof(int), 4);
compute_encoder->setBytes(&n_blocks, sizeof(int), 5);

MTL::Size group_dims = MTL::Size(n_blocks + 1, 1, 1);
MTL::Size group_dims = MTL::Size(n_thr_per_group, 1, 1);
MTL::Size grid_dims = MTL::Size(1, n_rows, 1);

compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
Expand Down
9 changes: 0 additions & 9 deletions mlx/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1785,15 +1785,6 @@ array sort(const array& a, int axis, StreamOrDevice s /* = {} */) {
throw std::invalid_argument(msg.str());
}

// TODO: Fix GPU kernel
if (a.shape(axis) >= (1u << 21) && to_stream(s).device.type == Device::gpu) {
std::ostringstream msg;
msg << "[sort] GPU sort cannot handle sort axis of >= 2M elements,"
<< " got array with sort axis size " << a.shape(axis) << "."
<< " Please place this operation on the CPU instead.";
throw std::runtime_error(msg.str());
}

return array(
a.shape(), a.dtype(), std::make_shared<Sort>(to_stream(s), axis), {a});
}
Expand Down
9 changes: 9 additions & 0 deletions python/tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1840,6 +1840,15 @@ def test_sort(self):
self.assertTrue(np.array_equal(c_np, c_mx))
self.assertEqual(b_mx.dtype, c_mx.dtype)

# Test very large array
if mx.default_device() == mx.gpu:
a_np = np.random.normal(20, 20, size=(2**22)).astype(np.float32)
a_mx = mx.array(a_np)

b_np = np.sort(a_np)
b_mx = mx.sort(a_mx)
self.assertTrue(np.array_equal(b_np, b_mx))

def test_partition(self):
shape = (3, 4, 5)
for dtype in ("int32", "float32"):
Expand Down

0 comments on commit 7f91436

Please sign in to comment.