Skip to content

Commit 800e983

Browse files
committed
chore: functions macro, restructure
1 parent 9b5a658 commit 800e983

File tree

11 files changed

+247
-205
lines changed

11 files changed

+247
-205
lines changed

src/constants.rs

+1-3
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@ pub const EMBEDDINGS_DIMENSION: usize = 384;
66

77
pub const SSE_CHANNEL_BUFFER_SIZE: usize = 1;
88

9-
pub const FINAL_EXPLANATION_TEMPERATURE: f64 = 0.7;
10-
11-
pub const FUNCTIONS_CALLS_TEMPERATURE: f64 = 0.5;
9+
pub const CHAT_COMPLETION_TEMPERATURE: f64 = 0.5;
1210

1311
pub const ACTIX_WEB_SERVER_PORT: usize = 3000;
1412

src/conversation/data.rs

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
use crate::prelude::*;
2+
use crate::{github::Repository, utils::functions::Function};
3+
use openai_api_rs::v1::chat_completion::FunctionCall;
4+
use serde::Deserialize;
5+
use std::str::FromStr;
6+
7+
#[derive(Deserialize)]
8+
pub struct Query {
9+
pub repository: Repository,
10+
pub query: String,
11+
}
12+
13+
impl ToString for Query {
14+
fn to_string(&self) -> String {
15+
let Query {
16+
repository:
17+
Repository {
18+
owner,
19+
name,
20+
branch,
21+
},
22+
query,
23+
} = self;
24+
format!(
25+
"##Repository Info##\nOwner:{}\nName:{}\nBranch:{}\n##User Query##\nQuery:{}",
26+
owner, name, branch, query
27+
)
28+
}
29+
}
30+
31+
#[derive(Debug)]
32+
pub struct RelevantChunk {
33+
pub path: String,
34+
pub content: String,
35+
}
36+
37+
impl ToString for RelevantChunk {
38+
fn to_string(&self) -> String {
39+
format!(
40+
"##Relevant file chunk##\nPath argument:{}\nRelevant content: {}",
41+
self.path,
42+
self.content.trim()
43+
)
44+
}
45+
}
46+
47+
#[derive(Debug, Clone)]
48+
pub struct ParsedFunctionCall {
49+
pub name: Function,
50+
pub args: serde_json::Value,
51+
}
52+
53+
impl TryFrom<&FunctionCall> for ParsedFunctionCall {
54+
type Error = anyhow::Error;
55+
56+
fn try_from(func: &FunctionCall) -> Result<Self> {
57+
let func = func.clone();
58+
let name = Function::from_str(&func.name.unwrap_or("done".into()))?;
59+
let args = func.arguments.unwrap_or("{}".to_string());
60+
let args = serde_json::from_str::<serde_json::Value>(&args)?;
61+
Ok(ParsedFunctionCall { name, args })
62+
}
63+
}

src/utils/conversation/mod.rs src/conversation/mod.rs

+12-74
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,35 @@
11
#![allow(unused_must_use)]
2+
mod data;
23
mod prompts;
34

45
use crate::{
5-
prelude::*,
66
constants::{RELEVANT_CHUNKS_LIMIT, RELEVANT_FILES_LIMIT},
77
db::RepositoryEmbeddingsDB,
88
embeddings::EmbeddingsModel,
9-
github::Repository,
9+
prelude::*,
1010
routes::events::{emit, QueryEvent},
1111
};
1212
use actix_web_lab::sse::Sender;
13-
use openai_api_rs::v1::chat_completion::{FinishReason, FunctionCall};
13+
pub use data::*;
14+
use openai_api_rs::v1::chat_completion::FinishReason;
1415
use openai_api_rs::v1::{
1516
api::Client,
1617
chat_completion::{
1718
ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, MessageRole,
1819
},
1920
};
20-
use serde::Deserialize;
2121
use std::env;
22-
use std::str::FromStr;
2322
use std::sync::Arc;
2423

2524
use prompts::{generate_completion_request, system_message};
2625

2726
use self::prompts::answer_generation_prompt;
2827

29-
use super::functions::{
28+
use crate::utils::functions::{
3029
paths_to_completion_message, relevant_chunks_to_completion_message, search_codebase,
3130
search_file, search_path, Function,
3231
};
3332

34-
#[derive(Deserialize)]
35-
pub struct Query {
36-
pub repository: Repository,
37-
pub query: String,
38-
}
39-
40-
impl ToString for Query {
41-
fn to_string(&self) -> String {
42-
let Query {
43-
repository:
44-
Repository {
45-
owner,
46-
name,
47-
branch,
48-
},
49-
query,
50-
} = self;
51-
format!(
52-
"##Repository Info##\nOwner:{}\nName:{}\nBranch:{}\n##User Query##\nQuery:{}",
53-
owner, name, branch, query
54-
)
55-
}
56-
}
57-
58-
#[derive(Debug)]
59-
pub struct RelevantChunk {
60-
pub path: String,
61-
pub content: String,
62-
}
63-
64-
impl ToString for RelevantChunk {
65-
fn to_string(&self) -> String {
66-
format!(
67-
"##Relevant file chunk##\nPath argument:{}\nRelevant content: {}",
68-
self.path,
69-
self.content.trim()
70-
)
71-
}
72-
}
73-
74-
#[derive(Debug, Clone)]
75-
struct ParsedFunctionCall {
76-
name: Function,
77-
args: serde_json::Value,
78-
}
79-
8033
pub struct Conversation<D: RepositoryEmbeddingsDB, M: EmbeddingsModel> {
8134
query: Query,
8235
client: Client,
@@ -133,15 +86,16 @@ impl<D: RepositoryEmbeddingsDB, M: EmbeddingsModel> Conversation<D, M> {
13386
#[allow(unused_labels)]
13487
'conversation: loop {
13588
//Generate a request with the message history and functions
136-
let request = generate_completion_request(self.messages.clone(), true);
89+
let request = generate_completion_request(self.messages.clone(), "auto");
13790

13891
match self.send_request(request).await {
13992
Ok(response) => {
14093
if let FinishReason::function_call = response.choices[0].finish_reason {
14194
if let Some(function_call) =
14295
response.choices[0].message.function_call.clone()
14396
{
144-
let parsed_function_call = parse_function_call(&function_call)?;
97+
let parsed_function_call =
98+
ParsedFunctionCall::try_from(&function_call)?;
14599
let function_call_message = ChatCompletionMessage {
146100
name: None,
147101
function_call: Some(function_call),
@@ -229,22 +183,17 @@ impl<D: RepositoryEmbeddingsDB, M: EmbeddingsModel> Conversation<D, M> {
229183
);
230184
self.append_message(completion_message);
231185
}
232-
Function::None => {
186+
Function::Done => {
233187
self.prepare_final_explanation_message();
234188

235189
//Generate a request with the message history and no functions
236190
let request =
237-
generate_completion_request(self.messages.clone(), false);
238-
emit(
239-
&self.sender,
240-
QueryEvent::GenerateResponse(Some(
241-
parsed_function_call.args,
242-
)),
243-
)
244-
.await;
191+
generate_completion_request(self.messages.clone(), "none");
192+
emit(&self.sender, QueryEvent::GenerateResponse(None)).await;
245193
let response = match self.send_request(request).await {
246194
Ok(response) => response,
247195
Err(e) => {
196+
dbg!(e.to_string());
248197
return Err(e);
249198
}
250199
};
@@ -268,14 +217,3 @@ impl<D: RepositoryEmbeddingsDB, M: EmbeddingsModel> Conversation<D, M> {
268217
}
269218
}
270219
}
271-
272-
fn parse_function_call(func: &FunctionCall) -> Result<ParsedFunctionCall> {
273-
let func = func.clone();
274-
let function_name = Function::from_str(&func.name.unwrap_or("none".into()))?;
275-
let function_args = func.arguments.unwrap_or("{}".to_string());
276-
let function_args = serde_json::from_str::<serde_json::Value>(&function_args)?;
277-
Ok(ParsedFunctionCall {
278-
name: function_name,
279-
args: function_args,
280-
})
281-
}

src/utils/conversation/prompts.rs src/conversation/prompts.rs

+33-51
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,48 @@
11
use openai_api_rs::v1::chat_completion::{
2-
ChatCompletionMessage, ChatCompletionRequest, Function, FunctionParameters, JSONSchemaDefine,
3-
JSONSchemaType, GPT3_5_TURBO,
2+
ChatCompletionMessage, ChatCompletionRequest, Function as F, FunctionParameters,
3+
JSONSchemaDefine, JSONSchemaType, GPT3_5_TURBO,
44
};
55
use std::collections::HashMap;
66

7-
use crate::constants::{FINAL_EXPLANATION_TEMPERATURE, FUNCTIONS_CALLS_TEMPERATURE};
7+
use crate::{constants::CHAT_COMPLETION_TEMPERATURE, utils::functions::Function};
88

99
pub fn generate_completion_request(
1010
messages: Vec<ChatCompletionMessage>,
11-
with_functions: bool,
11+
function_call: &str,
1212
) -> ChatCompletionRequest {
13-
//All the chat completion requests will have functions except for the final explanation request
14-
if with_functions {
15-
ChatCompletionRequest {
16-
model: GPT3_5_TURBO.into(),
17-
messages,
18-
functions: Some(functions()),
19-
function_call: None,
20-
temperature: Some(FUNCTIONS_CALLS_TEMPERATURE),
21-
top_p: None,
22-
n: None,
23-
stream: None,
24-
stop: None,
25-
max_tokens: None,
26-
presence_penalty: None,
27-
frequency_penalty: None,
28-
logit_bias: None,
29-
user: None,
30-
}
31-
} else {
32-
ChatCompletionRequest {
33-
model: GPT3_5_TURBO.into(),
34-
messages,
35-
functions: None,
36-
function_call: None,
37-
temperature: Some(FINAL_EXPLANATION_TEMPERATURE),
38-
top_p: None,
39-
n: None,
40-
stream: None,
41-
stop: None,
42-
max_tokens: None,
43-
presence_penalty: None,
44-
frequency_penalty: None,
45-
logit_bias: None,
46-
user: None,
47-
}
13+
// https://platform.openai.com/docs/api-reference/chat/create
14+
ChatCompletionRequest {
15+
model: GPT3_5_TURBO.into(),
16+
messages,
17+
functions: Some(functions()),
18+
function_call: Some(function_call.into()),
19+
temperature: Some(CHAT_COMPLETION_TEMPERATURE),
20+
top_p: None,
21+
n: None,
22+
stream: None,
23+
stop: None,
24+
max_tokens: None,
25+
presence_penalty: None,
26+
frequency_penalty: None,
27+
logit_bias: None,
28+
user: None,
4829
}
4930
}
5031

51-
pub fn functions() -> Vec<Function> {
32+
// https://platform.openai.com/docs/api-reference/chat/create#chat/create-functions
33+
pub fn functions() -> Vec<F> {
5234
vec![
53-
Function {
54-
name: "none".into(),
35+
F {
36+
name: Function::Done.to_string(),
5537
description: Some("This is the final step, and signals that you have enough information to respond to the user's query.".into()),
5638
parameters: Some(FunctionParameters {
5739
schema_type: JSONSchemaType::Object,
5840
properties: Some(HashMap::new()),
5941
required: None,
6042
}),
6143
},
62-
Function {
63-
name: "search_codebase".into(),
44+
F {
45+
name: Function::SearchCodebase.to_string(),
6446
description: Some("Search the contents of files in a repository semantically. Results will not necessarily match search terms exactly, but should be related.".into()),
6547
parameters: Some(FunctionParameters {
6648
schema_type: JSONSchemaType::Object,
@@ -77,8 +59,8 @@ pub fn functions() -> Vec<Function> {
7759
required: Some(vec!["query".into()]),
7860
})
7961
},
80-
Function {
81-
name: "search_path".into(),
62+
F {
63+
name: Function::SearchPath.to_string(),
8264
description: Some("Search the pathnames in a repository. Results may not be exact matches, but will be similar by some edit-distance. Use when you want to find a specific file".into()),
8365
parameters: Some(FunctionParameters {
8466
schema_type: JSONSchemaType::Object,
@@ -95,8 +77,8 @@ pub fn functions() -> Vec<Function> {
9577
required: Some(vec!["path".into()]),
9678
})
9779
},
98-
Function {
99-
name: "search_file".into(),
80+
F {
81+
name: Function::SearchFile.to_string(),
10082
description: Some("Search a file returned from functions.search_path. Results will not necessarily match search terms exactly, but should be related.".into()),
10183
parameters: Some(FunctionParameters {
10284
schema_type: JSONSchemaType::Object,
@@ -130,13 +112,13 @@ pub fn system_message() -> String {
130112
Follow these rules at all times:
131113
- Respond with functions to find information related to the query, until all relevant information has been found.
132114
- If the output of a function is not relevant or sufficient, try the same function again with different arguments or try using a different function
133-
- When you have enough information to answer the user's query respond with functions.none
115+
- When you have enough information to answer the user's query respond with functions.done
134116
- Do not assume the structure of the codebase, or the existence of files or folders
135117
- Never respond with a function that you've used before with the same arguments
136118
- Do NOT respond with functions.search_file unless you have already called functions.search_path
137-
- If after making a path search the query can be answered by the existance of the paths, use the functions.none function
119+
- If after making a path search the query can be answered by the existance of the paths, use the functions.done function
138120
- Only refer to paths that are returned by the functions.search_path function when calling functions.search_file
139-
- If after attempting to gather information you are still unsure how to answer the query, respond with the functions.none function
121+
- If after attempting to gather information you are still unsure how to answer the query, respond with the functions.done function
140122
- Always respond with a function call. Do NOT answer the question directly"#,
141123
)
142124
}

src/github/mod.rs

+3-5
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ pub async fn fetch_file_content(repository: &Repository, path: &str) -> Result<S
128128
let content = response.text().await?;
129129
Ok(content)
130130
} else {
131-
Ok(String::new())
131+
Err(anyhow::anyhow!("Unable to fetch file content"))
132132
}
133133
}
134134

@@ -195,10 +195,8 @@ mod tests {
195195

196196
let result = fetch_file_content(&repository, path).await;
197197

198-
//Assert that the function returns Result containing an empty string for invalid file path
199-
assert!(result.is_ok());
200-
let content = result.unwrap();
201-
assert!(content.len() == 0);
198+
//Assert that the function returns Err for an invalid file path
199+
assert!(result.is_err());
202200
}
203201

204202
#[test]

src/main.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
mod constants;
2+
mod conversation;
23
mod db;
34
mod embeddings;
45
mod github;

0 commit comments

Comments
 (0)