Skip to content

Commit

Permalink
add database
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszKielar committed Apr 16, 2024
1 parent 5816d6b commit e70b8cc
Showing 11 changed files with 183 additions and 49 deletions.
2 changes: 0 additions & 2 deletions .devcontainer/Dockerfile
Original file line number Diff line number Diff line change
@@ -20,8 +20,6 @@ RUN cargo install sqlx-cli --no-default-features --features native-tls,sqlite
RUN curl -fsSL https://deb.nodesource.com/setup_20.x | bash - \
&& apt-get install -y nodejs

RUN npm install [email protected]

USER vscode

RUN sudo chown -R $(whoami) /usr/local/* \
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -13,3 +13,4 @@ end2end/playwright-report/
playwright/.cache/

db.sqlite3*
localhost.sql
8 changes: 8 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@ sqlx = { version = "0.7.4", default-features = false, features = [
"uuid",
"macros",
], optional = true }
uuid = { version = "1.8", features = ["v4"] }
uuid = { version = "1.8", features = ["v7", "serde"] }

[features]
hydrate = ["leptos/hydrate", "leptos_meta/hydrate", "leptos_router/hydrate"]
149 changes: 119 additions & 30 deletions src/api.rs
Original file line number Diff line number Diff line change
@@ -1,49 +1,128 @@
use leptos::{logging, server, use_context, ServerFnError};
use serde::{Deserialize, Serialize};

#[cfg(feature = "ssr")]
use crate::api::{
db::save_message,
ollama::{default_model, OllamaChatParams, OllamaChatResponse, OllamaMessage},
};
use crate::models::{Conversation, Message};
use leptos::server;
use leptos::ServerFnError;

// const MODEL: &str = "llama2:7b";
const MODEL: &str = "mistral:7b";
// const MODEL: &str = "tinyllama";

// TODO: I need to save a context of the chat into DB
// that would help when user decided to come back to old conversation
// I won't be feeding model with previous prompts
// asynchronously save everything to DB (maybe in batch mode?? - future consideration)
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct ChatResponse {
pub message: Message,
}
// TODO: move to separate module backend/db
#[cfg(feature = "ssr")]
mod db {
use sqlx::SqlitePool;
use uuid::Uuid;

use crate::models::Message;

fn default_model() -> String {
MODEL.to_string()
// TODO: user proper error handling
pub async fn save_message(
pool: SqlitePool,
message: Message,
conversation_id: Uuid,
) -> Result<i64, String> {
let id = sqlx::query!(
r#"
INSERT INTO messages ( id, role, content, conversation_id )
VALUES ( ?1, ?2, ?3, ?4 )
"#,
message.id,
message.role,
message.content,
conversation_id
)
.execute(&pool)
.await
.unwrap()
.last_insert_rowid();

Ok(id)
}
}

#[derive(Deserialize, Serialize, Debug)]
struct ChatParams {
#[serde(default = "default_model")]
model: String,
messages: Vec<Message>,
#[serde(default)]
stream: bool,
// TODO: move to separate module backend/ollama
#[cfg(feature = "ssr")]
mod ollama {
use serde::{Deserialize, Serialize};

use crate::models::{Message, Role};

const MODEL: &str = "mistral:7b";
// const MODEL: &str = "llama2:7b";
// const MODEL: &str = "tinyllama";

pub fn default_model() -> String {
MODEL.to_string()
}

#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct OllamaMessage {
pub role: Role,
pub content: String,
}

impl From<&Message> for OllamaMessage {
fn from(value: &Message) -> Self {
Self {
role: Role::from(value.role.as_ref()),
content: value.content.clone(),
}
}
}

// TODO: I need to save a context of the chat into DB
// that would help when user decided to come back to old conversation
// I won't be feeding model with previous prompts
// asynchronously save everything to DB (maybe in batch mode?? - future consideration)
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct OllamaChatResponse {
pub message: OllamaMessage,
}

#[derive(Deserialize, Serialize, Debug)]
pub struct OllamaChatParams {
#[serde(default = "default_model")]
pub model: String,
pub messages: Vec<OllamaMessage>,
#[serde(default)]
pub stream: bool,
}
}

// TODO: save every prompt, response and context to database, async thread
// TODO: this function should take id of the conversation, prompt and context (history of conversation)
#[server(Chat, "/api/chat")]
pub async fn chat(conversation: Conversation) -> Result<Message, ServerFnError> {
#[server(Chat, "/api", "Url", "chat")]
pub async fn chat(
// TODO: I need to reduce amount of data send over the wire (maybe conversation_id and fetch it from the DB on a server?)
conversation: Conversation,
user_message: Message,
) -> Result<Message, ServerFnError> {
use leptos::{logging, use_context};
use reqwest;
use sqlx::SqlitePool;
use tokio::spawn;

let db_pool = use_context::<SqlitePool>().expect("SqlitePool not found");
let save_user_message = {
let db_pool = db_pool.clone();
let user_message = user_message.clone();
spawn(async move { save_message(db_pool, user_message, conversation.id) })
};

// TODO: handle lack of context
let client = use_context::<reqwest::Client>().expect("reqwest.Client not found");
let params = ChatParams {
let params = OllamaChatParams {
model: default_model(),
messages: conversation.messages,
messages: conversation
.iter()
.map(|m| OllamaMessage::from(m))
.collect(),
stream: false,
};
logging::log!("request params: {:?}", params);

// TODO: properly handle errors
let response: ChatResponse = client
let response: OllamaChatResponse = client
.post("http://host.docker.internal:11434/api/chat")
.json(&params)
.send()
@@ -54,6 +133,16 @@ pub async fn chat(conversation: Conversation) -> Result<Message, ServerFnError>
.unwrap();

logging::log!("response: {:?}", response);
let assistant_message = Message::assistant(response.message.content, conversation.id);
let save_assistant_message = {
let db_pool = db_pool.clone();
let assistant_message = assistant_message.clone();
spawn(async move { save_message(db_pool, assistant_message, conversation.id) })
};

// TODO: handle failures
let _ = save_user_message.await?;
let _ = save_assistant_message.await?;

Ok(response.message)
Ok(assistant_message)
}
9 changes: 6 additions & 3 deletions src/app.rs
Original file line number Diff line number Diff line change
@@ -17,10 +17,13 @@ pub fn App() -> impl IntoView {
// TODO: throw an error when prompt is empty
let send_prompt = create_action(move |prompt: &String| {
let prompt = prompt.to_owned();
let user_message = Message::user(prompt.clone());
set_conversation.update(move |c| c.append_message(user_message));
let user_message = Message::user(prompt.clone(), conversation().id);
{
let user_message = user_message.clone();
set_conversation.update(move |c| c.append_message(user_message));
}

chat(conversation())
chat(conversation(), user_message)
});

// TODO: disable submit button when we're waiting for server's response
2 changes: 1 addition & 1 deletion src/components/conversation_area.rs
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ pub fn ConversationArea(conversation: ReadSignal<Conversation>) -> impl IntoView
conversation()
.iter()
.map(move |msg| {
let message_class = match msg.role {
let message_class = match Role::from(msg.role.as_ref()) {
Role::User => USER_MESSAGE_CLASS,
Role::Assistant => ASSISTANT_MESSAGE_CLASS,
_ => panic!("system message not supported yet"),
4 changes: 2 additions & 2 deletions src/handlers.rs
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@ pub async fn server_fn_handler(
) -> impl IntoResponse {
handle_server_fns_with_context(
move || {
provide_context(app_state.pool.clone());
provide_context(app_state.db_pool.clone());
provide_context(app_state.reqwest_client.clone());
},
request,
@@ -30,7 +30,7 @@ pub async fn leptos_routes_handler(
app_state.leptos_options.clone(),
app_state.routes.clone(),
move || {
provide_context(app_state.pool.clone());
provide_context(app_state.db_pool.clone());
provide_context(app_state.reqwest_client.clone());
},
App,
4 changes: 2 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -15,7 +15,7 @@ async fn main() {
use sqlx::SqlitePool;

let db_url = env::var("DATABASE_URL").expect("DATABASE_URL not set");
let pool: SqlitePool = SqlitePoolOptions::new()
let db_pool: SqlitePool = SqlitePoolOptions::new()
.connect(&db_url)
.await
.expect("Could not make pool.");
@@ -32,7 +32,7 @@ async fn main() {

let app_state = AppState {
leptos_options,
pool: pool,
db_pool: db_pool,
reqwest_client: reqwest::Client::new(),
routes: routes.clone(),
};
49 changes: 42 additions & 7 deletions src/models.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use serde::{Deserialize, Serialize};
#[cfg(feature = "ssr")]
use sqlx::FromRow;
use uuid::Uuid;

#[derive(Deserialize, Serialize, Debug, Clone, Copy)]
pub enum Role {
@@ -10,37 +13,69 @@ pub enum Role {
Assistant,
}

impl From<&str> for Role {
fn from(value: &str) -> Self {
match value {
"system" => Role::System,
"user" => Role::User,
"assistant" => Role::Assistant,
_ => panic!("Unknown Role!"),
}
}
}

impl std::fmt::Display for Role {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Role::System => write!(f, "system"),
Role::User => write!(f, "user"),
Role::Assistant => write!(f, "assistant"),
}
}
}

// TODO: message should be reactive, saying, whenever it changes, I should update UI
#[derive(Deserialize, Serialize, Debug, Clone)]
#[cfg_attr(feature = "ssr", derive(FromRow))]
pub struct Message {
pub role: Role,
pub id: Uuid,
pub role: String,
pub content: String,
pub conversation_id: Uuid,
}

impl Message {
fn new(role: Role, content: String) -> Self {
Self { role, content }
fn new(role: Role, content: String, conversation_id: Uuid) -> Self {
Self {
id: Uuid::now_v7(),
role: role.to_string(),
content,
conversation_id,
}
}

pub fn user(content: String) -> Self {
Self::new(Role::User, content)
pub fn user(content: String, conversation_id: Uuid) -> Self {
Self::new(Role::User, content, conversation_id)
}

pub fn assistant(content: String) -> Self {
Self::new(Role::Assistant, content)
pub fn assistant(content: String, conversation_id: Uuid) -> Self {
Self::new(Role::Assistant, content, conversation_id)
}
}

// TODO: it should contain: id(uuid), messages (vec<Message>))
// Message should contain: id (uuid), persona (enum or string - human/assistant), text (string)
#[derive(Deserialize, Serialize, Debug, Clone)]
#[cfg_attr(feature = "ssr", derive(FromRow))]
pub struct Conversation {
pub id: Uuid,
pub messages: Vec<Message>,
}

impl Conversation {
pub fn new() -> Self {
Self {
id: Uuid::now_v7(),
messages: Vec::new(),
}
}
2 changes: 1 addition & 1 deletion src/state.rs
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@ use sqlx::SqlitePool;
#[derive(FromRef, Debug, Clone)]
pub struct AppState {
pub leptos_options: LeptosOptions,
pub pool: SqlitePool,
pub db_pool: SqlitePool,
pub reqwest_client: reqwest::Client,
pub routes: Vec<RouteListing>,
}

0 comments on commit e70b8cc

Please sign in to comment.