Skip to content

Commit

Permalink
Preliminary support for local LLM endpoints (tested on Ollama) (#71)
Browse files Browse the repository at this point in the history
* Preliminary support for local LLM endpoints (tested on Ollama)

* Added early bail out, changed the api_key handling order
  • Loading branch information
pc9441 authored and mahaloz committed Dec 6, 2024
1 parent f027a4f commit 93eda64
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
2 changes: 2 additions & 0 deletions dailalib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def create_plugin(*args, **kwargs):
gui_ctx_menu_actions["DAILA/LLM/Settings/update_api_key"] = ("Update API key...", litellm_api.ask_api_key)
gui_ctx_menu_actions["DAILA/LLM/Settings/update_pmpt_style"] = ("Change prompt style...", litellm_api.ask_prompt_style)
gui_ctx_menu_actions["DAILA/LLM/Settings/update_model"] = ("Change model...", litellm_api.ask_model)
gui_ctx_menu_actions["DAILA/LLM/Settings/update_custom_url"] = ("Set Custom OpenAI Endpoint...", litellm_api.ask_custom_endpoint)
gui_ctx_menu_actions["DAILA/LLM/Settings/update_custom_model"] = ("Set Custom OpenAI Model...", litellm_api.ask_custom_model)

#
# VarModel API (local variable renaming)
Expand Down
37 changes: 34 additions & 3 deletions dailalib/api/litellm/litellm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def __init__(
fit_to_tokens: bool = False,
chat_use_ctx: bool = True,
chat_event_callbacks: Optional[dict] = None,
custom_endpoint: Optional[str] = None,
custom_model: Optional[str] = None,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -47,6 +49,8 @@ def __init__(
self.fit_to_tokens = fit_to_tokens
self.chat_use_ctx = chat_use_ctx
self.chat_event_callbacks = chat_event_callbacks or {"send": None, "receive": None}
self.custom_endpoint = custom_endpoint
self.custom_model = custom_model

# delay prompt import
from .prompts import PROMPTS
Expand Down Expand Up @@ -79,24 +83,29 @@ def query_model(
# delay import because litellm attempts to query the server on import to collect cost information.
from litellm import completion

if not self.api_key:
if not self.api_key and not self.custom_endpoint:
raise ValueError(f"Model API key is not set. Please set it before querying the model {self.model}")

prompt_model = model or self.model
prompt_model = (model or self.model) if not self.custom_endpoint else self.custom_model
response = completion(
model=prompt_model,
messages=[
{"role": "user", "content": prompt}
],
max_tokens=max_tokens,
timeout=60,
api_base=self.custom_endpoint if self.custom_endpoint else None, # Use custom endpoint if set
api_key=self.api_key if not self.custom_endpoint else "dummy" # In most of cases custom endpoint doesn't need the api_key
)
# get the answer
try:
answer = response.choices[0].message.content
except (KeyError, IndexError) as e:
answer = None

if self.custom_endpoint:
return answer, 0

# get the estimated cost
try:
prompt_tokens = response.usage.prompt_tokens
Expand Down Expand Up @@ -189,7 +198,7 @@ def api_key(self, value):
os.environ["ANTHROPIC_API_KEY"] = self._api_key
elif "gemini/gemini" in self.model:
os.environ["GEMINI_API_KEY"] = self._api_key
elif "perplexity" in self.model:
elif "perplexity" in self.model:
os.environ["PERPLEXITY_API_KEY"] = self._api_key

def ask_api_key(self, *args, **kwargs):
Expand All @@ -202,6 +211,28 @@ def ask_api_key(self, *args, **kwargs):
api_key = api_key_or_path
self.api_key = api_key

def ask_custom_endpoint(self, *args, **kwargs):
custom_endpoint = self._dec_interface.gui_ask_for_string("Enter your custom OpenAI endpoint:", title="DAILA")
if not custom_endpoint.strip():
self.custom_endpoint = None
self._dec_interface.info(f"Custom endpoint disabled, defaulting to online API")
return
if not (custom_endpoint.lower().startswith("http://") or custom_endpoint.lower().startswith("https://")):
self.custom_endpoint = None
self._dec_interface.error("Invalid endpoint format")
return
self.custom_endpoint = custom_endpoint.strip()
self._dec_interface.info(f"Custom endpoint set to {self.custom_endpoint}")

def ask_custom_model(self, *args, **kwargs):
custom_model = self._dec_interface.gui_ask_for_string("Enter your custom OpenAI model name:", title="DAILA")
if not custom_model.strip():
self.custom_model = None
self._dec_interface.info(f"Custom model selection cleared")
return
self.custom_model = "openai/" + custom_model.strip()
self._dec_interface.info(f"Custom model set to {self.custom_model}")

def _set_prompt_style(self, prompt_style):
self.prompt_style = prompt_style
global active_prompt_style
Expand Down

0 comments on commit 93eda64

Please sign in to comment.