-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #35 from debuggerone/add_core_functionality
added core functionality and tests for that, added an install.py file…
- Loading branch information
Showing
24 changed files
with
496 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -165,3 +165,6 @@ cython_debug/ | |
venv-py39/ | ||
venv-py310/ | ||
venv-py311/ | ||
venv-py312/ | ||
|
||
config/settings.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import sqlite3 | ||
import os | ||
import json | ||
import argparse | ||
|
||
# Function to create config/settings.json | ||
|
||
def create_settings(ci_mode=False): | ||
if ci_mode: | ||
# Use default values for CI | ||
api_key = "sk-test-key" | ||
tier = "tier-4" | ||
log_path = './var/logs/error.log' | ||
database_path = './var/data/agents.db' | ||
else: | ||
# Prompt user for settings | ||
api_key = input('Enter your OpenAI API key: ') | ||
tier = input('Enter your OpenAI tier level (e.g., tier-1): ') | ||
log_path = input('Enter the log directory path [default: ./var/logs/error.log]: ') or './var/logs/error.log' | ||
database_path = input('Enter the database path [default: ./var/data/agents.db]: ') or './var/data/agents.db' | ||
|
||
# Save settings to JSON file | ||
settings = { | ||
'openai_api_key': api_key, | ||
'tier': tier, | ||
'log_path': log_path, | ||
'database_path': database_path | ||
} | ||
os.makedirs('./config', exist_ok=True) | ||
with open('./config/settings.json', 'w') as f: | ||
json.dump(settings, f, indent=4) | ||
print('Settings saved to config/settings.json') | ||
|
||
|
||
# Function to create the database structure | ||
|
||
def create_database(db_path): | ||
os.makedirs(os.path.dirname(db_path), exist_ok=True) | ||
conn = sqlite3.connect(db_path) | ||
c = conn.cursor() | ||
|
||
# Create tables | ||
c.execute('''CREATE TABLE IF NOT EXISTS models ( | ||
id INTEGER PRIMARY KEY, | ||
model TEXT NOT NULL, | ||
price_per_prompt_token REAL NOT NULL, | ||
price_per_completion_token REAL NOT NULL)''') | ||
|
||
c.execute('''CREATE TABLE IF NOT EXISTS rate_limits ( | ||
id INTEGER PRIMARY KEY, | ||
model TEXT NOT NULL, | ||
tier TEXT NOT NULL, | ||
rpm_limit INTEGER NOT NULL, | ||
tpm_limit INTEGER NOT NULL, | ||
rpd_limit INTEGER NOT NULL)''') | ||
|
||
c.execute('''CREATE TABLE IF NOT EXISTS api_usage ( | ||
id INTEGER PRIMARY KEY, | ||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, | ||
session_id TEXT NOT NULL, | ||
model TEXT NOT NULL, | ||
prompt_tokens INTEGER NOT NULL, | ||
completion_tokens INTEGER NOT NULL, | ||
total_tokens INTEGER NOT NULL, | ||
price_per_prompt_token REAL NOT NULL, | ||
price_per_completion_token REAL NOT NULL, | ||
total_cost REAL NOT NULL)''') | ||
|
||
c.execute('''CREATE TABLE IF NOT EXISTS chat_sessions ( | ||
id INTEGER PRIMARY KEY, | ||
session_id TEXT NOT NULL, | ||
start_time DATETIME DEFAULT CURRENT_TIMESTAMP, | ||
end_time DATETIME)''') | ||
|
||
c.execute('''CREATE TABLE IF NOT EXISTS chats ( | ||
id INTEGER PRIMARY KEY, | ||
session_id TEXT NOT NULL, | ||
chat_id TEXT NOT NULL, | ||
message TEXT NOT NULL, | ||
role TEXT NOT NULL, | ||
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)''') | ||
|
||
# Insert default models and rate limits | ||
c.execute("INSERT INTO models (model, price_per_prompt_token, price_per_completion_token) VALUES ('gpt-4o-mini', 0.03, 0.06)") | ||
c.execute("INSERT INTO rate_limits (model, tier, rpm_limit, tpm_limit, rpd_limit) VALUES ('gpt-4o-mini', 'tier-1', 60, 50000, 1000)") | ||
|
||
conn.commit() | ||
conn.close() | ||
print(f"Database created at {db_path}") | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='Setup script for installation.') | ||
parser.add_argument('--ci', action='store_true', help='Use default values for CI without prompting.') | ||
args = parser.parse_args() | ||
|
||
create_settings(ci_mode=args.ci) | ||
|
||
with open('./config/settings.json', 'r') as f: | ||
settings = json.load(f) | ||
|
||
create_database(settings['database_path']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
[pytest] | ||
pythonpath = src |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,3 +14,6 @@ urllib3==2.2.2 | |
virtualenv==20.26.3 | ||
black | ||
flake8 | ||
tiktoken | ||
anyio | ||
trio |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from setuptools import setup, find_packages | ||
|
||
setup( | ||
name='agentm-py', | ||
version='0.1', | ||
packages=find_packages(where='src'), | ||
package_dir={'': 'src'}, | ||
) |
Empty file.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import re | ||
|
||
|
||
def compose_prompt(template: str, variables: dict) -> str: | ||
return re.sub( | ||
r"{{\s*([^}\s]+)\s*}}", | ||
lambda match: str(variables.get(match.group(1), "")), | ||
template, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import asyncio | ||
|
||
|
||
class Semaphore: | ||
def __init__(self, max_concurrent_tasks): | ||
self.semaphore = asyncio.Semaphore(max_concurrent_tasks) | ||
|
||
async def __aenter__(self): | ||
await self.semaphore.acquire() | ||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb): | ||
self.semaphore.release() | ||
|
||
async def call_function(self, func, *args, **kwargs): | ||
async with self.semaphore: | ||
return await func(*args, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import sqlite3 | ||
from datetime import datetime | ||
|
||
|
||
class Database: | ||
def __init__(self, db_path): | ||
self.db_path = db_path | ||
|
||
def connect(self): | ||
return sqlite3.connect(self.db_path) | ||
|
||
def check_rate_limits(self, model): | ||
conn = self.connect() | ||
c = conn.cursor() | ||
|
||
# Check current API usage (RPM, TPM, RPD) | ||
c.execute( | ||
"SELECT SUM(total_tokens) FROM api_usage WHERE model = ? AND timestamp >= datetime('now', '-1 minute')", | ||
(model,), | ||
) | ||
tokens_last_minute = c.fetchone()[0] or 0 | ||
|
||
c.execute("SELECT tpm_limit FROM rate_limits WHERE model = ?", (model,)) | ||
tpm_limit = c.fetchone()[0] | ||
|
||
conn.close() | ||
return tokens_last_minute < tpm_limit | ||
|
||
def log_api_usage( | ||
self, session_id, model, prompt_tokens, completion_tokens, total_tokens | ||
): | ||
conn = self.connect() | ||
c = conn.cursor() | ||
|
||
# Fetch token prices | ||
c.execute( | ||
"SELECT price_per_prompt_token, price_per_completion_token FROM models WHERE model = ?", | ||
(model,), | ||
) | ||
prices = c.fetchone() | ||
prompt_price = prices[0] | ||
completion_price = prices[1] | ||
total_cost = (prompt_tokens * prompt_price) + ( | ||
completion_tokens * completion_price | ||
) | ||
|
||
c.execute( | ||
"INSERT INTO api_usage (session_id, model, prompt_tokens, completion_tokens, total_tokens, price_per_prompt_token, price_per_completion_token, total_cost) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", | ||
( | ||
session_id, | ||
model, | ||
prompt_tokens, | ||
completion_tokens, | ||
total_tokens, | ||
prompt_price, | ||
completion_price, | ||
total_cost, | ||
), | ||
) | ||
|
||
conn.commit() | ||
conn.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from core.logging import Logger | ||
|
||
|
||
class LogCompletePrompt: | ||
def __init__(self, complete_prompt_func): | ||
self.complete_prompt_func = complete_prompt_func | ||
self.logger = Logger() | ||
|
||
async def complete_prompt(self, *args, **kwargs): | ||
result = await self.complete_prompt_func(*args, **kwargs) | ||
|
||
if result["completed"]: | ||
self.logger.info("Prompt completed successfully.") | ||
else: | ||
self.logger.error("Prompt completion failed.") | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import logging | ||
import json | ||
import os | ||
|
||
|
||
class Logger: | ||
def __init__(self, settings_path="../config/settings.json"): | ||
self.settings = self.load_settings(settings_path) | ||
self.log_path = self.settings["log_path"] | ||
os.makedirs(os.path.dirname(self.log_path), exist_ok=True) | ||
logging.basicConfig( | ||
filename=self.log_path, | ||
level=logging.INFO, | ||
format="%(asctime)s - %(levelname)s - %(message)s", | ||
) | ||
|
||
def load_settings(self, settings_path): | ||
try: | ||
with open(settings_path, "r") as f: | ||
return json.load(f) | ||
except FileNotFoundError: | ||
raise Exception(f"Settings file not found at {settings_path}") | ||
except KeyError as e: | ||
raise Exception(f"Missing key in settings: {e}") | ||
|
||
def info(self, message): | ||
logging.info(message) | ||
|
||
def error(self, message): | ||
logging.error(message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import openai | ||
import json | ||
import sqlite3 | ||
from datetime import datetime | ||
from .token_counter import TokenCounter | ||
from .database import Database | ||
|
||
|
||
class OpenAIClient: | ||
def __init__(self, settings_path="../config/settings.json"): | ||
settings = self.load_settings(settings_path) | ||
self.api_key = settings["openai_api_key"] | ||
openai.api_key = self.api_key | ||
self.db = Database(settings["database_path"]) | ||
self.token_counter = TokenCounter() | ||
|
||
def load_settings(self, settings_path): | ||
try: | ||
with open(settings_path, "r") as f: | ||
return json.load(f) | ||
except FileNotFoundError: | ||
raise Exception(f"Settings file not found at {settings_path}") | ||
except KeyError as e: | ||
raise Exception(f"Missing key in settings: {e}") | ||
|
||
def complete_chat(self, messages, model="gpt-4o-mini", max_tokens=1500): | ||
# Check rate limits | ||
if not self.db.check_rate_limits(model): | ||
raise Exception(f"Rate limit exceeded for model {model}") | ||
|
||
prompt_tokens = self.token_counter.count_tokens(messages) | ||
|
||
try: | ||
response = openai.ChatCompletion.create( | ||
model=model, messages=messages, max_tokens=max_tokens | ||
) | ||
|
||
completion_tokens = self.token_counter.count_tokens( | ||
response.choices[0].message["content"] | ||
) | ||
total_tokens = prompt_tokens + completion_tokens | ||
|
||
# Log token usage and cost in the database | ||
self.db.log_api_usage( | ||
"session-1", model, prompt_tokens, completion_tokens, total_tokens | ||
) | ||
|
||
return response.choices[0].message["content"] | ||
except openai.error.OpenAIError as e: | ||
raise Exception(f"Error with OpenAI API: {str(e)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import asyncio | ||
from .concurrency import Semaphore | ||
|
||
|
||
class ParallelCompletePrompt: | ||
def __init__( | ||
self, complete_prompt_func, parallel_completions=1, should_continue_func=None | ||
): | ||
self.complete_prompt_func = complete_prompt_func | ||
self.parallel_completions = parallel_completions | ||
self.should_continue_func = should_continue_func or (lambda: True) | ||
self.semaphore = Semaphore(parallel_completions) | ||
|
||
async def complete_prompt(self, *args, **kwargs): | ||
async with self.semaphore: | ||
if not self.should_continue_func(): | ||
raise asyncio.CancelledError("Operation cancelled.") | ||
return await self.complete_prompt_func(*args, **kwargs) |
Oops, something went wrong.