Skip to content

Commit

Permalink
templater: add minijinja templater
Browse files Browse the repository at this point in the history
And use it in Triton chat completions and legacy completions.

For Mistral-7B-Instruct-v0.2, here is an example template for chat
completions. Put it in /etc/ai-router/templates/chat/mistral.j2:

```
{%- set bos_token = '<s>' -%}
{% set eos_token = '</s>' -%}
{{ bos_token -}}
{%- for message in messages -%}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif -%}
{% if message['role'] == 'user' -%}
{{ '[INST] ' + message['content'] + ' [/INST]' -}}
{% elif message ['role'] == 'assistant' -%}
{{ ' ' + message['content'] + eos_token -}}
{% else -%}
{{ raise_exception('Only user and assistant roles are supported!') }}
{% endif -%}
{% endfor %}
```

And configure the prompt_format in /etc/ai-router.toml:

```
[models.chat_completions."Mistral-7B-Instruct-v0.2"]
...
prompt_format = "mistral"
```

For legacy completions, a different template is needed, in
/etc/ai-router/templates/completions/mistral.j2:

```
[INST] {% for message in messages -%}
{{ message -}}
{% endfor %} [/INST]
```

Closes: #4
  • Loading branch information
stintel committed Mar 27, 2024
1 parent ca31262 commit 0178c6d
Show file tree
Hide file tree
Showing 13 changed files with 163 additions and 63 deletions.
24 changes: 24 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ bytemuck = "1.15.0"
bytes = "1.6.0"
clap = { version = "4.5.3", features = ["derive"] }
figment = { version = "0.10.15", features = ["env"] }
minijinja = { version = "1.0.15", features = ["loader"] }
openai_dive = { version = "0.4.5", features = ["stream"] }
opentelemetry = { version = "0.22.0", features = ["metrics"] }
opentelemetry-jaeger-propagator = "0.1.0"
Expand Down
2 changes: 1 addition & 1 deletion client/openai_chatcompletion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@

# We don't do prompt formatting in the proxy - yet
# Mistral Instruct
input = f'[INST] {FLAGS.input} [/INST]'
input = f'{FLAGS.input}'

start_time = time.time()
chat_completion = client.chat.completions.create(
Expand Down
65 changes: 19 additions & 46 deletions src/backend/triton/routes/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,40 +26,43 @@ use crate::backend::triton::utils::get_output_idx;
use crate::backend::triton::ModelInferRequest;
use crate::errors::AiRouterError;
use crate::request::{check_input_cc, AiRouterRequestData};
use crate::templater::{TemplateType, Templater};
use crate::utils::deserialize_bytes_tensor;

const MAX_TOKENS: u32 = 131_072;
const MODEL_OUTPUT_NAME: &str = "text_output";

#[instrument(skip(client, request, request_data))]
#[instrument(skip(client, request, request_data, templater))]
pub async fn compat_chat_completions(
client: GrpcInferenceServiceClient<Channel>,
request: Json<ChatCompletionParameters>,
request_data: &mut AiRouterRequestData,
templater: Templater,
) -> Response {
tracing::debug!("request: {:?}", request);

if request.stream.unwrap_or(false) {
chat_completions_stream(client, request, request_data)
chat_completions_stream(client, request, request_data, templater)
.await
.into_response()
} else {
chat_completions(client, request, request_data)
chat_completions(client, request, request_data, templater)
.await
.into_response()
}
}

#[instrument(skip(client, request, request_data))]
#[instrument(skip(client, request, request_data, templater))]
async fn chat_completions_stream(
mut client: GrpcInferenceServiceClient<Channel>,
Json(request): Json<ChatCompletionParameters>,
request_data: &mut AiRouterRequestData,
templater: Templater,
) -> Result<Sse<impl Stream<Item = anyhow::Result<Event>>>, AiRouterError<String>> {
let id = format!("cmpl-{}", Uuid::new_v4());
let created = u32::try_from(SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs())?;

let request = build_triton_request(request, request_data)?;
let request = build_triton_request(request, request_data, templater)?;
let model_name = request_data
.original_model
.clone()
Expand Down Expand Up @@ -159,13 +162,14 @@ async fn chat_completions_stream(
Ok(Sse::new(response_stream).keep_alive(KeepAlive::default()))
}

#[instrument(skip(client, request, request_data), err(Debug))]
#[instrument(skip(client, request, request_data, templater), err(Debug))]
async fn chat_completions(
mut client: GrpcInferenceServiceClient<Channel>,
Json(request): Json<ChatCompletionParameters>,
request_data: &mut AiRouterRequestData,
templater: Templater,
) -> Result<Json<ChatCompletionResponse>, AiRouterError<String>> {
let request = build_triton_request(request, request_data)?;
let request = build_triton_request(request, request_data, templater)?;
let model_name = request_data
.original_model
.clone()
Expand Down Expand Up @@ -237,18 +241,21 @@ async fn chat_completions(
fn build_triton_request(
request: ChatCompletionParameters,
request_data: &mut AiRouterRequestData,
templater: Templater,
) -> Result<ModelInferRequest, AiRouterError<String>> {
let chat_history = build_chat_history(request.messages);
tracing::debug!("chat history after formatting: {}", chat_history);

check_input_cc(&chat_history, &request.model, request_data)?;
let input = templater.apply(
&request.messages,
request_data.template.clone(),
&TemplateType::ChatCompletion,
)?;
check_input_cc(&input, &request.model, request_data)?;

let mut builder = Builder::new()
.model_name(request.model)
.input(
"text_input",
[1, 1],
InferTensorData::Bytes(vec![chat_history.as_bytes().to_vec()]),
InferTensorData::Bytes(vec![input.as_bytes().to_vec()]),
)
.input(
"bad_words",
Expand Down Expand Up @@ -316,40 +323,6 @@ fn build_triton_request(
Ok(builder.build().context("failed to build triton request")?)
}

fn build_chat_history(messages: Vec<ChatMessage>) -> String {
let mut history = String::new();
for message in messages {
let ChatMessageContent::Text(content) = message.content else {
continue;
};
match message.role {
Role::System => {
if let Some(name) = message.name {
history.push_str(&format!("System {name}: {content}\n"));
} else {
history.push_str(&format!("System: {content}\n"));
}
}
Role::User => {
if let Some(name) = message.name {
history.push_str(&format!("User {name}: {content}\n"));
} else {
history.push_str(&format!("User: {content}\n"));
}
}
Role::Assistant => {
history.push_str(&format!("Assistant: {content}\n"));
}
Role::Tool => {
history.push_str(&format!("Tool: {content}\n"));
}
Role::Function => {}
}
}
history.push_str("ASSISTANT:");
history
}

fn string_vec_to_byte_vecs(strings: &Vec<String>) -> Vec<Vec<u8>> {
let mut byte_vecs: Vec<Vec<u8>> = Vec::new();

Expand Down
34 changes: 19 additions & 15 deletions src/backend/triton/routes/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,42 @@ use crate::backend::triton::utils::get_output_idx;
use crate::backend::triton::ModelInferRequest;
use crate::errors::AiRouterError;
use crate::request::{check_input_cc, AiRouterRequestData};
use crate::templater::{TemplateType, Templater};
use crate::utils::{deserialize_bytes_tensor, string_or_seq_string};

const MODEL_OUTPUT_NAME: &str = "text_output";

#[instrument(skip(client, request, request_data))]
#[instrument(skip(client, request, request_data, templater))]
pub async fn compat_completions(
client: GrpcInferenceServiceClient<Channel>,
request: Json<CompletionCreateParams>,
request_data: &mut AiRouterRequestData,
templater: Templater,
) -> Response {
tracing::debug!("request: {:?}", request);

if request.stream {
completions_stream(client, request, request_data)
completions_stream(client, request, request_data, templater)
.await
.into_response()
} else {
completions(client, request, request_data)
completions(client, request, request_data, templater)
.await
.into_response()
}
}

#[instrument(skip(client, request, request_data))]
#[instrument(skip(client, request, request_data, templater))]
async fn completions_stream(
mut client: GrpcInferenceServiceClient<Channel>,
Json(request): Json<CompletionCreateParams>,
request_data: &mut AiRouterRequestData,
templater: Templater,
) -> Result<Sse<impl Stream<Item = anyhow::Result<Event>>>, AiRouterError<String>> {
let id = format!("cmpl-{}", Uuid::new_v4());
let created = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();

let request = build_triton_request(request, request_data)?;
let request = build_triton_request(request, request_data, templater)?;
let model_name = request_data
.original_model
.clone()
Expand Down Expand Up @@ -149,13 +152,14 @@ async fn completions_stream(
Ok(Sse::new(response_stream).keep_alive(KeepAlive::default()))
}

#[instrument(skip(client, request, request_data), err(Debug))]
#[instrument(skip(client, request, request_data, templater), err(Debug))]
async fn completions(
mut client: GrpcInferenceServiceClient<Channel>,
Json(request): Json<CompletionCreateParams>,
request_data: &mut AiRouterRequestData,
templater: Templater,
) -> Result<Json<Completion>, AiRouterError<String>> {
let request = build_triton_request(request, request_data)?;
let request = build_triton_request(request, request_data, templater)?;
let model_name = request_data
.original_model
.clone()
Expand Down Expand Up @@ -221,22 +225,22 @@ async fn completions(
fn build_triton_request(
request: CompletionCreateParams,
request_data: &mut AiRouterRequestData,
templater: Templater,
) -> Result<ModelInferRequest, AiRouterError<String>> {
let input: String = request.prompt.join(" ");
let input = templater.apply(
&request.prompt,
request_data.template.clone(),
&TemplateType::LegacyCompletion,
)?;
tracing::debug!("input after applying template: {}", input);
check_input_cc(&input, &request.model, request_data)?;

let mut builder = Builder::new()
.model_name(request.model)
.input(
"text_input",
[1, 1],
InferTensorData::Bytes(
request
.prompt
.into_iter()
.map(|s| s.as_bytes().to_vec())
.collect(),
),
InferTensorData::Bytes(vec![input.as_bytes().to_vec()]),
)
.input(
"max_tokens",
Expand Down
7 changes: 7 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use serde_with::{formats::PreferMany, serde_as, skip_serializing_none, OneOrMany
use uuid::Uuid;

const DEFAULT_CONFIG_FILE: &str = "/etc/ai-router.toml";
const DEFAULT_TEMPLATE_DIR: &str = "/etc/ai-router/templates";

pub type AiRouterModels = HashMap<AiRouterModelType, HashMap<String, AiRouterModel>>;

Expand Down Expand Up @@ -150,6 +151,8 @@ pub struct AiRouterDaemon {
pub listen_ip: String,
pub listen_port: u16,
pub otlp_endpoint: Option<String>,
#[serde(default = "default_template_dir")]
pub template_dir: String,
}

#[skip_serializing_none]
Expand Down Expand Up @@ -177,6 +180,10 @@ fn default_instance_id() -> String {
String::from(Uuid::new_v4())
}

fn default_template_dir() -> String {
String::from(DEFAULT_TEMPLATE_DIR)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ pub mod routes;
pub mod startup;
mod state;
pub mod telemetry;
mod templater;
mod tokenizers;
mod utils;
4 changes: 4 additions & 0 deletions src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub struct AiRouterRequestData {
pub max_input: Option<usize>,
pub original_model: Option<String>,
pub prompt_tokens: usize,
pub template: Option<String>,
pub tokenizer: Option<Tokenizer>,
}

Expand All @@ -18,6 +19,7 @@ impl AiRouterRequestData {
max_input: None,
original_model: None,
prompt_tokens: 0,
template: None,
tokenizer: None,
}
}
Expand All @@ -44,6 +46,8 @@ impl AiRouterRequestData {
}
}

request_data.template = model.prompt_format.clone();

Ok(request_data)
}
}
Expand Down
1 change: 1 addition & 0 deletions src/routes/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub async fn completion(
c.clone(),
request,
&mut request_data,
state.templater.clone(),
)
.await;
}
Expand Down
1 change: 1 addition & 0 deletions src/routes/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub async fn completion(
c.clone(),
request,
&mut request_data,
state.templater.clone(),
)
.await;
}
Expand Down
Loading

0 comments on commit 0178c6d

Please sign in to comment.