Skip to content

Commit

Permalink
use db executor
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszKielar committed Sep 11, 2024
1 parent 931c4bc commit 78fd26f
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 29 deletions.
4 changes: 2 additions & 2 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl App {
}

pub async fn init(&mut self) -> AppResult<()> {
let conversations = db::get_conversations(self.sqlite.clone()).await?;
let conversations = db::get_conversations(&self.sqlite).await?;
self.conversations.set_conversations(conversations);

Ok(())
Expand Down Expand Up @@ -119,7 +119,7 @@ impl App {
if let Some(conversation) = self.conversations.currently_selected() {
let user_prompt = self.prompt.get_content();
let user_message = db::create_message(
self.sqlite.clone(),
&self.sqlite,
Role::User,
user_prompt,
conversation.id,
Expand Down
2 changes: 1 addition & 1 deletion src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ impl Chat {
pub async fn load_messages(&mut self, conversation_id: u32) -> AppResult<()> {
self.reset();

let messages = db::get_messages(self.sqlite.clone(), conversation_id).await?;
let messages = db::get_messages(&self.sqlite, conversation_id).await?;
self.messages = messages;

Ok(())
Expand Down
45 changes: 30 additions & 15 deletions src/db.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
use std::time::Duration;

use sqlx::SqlitePool;
use sqlx::{Executor, Sqlite};

use crate::{
models::{Conversation, Message, Role},
AppResult,
};

pub async fn get_conversations(sqlite: SqlitePool) -> AppResult<Vec<Conversation>> {
pub async fn get_conversations<'e, E>(executor: E) -> AppResult<Vec<Conversation>>
where
E: Executor<'e, Database = Sqlite>,
{
let items = sqlx::query_as(
r#"
SELECT *
FROM conversations
ORDER BY created_at ASC
"#,
)
.fetch_all(&sqlite)
.persistent(false)
.fetch_all(executor)
.await?;

Ok(items)
}

pub async fn get_messages(sqlite: SqlitePool, conversation_id: u32) -> AppResult<Vec<Message>> {
pub async fn get_messages<'e, E>(executor: E, conversation_id: u32) -> AppResult<Vec<Message>>
where
E: Executor<'e, Database = Sqlite>,
{
// TODO: for some weird reason query is cached despite setting persistence to false
tokio::time::sleep(Duration::from_millis(10)).await;
let items = sqlx::query_as(
r#"
Expand All @@ -32,19 +40,22 @@ pub async fn get_messages(sqlite: SqlitePool, conversation_id: u32) -> AppResult
"#,
)
.bind(conversation_id)
.fetch_all(&sqlite)
.persistent(false)
.fetch_all(executor)
.await?;

Ok(items)
}

// TODO: fix transactions
pub async fn create_message(
sqlite: SqlitePool,
pub async fn create_message<'e, E>(
executor: E,
role: Role,
content: String,
conversation_id: u32,
) -> AppResult<Message> {
) -> AppResult<Message>
where
E: Executor<'e, Database = Sqlite>,
{
let new_message: Message = sqlx::query_as(
r#"
INSERT INTO messages(role, content, conversation_id)
Expand All @@ -55,18 +66,21 @@ pub async fn create_message(
.bind(role.to_string())
.bind(content)
.bind(conversation_id)
.fetch_one(&sqlite)
.persistent(false)
.fetch_one(executor)
.await?;

Ok(new_message)
}

// TODO: fix transactions
pub async fn update_message(
sqlite: SqlitePool,
pub async fn update_message<'e, E>(
executor: E,
content: String,
message_id: u32,
) -> AppResult<Message> {
) -> AppResult<Message>
where
E: Executor<'e, Database = Sqlite>,
{
let updated_message: Message = sqlx::query_as(
r#"
UPDATE messages
Expand All @@ -77,7 +91,8 @@ pub async fn update_message(
)
.bind(content)
.bind(message_id)
.fetch_one(&sqlite)
.persistent(false)
.fetch_one(executor)
.await?;

Ok(updated_message)
Expand Down
19 changes: 8 additions & 11 deletions src/ollama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async fn inference(
) -> AppResult<()> {
while let Some(inference_message) = inference_rx.recv().await {
let conversation_id = inference_message.conversation_id;
let messages = get_messages(sqlite.clone(), conversation_id).await?;
let messages = get_messages(&sqlite, conversation_id).await?;

let params = OllamaChatParams::new(DEFAULT_LLM_MODEL.to_string(), messages, false);

Expand All @@ -103,7 +103,7 @@ async fn inference(
let content = response.message.content.trim().to_string();

let assistant_response =
create_message(sqlite.clone(), Role::Assistant, content, conversation_id).await?;
create_message(&sqlite, Role::Assistant, content, conversation_id).await?;

let _ = event_tx.send(Event::Inference(
assistant_response,
Expand All @@ -122,7 +122,7 @@ async fn inference_stream(
) -> AppResult<()> {
while let Some(inference_message) = inference_rx.recv().await {
let conversation_id = inference_message.conversation_id;
let messages = get_messages(sqlite.clone(), conversation_id).await?;
let messages = get_messages(&sqlite, conversation_id).await?;

let params = OllamaChatParams::new(DEFAULT_LLM_MODEL.to_string(), messages, true);

Expand All @@ -135,13 +135,9 @@ async fn inference_stream(
.map(|chunk| chunk.unwrap())
.map(|chunk| serde_json::from_slice::<OllamaChatResponseStream>(&chunk));

let assistant_response = create_message(
sqlite.clone(),
Role::Assistant,
"".to_string(),
conversation_id,
)
.await?;
let mut tx = sqlite.begin().await?;
let assistant_response =
create_message(&mut *tx, Role::Assistant, "".to_string(), conversation_id).await?;

let mut is_first_chunk = true;
let mut content = String::new();
Expand Down Expand Up @@ -173,7 +169,8 @@ async fn inference_stream(
}
}

update_message(sqlite.clone(), content, assistant_response.id).await?;
update_message(&mut *tx, content, assistant_response.id).await?;
tx.commit().await?;
}

Ok(())
Expand Down

0 comments on commit 78fd26f

Please sign in to comment.