Skip to content

Commit

Permalink
Patch bug when not using PagedAttention (EricLBuehler#759)
Browse files Browse the repository at this point in the history
* Patch bug when not using paged attn

* Typo
  • Loading branch information
EricLBuehler authored Sep 7, 2024
1 parent a9a4c8e commit 1cec176
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 48 deletions.
54 changes: 24 additions & 30 deletions mistralrs-core/src/pipeline/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use crate::{
utils::tokens::get_token,
xlora_models::{XLoraQLlama, XLoraQPhi3},
};
use anyhow::{bail, Context, Result};
use anyhow::{bail, Result};
use candle_core::{DType, Device, Tensor};
use either::Either;
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
Expand Down Expand Up @@ -645,30 +645,32 @@ impl Pipeline for GGUFPipeline {
flash_meta,
flash_meta_full,
} = *inputs.downcast().expect("Downcast failed.");
let paged_attn_meta = paged_attn_meta
.as_mut()
.with_context(|| "Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
let paged_attn_meta = match (
self.get_metadata().cache_engine.as_ref(),
&mut paged_attn_meta,
) {
(Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
(Some(_), None) => {
// This can happen if Rust-side user code is wrong
candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
}
(None, Some(_)) => {
// This should never happen but we handle it anyway
candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
}
(None, None) => None,
};
let logits = match self.model {
Model::Llama(ref model) => model.forward(
&input_ids,
&seqlen_offsets,
seqlen_offsets_kernel,
context_lens,
self.get_metadata()
.cache_engine
.as_ref()
.map(|engine| (engine.get_kv_cache().clone(), paged_attn_meta)),
)?,
Model::Phi2(ref model) => model.forward(
&input_ids,
&seqlen_offsets,
context_lens,
self.get_metadata()
.cache_engine
.as_ref()
.map(|engine| (engine.get_kv_cache().clone(), paged_attn_meta)),
paged_attn_meta,
)?,
Model::Phi2(ref model) => {
model.forward(&input_ids, &seqlen_offsets, context_lens, paged_attn_meta)?
}
Model::XLoraLlama(ref model) => model.forward(
&input_ids,
input_ids_full.as_ref().unwrap_or(&input_ids),
Expand All @@ -682,14 +684,9 @@ impl Pipeline for GGUFPipeline {
&flash_meta,
flash_meta_full.as_ref().unwrap_or(&flash_meta),
)?,
Model::Phi3(ref model) => model.forward(
&input_ids,
&seqlen_offsets,
self.get_metadata()
.cache_engine
.as_ref()
.map(|engine| (engine.get_kv_cache().clone(), paged_attn_meta)),
)?,
Model::Phi3(ref model) => {
model.forward(&input_ids, &seqlen_offsets, paged_attn_meta)?
}
Model::XLoraPhi3(ref model) => model.forward(
&input_ids,
input_ids_full.as_ref().unwrap_or(&input_ids),
Expand All @@ -707,10 +704,7 @@ impl Pipeline for GGUFPipeline {
&input_ids,
&seqlen_offsets,
seqlen_offsets_kernel,
self.get_metadata()
.cache_engine
.as_ref()
.map(|engine| (engine.get_kv_cache().clone(), paged_attn_meta)),
paged_attn_meta,
)?,
};
Ok(ForwardInputsResult::CausalGeneration { logits })
Expand Down
26 changes: 17 additions & 9 deletions mistralrs-core/src/pipeline/normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use crate::{
normal_model_loader, xlora_model_loader, DeviceMapMetadata, PagedAttentionConfig, Pipeline,
Topology, TryIntoDType,
};
use anyhow::{Context, Result};
use anyhow::Result;
use candle_core::{Device, Tensor, Var};
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use mistralrs_quant::IsqType;
Expand Down Expand Up @@ -515,21 +515,29 @@ impl Pipeline for NormalPipeline {
flash_meta,
flash_meta_full,
} = *inputs.downcast().expect("Downcast failed.");
let paged_attn_meta = paged_attn_meta
.as_mut()
.with_context(|| "Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
let paged_attn_meta = match (
self.get_metadata().cache_engine.as_ref(),
&mut paged_attn_meta,
) {
(Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
(Some(_), None) => {
// This can happen if Rust-side user code is wrong
candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
}
(None, Some(_)) => {
// This should never happen but we handle it anyway
candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
}
(None, None) => None,
};
let logits = match self.model.is_xlora() {
false => self.model.forward(
&input_ids,
&seqlen_offsets,
seqlen_offsets_kernel,
context_lens,
position_ids,
self.get_metadata()
.cache_engine
.as_ref()
.map(|engine| (engine.get_kv_cache().clone(), paged_attn_meta)),
paged_attn_meta,
&flash_meta,
)?,
true => self.model.xlora_forward(
Expand Down
26 changes: 17 additions & 9 deletions mistralrs-core/src/pipeline/vision.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use crate::{
api_dir_list, api_get_file, get_paths, vision_normal_model_loader, AnyMoeExpertType,
DeviceMapMetadata, Ordering, PagedAttentionConfig, Pipeline, Topology, TryIntoDType,
};
use anyhow::{Context, Result};
use anyhow::Result;
use candle_core::{Device, Tensor, Var};
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use mistralrs_quant::IsqType;
Expand Down Expand Up @@ -411,10 +411,21 @@ impl Pipeline for VisionPipeline {
mut paged_attn_meta,
flash_meta,
} = *inputs.downcast::<ModelInputs>().expect("Downcast failed.");
let paged_attn_meta = paged_attn_meta
.as_mut()
.with_context(|| "Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
.map_err(|e| candle_core::Error::Msg(e.to_string()))?;
let paged_attn_meta = match (
self.get_metadata().cache_engine.as_ref(),
&mut paged_attn_meta,
) {
(Some(engine), Some(meta)) => Some((engine.get_kv_cache().clone(), meta)),
(Some(_), None) => {
// This can happen if Rust-side user code is wrong
candle_core::bail!("Forward step expected a PagedAttention input metadata. This was not provided, please ensure that the scheduler config is correctly configured for PagedAttention.")
}
(None, Some(_)) => {
// This should never happen but we handle it anyway
candle_core::bail!("Forward step got a PagedAttention input metadata but there is no cache engine. Please raise an issue.")
}
(None, None) => None,
};
let logits = self.model.forward(
&input_ids,
pixel_values,
Expand All @@ -423,10 +434,7 @@ impl Pipeline for VisionPipeline {
context_lens,
position_ids,
model_specific_args,
self.get_metadata()
.cache_engine
.as_ref()
.map(|engine| (engine.get_kv_cache().clone(), paged_attn_meta)),
paged_attn_meta,
&flash_meta,
)?;
Ok(ForwardInputsResult::CausalGeneration { logits })
Expand Down

0 comments on commit 1cec176

Please sign in to comment.