forked from qhjqhj00/MemoRAG
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagent.py
77 lines (68 loc) · 2.55 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from openai import OpenAI
from openai import AzureOpenAI
from functools import wraps
import logging
logger = logging.getLogger(__name__)
def except_retry_dec(retry_num: int = 3):
def decorator(func):
@wraps(func)
def wrapped_func(*args, **kwargs):
i = 0
while True:
try:
logger.info("openai agent post...")
ret = func(*args, **kwargs)
logger.info("openai agent post finished")
return ret
# error define: https://platform.openai.com/docs/guides/error-codes/python-library-error-types
except (
openai.BadRequestError,
openai.AuthenticationError,
) as e:
raise
except Exception as e: # pylint: disable=W0703
logger.error(f"{e}")
logger.info(f"sleep {i + 1}")
time.sleep(i + 1)
if i >= retry_num:
raise
logger.warning(f"do retry, time: {i}")
i += 1
return wrapped_func
return decorator
class Agent:
def __init__(
self, model, source, api_dict, temperature: float = 0.0):
self.model = model
self.temperature = temperature
if source == "azure":
self.client = AzureOpenAI(
azure_endpoint = api_dict["endpoint"],
api_version=api_dict["api_version"],
api_key=api_dict["api_key"],
)
elif source == "openai":
self.client = OpenAI(
# This is the default and can be omitted
api_key=api_dict["api_key"],
)
elif source == "deepseek":
self.client = OpenAI(
# This is the default and can be omitted
base_url=api_dict["base_url"],
api_key=api_dict["api_key"],
)
print(f"You are using {self.model} from {source}")
@except_retry_dec()
def generate(self, prompt: str, max_new_tokens:int=None) -> str:
_completion = self.client.chat.completions.create(
messages=[
{
"role": "user",
"content": prompt,
}
],
temperature=self.temperature,
model=self.model,
)
return [_completion.choices[0].message.content]