-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[refactor] Refactor AI module using new
AiSummary
trait
- Loading branch information
1 parent
39f1979
commit f72fdf3
Showing
4 changed files
with
121 additions
and
130 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,64 +1,55 @@ | ||
use super::prompt::RELEASE_SUMMARY_PROMPT; | ||
use super::AiSummary; | ||
use crate::utils::config; | ||
use anyhow::Result; | ||
use serde::Deserialize; | ||
use serde_json::json; | ||
|
||
pub async fn summarise_release(diff: &str, commit_messages: &[String]) -> Result<String> { | ||
let base_url = config::get("OPENAI_BASE_URL")?; | ||
let api_key = config::get("OPENAI_API_KEY")?; | ||
let model = config::get("OPENAI_MODEL")?; | ||
|
||
let commit_messages = commit_messages.join("\n"); | ||
|
||
let mut response = reqwest::Client::new() | ||
.post(format!("{base_url}/chat/completions")) | ||
.header("content-type", "application/json") | ||
.bearer_auth(api_key) | ||
.json(&json!({ | ||
"model": model, | ||
"temperature": 0.0, | ||
"frequency_penalty": 2.0, | ||
"messages": [ | ||
{ | ||
"role": "system", | ||
"content": RELEASE_SUMMARY_PROMPT | ||
}, | ||
{ | ||
"role": "user", | ||
"content": format!( | ||
"<Diff>{diff}</Diff> | ||
<Commit Messages>{commit_messages}<Commit Messages>" | ||
) | ||
} | ||
] | ||
})) | ||
.send() | ||
.await? | ||
.error_for_status()? | ||
.json::<Response>() | ||
.await?; | ||
|
||
Ok(response.get_summary()) | ||
pub struct ChatGpt; | ||
|
||
impl AiSummary for ChatGpt { | ||
const LLM_PROVIDER: &'static str = "OpenAI"; | ||
|
||
async fn make_request(input: String) -> Result<String> { | ||
let base_url = config::get("OPENAI_BASE_URL")?; | ||
let api_key = config::get("OPENAI_API_KEY")?; | ||
let model = config::get("OPENAI_MODEL")?; | ||
|
||
let mut response = reqwest::Client::new() | ||
.post(format!("{base_url}/chat/completions")) | ||
.header("content-type", "application/json") | ||
.bearer_auth(api_key) | ||
.json(&json!({ | ||
"model": model, | ||
"temperature": 0.0, | ||
"frequency_penalty": 2.0, | ||
"messages": [ | ||
{ "role": "system", "content": Self::SYSTEM_PROMPT }, | ||
{ "role": "user", "content": input } | ||
] | ||
})) | ||
.send() | ||
.await? | ||
.error_for_status()? | ||
.json::<ApiResponse>() | ||
.await?; | ||
|
||
let summary = response.choices.remove(0).message.content; | ||
|
||
Ok(summary) | ||
} | ||
} | ||
|
||
#[derive(Deserialize)] | ||
pub struct Response { | ||
choices: Vec<Message>, | ||
} | ||
|
||
impl Response { | ||
pub fn get_summary(&mut self) -> String { | ||
self.choices.remove(0).message.content | ||
} | ||
pub struct ApiResponse { | ||
choices: Vec<Choice>, | ||
} | ||
|
||
#[derive(Deserialize)] | ||
struct Message { | ||
message: Content, | ||
struct Choice { | ||
message: Message, | ||
} | ||
|
||
#[derive(Deserialize)] | ||
struct Content { | ||
struct Message { | ||
content: String, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,51 +1,49 @@ | ||
use super::prompt::RELEASE_SUMMARY_PROMPT; | ||
use super::{prompt::RELEASE_SUMMARY_PROMPT, AiSummary}; | ||
use crate::utils::config; | ||
use anyhow::Result; | ||
use serde::Deserialize; | ||
use serde_json::json; | ||
|
||
pub async fn summarise_release(diff: &str, commit_messages: &[String]) -> Result<String> { | ||
let base_url = config::get("ANTHROPIC_BASE_URL")?; | ||
let api_key = config::get("ANTHROPIC_API_KEY")?; | ||
let model = config::get("ANTHROPIC_MODEL")?; | ||
|
||
let commit_messages = commit_messages.join("\n"); | ||
|
||
let mut response = reqwest::Client::new() | ||
.post(format!("{base_url}/v1/messages")) | ||
.header("content-type", "application/json") | ||
.header("anthropic-version", "2023-06-01") | ||
.header("x-api-key", api_key) | ||
.json(&json!({ | ||
"model": model, | ||
"max_tokens": 1024, | ||
"temperature": 0.0, | ||
"system": format!("Prompt: {RELEASE_SUMMARY_PROMPT}"), | ||
"messages": [{ | ||
"role": "user", | ||
"content": format!(" | ||
<Diff>{diff}</Diff> | ||
<CommitMessages>{commit_messages}</CommitMessages> | ||
") | ||
}] | ||
})) | ||
.send() | ||
.await? | ||
.error_for_status()? | ||
.json::<Response>() | ||
.await?; | ||
|
||
let summary = response.content.remove(0).text; | ||
|
||
Ok(summary) | ||
pub struct Claude; | ||
|
||
impl AiSummary for Claude { | ||
const LLM_PROVIDER: &'static str = "Anthropic"; | ||
|
||
async fn make_request(input: String) -> Result<String> { | ||
let base_url = config::get("ANTHROPIC_BASE_URL")?; | ||
let api_key = config::get("ANTHROPIC_API_KEY")?; | ||
let model = config::get("ANTHROPIC_MODEL")?; | ||
|
||
let mut response = reqwest::Client::new() | ||
.post(format!("{base_url}/v1/messages")) | ||
.header("content-type", "application/json") | ||
.header("anthropic-version", "2023-06-01") | ||
.header("x-api-key", api_key) | ||
.json(&json!({ | ||
"model": model, | ||
"max_tokens": 1024, | ||
"temperature": 0.0, | ||
"system": format!("Prompt: {RELEASE_SUMMARY_PROMPT}"), | ||
"messages": [{ "role": "user", "content": input }] | ||
})) | ||
.send() | ||
.await? | ||
.error_for_status()? | ||
.json::<Response>() | ||
.await?; | ||
|
||
let summary = response.content.remove(0).text; | ||
|
||
Ok(summary) | ||
} | ||
} | ||
|
||
#[derive(Deserialize)] | ||
pub struct Response { | ||
content: Vec<TextContent>, | ||
content: Vec<Content>, | ||
} | ||
|
||
#[derive(Deserialize)] | ||
struct TextContent { | ||
struct Content { | ||
text: String, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,58 @@ | ||
use anyhow::{Error, Result}; | ||
use regex_lite::Regex; | ||
use std::sync::OnceLock; | ||
|
||
use crate::utils::config; | ||
mod chat_gpt; | ||
mod claude; | ||
mod prompt; | ||
|
||
static CHANGES_REGEX: OnceLock<Regex> = OnceLock::new(); | ||
|
||
fn extract_summary(ai_response: String) -> Result<String> { | ||
let summary_regex = | ||
CHANGES_REGEX.get_or_init(|| Regex::new(r"(?s)<Changes>(.*?)<\/Changes>").unwrap()); | ||
use crate::utils::config; | ||
use anyhow::Result; | ||
use chat_gpt::ChatGpt; | ||
use claude::Claude; | ||
use prompt::RELEASE_SUMMARY_PROMPT; | ||
use regex_lite::Regex; | ||
use std::sync::OnceLock; | ||
|
||
let Some(matches) = summary_regex.captures(&ai_response) else { | ||
return Ok(ai_response); | ||
}; | ||
pub async fn get_summary(diff: &str, commit_messages: &[String]) -> Result<String> { | ||
let llm_provider = config::get("LLM_PROVIDER")?; | ||
|
||
if matches.len() == 0 { | ||
return Ok(ai_response); | ||
match llm_provider.as_str() { | ||
"anthropic" => Claude::get_summary(diff, commit_messages).await, | ||
_ => ChatGpt::get_summary(diff, commit_messages).await, | ||
} | ||
|
||
Ok(matches.get(1).unwrap().as_str().trim().to_string()) | ||
} | ||
|
||
pub async fn get_summary(diff: &str, commit_messages: &[String]) -> Result<String> { | ||
let llm_provider = config::get("LLM_PROVIDER")?; | ||
trait AiSummary { | ||
const SYSTEM_PROMPT: &'static str = RELEASE_SUMMARY_PROMPT; | ||
const LLM_PROVIDER: &'static str; | ||
|
||
let ai_response = match llm_provider.as_str() { | ||
"openai" => get_chat_gpt_summary(diff, commit_messages).await, | ||
"anthropic" => get_claude_summary(diff, commit_messages).await, | ||
_ => { | ||
tracing::warn!("Unknown LLM provider '{llm_provider}', defaulting to Anthropic."); | ||
get_claude_summary(diff, commit_messages).await | ||
} | ||
}?; | ||
async fn make_request(input: String) -> Result<String>; | ||
|
||
extract_summary(ai_response) | ||
} | ||
async fn get_summary(diff: &str, commit_messages: &[String]) -> Result<String> { | ||
let commit_messages = commit_messages.join("\n"); | ||
|
||
async fn get_claude_summary(diff: &str, commit_messages: &[String]) -> Result<String> { | ||
claude::summarise_release(diff, commit_messages) | ||
.await | ||
.map_err(|err| get_err("Anthropic", err)) | ||
} | ||
let input = format!( | ||
"<Diff>{diff}</Diff> | ||
<CommitMessages>{commit_messages}<CommitMessages>" | ||
); | ||
|
||
async fn get_chat_gpt_summary(diff: &str, commit_messages: &[String]) -> Result<String> { | ||
chat_gpt::summarise_release(diff, commit_messages) | ||
.await | ||
.map_err(|err| get_err("OpenAI", err)) | ||
} | ||
let response = Self::make_request(input).await?; | ||
let summary = Self::extract_output(response); | ||
|
||
Ok(summary) | ||
} | ||
|
||
fn get_err(provider: &str, err: Error) -> Error { | ||
anyhow::anyhow!("*⚠️ An error occurred while using the {provider} provider:*\n\n ```{err}```") | ||
fn extract_output(output: String) -> String { | ||
let output_regex = | ||
OUTPUT_REGEX.get_or_init(|| Regex::new(r"(?s)<Output>(.*?)<\/Output>").unwrap()); | ||
|
||
let Some(matches) = output_regex.captures(&output) else { | ||
return output; | ||
}; | ||
|
||
if matches.len() == 0 { | ||
return output; | ||
} | ||
|
||
matches.get(1).unwrap().as_str().trim().to_string() | ||
} | ||
} | ||
|
||
static OUTPUT_REGEX: OnceLock<Regex> = OnceLock::new(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters