Skip to content

Commit

Permalink
add json_object
Browse files Browse the repository at this point in the history
  • Loading branch information
HikaruEgashira committed Dec 21, 2024
1 parent a649267 commit 1d2054d
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 262 deletions.
4 changes: 2 additions & 2 deletions src/analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub async fn analyze_file(
content: prompt,
}];

let chat_response = llm.chat(&messages).await?;
let chat_response = llm.chat(&messages, Some("Response".to_string())).await?;
let response: Response = serde_json::from_str(&chat_response)?;
info!("Initial analysis complete");

Expand Down Expand Up @@ -87,7 +87,7 @@ pub async fn analyze_file(
role: "user".to_string(),
content: prompt.clone(),
}];
let chat_response = llm.chat(&messages).await?;
let chat_response = llm.chat(&messages, Some("Response".to_string())).await?;
let vuln_response: Response = serde_json::from_str(&chat_response)?;

if verbosity > 0 {
Expand Down
226 changes: 0 additions & 226 deletions src/evaluator.rs

This file was deleted.

1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
pub mod analyzer;
pub mod evaluator;
pub mod llms;
pub mod parser;
pub mod prompts;
Expand Down
36 changes: 30 additions & 6 deletions src/llms/claude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ impl Claude {

#[async_trait]
impl LLM for Claude {
async fn chat(&self, messages: &[ChatMessage]) -> Result<String> {
#[derive(Serialize)]
async fn chat(
&self,
messages: &[ChatMessage],
response_model: Option<String>,
) -> Result<String> {
#[derive(Serialize)]
struct Request {
model: String,
max_tokens: u32,
Expand Down Expand Up @@ -140,7 +144,27 @@ impl LLM for Claude {
log::error!("Empty response content: {}", response_text);
return Err(anyhow::anyhow!("Empty response content"));
}
Ok(response.content[0].text.clone())
if let Some(model) = response_model {
let json_str = response.content[0].text.clone();
let parsed_model: Result<_, _> = serde_json::from_str(&json_str);
match parsed_model {
Ok(model) => Ok(model),
Err(parse_error) => {
log::error!(
"JSON Parsing error: {} | Raw response: {}",
parse_error,
response_text
);
Err(anyhow::anyhow!(
"Parsing error: {} | Raw response: {}",
parse_error,
response_text
))
}
}
} else {
Ok(response.content[0].text.clone())
}
}
Err(parse_error) => {
log::error!(
Expand Down Expand Up @@ -187,7 +211,7 @@ mod tests {
content: "What is 2+2?".to_string(),
}];

let result = claude.chat(&messages).await;
let result = claude.chat(&messages, None).await;
assert!(result.is_ok(), "Chat should succeed with valid API key");

let response = result.unwrap();
Expand All @@ -204,7 +228,7 @@ mod tests {
content: "What is 2+2?".to_string(),
}];

let result = claude.chat(&messages).await;
let result = claude.chat(&messages, None).await;
assert!(result.is_err(), "Chat should fail with invalid API key");
}

Expand All @@ -224,7 +248,7 @@ mod tests {
content: "".to_string(),
}];

let result = claude.chat(&messages).await;
let result = claude.chat(&messages, None).await;
assert!(result.is_err(), "Chat should fail with empty message");
assert!(result
.unwrap_err()
Expand Down
12 changes: 8 additions & 4 deletions src/llms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@ use tokio;

#[async_trait]
pub trait LLM {
async fn chat(&self, messages: &[ChatMessage]) -> Result<String>;
async fn chat(
&self,
messages: &[ChatMessage],
response_model: Option<String>,
) -> Result<String>;
}

// OpenAI Message type
Expand Down Expand Up @@ -76,7 +80,7 @@ mod tests {
role: "user".to_string(),
content: "Say 'test successful' in exactly those words.".to_string(),
}];
let response = claude.chat(&messages[..]).await?;
let response = claude.chat(&messages[..], None).await?;
assert!(
response.len() > 0,
"Claude response has a length greater than 0"
Expand All @@ -93,7 +97,7 @@ mod tests {
role: "user".to_string(),
content: "Say 'test successful' in exactly those words.".to_string(),
}];
let response = openai.chat(&messages[..]).await?;
let response = openai.chat(&messages[..], None).await?;
assert!(
response.len() > 0,
"OpenAI response has a length greater than 0"
Expand All @@ -110,7 +114,7 @@ mod tests {
role: "user".to_string(),
content: "Say 'test successful' in exactly those words.".to_string(),
}];
let response = ollama.chat(&messages[..]).await?;
let response = ollama.chat(&messages[..], None).await?;
assert!(
response.len() > 0,
"Ollama response has a length greater than 0"
Expand Down
Loading

0 comments on commit 1d2054d

Please sign in to comment.