Skip to content

Commit

Permalink
Merge pull request sambanova#475 from sambanova/bugfix/benchmarking_r…
Browse files Browse the repository at this point in the history
…acing_per_user

Bugfix/benchmarking racing per user
  • Loading branch information
snova-kwasia authored Dec 10, 2024
2 parents 4003a0e + f919b25 commit 01881d9
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 148 deletions.
77 changes: 54 additions & 23 deletions benchmarking/src/performance_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import transformers
from dotenv import load_dotenv
from langchain.prompts import PromptTemplate
from stqdm import stqdm
from streamlit.runtime.scriptrunner import add_script_run_ctx
from tqdm import tqdm

Expand Down Expand Up @@ -65,6 +64,9 @@ def __init__(
self.is_stream_mode = is_stream_mode
self.timeout = timeout
self.tokenizer = get_tokenizer(self.model_name)
self.stop_event = threading.Event()
self.ui_progress_bar = None
self.cli_progress_bar = None

# To be set upon saving of results
self.summary_file_path: Optional[str] = None
Expand Down Expand Up @@ -165,18 +167,23 @@ def send_requests(
self,
request_config_batch: List[Any],
completed_requests: List[Any],
progress_bars: Dict[str, Any],
progress: int,
start_time: float,
num_requests: int,
) -> None:
"""Sends multiple requests to LLM and collects results
Args:
request_config_batch (list): list of request configs for LLM calls
completed_requests (list): list of completed outputs from requests
progress_bar (tqdm): progress bar
progress (int): progress value
start_time (float): start time of the process
num_requests (int): number of total requests
"""
for request_config in request_config_batch:
if self.stop_event.is_set():
logger.info('Stopping request processing in thread due to stop signal.')
break
if time.monotonic() - start_time >= self.timeout:
break
req_metrics, response_text, request_config = llm_request(request_config, self.tokenizer)
Expand All @@ -186,8 +193,13 @@ def send_requests(
metrics=req_metrics, response_text=response_text, request_config=request_config
)
completed_requests.extend([response_object])
progress_bars['stqdm'].update(1)
progress_bars['tqdm'].update(1)
update_unit = 1
progress += update_unit

if self.cli_progress_bar:
self.cli_progress_bar.update(update_unit)
if self.ui_progress_bar:
self.ui_progress_bar(progress, num_requests)

def build_metrics_summary(
self,
Expand Down Expand Up @@ -435,6 +447,11 @@ def save_results(
logger.error('ERROR SAVING LLM OUTPUTS')
raise e

def stop_benchmark(self) -> None:
"""Stops the benchmarking process by setting the stop event."""
self.stop_event.set()
logger.info('Benchmarking process has been stopped.')

def run_benchmark(
self, sampling_params: Dict[str, Any] = {}, *args: Any, **kwargs: Any
) -> (
Expand All @@ -449,6 +466,9 @@ def run_benchmark(
Returns:
None
"""
self.cli_progress_bar = tqdm(total=len(self.dataset), desc='Running Requests')
self.ui_progress_bar = kwargs.get('progress_bar', None)

# Calculate performance metrics individually and summary
summary, individual_responses = self.get_token_throughput_latencies(
sampling_params=sampling_params,
Expand Down Expand Up @@ -507,19 +527,16 @@ def get_token_throughput_latencies(

threads = []
llm_responses: List[LLMResponse] = []
stqdm_progress_bar = stqdm(total=total_request_count, desc='Running Requests', mininterval=1)
tqdm_progress_bar = tqdm(total=total_request_count, desc='Running Requests')
progress_bars = {'stqdm': stqdm_progress_bar, 'tqdm': tqdm_progress_bar}
progress = 0

for request_config_batch in request_config_batches:
if self.stop_event.is_set():
logger.info('Stopping thread creation due to stop signal.')
break

thread = threading.Thread(
target=self.send_requests,
args=(
request_config_batch,
llm_responses,
progress_bars,
start_time,
),
args=(request_config_batch, llm_responses, progress, start_time, total_request_count),
)
threads.append(thread)
add_script_run_ctx(thread) # Give Streamlit context to thread
Expand All @@ -529,6 +546,10 @@ def get_token_throughput_latencies(
add_script_run_ctx(thread)
thread.join()

if self.stop_event.is_set():
logger.info('Benchmarking process terminated early due to stop signal.')
return {}, []

if llm_responses[0].metrics[common_metrics.ERROR_CODE]:
raise Exception(
f"""Unexpected error happened when executing requests: {llm_responses[0].metrics['error_code']}.
Expand Down Expand Up @@ -660,6 +681,11 @@ def create_output_filename(self, num_input_tokens: int, num_output_tokens: int)
)
return self.sanitize_file_prefix(output_file_name)

def stop_benchmark(self) -> None:
"""Stops the benchmarking process by setting the stop event."""
self.stop_event.set()
logger.info('Benchmarking process has been stopped.')

def run_benchmark(
self, sampling_params: Dict[str, Any] = {}, *args: Any, **kwargs: Any
) -> Tuple[Dict[str, Any] | Dict[str, object], List[Tuple[Dict[str, Any], str, RequestConfig]] | List[LLMResponse]]:
Expand All @@ -681,6 +707,10 @@ def run_benchmark(
num_input_tokens = kwargs.get('num_input_tokens', 1000)
num_output_tokens = kwargs.get('num_output_tokens', 10)
num_requests = kwargs.get('num_requests', 1)

self.cli_progress_bar = tqdm(total=num_requests, desc='Running Requests')
self.ui_progress_bar = kwargs.get('progress_bar', None)

if num_input_tokens < 40:
raise ValueError(
'The minimum number of input tokens that will be sent is 40' ' because of the prompting logic right now'
Expand Down Expand Up @@ -850,20 +880,17 @@ def get_token_throughput_latencies(
# completed requests respectively
threads: List[threading.Thread] = []
llm_responses: List[LLMResponse] = []
stqdm_progress_bar = stqdm(total=total_request_count, desc='Running Requests', mininterval=1)
tqdm_progress_bar = tqdm(total=total_request_count, desc='Running Requests')
progress_bars = {'stqdm': stqdm_progress_bar, 'tqdm': tqdm_progress_bar}
progress = 0

# Send request threads and add to the threads array
for request_config_batch in request_config_batches:
if self.stop_event.is_set():
logger.info('Stopping thread creation due to stop signal.')
break

thread = threading.Thread(
target=self.send_requests,
args=(
request_config_batch,
llm_responses,
progress_bars,
start_time,
),
args=(request_config_batch, llm_responses, progress, start_time, num_requests),
)
threads.append(thread)
add_script_run_ctx(thread) # Add Streamlit context to thread
Expand All @@ -874,6 +901,10 @@ def get_token_throughput_latencies(
add_script_run_ctx(thread)
thread.join()

if self.stop_event.is_set():
logger.info('Benchmarking process terminated early due to stop signal.')
return {}, []

# Error handling
error_codes = [llm_response.metrics['error_code'] for llm_response in llm_responses]

Expand Down
Loading

0 comments on commit 01881d9

Please sign in to comment.