Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszKielar committed May 5, 2024
1 parent 5ef97e1 commit 27981ac
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 65 deletions.
13 changes: 2 additions & 11 deletions src/frontend/components/message.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::models::{Message as MessageModel, Role};
pub(crate) fn Message(message: MaybeSignal<MessageModel>) -> impl IntoView {
let message = message.get();
let message_role = message.role;
let message_content = message.content;
let message_content = view! { <p>{message.content}</p> }.into_view();

let is_user = match Role::from(message_role) {
Role::Assistant | Role::System => false,
Expand All @@ -30,15 +30,6 @@ pub(crate) fn Message(message: MaybeSignal<MessageModel>) -> impl IntoView {
}
};

let message_content_class = {
if !is_user && message_content == "" {
view! { <Icon icon=icondata::LuTextCursorInput class="h-6 w-6 animate-pulse"/> }
.into_view()
} else {
view! { <p>{message_content}</p> }.into_view()
}
};

view! {
<div class=message_class>
<div class="text-base gap-4 md:gap-6 md:max-w-2xl lg:max-w-xl xl:max-w-3xl flex lg:px-0 m-auto w-full">
Expand All @@ -57,7 +48,7 @@ pub(crate) fn Message(message: MaybeSignal<MessageModel>) -> impl IntoView {
<div class="flex flex-grow flex-col gap-3">
<div class="min-h-10 flex flex-col items-start gap-4 whitespace-pre-wrap break-words">
<div class="markdown prose w-full break-words dark:prose-invert dark">
{message_content_class}
{message_content}
</div>
</div>
</div>
Expand Down
20 changes: 10 additions & 10 deletions src/frontend/views/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub(crate) fn Chat() -> impl IntoView {
} = use_context().unwrap();

let UseWebsocketReturn {
ready_state,
message: assistant_message,
send,
..
Expand Down Expand Up @@ -57,10 +58,12 @@ pub(crate) fn Chat() -> impl IntoView {
let dispatch = move |send: &dyn Fn(&str)| {
let user_message =
models::Message::user(Uuid::new_v4(), user_prompt.get(), conversation_id());
if user_message.content != "" {
if user_message.content != "" && !server_response_pending.get() {
messages.update(|msgs| msgs.push(user_message.clone()));
server_response_pending.set(true);
send(&serde_json::to_string(&user_message).unwrap());
messages.update(|msgs| msgs.push(user_message.clone()));
logging::log!("prompt send to the server: {:?}", user_message.content);

user_prompt.set("".to_string());
}
};
Expand Down Expand Up @@ -88,20 +91,16 @@ pub(crate) fn Chat() -> impl IntoView {
match models::Role::from(last_msg.role.clone()) {
models::Role::System => panic!("cannot happen"),
models::Role::User => {
logging::log!("adding assistant to the list: {:?}", assistant_message);
messages.update(|msgs| msgs.push(assistant_message));
if assistant_message.content != "" {
messages.update(|msgs| msgs.push(assistant_message));
}
}
models::Role::Assistant => {
if let Some(message_to_update_position) = messages
.get()
.iter()
.position(|m| m.id == assistant_message.id)
{
logging::log!(
"adding assistant to the existing message at [{:?}]: {:?}",
message_to_update_position,
assistant_message
);
// SAFETY: it's safe to unwrap, because I've just got the index
let current_message = messages
.get()
Expand Down Expand Up @@ -132,7 +131,8 @@ pub(crate) fn Chat() -> impl IntoView {
<div class="scroll-to-bottom--css-ikyem-1n7m0yu">
<div class="flex flex-col items-center text-sm bg-gray-800">
<div class="flex w-full items-center justify-center gap-1 border-b border-black/10 bg-gray-50 p-3 text-gray-500 dark:border-gray-900/50 dark:bg-gray-700 dark:text-gray-300">
"Model: " <b>{model}</b>
"Model: " <b>{model}</b> ", Server: "
{move || ready_state.get().to_string()}
</div>
<Transition>
// TODO: reload only when necessary
Expand Down
15 changes: 6 additions & 9 deletions src/frontend/views/home.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pub(crate) fn Home() -> impl IntoView {
} = use_context().unwrap();

let UseWebsocketReturn {
ready_state,
message: assistant_message,
send,
..
Expand Down Expand Up @@ -46,10 +47,11 @@ pub(crate) fn Home() -> impl IntoView {
let dispatch = move |send: &dyn Fn(&str)| {
let user_message =
models::Message::user(Uuid::new_v4(), user_prompt.get(), conversation_id());
if user_message.content != "" {
if user_message.content != "" && !server_response_pending.get() {
messages.update(|msgs| msgs.push(user_message.clone()));
server_response_pending.set(true);
send(&serde_json::to_string(&user_message).unwrap());
messages.update(|msgs| msgs.push(user_message.clone()));
logging::log!("prompt send to the server: {:?}", user_message.content);

// TODO: I should prob get this object from server response
let conversation = models::Conversation {
Expand Down Expand Up @@ -88,7 +90,6 @@ pub(crate) fn Home() -> impl IntoView {
match models::Role::from(last_msg.role.clone()) {
models::Role::System => panic!("cannot happen"),
models::Role::User => {
logging::log!("adding assistant to the list: {:?}", assistant_message);
messages.update(|msgs| msgs.push(assistant_message));
}
models::Role::Assistant => {
Expand All @@ -97,11 +98,6 @@ pub(crate) fn Home() -> impl IntoView {
.iter()
.position(|m| m.id == assistant_message.id)
{
logging::log!(
"adding assistant to the existing message at [{:?}]: {:?}",
message_to_update_position,
assistant_message
);
// SAFETY: it's safe to unwrap, because I've just got the index
let current_message = messages
.get()
Expand Down Expand Up @@ -133,7 +129,8 @@ pub(crate) fn Home() -> impl IntoView {
<div class="flex flex-col items-center text-sm bg-gray-800">

<div class="flex w-full items-center justify-center gap-1 border-b border-black/10 bg-gray-50 p-3 text-gray-500 dark:border-gray-900/50 dark:bg-gray-700 dark:text-gray-300">
"Model: " <b>{model}</b>
"Model: " <b>{model}</b> ", Server: "
{move || ready_state.get().to_string()}
</div>
<Messages messages=messages.into()/>
<div class="w-full h-32 flex-shrink-0"></div>
Expand Down
106 changes: 71 additions & 35 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async fn main() {
let db_pool = db_pool.clone();
let conversation = lokai::models::Conversation::new(
Uuid::new_v4(),
String::from("conversation 2 strasznie dluga nazwa"),
String::from("conversation 2 very loooooooooooong name"),
);
let _ = db::create_conversation(db_pool, conversation)
.await
Expand Down Expand Up @@ -109,41 +109,77 @@ cfg_if::cfg_if! {
use lokai::server::ollama::{OllamaChatResponseStream,OllamaChatParams, default_model};
use lokai::server::db;
use futures_util::StreamExt;
use futures_util::SinkExt as _;

pub async fn websocket(ws: WebSocketUpgrade, State(app_state): State<AppState>) -> Response {
logging::log!("ws: {:?}", ws);
ws.on_upgrade(|socket| handle_socket(socket, app_state))
}

async fn handle_socket(mut socket: WebSocket, app_state: AppState) {
let (tx, mut rx) = tokio::sync::mpsc::channel::<String>(100);

while let Some(msg) = socket.recv().await {
let tx = tx.clone();
let app_state = app_state.clone();
logging::log!("socket.recv(): {:?}", msg);
if let Ok(WebSocketMessage::Text(payload)) = msg {
logging::log!("axum::extract::ws::Message::Text: {:?}", payload);
let user_prompt = serde_json::from_str::<models::Message>(&payload).unwrap();
tokio::spawn(async move {
inference(user_prompt, tx, app_state).await
});
} else {
// client disconnected
return;
async fn handle_socket(socket: WebSocket, app_state: AppState) {
logging::log!("socket: {:?}", socket);
let (inference_request_tx, mut inference_request_rx) = tokio::sync::mpsc::channel::<models::Message>(100);
let (inference_response_tx, mut inference_response_rx) = tokio::sync::mpsc::channel::<models::Message>(100);

// inference thread
let inference_thread = tokio::spawn(async move {
logging::log!("inference thread started");
while let Some(user_prompt) = inference_request_rx.recv().await {
let inference_response_tx_clone = inference_response_tx.clone();
let app_state_clone = app_state.clone();
inference(user_prompt, inference_response_tx_clone, app_state_clone).await;
};
logging::log!("inference thread exited");
});

while let Some(msg) = rx.recv().await {
if socket.send(WebSocketMessage::Text(msg)).await.is_err() {
let (mut sender, mut receiver) = socket.split();

// receiver thread
let _ = tokio::spawn(async move {
while let Some(assistant_response_chunk) = inference_response_rx.recv().await {
logging::log!("got assistant response chunk: {:?}", assistant_response_chunk);
let assistant_response_chunk_json = serde_json::to_string(&assistant_response_chunk).unwrap();
if sender.send(WebSocketMessage::Text(assistant_response_chunk_json)).await.is_err() {
// client disconnected
return ;
}
};
});

// sender thread
let _ = tokio::spawn(async move {
while let Some(Ok(WebSocketMessage::Text(user_prompt))) = receiver.next().await {
logging::log!("message received through the socket: {:?}", user_prompt);
let user_prompt = serde_json::from_str::<models::Message>(&user_prompt).unwrap();
let _ = inference_request_tx.send(user_prompt).await;
}
};
});

// https://github.com/tokio-rs/axum/blob/main/examples/websockets/src/main.rs
// tokio::select! {
// rv_a = (&mut assistant_response) => {
// match rv_a {
// Ok(_) => println!("rv_a arm Ok"),
// Err(_) => println!("rv_a arm Err"),
// }
// user_prompt_request.abort();
// },
// rv_b = (&mut user_prompt_request) => {
// match rv_b {
// Ok(_) => println!("rv_a arm Ok"),
// Err(_) => println!("rv_a arm Err"),
// }
// assistant_response.abort();
// },
// }

let _ = inference_thread.await;
logging::log!("I exited!")
}

// TODO: use transactions
async fn inference(user_prompt: models::Message, tx: tokio::sync::mpsc::Sender<String>, app_state: AppState) {
logging::log!("Got user prompt for inference: {:?}", user_prompt);
async fn inference(user_prompt: models::Message, inference_response_tx: tokio::sync::mpsc::Sender<models::Message>, app_state: AppState) {
logging::log!("got user prompt for inference: {:?}", user_prompt);
let client = app_state.reqwest_client;

let conversation = models::Conversation::new(user_prompt.conversation_id, user_prompt.content.clone());
Expand Down Expand Up @@ -175,25 +211,25 @@ cfg_if::cfg_if! {
.unwrap()
.bytes_stream().map(|chunk| chunk.unwrap()).map(|chunk| serde_json::from_slice::<OllamaChatResponseStream>(&chunk));

let mut assistant_response = models::Message::assistant(Uuid::new_v4(), "".to_string(), conversation_id);
let mut assistant_response = models::Message::assistant(Uuid::new_v4(), "".to_string(), conversation_id);

while let Some(chunk) = stream.next().await {
if let Ok(chunk) = chunk {
assistant_response.update_content(&chunk.message.content);
while let Some(chunk) = stream.next().await {
if let Ok(chunk) = chunk {
assistant_response.update_content(&chunk.message.content);

let assistant_response_chunk = models::Message::assistant(assistant_response.id, chunk.message.content, conversation_id);
let assistant_response_json = serde_json::to_string(&assistant_response_chunk).unwrap();
if tx.send(assistant_response_json.to_string()).await.is_err() {
break;
};
let assistant_response_chunk = models::Message::assistant(assistant_response.id, chunk.message.content, conversation_id);
if inference_response_tx.send(assistant_response_chunk).await.is_err() {
break;
};

if chunk.done {
break;
}
if chunk.done {
break;
}
}
}

let _ = db::create_message(app_state.db_pool, assistant_response).await;
let _ = db::create_message(app_state.db_pool, assistant_response).await;
logging::log!("inference done");
}
}
}

0 comments on commit 27981ac

Please sign in to comment.