Skip to content

Commit

Permalink
Add memory_efficient_threshold kwarg to sdpa kernel (ml-explore#1319)
Browse files Browse the repository at this point in the history
Allows opt-in to memory efficient GPU shader at proscribed sequence
length.  Otherwise, utilizes aggregate MLX primitives for best latency.
  • Loading branch information
bpkeene authored Aug 12, 2024
1 parent 9231617 commit 19fb69e
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 4 deletions.
13 changes: 10 additions & 3 deletions mlx/fast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,7 @@ array scaled_dot_product_attention(
const array& values,
const float scale,
const std::optional<array>& mask,
const std::optional<int>& memory_efficient_threshold,
StreamOrDevice s) {
for (const auto& tensor : {queries, keys, values}) {
if (tensor.ndim() != 4) {
Expand Down Expand Up @@ -535,6 +536,11 @@ array scaled_dot_product_attention(
* * dtype is not fp32 or fp16
*/

int threshold = 1e6;
if (memory_efficient_threshold.has_value()) {
threshold = std::max(1, memory_efficient_threshold.value());
}

bool needs_mask = mask.has_value();
auto fallback = [scale, needs_mask, final_type, n_q_heads, n_kv_heads, &s](
const std::vector<array>& inputs) {
Expand Down Expand Up @@ -581,9 +587,10 @@ array scaled_dot_product_attention(
bool implementation_supports_use_case =
supports_sdpa || supports_full_self_attention;

// disabling full self attention until perf is tuned;
// likewise for sdpa
implementation_supports_use_case &= false;
// sdpa gpu shader is disabled except for memory efficient opt-in
const int seq_for_threshold = queries.shape(2);
bool use_memory_efficient_impl = seq_for_threshold >= threshold;
implementation_supports_use_case &= use_memory_efficient_impl;

if (implementation_supports_use_case) {
auto out_shape =
Expand Down
1 change: 1 addition & 0 deletions mlx/fast.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ array scaled_dot_product_attention(
const array& values,
const float scale,
const std::optional<array>& mask = std::nullopt,
const std::optional<int>& memory_efficient_threshold = std::nullopt,
StreamOrDevice s = {});

std::tuple<array, array, array> affine_quantize(
Expand Down
1 change: 1 addition & 0 deletions python/src/fast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ void init_fast(nb::module_& parent_module) {
nb::kw_only(),
"scale"_a,
"mask"_a = nb::none(),
"memory_efficient_threshold"_a = nb::none(),
"stream"_a = nb::none(),
nb::sig(
"def scaled_dot_product_attention(q: array, k: array, v: array, *, scale: float, mask: Union[None, array] = None, stream: Union[None, Stream, Device] = None) -> array"),
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_fast_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_fast_sdpa(self):

reference = mlx_primitives_sdpa_with_gqa(q_mlx, k_mlx, v_mlx, scale)
o_mlx = mx.fast.scaled_dot_product_attention(
q_mlx, k_mlx, v_mlx, scale=scale
q_mlx, k_mlx, v_mlx, scale=scale, memory_efficient_threshold=2
)

self.assertListEqual(list(reference.shape), list(o_mlx.shape))
Expand Down

0 comments on commit 19fb69e

Please sign in to comment.