Skip to content

Commit

Permalink
triton/audio: use templater for transcriptions text_prefix
Browse files Browse the repository at this point in the history
For Whisper, the following template should be put in
/etc/ai-router/templates/transcription/whisper.j2 and prompt_format for
the model should be set to "whisper".

TODO: is TEXT_PREFIX specific to Whisper? If so, we might need to
rethink this completely.
  • Loading branch information
stintel committed May 31, 2024
1 parent ce5b659 commit 194e429
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 7 deletions.
16 changes: 10 additions & 6 deletions src/backend/triton/routes/audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ use crate::{
ModelInferRequest,
},
errors::AiRouterError,
request::AiRouterRequestData,
templater::Templater,
};

const MODEL_OUTPUT_NAME: &str = "TRANSCRIPTS";
Expand All @@ -37,12 +39,14 @@ struct AudioTranscriptionResponse {
pub(crate) async fn transcriptions(
mut client: GrpcInferenceServiceClient<Channel>,
parameters: AudioTranscriptionParameters,
request_data: AiRouterRequestData,
templater: Templater,
) -> Result<Response, AiRouterError<String>> {
// this results in the audio bytes being written to the OLTP endpoint
// tracing::debug!("triton audio transcriptions request: {:?}", parameters);
let response_format = parameters.response_format.clone();

let request = build_triton_request(parameters)?;
let request = build_triton_request(parameters, request_data, templater)?;
let request_stream = stream! { yield request };
let mut stream = client
.model_stream_infer(tonic::Request::new(request_stream))
Expand Down Expand Up @@ -166,6 +170,8 @@ fn get_audio_samples(cursor: Cursor<Bytes>) -> Result<(i64, Vec<f32>), AiRouterE
#[instrument(level = "debug", skip(request))]
fn build_triton_request(
request: AudioTranscriptionParameters,
request_data: AiRouterRequestData,
templater: Templater,
) -> Result<ModelInferRequest, AiRouterError<String>> {
let audio = match request.file {
AudioTranscriptionFile::Bytes(b) => {
Expand All @@ -184,18 +190,16 @@ fn build_triton_request(
}
};

let text_prefix = templater.apply_transcription(request.language, request_data.template)?;

let (num_samples, audio) = get_audio_samples(audio)?;

let builder = Builder::new()
.model_name(request.model)
.input(
"TEXT_PREFIX",
[1, 1],
InferTensorData::Bytes(vec![
"<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
.as_bytes()
.to_vec(),
]),
InferTensorData::Bytes(vec![text_prefix.as_bytes().to_vec()]),
)
.input("WAV", [1, num_samples], InferTensorData::FP32(audio))
.output(MODEL_OUTPUT_NAME);
Expand Down
16 changes: 15 additions & 1 deletion src/routes/audio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use crate::backend::openai::routes as openai_routes;
use crate::backend::triton::routes as triton_routes;
use crate::config::AiRouterModelType;
use crate::errors::AiRouterError;
use crate::request::AiRouterRequestData;
use crate::state::{BackendTypes, State};
use crate::utils::get_file_extension;

Expand Down Expand Up @@ -93,7 +94,20 @@ pub async fn transcriptions(
return openai_routes::audio::transcriptions(c, parameters).await
}
BackendTypes::Triton(c) => {
return triton_routes::audio::transcriptions(c.clone(), parameters).await
let request_data = AiRouterRequestData::build(
model,
model.backend_model.clone(),
&parameters.model,
&state,
)?;

return triton_routes::audio::transcriptions(
c.clone(),
parameters,
request_data,
state.templater.clone(),
)
.await;
}
}
}
Expand Down
35 changes: 35 additions & 0 deletions src/templater.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,41 @@ impl Templater {

Ok(rendered)
}

pub fn apply_transcription(
self,
language: Option<String>,
template: Option<String>,
) -> Result<String, AiRouterError<String>> {
let mut rendered = String::new();

if template.is_none() {
return Ok(rendered);
}

if let Some(template) = template {
let template = format!("transcription/{template}.j2");
let tpl = self.env.get_template(&template).map_err(|e| {
AiRouterError::InternalServerError(format!(
"failed to load transcription template {template}: {e}",
))
})?;

let language = language.unwrap_or_else(|| String::from("en"));

let ctx = context! {language => language};

rendered = tpl.render(ctx).map_err(|e| {
AiRouterError::InternalServerError(format!(
"failed to render transcription template: {e}"
))
})?;
}

tracing::debug!("prefix after applying transcription template: {rendered}");

Ok(rendered)
}
}

fn raise_exception(_state: &State, msg: String) -> Result<String, Error> {
Expand Down

0 comments on commit 194e429

Please sign in to comment.