-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathllm.py
108 lines (95 loc) · 3.29 KB
/
llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from openai import OpenAI
from langchain_openai import ChatOpenAI
from langchain_community.chat_models import ChatOllama
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_groq import ChatGroq
try:
from .utils.db import (
load_api_key,
load_openai_url,
load_model_settings,
load_groq_api_key,
load_google_api_key,
)
from .custom_callback import customcallback
from .llm_settings import llm_settings
except ImportError:
from utils.db import (
load_api_key,
load_openai_url,
load_model_settings,
load_groq_api_key,
load_google_api_key,
)
from custom_callback import customcallback
from llm_settings import llm_settings
the_callback = customcallback(strip_tokens=False, answer_prefix_tokens=["Answer"])
def get_model(high_context=False):
the_model = load_model_settings()
the_api_key = load_api_key()
the_groq_api_key = load_groq_api_key()
the_google_api_key = load_google_api_key()
the_openai_url = load_openai_url()
def open_ai_base(high_context):
if the_openai_url == "default":
true_model = the_model
if high_context:
true_model = "gpt-4-turbo"
return {
"model": true_model,
"api_key": the_api_key,
"max_retries": 15,
"streaming": True,
"callbacks": [the_callback],
}
else:
return {
"model": the_model,
"api_key": the_api_key,
"max_retries": 15,
"streaming": True,
"callbacks": [the_callback],
"base_url": the_openai_url,
}
args_mapping = {
ChatOpenAI: open_ai_base(high_context=high_context),
ChatOllama: {"model": the_model},
ChatGroq: {
"temperature": 0,
"model_name": the_model.replace("-groq", ""),
"groq_api_key": the_openai_url,
},
ChatGoogleGenerativeAI: {
"model": the_model,
"google_api_key": the_google_api_key,
},
}
model_mapping = {}
for model_name, model_args in llm_settings.items():
the_tuple = None
if model_args["provider"] == "openai":
the_tuple = (ChatOpenAI, args_mapping[ChatOpenAI])
elif model_args["provider"] == "ollama":
the_tuple = (
ChatOpenAI,
{
"api_key": "ollama",
"base_url": "http://localhost:11434/v1",
"model": model_name,
},
)
elif model_args["provider"] == "google":
the_tuple = (ChatGoogleGenerativeAI, args_mapping[ChatGoogleGenerativeAI])
elif model_args["provider"] == "groq":
the_tuple = (ChatGroq, args_mapping[ChatGroq])
if the_tuple:
model_mapping[model_name] = the_tuple
model_class, args = model_mapping[the_model]
return model_class(**args) if model_class else None
def get_client():
the_api_key = load_api_key()
the_openai_url = load_openai_url()
if the_openai_url == "default":
return OpenAI(api_key=the_api_key)
else:
return OpenAI(api_key=the_api_key, base_url=the_openai_url)