-
Notifications
You must be signed in to change notification settings - Fork 502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Proposal] Support Multiple Prefill + Decode in a loop #9466
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/9466
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit 785a121 with merge base 0342bab ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D71567692 |
This PR needs a
|
aff0993
to
08b0f80
Compare
Summary: Pull Request resolved: pytorch#9466 Differential Revision: D71567692
This pull request was exported from Phabricator. Differential Revision: D71567692 |
Yeah I think each cache input pointer need to be updated to base address + input_pos. |
Just to elaborate, for shift pointers when switching between one method to another, the followings are needed:
The update from prefill to kv_io is already implemented by |
Summary: Pull Request resolved: pytorch#9466 Differential Revision: D71567692
08b0f80
to
03bd85a
Compare
This pull request was exported from Phabricator. Differential Revision: D71567692 |
Thanks for the feedback. I updated |
v_cache_in_[prefill_forward_name_]; | ||
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& v_cache_out = | ||
v_cache_out_[prefill_forward_name_]; | ||
for (int i = 0, v_cache_stride = head_dim_ * pos; i < v_cache_in.size(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the pointers of v_cache_in/out
might have been updated few rounds (ARN + KV + ARN + KV...) before in your scenario, the base pointers will not stay in the initial state.
So if the pos
is an absolute value, the pointers are going to be put beyond the expected positions. (same issues to k_cache
)
03bd85a
to
274d55e
Compare
Summary: Pull Request resolved: pytorch#9466 Differential Revision: D71567692
This pull request was exported from Phabricator. Differential Revision: D71567692 |
Summary: Pull Request resolved: pytorch#9466 Differential Revision: D71567692
274d55e
to
5cd8a6d
Compare
This pull request was exported from Phabricator. Differential Revision: D71567692 |
Summary: Pull Request resolved: pytorch#9466 Differential Revision: D71567692
5cd8a6d
to
785a121
Compare
This pull request was exported from Phabricator. Differential Revision: D71567692 |
@haowhsu-quic I updated the PR with your comments - still seeing bad output but I think we are setting last position correctly. |
for (int i = 0, v_cache_stride = head_dim_ * pos_diff; i < v_cache_in.size(); | ||
i++) { | ||
v_cache_in[i]->set_data( | ||
v_cache_in[i]->mutable_data<uint8_t>() + v_cache_stride); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
v_cache_out
needs to be updated as well, please refer to the resolved comment, thank you.
i++) { | ||
k_cache_in[i]->set_data( | ||
k_cache_in[i]->mutable_data<uint8_t>() + k_cache_stride); | ||
uint8_t* ptr_in = k_cache_in[i]->mutable_data<uint8_t>() - pos_diff; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we need to get the origin for deep copy: uint8_t* ptr_in = k_cache_in[i]->mutable_data<uint8_t>() - pos;
@haowhsu-quic do you mean changes like this?
It seems still not quite right. I need to learn your logic a bit more. |
Yes, could you update the latest change? thank you. |
std::vector<std::unique_ptr<executorch::aten::TensorImpl>>& k_cache_in = | ||
k_cache_in_[prefill_forward_name_]; | ||
|
||
size_t copied_size = pos_diff * sizeof(uint8_t); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
copied_size
should be pos * sizeof(uint8_t);
k_cache_in[i]->set_data( | ||
k_cache_in[i]->mutable_data<uint8_t>() + k_cache_stride); | ||
uint8_t* ptr_in = k_cache_in[i]->mutable_data<uint8_t>() - pos_diff; | ||
for (int j = 0; j < head_dim_; ++j) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, probably need to change here a bit: for (int j = 0; j <= head_dim_; ++j) {
I forgot we preserve extra space to prevent shifting beyond boundary.
updated with the suggestion, still incorrect. I'm trying to dump the kv cache value to confirm
|
I'll ask @kushrast to update the PR tomorrow. Probably I don't have the permission to update it... |
To double check, is the k_cache shape |
I think last for loop should be: for (int i = 0; i < k_cache_in.size(); i++) {
// should update first to current absolute position
k_cache_in[i]->set_data(ptr_in + pos_diff);
uint8_t* ptr_in = k_cache_in[i]->mutable_data<uint8_t>();
for (int j = 0; j <= head_dim_; ++j) {
uint8_t* dst = ptr_in - pos + j * prefill_cache_len_;
const uint8_t* src = ptr_in - pos + j * kv_cache_len_;
memcpy(dst, src, copied_size);
}
} Shape of |
Is it possible we can have the |
That might be tricky, because it’s an internal model. The easiest way is online debug session |
Actually let me try export the stories model and verify accuracy with that. I think it will have similar issue. In the meanwhile, this is the latest
|
Actually you can repro the accuracy with this command line with stories model
In the second iteration, It should include the previous prompt + generate output + second prompt |
I'm getting the correct output from #9662, wonder if you can take a look and see if it's correct |
We would like to support multi-turn conversations with the AR-N model by allowing prefill + decode to be called in a loop without resetting the internal KV cache state.
Example:
Initialize Runner and KV Cache
Prompt: "Call David"
Response: "Okay, Call David Lee?"
Prompt: "No, call David Smith"
Response: "Okay, Calling David Smith, not David Lee"
Clear KV Cache
Assumptions:
What this PR does:
update_kv_to_prefill_io
to advance prefill pointers for v_cache. Also sets attention_mask up to pos (to cover tokens generated during decode).Current State:
The code does not crash, but also does not produce proper input. Tested on Samsung S24 with QNN 2.28 binaries
./qnn_llama3_2_runner --model_path hybrid_llama_qnn.pte --tokenizer_path tiktokenizer.bin --eval_mode 1 --prompt "Call David" --kv_updater "ShiftPointer" --logits_scale 0.1 --output_path output.txt --num_iters 2
Example with a model trained for communication:
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Call David<|eot_id|><|start_header_id|>assistant<|end_header_id|>
Okay, call David
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Call David<|eot_id|><|start_header_id|>assistant<|end_header_id|>
renamerenamerenamehabihabihabihabihabihabihabihabi date date date date date date握 culturesogh MMI MMI MMIhabihabihabihabihabislidesванetus Tangourt Abrams datefeedingfeedinghabi Date dateolved MMIhabihabihabihabihabihabihabihabihabihabihabihabi族OGLEalse date date date dateucusucusucushabihabi