Skip to content

Commit

Permalink
Compile stride bug (ml-explore#812)
Browse files Browse the repository at this point in the history
* fix compile stride bug

* revert sdpa fix

* fix cpu

* fix bug with simplifying outputs
  • Loading branch information
awni authored Mar 11, 2024
1 parent a4d290a commit 7c44160
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 12 deletions.
21 changes: 16 additions & 5 deletions mlx/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,23 @@ void array::copy_shared_buffer(const array& other) {
copy_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}

void array::move_shared_buffer(array other) {
void array::move_shared_buffer(
array other,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset /* = 0 */) {
array_desc_->data = std::move(other.array_desc_->data);
array_desc_->strides = other.strides();
array_desc_->flags = other.flags();
array_desc_->data_size = other.data_size();
array_desc_->data_ptr = other.array_desc_->data_ptr;
array_desc_->strides = strides;
array_desc_->flags = flags;
array_desc_->data_size = data_size;
auto char_offset = sizeof(char) * itemsize() * offset;
array_desc_->data_ptr = static_cast<void*>(
static_cast<char*>(other.array_desc_->data_ptr) + char_offset);
}

void array::move_shared_buffer(array other) {
move_shared_buffer(other, other.strides(), other.flags(), other.data_size());
}

array::ArrayDesc::ArrayDesc(const std::vector<int>& shape, Dtype dtype)
Expand Down
7 changes: 7 additions & 0 deletions mlx/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,13 @@ class array {

void copy_shared_buffer(const array& other);

void move_shared_buffer(
array other,
const std::vector<size_t>& strides,
Flags flags,
size_t data_size,
size_t offset = 0);

void move_shared_buffer(array other);

void overwrite_descriptor(const array& other) {
Expand Down
4 changes: 3 additions & 1 deletion mlx/backend/common/compiled_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,9 @@ void Compiled::eval_cpu(
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in);
outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
o++;
}
}
for (; o < outputs.size(); ++o) {
Expand Down
4 changes: 3 additions & 1 deletion mlx/backend/metal/compiled.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,9 @@ void Compiled::eval_gpu(
if (in.flags().row_contiguous && in.nbytes() == outputs[o].nbytes() &&
in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].move_shared_buffer(in);
outputs[o].move_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size());
o++;
}
}
for (; o < outputs.size(); ++o) {
Expand Down
5 changes: 2 additions & 3 deletions mlx/backend/metal/kernels/scaled_dot_product_attention.metal
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@ template<typename T, typename T2, typename T4, uint16_t TILE_SIZE_CONST, uint16_
device float* O_partials [[buffer(5)]],
device float* p_lse [[buffer(6)]],
device float* p_maxes [[buffer(7)]],
threadgroup T* threadgroup_block [[threadgroup(0)]],
uint simd_lane_id [[thread_index_in_simdgroup]],
uint simd_group_id [[simdgroup_index_in_threadgroup]],
uint3 tid [[threadgroup_position_in_grid]]) {

threadgroup T threadgroup_block[32768 / sizeof(T)];

constexpr const size_t DK = 128;
constexpr const ulong SIMDGROUP_MATRIX_LOAD_FACTOR = 8;
constexpr const size_t THREADS_PER_SIMDGROUP = 32;
Expand Down Expand Up @@ -358,6 +356,7 @@ template [[host_name("fast_inference_sdpa_compute_partials_" #itype "_" #tile_si
device float* O_partials [[buffer(5)]], \
device float* p_lse [[buffer(6)]], \
device float* p_maxes [[buffer(7)]], \
threadgroup itype *threadgroup_block [[threadgroup(0)]], \
uint simd_lane_id [[thread_index_in_simdgroup]], \
uint simd_group_id [[simdgroup_index_in_threadgroup]], \
uint3 tid [[threadgroup_position_in_grid]]);
Expand Down
2 changes: 2 additions & 0 deletions mlx/backend/metal/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ void sdpa_metal(
set_array_buffer(compute_encoder, p_lse, 6);
set_array_buffer(compute_encoder, p_rowmaxes, 7);

constexpr const uint tgroupMemorySize = 32768;
compute_encoder->setThreadgroupMemoryLength(tgroupMemorySize, 0);
compute_encoder->dispatchThreadgroups(grid_dims, group_dims);

{
Expand Down
4 changes: 2 additions & 2 deletions mlx/compile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,8 @@ void compile_simplify(
}
auto& src = parents->second[j].first;
auto& dst = parents->second[i].first;
if (src.id() != dst.id() && array_equivalent(src, dst)) {
if (src.id() != dst.id() && array_equivalent(src, dst) &&
output_set.find(src.id()) == output_set.end()) {
merge(dst, src, parents_map);
mask[j] = true;
}
Expand All @@ -456,7 +457,6 @@ void compile_simplify(
return output_set.find(a.id()) == output_set.end();
}
};

bool discard = maybe_merge_parents(arr);
for (auto& s : arr.siblings()) {
discard &= maybe_merge_parents(s);
Expand Down
8 changes: 8 additions & 0 deletions python/tests/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,14 @@ def fun(x, y):
with self.assertRaises(ValueError):
out = fun(mx.array(0.0), y=MyClass())

def test_compile_create_list(self):
@mx.compile
def fun():
return [0.1 * mx.zeros((2,)), 0.1 * mx.zeros((2,))]

out = fun()
mx.eval(out)


if __name__ == "__main__":
unittest.main()
15 changes: 15 additions & 0 deletions tests/compile_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,3 +703,18 @@ TEST_CASE("test shapeless compile") {
CHECK_NE(out.inputs()[1].id(), out2.inputs()[1].id());
}
}

auto compile_broadcast_add(const std::vector<array>& inputs) {
auto b = zeros({8, 8});
return std::vector<array>{inputs[0] + b};
}

TEST_CASE("test compile strides") {
{
auto cfun = compile(compile_broadcast_add);
auto a = zeros({1, 8, 8});
auto out = cfun({a})[0];
eval(out);
CHECK_EQ(out.strides().size(), 3);
}
}

0 comments on commit 7c44160

Please sign in to comment.