Skip to content

Commit

Permalink
[KVCache] Support returning query positions (apache#16578)
Browse files Browse the repository at this point in the history
This PR adds a new function to PagedKVCache to
return in-sequence positions for each location
in a batch of sequences that is being forwarded.
This function helps apply positional embeddings
for language models that do not use Rotary positional
embeddings.
  • Loading branch information
MasterJH5574 authored Feb 16, 2024
1 parent daa37e7 commit efc2ae9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/runtime/relax_vm/kv_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,15 @@ class AttentionKVCache : public Object {
virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, Optional<NDArray> mask,
NDArray o_data) = 0;

/************** Positions **************/

/*!
* \brief Get the in-sequence positions of each slot in the query.
* This function is supposed to be invoked after calling BeginForward.
* \return The in-sequence query positions, in shape `(total_length,)`.
*/
virtual NDArray GetQueryPositions() const = 0;

/************** Debug Helpers **************/

/*!
Expand Down
9 changes: 9 additions & 0 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCache {
AttentionInternal(layer_id, q_data, k_data, v_data, o_data);
}

NDArray GetQueryPositions() const final {
CHECK(!dirty_aux_data_device_)
<< "The auxiliary arrays are not synchronized to device. Please call "
"`BeginForward` to synchronize before calling `GetQueryPositions`.";
return q_rope_position_map_view_;
};

void DebugGetKV(int64_t seq_id, int64_t start_pos, int64_t end_pos, NDArray k_data,
NDArray v_data) final {
CHECK(f_debug_get_kv_.defined())
Expand Down Expand Up @@ -1231,6 +1238,8 @@ TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_begin_forward")
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::BeginForward);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_end_forward")
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::EndForward);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_get_query_positions")
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::GetQueryPositions);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_debug_get_kv")
.set_body_method<PagedAttentionKVCache>(&PagedAttentionKVCacheObj::DebugGetKV);
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_attention")
Expand Down

0 comments on commit efc2ae9

Please sign in to comment.