Skip to content

Commit

Permalink
Merge pull request #109 from FrancescoCaracciolo/master
Browse files Browse the repository at this point in the history
  • Loading branch information
qwersyk authored Dec 19, 2024
2 parents f99d3a6 + 30a34bd commit a76bc78
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@
Assistant: Hello, how can I assist you today?
User: Can you help me?
Assistant: Yes, of course, what do you need help with?""",
"get_suggestions_prompt": """Suggest a few questions that the user would ask and put them in a JSON array. You have to write ONLY the JSON array an nothing else""",
"get_suggestions_prompt": """Suggest a few questions that the user would ask and put them in a JSON array. You have to write ONLY the JSON array of strings an nothing else.""",
"custom_prompt": "",

}
Expand Down
77 changes: 77 additions & 0 deletions src/extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re, base64, io
import os, sys
import xml.dom.minidom, html
import json

class ReplaceHelper:
DISTRO = None
Expand Down Expand Up @@ -263,6 +264,7 @@ def can_escape_sandbox() -> bool:
except subprocess.CalledProcessError as _:
return False
return True

def get_streaming_extra_setting():
return {
"key": "streaming",
Expand All @@ -271,6 +273,7 @@ def get_streaming_extra_setting():
"type": "toggle",
"default": True
}

def override_prompts(override_setting, PROMPTS):
prompt_list = {}
for prompt in PROMPTS:
Expand All @@ -279,3 +282,77 @@ def override_prompts(override_setting, PROMPTS):
else:
prompt_list[prompt] = PROMPTS[prompt]
return prompt_list


def extract_json(input_string: str) -> str:
"""Extract JSON string from input string
Args:
input_string (): The input string
Returns:
str: The JSON string
"""
# Regular expression to find JSON objects or arrays
json_pattern = re.compile(r'\{.*?\}|\[.*?\]', re.DOTALL)

# Find all JSON-like substrings
matches = json_pattern.findall(input_string)
# Parse each match and return the first valid JSON
for match in matches:
try:
json_data = json.loads(match)
return match
except json.JSONDecodeError:
continue
print("Wrong JSON", input_string)
return []


def remove_markdown(text: str) -> str:
"""
Remove markdown from text
Args:
text: The text to remove markdown from
Returns:
str: The text without markdown
"""
# Remove headers
text = re.sub(r'^#{1,6}\s*', '', text, flags=re.MULTILINE)

# Remove emphasis (bold and italic)
text = re.sub(r'\*\*(.*?)\*\*', r'\1', text) # Bold
text = re.sub(r'__(.*?)__', r'\1', text) # Bold
text = re.sub(r'\*(.*?)\*', r'\1', text) # Italic
text = re.sub(r'_(.*?)_', r'\1', text) # Italic

# Remove inline code
text = re.sub(r'`([^`]*)`', r'\1', text)

# Remove code blocks
text = re.sub(r'```[\s\S]*?```', '', text)

# Remove links, keep the link text
text = re.sub(r'\[([^\]]+)\]\([^\)]+\)', r'\1', text)

# Remove images, keep the alt text
text = re.sub(r'!\[([^\]]*)\]\([^\)]+\)', r'\1', text)

# Remove strikethrough
text = re.sub(r'~~(.*?)~~', r'\1', text)

# Remove blockquotes
text = re.sub(r'^>\s*', '', text, flags=re.MULTILINE)

# Remove unordered list markers
text = re.sub(r'^\s*[-+*]\s+', '', text, flags=re.MULTILINE)

# Remove ordered list markers
text = re.sub(r'^\s*\d+\.\s+', '', text, flags=re.MULTILINE)

# Remove extra newlines
text = re.sub(r'\n{2,}', '\n', text)

return text.strip()
45 changes: 40 additions & 5 deletions src/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, Any
import json
import base64
from .extra import convert_history_openai, extract_image, find_module, get_streaming_extra_setting, install_module, open_website, get_image_path, get_spawn_command, quote_string
from .extra import convert_history_openai, extract_image, extract_json, find_module, get_streaming_extra_setting, install_module, open_website, get_image_path, get_spawn_command, quote_string
from .handler import Handler

class LLMHandler(Handler):
Expand Down Expand Up @@ -130,7 +130,7 @@ def get_suggestions(self, request_prompt:str = "", amount:int=1) -> list[str]:
history += message["User"] + ": " + message["Message"] + "\n"
for i in range(0, amount):
generated = self.generate_text(history + "\n\n" + request_prompt)
generated = generated.replace("```json", "").replace("```", "")
generated = extract_json(generated)
try:
j = json.loads(generated)
except Exception as _:
Expand Down Expand Up @@ -455,10 +455,36 @@ class GeminiHandler(LLMHandler):
Official Google Gemini APIs, they support history and system prompts
"""

default_models = [("gemini-1.5-flash","gemini-1.5-flash"), ("gemini-1.5-flash-8b", "gemini-1.5-flash-8b") , ("gemini-1.0-pro", "gemini-1.0-pro"), ("gemini-1.5-pro","gemini-1.5-pro") ]

def __init__(self, settings, path):
super().__init__(settings, path)
self.cache = {}
if self.get_setting("models", False) is None or len(self.get_setting("models", False)) == 0 or True:
self.models = self.default_models
threading.Thread(target=self.get_models).start()
else:
self.models = json.loads(self.get_setting("models", False))

def get_models(self):
if self.is_installed():
try:
import google.generativeai as genai
api = self.get_setting("apikey", False)
if api is None:
return
genai.configure(api_key=api)
models = genai.list_models()
result = tuple()
for model in models:
if "generateContent" in model.supported_generation_methods:
result += ((model.display_name, model.name,),)
self.models = result
self.set_setting("models", json.dumps(result))
self.settings_update()
except Exception as e:
print("Error getting " + self.key + " models: " + str(e))

@staticmethod
def get_extra_requirements() -> list:
return ["google-generativeai"]
Expand All @@ -479,13 +505,21 @@ def get_extra_settings(self) -> list:
"type": "entry",
"default": ""
},
{
"key": "refresh_models",
"title": _("Refresh Models"),
"description": _("Refresh models list"),
"type": "button",
"icon": "view-refresh-symbolic",
"callback": lambda button: self.get_models(),
},
{
"key": "model",
"title": _("Model"),
"description": _("AI Model to use, available: gemini-1.5-pro, gemini-1.0-pro, gemini-1.5-flash"),
"type": "combo",
"default": "gemini-1.5-flash",
"values": [("gemini-1.5-flash-8b", "gemini-1.5-flash-8b"), ("gemini-1.5-flash","gemini-1.5-flash") , ("gemini-1.0-pro", "gemini-1.0-pro"), ("gemini-1.5-pro","gemini-1.5-pro") ]
"default": self.models[0][1],
"values": self.models,
},
{
"key": "streaming",
Expand Down Expand Up @@ -1071,6 +1105,7 @@ def generate_text_stream(self, prompt: str, history: list[dict[str, str]] = [],
from openai import OpenAI
history.append({"User": "User", "Message": prompt})
messages = self.convert_history(history, system_prompt)
print([message["role"] for message in messages])
api = self.get_setting("api")
if api == "":
api = "nokey"
Expand Down Expand Up @@ -1116,7 +1151,7 @@ def get_extra_settings(self) -> list:

class GroqHandler(OpenAIHandler):
key = "groq"
default_models = (("llama-3.1-70B-versatile", "llama-3.1-70B-versatile" ), )
default_models = (("llama-3.3-70B-versatile", "llama-3.3-70B-versatile" ), )
def supports_vision(self) -> bool:
return "vision" in self.get_setting("model")

Expand Down
25 changes: 15 additions & 10 deletions src/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .constants import AVAILABLE_LLMS, AVAILABLE_PROMPTS, PROMPTS, AVAILABLE_TTS, AVAILABLE_STT
from gi.repository import Gtk, Adw, Pango, Gio, Gdk, GObject, GLib, GdkPixbuf
from .stt import AudioRecorder
from .extra import get_spawn_command, install_module, markwon_to_pango, override_prompts, replace_variables
from .extra import get_spawn_command, install_module, markwon_to_pango, override_prompts, remove_markdown, replace_variables
import threading
import posixpath
import json
Expand Down Expand Up @@ -537,7 +537,7 @@ def handle_second_block_change(self,*a):
elif type(header_widget) is Gtk.Box:
self.explorer_panel_headerbox.append(self.headerbox)

def on_flap_button_toggled(self, toggle_button):
def on_flap_button_toggled(self, toggle_button):
self.flap_button_left.set_active(True)
if self.main_program_block.get_name() == "visible":
self.main_program_block.set_name("hide")
Expand Down Expand Up @@ -711,8 +711,8 @@ def generate_chat_name(self, button, multithreading=False):
button.set_child(spinner)
button.set_can_target(False)
button.set_has_frame(True)
# TODO: take the history for the correct chat
self.model.set_history([], self)
self.model.set_history([], self.get_history(self.chats[int(button.get_name())]["chat"]))
name = self.model.generate_chat_name(self.prompts["generate_name_prompt"])
if name != "Chat has been stopped":
self.chats[int(button.get_name())]["name"] = name
Expand Down Expand Up @@ -1248,10 +1248,12 @@ def wait_threads_sm():
GLib.idle_add(self.scrolled_chat)
self.save_chat()

def get_history(self) -> list[dict[str, str]]:
def get_history(self, chat = None) -> list[dict[str, str]]:
if chat is None:
chat = self.chat
history = []
count = self.memory
for msg in self.chat[:-1]:
for msg in chat[:-1]:
if count == 0:
break
if msg["User"] == "Console" and msg["Message"] == "None":
Expand Down Expand Up @@ -1281,7 +1283,7 @@ def send_message(self):
self.curr_label = ""
GLib.idle_add(self.create_streaming_message_label)
self.streaming_lable = None
message_label = self.model.send_message_stream(self, self.chat[-1]["Message"], self.update_message)
message_label = self.model.send_message_stream(self, self.chat[-1]["Message"], self.update_message, [stream_number_variable])
try:
self.streaming_box.get_parent().set_visible(False)
except:
Expand All @@ -1295,6 +1297,7 @@ def send_message(self):
tts_thread = None
if self.tts_enabled:
message=re.sub(r"```.*?```", "", message_label, flags=re.DOTALL)
message = remove_markdown(message)
if not(not message.strip() or message.isspace() or all(char == '\n' for char in message)):
tts_thread = threading.Thread(target=self.tts.play_audio, args=(message, ))
tts_thread.start()
Expand All @@ -1315,16 +1318,18 @@ def create_streaming_message_label(self):
self.streaming_label = Gtk.TextView(wrap_mode=Gtk.WrapMode.WORD_CHAR, editable=False, hexpand=True)
scrolled_window.add_css_class("scroll")
self.streaming_label.add_css_class("scroll")
apply_css_to_widget(scrolled_window, ".scroll { background-color: rgba(0,0,0,0)}")
apply_css_to_widget(self.streaming_label, ".scroll { background-color: rgba(0,0,0,0)}")
apply_css_to_widget(scrolled_window, ".scroll { background-color: rgba(0,0,0,0);}")
apply_css_to_widget(self.streaming_label, ".scroll { background-color: rgba(0,0,0,0);}")
scrolled_window.set_child(self.streaming_label)
text_buffer = self.streaming_label.get_buffer()
tag = text_buffer.create_tag("no-background", background_set=False, paragraph_background_set=False)
text_buffer.apply_tag(tag, text_buffer.get_start_iter(), text_buffer.get_end_iter())
self.streaming_box=self.add_message("Assistant", scrolled_window)
self.streaming_box.set_overflow(Gtk.Overflow.VISIBLE)

def update_message(self, message):
def update_message(self, message, stream_number_variable):
if self.stream_number_variable != stream_number_variable:
return
self.streamed_message = message
if self.streaming_label is not None:
added_message = message[len(self.curr_label):]
Expand Down

0 comments on commit a76bc78

Please sign in to comment.