Skip to content

Commit

Permalink
templater: add minijinja templater
Browse files Browse the repository at this point in the history
And use it in Triton chat completions and legacy completions.

To activate, configure a prompt_format for the chat_completions model:

```
[models.chat_completions."Mistral-7B-Instruct-v0.2"]
...
prompt_format = "mistral"
```

This will look for templates in /etc/ai-router/templates. The template
for chat completions should go in the chat subdirectory, and for legacy
completions the template should go in the completions subdirectory.

Example templates Mistral-7B-Instruct-v0.2 (exclude the ```):

Chat, based on the template from the Hugging Face Hub, which only
supports the user and assistant roles, to be placed in
/etc/ai-router/templates/chat/mistral.j2:

```
{%- set bos_token = '<s>' -%}
{% set eos_token = '</s>' -%}
{{ bos_token -}}
{%- for message in messages -%}
{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
{% endif -%}
{% if message['role'] == 'user' -%}
{{ '[INST] ' + message['content'] + ' [/INST]' -}}
{% elif message ['role'] == 'assistant' -%}
{{ ' ' + message['content'] + eos_token -}}
{% else -%}
{{ raise_exception('Only user and assistant roles are supported!') }}
{% endif -%}
{% endfor %}
```

Modified version of the above template that injects a system prompt
before the first user prompt:

```
{%- set bos_token = '<s>' -%}
{% set eos_token = '</s>' -%}
{% set mod = 0 -%}
{% set system = '' -%}
{{ bos_token -}}
{%- for message in messages -%}
{% if (message['role'] == 'system' and loop.index0 == 0) -%}
{% set mod = 1 -%}
{% set system = message['content'] %}
{% else -%}
{% if (message['role'] == 'user') != (loop.index0 % 2 == mod) -%}
{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/... ' ) }}
{% endif -%}
{% if message['role'] == 'user' -%}
{% if system and system | length > 0 and loop.index0 == 1 -%}
{{ '[INST] ' + system + '\n' + message['content'] + ' [/INST]' -}}
{% else -%}
{{ '[INST] ' + message['content'] + ' [/INST]' -}}
{% endif %}
{% elif message ['role'] == 'assistant' -%}
{{ ' ' + message['content'] + eos_token -}}
{% else -%}
{{ raise_exception('Only user and assistant roles are supported!') }}
{% endif -%}
{% endif -%}
{% endfor -%}
```

Legacy completions do not support roles, so a much simpler template can
be used, in /etc/ai-router/templates/completions/mistral.j2:

```
[INST] {% for message in messages -%}
{{ message -}}
{% endfor %} [/INST]
```

As we use chat_completion models in the config for both chat completions
and legacy completions, configure a prompt_format for a model will
require you to place a template file for both chat completions and
legacy completions in the expected location. If one of them is missing,
ai-router will not start. The error message should point out why:

Error: config file validation failed: model `meta-llama/Llama-2-70b-chat-hf` has prompt_format configured but template legacy completions (/etc/ai-router/templates/completions/llama.j2) is missing

If you wish to only enable chat completions for a model, and disable
legacy completions, this can be done by simply raising an exception in
the template:

```
{{ raise_exception('Legacy completions are disabled for this model') }}
```

Closes: #4
  • Loading branch information
stintel committed Apr 24, 2024
1 parent 68e2833 commit c19ce88
Show file tree
Hide file tree
Showing 13 changed files with 160 additions and 62 deletions.
24 changes: 24 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ bytes = "1.6.0"
clap = { version = "4.5.4", features = ["derive"] }
figment = { version = "0.10.17", features = ["env"] }
openai_dive = { version = "0.4.6", default-features = false, features = ["rustls-tls", "stream", "tokio", "tokio-util"] }
minijinja = { version = "1.0.20", features = ["loader"] }
opentelemetry = { version = "0.22.0", features = ["metrics"] }
opentelemetry-jaeger-propagator = "0.1.0"
opentelemetry-otlp = "0.15.0"
Expand Down
2 changes: 1 addition & 1 deletion client/openai_chatcompletion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@

# We don't do prompt formatting in the proxy - yet
# Mistral Instruct
input = f'[INST] {FLAGS.input} [/INST]'
input = f'{FLAGS.input}'

start_time = time.time()
chat_completion = client.chat.completions.create(
Expand Down
65 changes: 19 additions & 46 deletions src/backend/triton/routes/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,40 +26,43 @@ use crate::backend::triton::utils::get_output_idx;
use crate::backend::triton::ModelInferRequest;
use crate::errors::AiRouterError;
use crate::request::{check_input_cc, AiRouterRequestData};
use crate::templater::{TemplateType, Templater};
use crate::utils::deserialize_bytes_tensor;

const MAX_TOKENS: u32 = 131_072;
const MODEL_OUTPUT_NAME: &str = "text_output";

#[instrument(skip(client, request, request_data))]
#[instrument(skip(client, request, request_data, templater))]
pub async fn compat_chat_completions(
client: GrpcInferenceServiceClient<Channel>,
request: Json<ChatCompletionParameters>,
request_data: &mut AiRouterRequestData,
templater: Templater,
) -> Response {
tracing::debug!("request: {:?}", request);

if request.stream.unwrap_or(false) {
chat_completions_stream(client, request, request_data)
chat_completions_stream(client, request, request_data, templater)
.await
.into_response()
} else {
chat_completions(client, request, request_data)
chat_completions(client, request, request_data, templater)
.await
.into_response()
}
}

#[instrument(skip(client, request, request_data))]
#[instrument(skip(client, request, request_data, templater))]
async fn chat_completions_stream(
mut client: GrpcInferenceServiceClient<Channel>,
Json(request): Json<ChatCompletionParameters>,
request_data: &mut AiRouterRequestData,
templater: Templater,
) -> Result<Sse<impl Stream<Item = anyhow::Result<Event>>>, AiRouterError<String>> {
let id = format!("cmpl-{}", Uuid::new_v4());
let created = u32::try_from(SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs())?;

let request = build_triton_request(request, request_data)?;
let request = build_triton_request(request, request_data, templater)?;
let model_name = request_data
.original_model
.clone()
Expand Down Expand Up @@ -159,13 +162,14 @@ async fn chat_completions_stream(
Ok(Sse::new(response_stream).keep_alive(KeepAlive::default()))
}

#[instrument(skip(client, request, request_data), err(Debug))]
#[instrument(skip(client, request, request_data, templater), err(Debug))]
async fn chat_completions(
mut client: GrpcInferenceServiceClient<Channel>,
Json(request): Json<ChatCompletionParameters>,
request_data: &mut AiRouterRequestData,
templater: Templater,
) -> Result<Json<ChatCompletionResponse>, AiRouterError<String>> {
let request = build_triton_request(request, request_data)?;
let request = build_triton_request(request, request_data, templater)?;
let model_name = request_data
.original_model
.clone()
Expand Down Expand Up @@ -237,18 +241,21 @@ async fn chat_completions(
fn build_triton_request(
request: ChatCompletionParameters,
request_data: &mut AiRouterRequestData,
templater: Templater,
) -> Result<ModelInferRequest, AiRouterError<String>> {
let chat_history = build_chat_history(request.messages);
tracing::debug!("chat history after formatting: {}", chat_history);

check_input_cc(&chat_history, &request.model, request_data)?;
let input = templater.apply(
&request.messages,
request_data.template.clone(),
&TemplateType::ChatCompletion,
)?;
check_input_cc(&input, &request.model, request_data)?;

let mut builder = Builder::new()
.model_name(request.model)
.input(
"text_input",
[1, 1],
InferTensorData::Bytes(vec![chat_history.as_bytes().to_vec()]),
InferTensorData::Bytes(vec![input.as_bytes().to_vec()]),
)
.input(
"bad_words",
Expand Down Expand Up @@ -318,40 +325,6 @@ fn build_triton_request(
Ok(builder.build().context("failed to build triton request")?)
}

fn build_chat_history(messages: Vec<ChatMessage>) -> String {
let mut history = String::new();
for message in messages {
let ChatMessageContent::Text(content) = message.content else {
continue;
};
match message.role {
Role::System => {
if let Some(name) = message.name {
history.push_str(&format!("System {name}: {content}\n"));
} else {
history.push_str(&format!("System: {content}\n"));
}
}
Role::User => {
if let Some(name) = message.name {
history.push_str(&format!("User {name}: {content}\n"));
} else {
history.push_str(&format!("User: {content}\n"));
}
}
Role::Assistant => {
history.push_str(&format!("Assistant: {content}\n"));
}
Role::Tool => {
history.push_str(&format!("Tool: {content}\n"));
}
Role::Function => {}
}
}
history.push_str("ASSISTANT:");
history
}

fn string_vec_to_byte_vecs(strings: &Vec<String>) -> Vec<Vec<u8>> {
let mut byte_vecs: Vec<Vec<u8>> = Vec::new();

Expand Down
33 changes: 18 additions & 15 deletions src/backend/triton/routes/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,40 +23,43 @@ use crate::backend::triton::utils::get_output_idx;
use crate::backend::triton::ModelInferRequest;
use crate::errors::AiRouterError;
use crate::request::{check_input_cc, AiRouterRequestData};
use crate::templater::{TemplateType, Templater};
use crate::utils::{deserialize_bytes_tensor, string_or_seq_string};

const MAX_TOKENS: u32 = 131_072;
const MODEL_OUTPUT_NAME: &str = "text_output";

#[instrument(skip(client, request, request_data))]
#[instrument(skip(client, request, request_data, templater))]
pub async fn compat_completions(
client: GrpcInferenceServiceClient<Channel>,
request: Json<CompletionCreateParams>,
request_data: &mut AiRouterRequestData,
templater: Templater,
) -> Response {
tracing::debug!("request: {:?}", request);

if request.stream {
completions_stream(client, request, request_data)
completions_stream(client, request, request_data, templater)
.await
.into_response()
} else {
completions(client, request, request_data)
completions(client, request, request_data, templater)
.await
.into_response()
}
}

#[instrument(skip(client, request, request_data))]
#[instrument(skip(client, request, request_data, templater))]
async fn completions_stream(
mut client: GrpcInferenceServiceClient<Channel>,
Json(request): Json<CompletionCreateParams>,
request_data: &mut AiRouterRequestData,
templater: Templater,
) -> Result<Sse<impl Stream<Item = anyhow::Result<Event>>>, AiRouterError<String>> {
let id = format!("cmpl-{}", Uuid::new_v4());
let created = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();

let request = build_triton_request(request, request_data)?;
let request = build_triton_request(request, request_data, templater)?;
let model_name = request_data
.original_model
.clone()
Expand Down Expand Up @@ -150,13 +153,14 @@ async fn completions_stream(
Ok(Sse::new(response_stream).keep_alive(KeepAlive::default()))
}

#[instrument(skip(client, request, request_data), err(Debug))]
#[instrument(skip(client, request, request_data, templater), err(Debug))]
async fn completions(
mut client: GrpcInferenceServiceClient<Channel>,
Json(request): Json<CompletionCreateParams>,
request_data: &mut AiRouterRequestData,
templater: Templater,
) -> Result<Json<Completion>, AiRouterError<String>> {
let request = build_triton_request(request, request_data)?;
let request = build_triton_request(request, request_data, templater)?;
let model_name = request_data
.original_model
.clone()
Expand Down Expand Up @@ -222,22 +226,21 @@ async fn completions(
fn build_triton_request(
request: CompletionCreateParams,
request_data: &mut AiRouterRequestData,
templater: Templater,
) -> Result<ModelInferRequest, AiRouterError<String>> {
let input: String = request.prompt.join(" ");
let input = templater.apply(
&request.prompt,
request_data.template.clone(),
&TemplateType::LegacyCompletion,
)?;
check_input_cc(&input, &request.model, request_data)?;

let mut builder = Builder::new()
.model_name(request.model)
.input(
"text_input",
[1, 1],
InferTensorData::Bytes(
request
.prompt
.into_iter()
.map(|s| s.as_bytes().to_vec())
.collect(),
),
InferTensorData::Bytes(vec![input.as_bytes().to_vec()]),
)
.input(
"max_tokens",
Expand Down
7 changes: 7 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use serde_with::{formats::PreferMany, serde_as, skip_serializing_none, OneOrMany
use uuid::Uuid;

const DEFAULT_CONFIG_FILE: &str = "/etc/ai-router.toml";
const DEFAULT_TEMPLATE_DIR: &str = "/etc/ai-router/templates";

pub type AiRouterModels = HashMap<AiRouterModelType, HashMap<String, AiRouterModel>>;

Expand Down Expand Up @@ -150,6 +151,8 @@ pub struct AiRouterDaemon {
pub listen_ip: String,
pub listen_port: u16,
pub otlp_endpoint: Option<String>,
#[serde(default = "default_template_dir")]
pub template_dir: String,
}

#[skip_serializing_none]
Expand Down Expand Up @@ -180,6 +183,10 @@ fn default_instance_id() -> String {
String::from(Uuid::new_v4())
}

fn default_template_dir() -> String {
String::from(DEFAULT_TEMPLATE_DIR)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ pub mod routes;
pub mod startup;
mod state;
pub mod telemetry;
mod templater;
mod tokenizers;
mod utils;
4 changes: 4 additions & 0 deletions src/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub struct AiRouterRequestData {
pub max_tokens: Option<u32>,
pub original_model: Option<String>,
pub prompt_tokens: usize,
pub template: Option<String>,
pub tokenizer: Option<Tokenizer>,
}

Expand All @@ -21,6 +22,7 @@ impl AiRouterRequestData {
max_tokens: None,
original_model: None,
prompt_tokens: 0,
template: None,
tokenizer: None,
}
}
Expand Down Expand Up @@ -52,6 +54,8 @@ impl AiRouterRequestData {
request_data.max_tokens = Some(max_tokens);
}

request_data.template = model.prompt_format.clone();

Ok(request_data)
}
}
Expand Down
1 change: 1 addition & 0 deletions src/routes/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub async fn completion(
c.clone(),
request,
&mut request_data,
state.templater.clone(),
)
.await;
}
Expand Down
1 change: 1 addition & 0 deletions src/routes/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pub async fn completion(
c.clone(),
request,
&mut request_data,
state.templater.clone(),
)
.await;
}
Expand Down
2 changes: 2 additions & 0 deletions src/startup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use crate::state::State;
/// - when we're unable to connect to the Triton endpoint
/// - when we're unable to bind the `TCPListener` for the axum server
/// - when we're unable to start the axum server
/// # Panics
/// - when we're unable to initialize the templater
pub async fn run_server(config_file: &AiRouterConfigFile) -> anyhow::Result<()> {
let (prometheus_layer, metric_handle) = PrometheusMetricLayer::pair();

Expand Down
Loading

0 comments on commit c19ce88

Please sign in to comment.