Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
lukaszKielar committed Sep 7, 2024
1 parent a84a037 commit 70eac00
Show file tree
Hide file tree
Showing 7 changed files with 205 additions and 114 deletions.
1 change: 1 addition & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"extensions": [
"esbenp.prettier-vscode",
"fill-labs.dependi",
"Gruntfuggly.todo-tree",
"ms-azuretools.vscode-docker",
"oderwat.indent-rainbow"
]
Expand Down
80 changes: 36 additions & 44 deletions src/app.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
use sqlx::SqlitePool;
use tokio::sync::mpsc::{self, UnboundedSender};
use tokio::sync::mpsc::{self, Sender, UnboundedSender};

use crate::{
chat::Chat,
Expand Down Expand Up @@ -46,12 +46,12 @@ impl AppFocus {

// TODO: create shared AppState(SqlitePool)

// TODO: make all attrs private
pub struct App {
pub chat: Chat,
pub conversations: Conversations,
pub prompt: Prompt<'static>,
pub prompt: Prompt,
focus: AppFocus,
inference_tx: Sender<Message>,
running: bool,
sqlite: SqlitePool,
_ollama: Ollama,
Expand All @@ -61,10 +61,11 @@ impl App {
pub fn new(sqlite: SqlitePool, event_tx: UnboundedSender<Event>) -> Self {
let (inference_tx, inference_rx) = mpsc::channel::<Message>(10);
Self {
chat: Default::default(),
conversations: Default::default(),
prompt: Prompt::new(inference_tx),
chat: Chat::new(sqlite.clone()),
conversations: Conversations::new(sqlite.clone()),
prompt: Default::default(),
focus: Default::default(),
inference_tx,
running: true,
sqlite: sqlite.clone(),
_ollama: Ollama::new(sqlite, inference_rx, event_tx),
Expand All @@ -73,7 +74,7 @@ impl App {

pub async fn init(&mut self) -> AppResult<()> {
let conversations = db::get_conversations(self.sqlite.clone()).await?;
self.conversations.conversations = conversations;
self.conversations.set_conversations(conversations);

Ok(())
}
Expand Down Expand Up @@ -101,7 +102,7 @@ impl App {
if key_event.modifiers == KeyModifiers::CONTROL {
self.running = false;
} else {
self.prompt.text_area.input(key_event);
self.prompt.handle_input(key_event);
}
}
KeyCode::Enter => {
Expand All @@ -110,22 +111,20 @@ impl App {
// https://github.com/crossterm-rs/crossterm/issues/685
if let AppFocus::Prompt = self.current_focus() {
if key_event.modifiers == KeyModifiers::SHIFT {
self.prompt.text_area.insert_str("\n");
self.prompt.new_line();
} else {
// we're able to send only when we have selected conversation
if let Some(conversation) = self.conversations.currently_selected() {
// TODO: 1. get text, 2. send text to inference thread, 3. clear input
let user_input =
self.prompt.text_area.lines().join("\n").trim().to_string();
let user_prompt = self.prompt.get_content();
let user_message = db::create_message(
self.sqlite.clone(),
Role::User,
user_input,
user_prompt,
conversation.id,
)
.await?;
self.chat.messages.push(user_message.clone());
self.prompt.inference_tx.send(user_message).await?;
self.chat.push(user_message.clone());
self.inference_tx.send(user_message).await?;
self.prompt.clear();
}
}
Expand All @@ -136,72 +135,65 @@ impl App {
// 1. get index of conversation
// 2. get messages for conversation
// 3. mutate state of app by assigning messages to proper attr
self.conversations.state.scroll_down_by(1);
if let Some(current_index) = self.conversations.state.selected() {
self.chat.state.select(None);
if let Some(item) = self.conversations.conversations.get(current_index) {
let messages = db::get_messages(self.sqlite.clone(), item.id).await?;
self.chat.messages = messages;
}
};
self.conversations.down();
if let Some(conversation) = self.conversations.currently_selected() {
self.chat.load_messages(conversation.id).await?;
}
}
AppFocus::Messages => self.chat.state.scroll_down_by(1),
AppFocus::Messages => self.chat.down(),
AppFocus::Prompt => {
self.prompt.text_area.input(key_event);
self.prompt.handle_input(key_event);
}
},
KeyCode::Up => match self.current_focus() {
AppFocus::Conversation => {
self.conversations.state.scroll_up_by(1);
if let Some(current_index) = self.conversations.state.selected() {
if let Some(item) = self.conversations.conversations.get(current_index) {
let messages = db::get_messages(self.sqlite.clone(), item.id).await?;
self.chat.messages = messages;
}
};
self.conversations.up();
if let Some(conversation) = self.conversations.currently_selected() {
self.chat.load_messages(conversation.id).await?;
}
}
AppFocus::Messages => self.chat.state.scroll_up_by(1),
AppFocus::Messages => self.chat.up(),
AppFocus::Prompt => {
self.prompt.text_area.input(key_event);
self.prompt.handle_input(key_event);
}
},
KeyCode::Esc => match self.current_focus() {
AppFocus::Conversation => {
self.conversations.state.select(None);
self.chat.messages = vec![];
self.conversations.unselect();
self.chat.reset();
}
AppFocus::Messages => self.chat.state.select(None),
AppFocus::Prompt => {}
AppFocus::Messages => self.chat.unselect(),
_ => {}
},
KeyCode::Tab => self.next_focus(),
KeyCode::BackTab => self.previous_focus(),
_ => {
if let AppFocus::Prompt = self.current_focus() {
self.prompt.text_area.input(key_event);
self.prompt.handle_input(key_event);
}
}
}
Ok(())
}

async fn handle_inference_event(&mut self, message: Message) -> AppResult<()> {
self.chat.messages.push(message);
self.chat.push(message);

Ok(())
}

async fn handle_inference_stream_event(&mut self, message: Message) -> AppResult<()> {
if let Some(last_message) = self.chat.messages.last() {
if let Some(last_message) = self.chat.last() {
if let Some(conversation) = self.conversations.currently_selected() {
if conversation.id.eq(&message.conversation_id) {
match last_message.role {
Role::Assistant => {
self.chat.messages.pop();
self.chat.messages.push(message);
self.chat.pop();
self.chat.push(message);
}
Role::System => {}
Role::User => {
self.chat.messages.push(message);
self.chat.push(message);
}
}
}
Expand Down
76 changes: 70 additions & 6 deletions src/chat.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,75 @@
use ratatui::widgets::ListState;
use ratatui::widgets::{List, ListItem, ListState};
use sqlx::SqlitePool;

use crate::models::Message;
use crate::{db, models::Message, AppResult};

// TODO: make all attrs private
// TODO: create common StatefulList trait and implement it for ConversationList and MessageList
#[derive(Default)]
pub struct Chat {
pub messages: Vec<Message>,
messages: Vec<Message>,
pub state: ListState,
sqlite: SqlitePool,
}

impl Chat {
pub fn new(sqlite: SqlitePool) -> Self {
Self {
messages: vec![],
state: Default::default(),
sqlite,
}
}

pub fn currently_selected(&self) -> Option<Message> {
let selected_index = self.state.selected()?;
self.messages.get(selected_index).cloned()
}

pub fn reset(&mut self) {
self.unselect();
self.messages = vec![];
}

pub fn unselect(&mut self) {
self.state.select(None);
}

pub fn push(&mut self, message: Message) {
self.messages.push(message);
}

pub fn pop(&mut self) {
self.messages.pop();
}

pub fn last(&self) -> Option<&Message> {
self.messages.last()
}

pub fn up(&mut self) {
self.state.scroll_up_by(1);
}

pub fn down(&mut self) {
self.state.scroll_down_by(1);
}

pub async fn load_messages(&mut self, conversation_id: u32) -> AppResult<()> {
let messages = db::get_messages(self.sqlite.clone(), conversation_id).await?;
self.messages = messages;

Ok(())
}

pub fn as_list_widget<F, T>(&self, f: F) -> List<'static>
where
F: Fn(&Message) -> T,
T: Into<ListItem<'static>>,
{
let items = self
.messages
.iter()
.map(|elem| f(elem).into())
.collect::<Vec<ListItem>>();

List::new(items)
}
}
51 changes: 46 additions & 5 deletions src/conversations.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,59 @@
use ratatui::widgets::ListState;
use ratatui::widgets::{List, ListItem, ListState};
use sqlx::SqlitePool;

use crate::models::Conversation;

// TODO: make all attrs private
// TODO: create common StatefulList trait and implement it for Conversations and MessageList
#[derive(Default)]
pub struct Conversations {
pub conversations: Vec<Conversation>,
conversations: Vec<Conversation>,
pub state: ListState,
_sqlite: SqlitePool,
}

impl Conversations {
pub fn new(sqlite: SqlitePool) -> Self {
Self {
conversations: vec![],
state: Default::default(),
_sqlite: sqlite,
}
}

pub fn set_conversations(&mut self, conversations: Vec<Conversation>) {
self.conversations = conversations;
}

pub fn currently_selected(&self) -> Option<Conversation> {
let selected_index = self.state.selected()?;
self.conversations.get(selected_index).cloned()
}

pub fn unselect(&mut self) {
self.state.select(None);
}

pub fn push(&mut self, conversation: Conversation) {
self.conversations.push(conversation);
}

pub fn up(&mut self) {
self.state.scroll_up_by(1);
}

pub fn down(&mut self) {
self.state.scroll_down_by(1);
}

pub fn as_list_widget<F, T>(&self, f: F) -> List<'static>
where
F: Fn(&Conversation) -> T,
T: Into<ListItem<'static>>,
{
let items = self
.conversations
.iter()
.map(|elem| f(elem).into())
.collect::<Vec<ListItem>>();

List::new(items)
}
}
25 changes: 0 additions & 25 deletions src/models.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use chrono::{DateTime, Utc};
use ratatui::text::Text;
use serde::{Deserialize, Serialize};
use sqlx::{sqlite::SqliteRow, FromRow, Row};
use textwrap::Options;

#[derive(Serialize, Deserialize, FromRow, Debug, Clone)]
pub struct Conversation {
Expand Down Expand Up @@ -40,29 +38,6 @@ pub struct Message {
pub created_at: DateTime<Utc>,
}

impl Message {
pub fn wrapped(&self, width: usize) -> Text<'_> {
let icon = match self.role {
Role::Assistant => "🤖",
Role::System => "🧰",
Role::User => "👤",
};
let content = format!("{} {}", icon, self.content.trim());
Text::from(textwrap::wrap(&content, Options::new(width)).join("\n"))
}
}

impl<'a> From<&Message> for Text<'a> {
fn from(val: &Message) -> Self {
let icon = match val.role {
Role::Assistant => "🤖",
Role::System => "🧰",
Role::User => "👤",
};
format!("{} {}", icon, val.content).into()
}
}

impl FromRow<'_, SqliteRow> for Message {
fn from_row(row: &'_ SqliteRow) -> sqlx::Result<Self> {
let role = row
Expand Down
Loading

0 comments on commit 70eac00

Please sign in to comment.