Skip to content

Commit

Permalink
[Relax][Runtime] RNNState for Space State Models (apache#16568)
Browse files Browse the repository at this point in the history
* [Relax][Runtime] RNNState for Space State Models

This commit adds the RNNState class to the Relax VM, similar to the
PagedKVCache, for space state models like RWKV and mamba

* refactor
  • Loading branch information
Hzfengsy authored Feb 21, 2024
1 parent d91fe45 commit 3ef478b
Show file tree
Hide file tree
Showing 6 changed files with 947 additions and 52 deletions.
80 changes: 80 additions & 0 deletions src/runtime/relax_vm/kv_state.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "kv_state.h"

#include <utility>

namespace tvm {
namespace runtime {
namespace relax_vm {

// Register Object Type
TVM_REGISTER_OBJECT_TYPE(KVStateObj);
TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj);
TVM_REGISTER_OBJECT_TYPE(RNNStateObj);

// KV State base methods
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_clear").set_body_method<KVState>(&KVStateObj::Clear);
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_add_sequence")
.set_body_method<KVState>(&KVStateObj::AddSequence);
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_remove_sequence")
.set_body_method<KVState>(&KVStateObj::RemoveSequence);
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_fork_sequence")
.set_body_method<KVState>(&KVStateObj::ForkSequence);
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_popn").set_body_method<KVState>(&KVStateObj::PopN);
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_begin_forward")
.set_body_method<KVState>(&KVStateObj::BeginForward);
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward")
.set_body_method<KVState>(&KVStateObj::EndForward);

// Attention KV Cache methods
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetNumAvailablePages);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_query_positions")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::GetQueryPositions);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_debug_get_kv")
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::DebugGetKV);
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention")
.set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
double attn_score_scaling_factor, NDArray q_data, NDArray k_data,
NDArray v_data, NDArray o_data) {
kv_cache->Attention(layer_id, std::move(q_data), std::move(k_data), std::move(v_data),
NullOpt, std::move(o_data), attn_score_scaling_factor);
});
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_attention_with_fused_qkv")
.set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
double attn_score_scaling_factor, NDArray qkv_data, NDArray o_data) {
kv_cache->AttentionWithFusedQKV(layer_id, std::move(qkv_data), NullOpt, std::move(o_data),
attn_score_scaling_factor);
});

// RNN State methods
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method<RNNState>(&RNNStateObj::Get);
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_set")
.set_body_typed([](RNNState state, int64_t layer_id, int64_t state_id, NDArray data) {
state->Set(layer_id, state_id, data);
return state;
});
TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_debug_get")
.set_body_method<RNNState>(&RNNStateObj::DebugGet);

} // namespace relax_vm
} // namespace runtime
} // namespace tvm
118 changes: 93 additions & 25 deletions src/runtime/relax_vm/kv_cache.h → src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,45 @@
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_RUNTIME_RELAX_VM_KV_CACHE_H_
#define TVM_RUNTIME_RELAX_VM_KV_CACHE_H_
#ifndef TVM_RUNTIME_RELAX_VM_KV_STATE_H_
#define TVM_RUNTIME_RELAX_VM_KV_STATE_H_
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/registry.h>

#include "tvm/runtime/object.h"

namespace tvm {
namespace runtime {
namespace relax_vm {

/*!
* \brief The base class of attention KV cache for efficient
* k/v data management and attention computation.
*/
class AttentionKVCache : public Object {
/*! \brief The base class of attention KV cache and rnn state. */
class KVStateObj : public Object {
public:
/*! \brief Reset the KV cache. */
/*! \brief Reset the KV State. */
virtual void Clear() = 0;

/************** Sequence Management **************/

/*!
* \brief Add a new sequence with empty K/V data in the cache.
* \brief Add a new sequence with empty K/V state in the cache.
* Check if the validity of the input sequence id.
* \param seq_id The id of the new sequence to be added.
* \throws Error if the given sequence id is not valid.
*/
virtual void AddSequence(int64_t seq_id) = 0;

/*!
* \brief Remove a sequence and its K/V data from the KV cache.
* \brief Remove a sequence and its K/V state from the KV cache.
* \param seq_id The sequence to remove from cache.
* \throws Error if the given sequence id is not valid.
*/
virtual void RemoveSequence(int64_t seq_id) = 0;

/*!
* \brief Fork the K/V data of parent sequence to the child sequence.
* After the fork, the child sequence has K/V data of the parent
* \brief Fork the K/V state of parent sequence to the child sequence.
* After the fork, the child sequence has K/V state of the parent
* sequence.
* \param parent_seq_id The parent (source) of the fork.
* \param child_seq_id The child (destination) of the fork.
Expand All @@ -73,18 +72,6 @@ class AttentionKVCache : public Object {
*/
virtual void PopN(int64_t seq_id, int32_t n) = 0;

/************** Raw Info Query **************/

/*!
* \brief Get the number of available pages in the KV cache.
* When the underlying KV cache implementation is not
* paged KV cache, the function falls back to return the
* number of remaining size (in terms of number of tokens).
*/
virtual int32_t GetNumAvailablePages() const = 0;

/************** Attention **************/

/*!
* \brief Mark the start of the forward function with the ids of
* the sequences and the sequence length to forward for each
Expand All @@ -109,6 +96,34 @@ class AttentionKVCache : public Object {
*/
virtual void EndForward() = 0;

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.vm.KVState";
TVM_DECLARE_BASE_OBJECT_INFO(KVStateObj, Object)
};

class KVState : public ObjectRef {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(KVState, ObjectRef, KVStateObj);
};

/*!
* \brief The base class of attention KV cache for efficient
* k/v data management and attention computation.
*/
class AttentionKVCacheObj : public KVStateObj {
public:
/************** Raw Info Query **************/

/*!
* \brief Get the number of available pages in the KV cache.
* When the underlying KV cache implementation is not
* paged KV cache, the function falls back to return the
* number of remaining size (in terms of number of tokens).
*/
virtual int32_t GetNumAvailablePages() const = 0;

/************** Attention **************/

/*!
* \brief Compute attention with the given Q/K/V data at the specified
* layer with regard to the previously reserved append lengths.
Expand Down Expand Up @@ -197,10 +212,63 @@ class AttentionKVCache : public Object {
* \param v_data The V data to set in layout elaborated above.
*/
virtual void DebugSetKV(int64_t seq_id, int64_t start_pos, NDArray k_data, NDArray v_data) = 0;

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.vm.AttentionKVCache";
TVM_DECLARE_BASE_OBJECT_INFO(AttentionKVCacheObj, KVStateObj);
};

class AttentionKVCache : public KVState {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCache, KVState, AttentionKVCacheObj);
};

/*!
* \brief The base class of RNN State for efficient
* State data management and attention computation.
*/
class RNNStateObj : public KVStateObj {
public:
/************** Interaction **************/
/*!
* \brief Get the State data for the specified sequence.
* \param layer_id The model layer where the state is set.
* \param state_id The state id within the layer.
* \param o_data The output data to be fetched.
* \return The array of State data, each element corresponds to a state.
* \throws Error if the given sequence id is not valid.
*/
virtual void Get(int64_t layer_id, int64_t state_id, NDArray o_data) = 0;

/*!
* \brief Set the State data for the specified sequence.
* \param layer_id The model layer where the state is set.
* \param state_id The state id within the layer.
* \param data The data to be set.
* \throws Error if the given sequence id is not valid.
*/
virtual void Set(int64_t layer_id, int64_t state_id, NDArray data) = 0;

/*!
* \brief Fetch the compact rnn state data of the given sequence.
* \param layer_id The model layer where the state is set.
* \param state_id The state id within the layer.
* \param seq_id The sequence whose state data is to be fetched.
*/
virtual NDArray DebugGet(int64_t layer_id, int64_t state_id, int64_t seq_id) = 0;

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.vm.RNNState";
TVM_DECLARE_BASE_OBJECT_INFO(RNNStateObj, KVStateObj);
};

class RNNState : public KVState {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(RNNState, KVState, RNNStateObj);
};

} // namespace relax_vm
} // namespace runtime
} // namespace tvm

#endif // TVM_RUNTIME_RELAX_VM_KV_CACHE_H_
#endif // TVM_RUNTIME_RELAX_VM_KV_STATE_H_
11 changes: 6 additions & 5 deletions src/runtime/relax_vm/lm_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ namespace relax_vm {
/*!
* \brief An object representing an attention kv cache.
*/
class AttentionKVCacheObj : public Object {
class AttentionKVCacheLegacyObj : public Object {
public:
/*!
* \brief Underlying support data.
Expand Down Expand Up @@ -227,7 +227,7 @@ class AttentionKVCacheObj : public Object {

static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
static constexpr const char* _type_key = "relax.vm.AttentionKVCacheLegacy";
TVM_DECLARE_FINAL_OBJECT_INFO(AttentionKVCacheObj, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(AttentionKVCacheLegacyObj, Object);
};

/*! \brief reference to closure. */
Expand All @@ -239,7 +239,7 @@ class AttentionKVCacheLegacy : public ObjectRef {
*/
static AttentionKVCacheLegacy Create(NDArray init_data, ShapeTuple reserve_shape,
int init_fill_count) {
auto n = make_object<AttentionKVCacheObj>();
auto n = make_object<AttentionKVCacheLegacyObj>();
n->data = NDArray::Empty(reserve_shape, init_data->dtype, init_data->device);
n->fill_count = 0;
n->Append(init_data);
Expand All @@ -250,10 +250,11 @@ class AttentionKVCacheLegacy : public ObjectRef {
return AttentionKVCacheLegacy(n);
}

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCacheLegacy, ObjectRef, AttentionKVCacheObj);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(AttentionKVCacheLegacy, ObjectRef,
AttentionKVCacheLegacyObj);
};

TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj);
TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheLegacyObj);

//-------------------------------------------------
// Register runtime functions
Expand Down
Loading

0 comments on commit 3ef478b

Please sign in to comment.