Skip to content

Commit

Permalink
[refactor] Refactor AI module using new AiSummary trait
Browse files Browse the repository at this point in the history
  • Loading branch information
constantincerdan committed Jul 17, 2024
1 parent 39f1979 commit f72fdf3
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 130 deletions.
87 changes: 39 additions & 48 deletions src/ai/chat_gpt.rs
Original file line number Diff line number Diff line change
@@ -1,64 +1,55 @@
use super::prompt::RELEASE_SUMMARY_PROMPT;
use super::AiSummary;
use crate::utils::config;
use anyhow::Result;
use serde::Deserialize;
use serde_json::json;

pub async fn summarise_release(diff: &str, commit_messages: &[String]) -> Result<String> {
let base_url = config::get("OPENAI_BASE_URL")?;
let api_key = config::get("OPENAI_API_KEY")?;
let model = config::get("OPENAI_MODEL")?;

let commit_messages = commit_messages.join("\n");

let mut response = reqwest::Client::new()
.post(format!("{base_url}/chat/completions"))
.header("content-type", "application/json")
.bearer_auth(api_key)
.json(&json!({
"model": model,
"temperature": 0.0,
"frequency_penalty": 2.0,
"messages": [
{
"role": "system",
"content": RELEASE_SUMMARY_PROMPT
},
{
"role": "user",
"content": format!(
"<Diff>{diff}</Diff>
<Commit Messages>{commit_messages}<Commit Messages>"
)
}
]
}))
.send()
.await?
.error_for_status()?
.json::<Response>()
.await?;

Ok(response.get_summary())
pub struct ChatGpt;

impl AiSummary for ChatGpt {
const LLM_PROVIDER: &'static str = "OpenAI";

async fn make_request(input: String) -> Result<String> {
let base_url = config::get("OPENAI_BASE_URL")?;
let api_key = config::get("OPENAI_API_KEY")?;
let model = config::get("OPENAI_MODEL")?;

let mut response = reqwest::Client::new()
.post(format!("{base_url}/chat/completions"))
.header("content-type", "application/json")
.bearer_auth(api_key)
.json(&json!({
"model": model,
"temperature": 0.0,
"frequency_penalty": 2.0,
"messages": [
{ "role": "system", "content": Self::SYSTEM_PROMPT },
{ "role": "user", "content": input }
]
}))
.send()
.await?
.error_for_status()?
.json::<ApiResponse>()
.await?;

let summary = response.choices.remove(0).message.content;

Ok(summary)
}
}

#[derive(Deserialize)]
pub struct Response {
choices: Vec<Message>,
}

impl Response {
pub fn get_summary(&mut self) -> String {
self.choices.remove(0).message.content
}
pub struct ApiResponse {
choices: Vec<Choice>,
}

#[derive(Deserialize)]
struct Message {
message: Content,
struct Choice {
message: Message,
}

#[derive(Deserialize)]
struct Content {
struct Message {
content: String,
}
72 changes: 35 additions & 37 deletions src/ai/claude.rs
Original file line number Diff line number Diff line change
@@ -1,51 +1,49 @@
use super::prompt::RELEASE_SUMMARY_PROMPT;
use super::{prompt::RELEASE_SUMMARY_PROMPT, AiSummary};
use crate::utils::config;
use anyhow::Result;
use serde::Deserialize;
use serde_json::json;

pub async fn summarise_release(diff: &str, commit_messages: &[String]) -> Result<String> {
let base_url = config::get("ANTHROPIC_BASE_URL")?;
let api_key = config::get("ANTHROPIC_API_KEY")?;
let model = config::get("ANTHROPIC_MODEL")?;

let commit_messages = commit_messages.join("\n");

let mut response = reqwest::Client::new()
.post(format!("{base_url}/v1/messages"))
.header("content-type", "application/json")
.header("anthropic-version", "2023-06-01")
.header("x-api-key", api_key)
.json(&json!({
"model": model,
"max_tokens": 1024,
"temperature": 0.0,
"system": format!("Prompt: {RELEASE_SUMMARY_PROMPT}"),
"messages": [{
"role": "user",
"content": format!("
<Diff>{diff}</Diff>
<CommitMessages>{commit_messages}</CommitMessages>
")
}]
}))
.send()
.await?
.error_for_status()?
.json::<Response>()
.await?;

let summary = response.content.remove(0).text;

Ok(summary)
pub struct Claude;

impl AiSummary for Claude {
const LLM_PROVIDER: &'static str = "Anthropic";

async fn make_request(input: String) -> Result<String> {
let base_url = config::get("ANTHROPIC_BASE_URL")?;
let api_key = config::get("ANTHROPIC_API_KEY")?;
let model = config::get("ANTHROPIC_MODEL")?;

let mut response = reqwest::Client::new()
.post(format!("{base_url}/v1/messages"))
.header("content-type", "application/json")
.header("anthropic-version", "2023-06-01")
.header("x-api-key", api_key)
.json(&json!({
"model": model,
"max_tokens": 1024,
"temperature": 0.0,
"system": format!("Prompt: {RELEASE_SUMMARY_PROMPT}"),
"messages": [{ "role": "user", "content": input }]
}))
.send()
.await?
.error_for_status()?
.json::<Response>()
.await?;

let summary = response.content.remove(0).text;

Ok(summary)
}
}

#[derive(Deserialize)]
pub struct Response {
content: Vec<TextContent>,
content: Vec<Content>,
}

#[derive(Deserialize)]
struct TextContent {
struct Content {
text: String,
}
84 changes: 43 additions & 41 deletions src/ai/mod.rs
Original file line number Diff line number Diff line change
@@ -1,56 +1,58 @@
use anyhow::{Error, Result};
use regex_lite::Regex;
use std::sync::OnceLock;

use crate::utils::config;
mod chat_gpt;
mod claude;
mod prompt;

static CHANGES_REGEX: OnceLock<Regex> = OnceLock::new();

fn extract_summary(ai_response: String) -> Result<String> {
let summary_regex =
CHANGES_REGEX.get_or_init(|| Regex::new(r"(?s)<Changes>(.*?)<\/Changes>").unwrap());
use crate::utils::config;
use anyhow::Result;
use chat_gpt::ChatGpt;
use claude::Claude;
use prompt::RELEASE_SUMMARY_PROMPT;
use regex_lite::Regex;
use std::sync::OnceLock;

let Some(matches) = summary_regex.captures(&ai_response) else {
return Ok(ai_response);
};
pub async fn get_summary(diff: &str, commit_messages: &[String]) -> Result<String> {
let llm_provider = config::get("LLM_PROVIDER")?;

if matches.len() == 0 {
return Ok(ai_response);
match llm_provider.as_str() {
"anthropic" => Claude::get_summary(diff, commit_messages).await,
_ => ChatGpt::get_summary(diff, commit_messages).await,
}

Ok(matches.get(1).unwrap().as_str().trim().to_string())
}

pub async fn get_summary(diff: &str, commit_messages: &[String]) -> Result<String> {
let llm_provider = config::get("LLM_PROVIDER")?;
trait AiSummary {
const SYSTEM_PROMPT: &'static str = RELEASE_SUMMARY_PROMPT;
const LLM_PROVIDER: &'static str;

let ai_response = match llm_provider.as_str() {
"openai" => get_chat_gpt_summary(diff, commit_messages).await,
"anthropic" => get_claude_summary(diff, commit_messages).await,
_ => {
tracing::warn!("Unknown LLM provider '{llm_provider}', defaulting to Anthropic.");
get_claude_summary(diff, commit_messages).await
}
}?;
async fn make_request(input: String) -> Result<String>;

extract_summary(ai_response)
}
async fn get_summary(diff: &str, commit_messages: &[String]) -> Result<String> {
let commit_messages = commit_messages.join("\n");

async fn get_claude_summary(diff: &str, commit_messages: &[String]) -> Result<String> {
claude::summarise_release(diff, commit_messages)
.await
.map_err(|err| get_err("Anthropic", err))
}
let input = format!(
"<Diff>{diff}</Diff>
<CommitMessages>{commit_messages}<CommitMessages>"
);

async fn get_chat_gpt_summary(diff: &str, commit_messages: &[String]) -> Result<String> {
chat_gpt::summarise_release(diff, commit_messages)
.await
.map_err(|err| get_err("OpenAI", err))
}
let response = Self::make_request(input).await?;
let summary = Self::extract_output(response);

Ok(summary)
}

fn get_err(provider: &str, err: Error) -> Error {
anyhow::anyhow!("*⚠️ An error occurred while using the {provider} provider:*\n\n ```{err}```")
fn extract_output(output: String) -> String {
let output_regex =
OUTPUT_REGEX.get_or_init(|| Regex::new(r"(?s)<Output>(.*?)<\/Output>").unwrap());

let Some(matches) = output_regex.captures(&output) else {
return output;
};

if matches.len() == 0 {
return output;
}

matches.get(1).unwrap().as_str().trim().to_string()
}
}

static OUTPUT_REGEX: OnceLock<Regex> = OnceLock::new();
8 changes: 4 additions & 4 deletions src/ai/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub const RELEASE_SUMMARY_PROMPT: &str = "
Exclude Unchanged Sections: Only include headings for New features, Improvements, Bug fixes, and Dependency changes if there are updates to list for those headings.
</Steps>
<ExampleOutPut1>
<Changes>
<Output>
*New features*:
• Search results can now be filtered by date and relevance.
• New avatar customisation options have been added to user profiles.
Expand All @@ -34,18 +34,18 @@ pub const RELEASE_SUMMARY_PROMPT: &str = "
</Changes>
</ExampleOutPut1>
<ExampleOutPut2>
<Changes>
<Output>
*New features*:
• Added support for tracking URLs in Discord messages for new product discoveries.
</Changes>
</ExampleOutPut2>
<ExampleOutPut3>
<Changes>
<Output>
*Bug fixes*:
• Fixed an issue where the Twitter hyperlink was not displaying properly.
</Changes>
</ExampleOutPut3>
Please perform this analysis on the provided git code diff and commits and deliver a summary as described above based on that diff.
The output should be placed in <Changes> tags as demonstrated in the examples above.
The output should be placed in <Output> tags as demonstrated in the examples above.
Avoid including headings for New features, Improvements, Bug fixes, or Dependency changes if there are no updates to list for those headings.
";

0 comments on commit f72fdf3

Please sign in to comment.