Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Proposal] Support Multiple Prefill + Decode in a loop #9466

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

kushrast
Copy link

@kushrast kushrast commented Mar 20, 2025

We would like to support multi-turn conversations with the AR-N model by allowing prefill + decode to be called in a loop without resetting the internal KV cache state.

Example:
Initialize Runner and KV Cache
Prompt: "Call David"
Response: "Okay, Call David Lee?"
Prompt: "No, call David Smith"
Response: "Okay, Calling David Smith, not David Lee"
Clear KV Cache

Assumptions:

  • The assumption here is that to support multiple prefill + decode in a loop, we need to update the prefill_input_pos and prefill_attention_mask to reflect the previously decoded tokens + new prompt tokens.
  • An additional assumption is that we need to update the pointers for k_cache and v_cache for prefill

What this PR does:

  • Add a function update_kv_to_prefill_io to advance prefill pointers for v_cache. Also sets attention_mask up to pos (to cover tokens generated during decode).
  • Update fill_prefill_toks to take in previous tokens prefilled + generated so far and using this to set input_pos and attention mask
  • Updates test runner to save number of tokens generated and pass to IO Manager. Also comments out resetting KV Cache state so it gets re-used.

Current State:
The code does not crash, but also does not produce proper input. Tested on Samsung S24 with QNN 2.28 binaries

./qnn_llama3_2_runner --model_path hybrid_llama_qnn.pte --tokenizer_path tiktokenizer.bin --eval_mode 1 --prompt "Call David" --kv_updater "ShiftPointer" --logits_scale 0.1 --output_path output.txt --num_iters 2

Example with a model trained for communication:
<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Call David<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Okay, call David
<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Call David<|eot_id|><|start_header_id|>assistant<|end_header_id|>

renamerenamerenamehabihabihabihabihabihabihabihabi date date date date date date握 culturesogh MMI MMI MMIhabihabihabihabihabislidesванetus Tangourt Abrams datefeedingfeedinghabi Date dateolved MMIhabihabihabihabihabihabihabihabihabihabihabihabi族OGLEalse date date date dateucusucusucushabihabi

@kushrast kushrast requested a review from cccclai as a code owner March 20, 2025 19:23
Copy link

pytorch-bot bot commented Mar 20, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/9466

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures

As of commit 785a121 with merge base 0342bab (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 20, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71567692

Copy link

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

kushrast pushed a commit to kushrast/executorch that referenced this pull request Mar 20, 2025
Summary: Pull Request resolved: pytorch#9466

Differential Revision: D71567692
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71567692

@kushrast kushrast changed the title Adding KV to Prefill IO [Proposal] Support Multiple Prefill + Decode in a loop Mar 20, 2025
@sxu
Copy link
Contributor

sxu commented Mar 20, 2025

but we are unsure if we need to update the pointers for the k_cache.

Yeah I think each cache input pointer need to be updated to base address + input_pos.

@sxu
Copy link
Contributor

sxu commented Mar 20, 2025

but we are unsure if we need to update the pointers for the k_cache.

Yeah I think each cache input pointer need to be updated to base address + input_pos.

Just to elaborate, for shift pointers when switching between one method to another, the followings are needed:

  1. update K cache (need to scatter the update because K is transposed) and V cache content (simply memcpy).
  2. prepare new method's KV input cache pointers: each K cache points to base address + input_pos (again, because K is transposed), each V cache points to base address + input_pos * head_dim.
  3. update new methods mask.

The update from prefill to kv_io is already implemented by update_prefill_to_kv_io, it performs all of 1~3 outlined above. The new update_kv_to_prefill_io is only doing 2) for V caches. I think you at least need to set the points for K caches as well. You could also add 1) and 3), but they can also be performed by calling other existing functions (update_kv_io for 1, fill_prefill_toks for 3), it's up to you.

kushrast pushed a commit to kushrast/executorch that referenced this pull request Mar 20, 2025
Summary: Pull Request resolved: pytorch#9466

Differential Revision: D71567692
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71567692

@kushrast
Copy link
Author

kushrast commented Mar 20, 2025

but we are unsure if we need to update the pointers for the k_cache.

Yeah I think each cache input pointer need to be updated to base address + input_pos.

Just to elaborate, for shift pointers when switching between one method to another, the followings are needed:

  1. update K cache (need to scatter the update because K is transposed) and V cache content (simply memcpy).
  2. prepare new method's KV input cache pointers: each K cache points to base address + input_pos (again, because K is transposed), each V cache points to base address + input_pos * head_dim.
  3. update new methods mask.

The update from prefill to kv_io is already implemented by update_prefill_to_kv_io, it performs all of 1~3 outlined above. The new update_kv_to_prefill_io is only doing 2) for V caches. I think you at least need to set the points for K caches as well. You could also add 1) and 3), but they can also be performed by calling other existing functions (update_kv_io for 1, fill_prefill_toks for 3), it's up to you.

Thanks for the feedback. I updated update_kv_to_prefill_io to do steps 2 and 3. I am assuming 1 is done by the last call to update_kv_io in kv_execute. I also noticed I was resetting the KV cache state instead of re-using it between iterations. I have commented that out now, though I still see bad output from the second iteration of the model through the runner.

v_cache_in_[prefill_forward_name_];
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& v_cache_out =
v_cache_out_[prefill_forward_name_];
for (int i = 0, v_cache_stride = head_dim_ * pos; i < v_cache_in.size();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the pointers of v_cache_in/out might have been updated few rounds (ARN + KV + ARN + KV...) before in your scenario, the base pointers will not stay in the initial state.
So if the pos is an absolute value, the pointers are going to be put beyond the expected positions. (same issues to k_cache)

kushrast pushed a commit to kushrast/executorch that referenced this pull request Mar 24, 2025
Summary: Pull Request resolved: pytorch#9466

Differential Revision: D71567692
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71567692

kushrast pushed a commit to kushrast/executorch that referenced this pull request Mar 24, 2025
Summary: Pull Request resolved: pytorch#9466

Differential Revision: D71567692
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71567692

Summary: Pull Request resolved: pytorch#9466

Differential Revision: D71567692
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D71567692

@kushrast
Copy link
Author

@haowhsu-quic I updated the PR with your comments - still seeing bad output but I think we are setting last position correctly.

for (int i = 0, v_cache_stride = head_dim_ * pos_diff; i < v_cache_in.size();
i++) {
v_cache_in[i]->set_data(
v_cache_in[i]->mutable_data<uint8_t>() + v_cache_stride);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v_cache_out needs to be updated as well, please refer to the resolved comment, thank you.

i++) {
k_cache_in[i]->set_data(
k_cache_in[i]->mutable_data<uint8_t>() + k_cache_stride);
uint8_t* ptr_in = k_cache_in[i]->mutable_data<uint8_t>() - pos_diff;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we need to get the origin for deep copy: uint8_t* ptr_in = k_cache_in[i]->mutable_data<uint8_t>() - pos;

@cccclai
Copy link
Contributor

cccclai commented Mar 26, 2025

@haowhsu-quic do you mean changes like this?

+
+  // update v_cache
+  std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& v_cache_out =
+      v_cache_out_[prefill_forward_name_];
+
+  for (int i = 0, v_cache_stride = head_dim_ * pos_diff; i < v_cache_in.size(); ++i) {
+    v_cache_in[i]->set_data(v_cache_in[i]->mutable_data<uint8_t>() + v_cache_stride);
+    v_cache_out[i]->set_data(v_cache_out[i]->mutable_data<uint8_t>() + v_cache_stride);
   }
 
   // update k_cache
@@ -521,7 +525,7 @@
        i++) {
     k_cache_in[i]->set_data(
         k_cache_in[i]->mutable_data<uint8_t>() + k_cache_stride);
-    uint8_t* ptr_in = k_cache_in[i]->mutable_data<uint8_t>() - pos_diff;
+    uint8_t* ptr_in = k_cache_in[i]->mutable_data<uint8_t>() - pos;
     for (int j = 0; j < head_dim_; ++j) {
       memcpy(
         ptr_in + j * prefill_cache_len_,

It seems still not quite right. I need to learn your logic a bit more.

@haowhsu-quic
Copy link
Collaborator

@haowhsu-quic do you mean changes like this?

+
+  // update v_cache
+  std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& v_cache_out =
+      v_cache_out_[prefill_forward_name_];
+
+  for (int i = 0, v_cache_stride = head_dim_ * pos_diff; i < v_cache_in.size(); ++i) {
+    v_cache_in[i]->set_data(v_cache_in[i]->mutable_data<uint8_t>() + v_cache_stride);
+    v_cache_out[i]->set_data(v_cache_out[i]->mutable_data<uint8_t>() + v_cache_stride);
   }
 
   // update k_cache
@@ -521,7 +525,7 @@
        i++) {
     k_cache_in[i]->set_data(
         k_cache_in[i]->mutable_data<uint8_t>() + k_cache_stride);
-    uint8_t* ptr_in = k_cache_in[i]->mutable_data<uint8_t>() - pos_diff;
+    uint8_t* ptr_in = k_cache_in[i]->mutable_data<uint8_t>() - pos;
     for (int j = 0; j < head_dim_; ++j) {
       memcpy(
         ptr_in + j * prefill_cache_len_,

It seems still not quite right. I need to learn your logic a bit more.

Yes, could you update the latest change? thank you.

std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& k_cache_in =
k_cache_in_[prefill_forward_name_];

size_t copied_size = pos_diff * sizeof(uint8_t);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copied_size should be pos * sizeof(uint8_t);

k_cache_in[i]->set_data(
k_cache_in[i]->mutable_data<uint8_t>() + k_cache_stride);
uint8_t* ptr_in = k_cache_in[i]->mutable_data<uint8_t>() - pos_diff;
for (int j = 0; j < head_dim_; ++j) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, probably need to change here a bit: for (int j = 0; j <= head_dim_; ++j) {
I forgot we preserve extra space to prevent shifting beyond boundary.

@cccclai
Copy link
Contributor

cccclai commented Mar 26, 2025

updated with the suggestion, still incorrect. I'm trying to dump the kv cache value to confirm

   int64_t pos_diff = pos - last_pos_;
   std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& v_cache_in =
       v_cache_in_[prefill_forward_name_];
-  for (int i = 0, v_cache_stride = head_dim_ * pos_diff; i < v_cache_in.size();
-       i++) {
-    v_cache_in[i]->set_data(
-        v_cache_in[i]->mutable_data<uint8_t>() + v_cache_stride);
+
+  // update v_cache
+  std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& v_cache_out =
+      v_cache_out_[prefill_forward_name_];
+
+  for (int i = 0, v_cache_stride = head_dim_ * pos_diff; i < v_cache_in.size(); ++i) {
+    v_cache_in[i]->set_data(v_cache_in[i]->mutable_data<uint8_t>() + v_cache_stride);
+    v_cache_out[i]->set_data(v_cache_out[i]->mutable_data<uint8_t>() + v_cache_stride);
   }
 
   // update k_cache
   std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& k_cache_in =
       k_cache_in_[prefill_forward_name_];
 
-  size_t copied_size = pos_diff * sizeof(uint8_t);
+  size_t copied_size = pos * sizeof(uint8_t);
 
-  for (int i = 0, k_cache_stride = pos_diff * sizeof(uint8_t); i < k_cache_in_.size();
-       i++) {
-    k_cache_in[i]->set_data(
-        k_cache_in[i]->mutable_data<uint8_t>() + k_cache_stride);
-    uint8_t* ptr_in = k_cache_in[i]->mutable_data<uint8_t>() - pos_diff;
-    for (int j = 0; j < head_dim_; ++j) {
-      memcpy(
-        ptr_in + j * prefill_cache_len_,
-        ptr_in + j * kv_cache_len_,
-        copied_size);
+  for (int i = 0; i < k_cache_in.size(); i++) {
+    uint8_t* ptr_in = k_cache_in[i]->mutable_data<uint8_t>();

+    // Move pointer forward by pos_diff
+    k_cache_in[i]->set_data(ptr_in + pos_diff);
+    // Copy data from kv_cache region to prefill_cache region for each head
+    for (int j = 0; j <= head_dim_; ++j) {
+      uint8_t* dst = ptr_in - pos + j * prefill_cache_len_;
+      const uint8_t* src = ptr_in - pos + j * kv_cache_len_;
+      memcpy(dst, src, copied_size);
     }
   }
 

@cccclai
Copy link
Contributor

cccclai commented Mar 26, 2025

I'll ask @kushrast to update the PR tomorrow. Probably I don't have the permission to update it...

@cccclai
Copy link
Contributor

cccclai commented Mar 26, 2025

To double check, is the k_cache shape (head_dim + 1, seq_len - 1, num_layers)?

@haowhsu-quic
Copy link
Collaborator

I think last for loop should be:

for (int i = 0; i < k_cache_in.size(); i++) {
  // should update first to current absolute position
  k_cache_in[i]->set_data(ptr_in + pos_diff);
  uint8_t* ptr_in = k_cache_in[i]->mutable_data<uint8_t>();
  for (int j = 0; j <= head_dim_; ++j) {
    uint8_t* dst = ptr_in - pos + j * prefill_cache_len_;
    const uint8_t* src = ptr_in - pos + j * kv_cache_len_;
    memcpy(dst, src, copied_size);
  }
}

Shape of k_cache_in (single head) is (head_dim_+1, prefill_ar_len_) for prefill mode, (head_dim_+1, kv_cache_len_) for decode mode. But both of them are actually mapped to the same chunk of memory, that's why we need deep copy here. Or the data fetching would be incorrect.

@haowhsu-quic
Copy link
Collaborator

Is it possible we can have the .pte file to help resolve issue?

@cccclai
Copy link
Contributor

cccclai commented Mar 26, 2025

Is it possible we can have the .pte file to help resolve issue?

That might be tricky, because it’s an internal model. The easiest way is online debug session

@cccclai
Copy link
Contributor

cccclai commented Mar 26, 2025

Actually let me try export the stories model and verify accuracy with that. I think it will have similar issue. In the meanwhile, this is the latest

void ShiftPointerIoMgr::update_kv_to_prefill_io(
  int64_t pos,
  std::vector<std::vector<executorch::aten::Tensor>>& output_tensors) {
  // update v_cache
  assert(pos <= 512);
  int64_t pos_diff = pos - last_pos_;
  std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& v_cache_in =
      v_cache_in_[prefill_forward_name_];

  // update v_cache
  std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& v_cache_out =
      v_cache_out_[prefill_forward_name_];

  for (int i = 0, v_cache_stride = head_dim_ * pos_diff; i < v_cache_in.size(); ++i) {
    v_cache_in[i]->set_data(v_cache_in[i]->mutable_data<uint8_t>() + v_cache_stride);
    v_cache_out[i]->set_data(v_cache_out[i]->mutable_data<uint8_t>() + v_cache_stride);
  }

  // update k_cache
  std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& k_cache_in =
      k_cache_in_[prefill_forward_name_];

  size_t copied_size = pos * sizeof(uint8_t);

  for (int i = 0; i < k_cache_in.size(); i++) {
    k_cache_in[i]->set_data(ptr_in + pos_diff);
    uint8_t* ptr_in = k_cache_in[i]->mutable_data<uint8_t>();
    // Copy data from kv_cache region to prefill_cache region for each head
    for (int j = 0; j <= head_dim_; ++j) {
      uint8_t* dst = ptr_in - pos + j * prefill_cache_len_;
      const uint8_t* src = ptr_in - pos + j * kv_cache_len_;
      memcpy(dst, src, copied_size);
    }
  }

  // Setting attention mask from context_len - prefill_ar_len - i to context_len
  IO* ptr = static_cast<IO*>(data_ptr_.get());
  for (int i = prefill_ar_len_; i < pos; i++) {
    for (int j = 0; j < prefill_ar_len_; j++) {
      ptr->prefill_attention_mask[j * context_len_ + context_len_ - prefill_ar_len_ - i] = 65535;
    }
  }
}

@cccclai
Copy link
Contributor

cccclai commented Mar 26, 2025

Actually you can repro the accuracy with this command line with stories model

./qnn_llama3_2_runner --model_path hybrid_stories_qnn.pte    --tokenizer_path tokenizer.bin  --eval_mode 1 --prompt "Once" --kv_updater "ShiftPointer" --logits_scale 0.1 --output_path output.txt --num_iters 2

In the second iteration, It should include the previous prompt + generate output + second prompt

@cccclai
Copy link
Contributor

cccclai commented Mar 26, 2025

I'm getting the correct output from #9662, wonder if you can take a look and see if it's correct

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants