From 93eda644e4160517ff84478fdc6d6de213114bff Mon Sep 17 00:00:00 2001 From: pc9441 <153177056+pc9441@users.noreply.github.com> Date: Thu, 5 Dec 2024 09:35:11 +0100 Subject: [PATCH] Preliminary support for local LLM endpoints (tested on Ollama) (#71) * Preliminary support for local LLM endpoints (tested on Ollama) * Added early bail out, changed the api_key handling order --- dailalib/__init__.py | 2 ++ dailalib/api/litellm/litellm_api.py | 37 ++++++++++++++++++++++++++--- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/dailalib/__init__.py b/dailalib/__init__.py index c9ef2a2..fd79059 100644 --- a/dailalib/__init__.py +++ b/dailalib/__init__.py @@ -30,6 +30,8 @@ def create_plugin(*args, **kwargs): gui_ctx_menu_actions["DAILA/LLM/Settings/update_api_key"] = ("Update API key...", litellm_api.ask_api_key) gui_ctx_menu_actions["DAILA/LLM/Settings/update_pmpt_style"] = ("Change prompt style...", litellm_api.ask_prompt_style) gui_ctx_menu_actions["DAILA/LLM/Settings/update_model"] = ("Change model...", litellm_api.ask_model) + gui_ctx_menu_actions["DAILA/LLM/Settings/update_custom_url"] = ("Set Custom OpenAI Endpoint...", litellm_api.ask_custom_endpoint) + gui_ctx_menu_actions["DAILA/LLM/Settings/update_custom_model"] = ("Set Custom OpenAI Model...", litellm_api.ask_custom_model) # # VarModel API (local variable renaming) diff --git a/dailalib/api/litellm/litellm_api.py b/dailalib/api/litellm/litellm_api.py index 8dac3db..afbcc52 100644 --- a/dailalib/api/litellm/litellm_api.py +++ b/dailalib/api/litellm/litellm_api.py @@ -36,6 +36,8 @@ def __init__( fit_to_tokens: bool = False, chat_use_ctx: bool = True, chat_event_callbacks: Optional[dict] = None, + custom_endpoint: Optional[str] = None, + custom_model: Optional[str] = None, **kwargs ): super().__init__(**kwargs) @@ -47,6 +49,8 @@ def __init__( self.fit_to_tokens = fit_to_tokens self.chat_use_ctx = chat_use_ctx self.chat_event_callbacks = chat_event_callbacks or {"send": None, "receive": None} + self.custom_endpoint = custom_endpoint + self.custom_model = custom_model # delay prompt import from .prompts import PROMPTS @@ -79,10 +83,10 @@ def query_model( # delay import because litellm attempts to query the server on import to collect cost information. from litellm import completion - if not self.api_key: + if not self.api_key and not self.custom_endpoint: raise ValueError(f"Model API key is not set. Please set it before querying the model {self.model}") - prompt_model = model or self.model + prompt_model = (model or self.model) if not self.custom_endpoint else self.custom_model response = completion( model=prompt_model, messages=[ @@ -90,6 +94,8 @@ def query_model( ], max_tokens=max_tokens, timeout=60, + api_base=self.custom_endpoint if self.custom_endpoint else None, # Use custom endpoint if set + api_key=self.api_key if not self.custom_endpoint else "dummy" # In most of cases custom endpoint doesn't need the api_key ) # get the answer try: @@ -97,6 +103,9 @@ def query_model( except (KeyError, IndexError) as e: answer = None + if self.custom_endpoint: + return answer, 0 + # get the estimated cost try: prompt_tokens = response.usage.prompt_tokens @@ -189,7 +198,7 @@ def api_key(self, value): os.environ["ANTHROPIC_API_KEY"] = self._api_key elif "gemini/gemini" in self.model: os.environ["GEMINI_API_KEY"] = self._api_key - elif "perplexity" in self.model: + elif "perplexity" in self.model: os.environ["PERPLEXITY_API_KEY"] = self._api_key def ask_api_key(self, *args, **kwargs): @@ -202,6 +211,28 @@ def ask_api_key(self, *args, **kwargs): api_key = api_key_or_path self.api_key = api_key + def ask_custom_endpoint(self, *args, **kwargs): + custom_endpoint = self._dec_interface.gui_ask_for_string("Enter your custom OpenAI endpoint:", title="DAILA") + if not custom_endpoint.strip(): + self.custom_endpoint = None + self._dec_interface.info(f"Custom endpoint disabled, defaulting to online API") + return + if not (custom_endpoint.lower().startswith("http://") or custom_endpoint.lower().startswith("https://")): + self.custom_endpoint = None + self._dec_interface.error("Invalid endpoint format") + return + self.custom_endpoint = custom_endpoint.strip() + self._dec_interface.info(f"Custom endpoint set to {self.custom_endpoint}") + + def ask_custom_model(self, *args, **kwargs): + custom_model = self._dec_interface.gui_ask_for_string("Enter your custom OpenAI model name:", title="DAILA") + if not custom_model.strip(): + self.custom_model = None + self._dec_interface.info(f"Custom model selection cleared") + return + self.custom_model = "openai/" + custom_model.strip() + self._dec_interface.info(f"Custom model set to {self.custom_model}") + def _set_prompt_style(self, prompt_style): self.prompt_style = prompt_style global active_prompt_style