Skip to content

Commit

Permalink
injecting the executor callback to the engine
Browse files Browse the repository at this point in the history
  • Loading branch information
Aynur Adanbekova authored and Aynur Adanbekova committed Nov 17, 2024
1 parent 743a279 commit c01493b
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 17 deletions.
53 changes: 39 additions & 14 deletions src/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
from typing import Any, Callable, Dict, List, Optional, Union
from symai.core_ext import bind
from symai import Expression
from symai.functional import EngineRepository

lgr = logging.getLogger()
lgr.setLevel(logging.CRITICAL)



class BatchScheduler(Expression):
"""
A class for scheduling and executing batch operations with Expressions from symbolicai.
Expand All @@ -17,10 +19,12 @@ class BatchScheduler(Expression):
utilizing multiple workers and an external engine for processing.
"""

def __init__(self):
def __init__(self):
"""Initialize the BatchScheduler without parameters."""
super().__init__()

repository = EngineRepository()
repository.get('neurosymbolic').__setattr__("executor_callback",self.executor_callback)

@bind(engine="neurosymbolic", property="__call__")
def engine(self, *args, **kwargs):
"""
Expand All @@ -41,7 +45,8 @@ def single_expression(self, data_point: Any, **kwargs) -> Any:
"""
expr = self.expr
try:
return expr(data_point, executor_callback=self.executor_callback, **kwargs)
#return expr(data_point, executor_callback=self.executor_callback, **kwargs)
return expr(data_point, **kwargs)
except Exception as e:
print(f"Data point {data_point} generated an exception: {str(e)}")
return e
Expand All @@ -59,17 +64,27 @@ def executor_callback(self, argument: Any) -> Any:
Any: The processed response for the given argument.
"""
with self.lock:
self.arguments.append(argument)
self.llm_calls_queue.append(argument)
arg_id = id(argument)
if arg_id not in self.llm_responses.keys():
self.llm_responses[arg_id] = None
self.llm_response_ready[arg_id] = threading.Event()
if len(self.arguments) >= self.batch_size or self.pending_tasks < self.batch_size:
print("args: ")
print(len(self.llm_calls_queue))
print("batch size " +str(self.batch_size))
print("pending "+ str(self.pending_tasks))
#enough llm calls arrived for batch execution
if len(self.llm_calls_queue) >= self.batch_size: #or self.pending_tasks < self.batch_size:
print("enough llm calls arrived for batch execution")
self.batch_ready.set()
self.llm_response_ready[arg_id].wait()
print("response is ready")
with self.lock:
llm_response = self.llm_responses.pop(arg_id)
del self.llm_response_ready[arg_id]
print("args: "+str(len(self.llm_calls_queue)))
print("batch size " +str(self.batch_size))
print("pending "+ str(self.pending_tasks))
return llm_response

def execute_queries(self) -> None:
Expand All @@ -79,12 +94,14 @@ def execute_queries(self) -> None:
This method runs in a separate thread and processes batches of arguments
generated by the symbolicai Expressions.
"""
while not self.processing_complete.is_set() or self.arguments:
while not self.all_batches_complete.is_set() or self.llm_calls_queue:
print("condition met")
self.batch_ready.wait()
print("batch ready after condtion")
self.batch_ready.clear()
with self.lock:
current_arguments = self.arguments[:self.batch_size]
self.arguments = self.arguments[self.batch_size:]
current_arguments = self.llm_calls_queue[:self.batch_size]
self.llm_calls_queue = self.llm_calls_queue[self.batch_size:]
if current_arguments:
llm_batch_responses = self.engine()(current_arguments)
llm_batch_responses = [(resp[0] if isinstance(resp[0], list) else [resp[0]], resp[1]) for resp in llm_batch_responses]
Expand All @@ -93,8 +110,13 @@ def execute_queries(self) -> None:
arg_id = id(arg)
self.llm_responses[arg_id] = llm_response
self.llm_response_ready[arg_id].set()
if self.arguments and self.pending_tasks < self.batch_size:
print("PEEEEENDING" + str(self.pending_tasks))
print(len(self.llm_calls_queue))
if self.batch_size <= len(self.llm_calls_queue) or self.pending_tasks<=self.batch_size :
self.batch_ready.set()
print(len(self.llm_calls_queue))
print("batch ready")


def forward(self, expr: Expression, num_workers: int, dataset: List[Any], batch_size: int = 5, **kwargs) -> List[Any]:
"""
Expand All @@ -113,11 +135,11 @@ def forward(self, expr: Expression, num_workers: int, dataset: List[Any], batch_
self.num_workers = num_workers
self.dataset = dataset
self.results = {}
self.arguments = []
self.llm_calls_queue = []
self.lock = threading.Lock()
self.batch_size = min(batch_size, len(dataset) if dataset else 1, num_workers)
self.batch_ready = threading.Event()
self.processing_complete = threading.Event()
self.all_batches_complete = threading.Event()
self.llm_responses = {}
self.llm_response_ready = {}
self.pending_tasks = len(self.dataset)
Expand All @@ -133,13 +155,16 @@ def forward(self, expr: Expression, num_workers: int, dataset: List[Any], batch_
final_result = future.result()
self.results[data_point] = final_result
except Exception as exc:
print(f'Data point {data_point} generated an exception: {exc}')
print(f'batch {data_point} generated an exception: {exc}')
finally:
#decrement the total number of expressions to be run
self.pending_tasks -= 1
print("remaining tasks: " + str(self.pending_tasks))

if self.pending_tasks < self.batch_size:
self.batch_ready.set()
self.processing_complete.set()
print("processing complete")
self.all_batches_complete.set()
print("all batches complete")
self.batch_ready.set()
query_thread.join()
return [self.results.get(data_point) for data_point in self.dataset]
14 changes: 11 additions & 3 deletions tests/test_batchscheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def forward(self, input, **kwargs):
return Symbol(input).query("Quickly process this short input", **kwargs)


@pytest.mark.timeout(5)
def test_simple_batch():
expr = TestExpression
inputs = ["test1", "test2", "test3"]
Expand All @@ -130,6 +131,7 @@ def test_simple_batch():
assert f"test{i}" in str(result)
assert "Summarize this input" in str(result)

@pytest.mark.timeout(5)
def test_nested_batch():
expr = NestedExpression
inputs = ["nested1", "nested2"]
Expand All @@ -141,6 +143,7 @@ def test_nested_batch():
assert "Elaborate on this result" in str(result)
assert "Summarize this input" in str(result)

@pytest.mark.timeout(5)
def test_conditional_batch():
expr = ConditionalExpression
inputs = ["short", "this is a long input"]
Expand All @@ -150,6 +153,7 @@ def test_conditional_batch():
assert "Briefly comment on this short input" in str(results[0])
assert "Analyze this long input" in str(results[1])

@pytest.mark.timeout(5)
def test_slow_batch():
expr = SlowExpression
inputs = ["slow1", "slow2"]
Expand All @@ -160,6 +164,7 @@ def test_slow_batch():
assert f"slow{i}" in str(result)
assert "Process this input after a 5 second delay" in str(result)

@pytest.mark.timeout(5)
def test_double_nested_slow_batch():
expr = DoubleNestedExpressionSlow
inputs = ["input1", "input2"]
Expand All @@ -170,6 +175,7 @@ def test_double_nested_slow_batch():
assert f"input{i}" in str(result)
assert "Synthesize these results" in str(result)

@pytest.mark.timeout(5)
def test_simple_batch_variations():
expr = TestExpression
inputs = ["test1", "test2", "test3", "test4", "test5", "test6"]
Expand All @@ -189,6 +195,7 @@ def test_simple_batch_variations():
assert f"test{i}" in str(result)
assert "Summarize this input" in str(result)

@pytest.mark.timeout(5)
def test_nested_batch_variations():
expr = NestedExpression
inputs = ["nested1", "nested2", "nested3", "nested4"]
Expand All @@ -209,6 +216,7 @@ def test_nested_batch_variations():
assert f"nested{i}" in str(result)
assert "Elaborate on this result" in str(result)

@pytest.mark.timeout(5)
def test_conditional_batch_variations():
expr = ConditionalExpression
inputs = ["short", "this is a long input", "short+", "yet another long input"]
Expand All @@ -221,6 +229,7 @@ def test_conditional_batch_variations():
assert "Briefly comment on this short input" in str(results[2])
assert "Analyze this long input" in str(results[3])

@pytest.mark.timeout(5)
def test_slow_batch_variations():
expr = SlowExpression
inputs = ["slow1", "slow2", "slow3", "slow4", "slow5"]
Expand All @@ -239,6 +248,7 @@ def test_slow_batch_variations():
assert f"slow{i}" in str(result)
assert "Process this input after a 5 second delay" in str(result)

@pytest.mark.timeout(5)
def test_double_nested_slow_batch_variations():
expr = DoubleNestedExpressionSlow
inputs = ["input1", "input2", "input3"]
Expand All @@ -258,6 +268,7 @@ def test_double_nested_slow_batch_variations():
assert "Synthesize these results" in str(result)


@pytest.mark.timeout(5)
def test_double_nested_batch():
expr = DoubleNestedExpression
inputs = ["nested1", "nested2", "nested3"]
Expand All @@ -269,6 +280,3 @@ def test_double_nested_batch():
assert "Combine these results" in str(result)
assert "Summarize this input" in str(result)
assert "Elaborate on this result" in str(result)



40 changes: 40 additions & 0 deletions tests/test_compare_batch_and_single.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import itertools
import os
import sys
import time

import pytest
from symai import Expression, Symbol
from symai.backend.base import BatchEngine
from symai.functional import EngineRepository

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__name__), '..')))
from src.func import BatchScheduler

class TestExpression(Expression):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def forward(self, input, **kwargs):
res = Symbol(input).query("Follow this instruction: ", **kwargs)
return res.value

def test_expression_with_and_without_scheduler():
expr = TestExpression()
input_text = "Write an essay about AI"
kwargs = {"temperature": 0}

# Test without scheduler
direct_result = expr(input_text, **kwargs)
direct_result2 = expr(input_text, **kwargs)

# Test with scheduler
scheduler = BatchScheduler()
scheduled_results = scheduler(TestExpression, num_workers=1, dataset=[input_text], **kwargs)
scheduled_result = scheduled_results[0]

# Both should contain the input prompt and have temperature=0 applied
assert "Write an essay about AI" in str(direct_result)
assert "Write an essay about AI" in str(scheduled_result)
assert "Summarize this input" in str(direct_result)
assert "Summarize this input" in str(scheduled_result)

0 comments on commit c01493b

Please sign in to comment.