Skip to content

Commit

Permalink
eval: sqlite & calculator
Browse files Browse the repository at this point in the history
  • Loading branch information
XOR-op committed May 18, 2024
1 parent 755e862 commit 7a98432
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 2 deletions.
3 changes: 3 additions & 0 deletions conveyor/plugin/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ class BasePlugin:
def __init__(self):
pass

def post_init(self):
pass

def process_new_dat(self, data):
raise NotImplementedError

Expand Down
2 changes: 1 addition & 1 deletion conveyor/plugin/news_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,5 @@ def compute(self, data: dict):
self.abort = True
end_time = datetime.datetime.now()
dur = (end_time - self.start_time).total_seconds()
print(f"Plugin syntax error detected: {dur}s")
print(f"Result: {dur}")
self.answer = dur
12 changes: 11 additions & 1 deletion conveyor/plugin/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from multiprocessing.connection import Connection

from conveyor.plugin.base_plugin import BasePlugin, PlaceholderPlugin
from conveyor.plugin.calculator_plugin import CalculatorPlugin
from conveyor.plugin.news_plugin import LocalNewsPlugin
from conveyor.plugin.planning_plugin import PlanningPlugin
from conveyor.plugin.python_plugin import PythonPlugin
from conveyor.plugin.search_plugin import SearchPlugin
from conveyor.plugin.sqlite_plugin import SqlitePlugin

from conveyor.utils import getLogger

Expand All @@ -22,6 +24,7 @@

def plugin_loop(plugin: BasePlugin, ep: Connection):
print("Starting plugin loop", file=sys.stderr)
plugin.post_init()
while True:
data = ep.recv()
if data == finish_str:
Expand Down Expand Up @@ -68,7 +71,14 @@ def start_plugin(self, client_id: str, plugin_name: str):
logging.debug(
f"[PluginScheduler:{client_id}] Starting local news plugin"
)

case "query_database":
plugin = SqlitePlugin(self.lazy)
logging.debug(f"[PluginScheduler:{client_id}] Starting sqlite plugin")
case "calculator":
plugin = CalculatorPlugin(self.lazy)
logging.debug(
f"[PluginScheduler:{client_id}] Starting calculator plugin"
)
case _:
plugin = PlaceholderPlugin()
logging.warn(
Expand Down
40 changes: 40 additions & 0 deletions conveyor/plugin/sqlite_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from conveyor.plugin.base_plugin import BasePlugin
from conveyor.utils import getLogger
import sqlite3

logging = getLogger(__name__)


class SqlitePlugin(BasePlugin):
def __init__(self, lazy: bool = False):
super().__init__()
self.lazy = lazy
self.connection = None
self.cursor = None
self.buf = None
self.answer = None

def post_init(self):
if not self.lazy:
self.connection = sqlite3.connect("_private/test.sqlite3")
self.cursor = self.connection.cursor()
return super().post_init()

def process_new_dat(self, data: dict):
if self.lazy:
self.buf = data
else:
self.compute(data)

def finish(self):
if self.lazy:
self.compute(self.buf)
return self.answer

def compute(self, data: dict):
if self.connection is None:
self.connection = sqlite3.connect("_private/test.sqlite3")
self.cursor = self.connection.cursor()
if data.get("query") is not None:
self.cursor.execute("SELECT * FROM employees ORDER BY name")
self.answer = self.cursor.fetchall()
170 changes: 170 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,172 @@ def eval_validation(lazy: bool) -> float:
return res


def eval_sqlite(lazy: bool) -> float:
model_name = "meetkai/functionary-small-v2.2"
plugin_scheduler = PluginScheduler(lazy=lazy)
engine = ScheduleEngine(
ModelConfig(model_name),
FunctionaryParser,
plugin_scheduler,
sequential_call=True,
)
logging.info(f"Model {model_name} loaded")
req_id = engine.request_pool.add_request(
generate_functionary_input(
messages=[
{
"role": "user",
"content": "Show me all the names of employees.",
}
],
tools=[
{
"type": "function",
"function": {
"name": "query_database",
"description": "Query the database to get info for employees",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Name of queried employees, all to query all employees",
},
},
"required": ["query"],
},
},
}
],
)
+ "\n<|from|> assistant\n<|recipient|>"
)
# let parser aware of <|recipient|> token
engine.request_pool.queued_requests[0].parser.buffer.append(32001)
init_tokens_len = len(engine.request_pool.queued_requests[0].tokens)
i = 0
finished = None
time_start = time.perf_counter()
while i < 500:
finished = engine.iteration_step()
if finished:
break
i += 1

if finished:
res = None
if not plugin_scheduler.lazy:
while len(plugin_scheduler.waiting_queue) > 0:
res = plugin_scheduler.poll_finished(
list(plugin_scheduler.waiting_queue.keys())[0]
)
if len(plugin_scheduler.lazy_queue) > 0:
cur_id = list(plugin_scheduler.lazy_queue.keys())[0]
while (
cur_id in plugin_scheduler.lazy_queue
and len(plugin_scheduler.lazy_queue[cur_id]) > 0
):
plugin_scheduler.flush_lazy_sequentially(cur_id)
while len(plugin_scheduler.waiting_queue) > 0:
res = plugin_scheduler.poll_finished(cur_id)
time_end = time.perf_counter()
logging.info(f"Plugin result: {res}")
logging.info(f"Finished: {finished[0].decode()}")
logging.info(
f"Speed: {(len(finished[0].tokens)-init_tokens_len)/(time_end-time_start)} tokens/s"
)
logging.info(f"Time: {time_end-time_start} s")
ret_val = time_end - time_start
else:
logging.info("Ongoing: " + engine.context.requests[0].decode())
ret_val = -1
plugin_scheduler.join_all()
return ret_val


def eval_calculator(lazy: bool) -> float:
model_name = "meetkai/functionary-small-v2.2"
plugin_scheduler = PluginScheduler(lazy=lazy)
engine = ScheduleEngine(
ModelConfig(model_name),
FunctionaryParser,
plugin_scheduler,
sequential_call=True,
)
logging.info(f"Model {model_name} loaded")
req_id = engine.request_pool.add_request(
generate_functionary_input(
messages=[
{
"role": "user",
"content": "What is the result of 200 * 701",
}
],
tools=[
{
"type": "function",
"function": {
"name": "calculator",
"description": "Evaluate expression",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "expression need evaluating, in the form of python expression",
},
},
"required": ["query"],
},
},
}
],
)
+ "\n<|from|> assistant\n<|recipient|>"
)
# let parser aware of <|recipient|> token
engine.request_pool.queued_requests[0].parser.buffer.append(32001)
init_tokens_len = len(engine.request_pool.queued_requests[0].tokens)
i = 0
finished = None
time_start = time.perf_counter()
while i < 500:
finished = engine.iteration_step()
if finished:
break
i += 1

if finished:
res = None
if not plugin_scheduler.lazy:
while len(plugin_scheduler.waiting_queue) > 0:
res = plugin_scheduler.poll_finished(
list(plugin_scheduler.waiting_queue.keys())[0]
)
if len(plugin_scheduler.lazy_queue) > 0:
cur_id = list(plugin_scheduler.lazy_queue.keys())[0]
while (
cur_id in plugin_scheduler.lazy_queue
and len(plugin_scheduler.lazy_queue[cur_id]) > 0
):
plugin_scheduler.flush_lazy_sequentially(cur_id)
while len(plugin_scheduler.waiting_queue) > 0:
res = plugin_scheduler.poll_finished(cur_id)
time_end = time.perf_counter()
logging.info(f"Plugin result: {res}")
logging.info(f"Finished: {finished[0].decode()}")
logging.info(
f"Speed: {(len(finished[0].tokens)-init_tokens_len)/(time_end-time_start)} tokens/s"
)
logging.info(f"Time: {time_end-time_start} s")
ret_val = time_end - time_start
else:
logging.info("Ongoing: " + engine.context.requests[0].decode())
ret_val = -1
plugin_scheduler.join_all()
return ret_val


def eval_scheduling():
# model_name = "mistralai/Mistral-7B-Instruct-v0.2"
model_name = "meetkai/functionary-small-v2.2"
Expand Down Expand Up @@ -444,6 +610,10 @@ def eval_python_wrapper(lazy: bool) -> float:
result = eval_planning(lazy)
case "validation":
result = eval_validation(lazy)
case "sqlite":
result = eval_sqlite(lazy)
case "calculator":
result = eval_calculator(lazy)
case _:
print("Usage: python main.py [python|scheduling|search] [lazy?]")
sys.exit(1)
Expand Down

0 comments on commit 7a98432

Please sign in to comment.