Skip to content

Commit

Permalink
lf.concurrent_map: support TimeIt summary in progress bar.
Browse files Browse the repository at this point in the history
This allows users to use `pg.timeit` to measure the execution time/statistics of code blocks in the thread func.

Example:
```python
def foo(x):
  with pg.timeit('foo'):
    time.sleep(1)
  with pg.timeit('bar'):
    time.sleep(0.5)
  return x

for _, _, _ in lf.concurrent_map(foo, range(10), max_workers=2, show_progress=True):
  pass
```
The progress bar will print status like:
```
10/10 [00:06<00:00, 1.35it/s, Succeeded=100.00% (10/10), Failed=0.00% (0/10), AvgDuration=1.50s, TimeIt=foo (1.0s, 10/10), bar (0.5s, 10/10)]
```
PiperOrigin-RevId: 679439662
  • Loading branch information
daiyip authored and langfun authors committed Sep 27, 2024
1 parent 9f9cb9f commit 1f95b77
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 62 deletions.
128 changes: 89 additions & 39 deletions langfun/core/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
try:
from tqdm import auto as tqdm # pylint: disable=g-import-not-at-top
progress_bar = 'tqdm'
except ImportError as e:
except ImportError:
progress_bar = 'console'
tqdm = None

Expand All @@ -57,7 +57,7 @@ class RetryError(RuntimeError):
def __init__(
self,
func: Callable[..., Any],
errors: list[Exception],
errors: list[BaseException],
wait_intervals: list[int],
):
assert len(errors) == len(wait_intervals) + 1
Expand Down Expand Up @@ -112,8 +112,8 @@ def __hash__(self) -> int:
def with_retry(
func: Callable[[Any], Any],
retry_on_errors: Union[
Union[Type[Exception], Tuple[Type[Exception], str]],
Sequence[Union[Type[Exception], Tuple[Type[Exception], str]]],
Union[Type[BaseException], Tuple[Type[BaseException], str]],
Sequence[Union[Type[BaseException], Tuple[Type[BaseException], str]]],
],
max_attempts: int,
retry_interval: int | tuple[int, int] = (5, 60),
Expand Down Expand Up @@ -186,8 +186,8 @@ def concurrent_execute(
executor: Union[concurrent.futures.ThreadPoolExecutor, str, None] = None,
max_workers: int = 32,
retry_on_errors: Union[
Union[Type[Exception], Tuple[Type[Exception], str]],
Sequence[Union[Type[Exception], Tuple[Type[Exception], str]]],
Union[Type[BaseException], Tuple[Type[BaseException], str]],
Sequence[Union[Type[BaseException], Tuple[Type[BaseException], str]]],
None,
] = None,
max_attempts: int = 5,
Expand Down Expand Up @@ -251,46 +251,66 @@ class Job:
func: Callable[[Any], Any]
arg: Any
result: Any = pg.MISSING_VALUE
error: Exception | None = None
start_time: float | None = None
end_time: float | None = None
timeit: pg.object_utils.TimeIt = dataclasses.field(
default_factory=lambda: pg.object_utils.TimeIt('job')
)

@property
def elapse(self) -> float:
"""Returns the running time in seconds since the job get started."""
return self.timeit.elapse

@property
def error(self) -> BaseException | None:
"""Returns the error if the job failed."""
return self.timeit.error

def __call__(self) -> Any:
self.start_time = time.time()
try:
self.result = self.func(self.arg)
return self.result
except Exception as e: # pylint: disable=broad-exception-caught
self.error = e
with self.timeit:
self.result = self.func(self.arg)
return self.result
except BaseException as e: # pylint: disable=broad-exception-caught
return e
finally:
self.end_time = time.time()

def mark_canceled(self, error: Exception) -> None:
def mark_canceled(self, error: BaseException) -> None:
"""Marks the job as canceled."""
self.error = error
self.end_time = time.time()

@property
def elapse(self) -> float:
"""Returns the running time in seconds since the job get started."""
if self.start_time is None:
return 0.0
if self.end_time is None:
return time.time() - self.start_time
return self.end_time - self.start_time
self.timeit.end(error)


@dataclasses.dataclass
class Progress:
"""Concurrent processing progress."""
total: int

@dataclasses.dataclass
class TimeItSummary:
"""Execution details for each `pg.timeit`."""

num_started: int = 0
num_ended: int = 0
num_failed: int = 0
avg_duration: float = 0.0

def aggregate(self, status: pg.object_utils.TimeIt.Status):
self.avg_duration = (
(self.avg_duration * self.num_started + status.elapse)
/ (self.num_started + 1)
)
self.num_started += 1
if status.has_ended:
self.num_ended += 1
if status.has_error:
self.num_failed += 1

_succeeded: int = 0
_failed: int = 0
_last_error: Exception | None = None
_last_error: BaseException | None = None
_total_duration: float = 0.0
_job: Job | None = None
_timeit_summary: dict[str, TimeItSummary] = dataclasses.field(
default_factory=dict
)

@property
def succeeded(self) -> int:
Expand All @@ -308,7 +328,7 @@ def completed(self) -> int:
return self.succeeded + self.failed

@property
def last_error(self) -> Exception | None:
def last_error(self) -> BaseException | None:
"""Returns last error."""
return self._last_error

Expand Down Expand Up @@ -338,6 +358,28 @@ def avg_duration(self) -> float:
return 0.0
return self._total_duration / self.completed

@property
def timeit_summary(self) -> dict[str, TimeItSummary]:
"""Returns the aggregated summary for each `pg.timeit`."""
return self._timeit_summary

def timeit_summary_str(self) -> str | None:
if not self.timeit_summary:
return None
return ', '.join([
'%s (%.2fs, %d/%d)' % (
k, v.avg_duration, v.num_ended, v.num_started
) for k, v in self.timeit_summary.items()
])

def last_error_str(self) -> str | None:
if self.last_error is None:
return None
error_text = repr(self.last_error)
if len(error_text) >= 64:
error_text = error_text[:64] + '...'
return error_text

def update(self, job: Job) -> None:
"""Mark a job as completed."""
self._job = job
Expand All @@ -347,6 +389,14 @@ def update(self, job: Job) -> None:
self._failed += 1
self._last_error = job.error
self._total_duration += job.elapse
self.merge_timeit_summary(job)

def merge_timeit_summary(self, job: Job):
for child in job.timeit.children:
for name, status in child.status().items():
if name not in self._timeit_summary:
self._timeit_summary[name] = Progress.TimeItSummary()
self._timeit_summary[name].aggregate(status)


class ProgressBar:
Expand Down Expand Up @@ -498,17 +548,17 @@ def concurrent_map(
status_fn: Callable[[Progress], dict[str, Any]] | None = None,
timeout: int | None = None,
silence_on_errors: Union[
Type[Exception], Tuple[Type[Exception], ...], None
Type[BaseException], Tuple[Type[BaseException], ...], None
] = Exception,
retry_on_errors: Union[
Type[Exception],
Tuple[Type[Exception], ...],
Type[BaseException],
Tuple[Type[BaseException], ...],
None,
] = None,
max_attempts: int = 5,
retry_interval: int | tuple[int, int] = (5, 60),
exponential_backoff: bool = True,
) -> Iterator[tuple[Any, Any, Exception | None]]:
) -> Iterator[tuple[Any, Any, BaseException | None]]:
"""Maps inputs to outptus via func concurrently under current context.
Args:
Expand Down Expand Up @@ -608,13 +658,13 @@ def update_progress_bar(progress: Progress) -> None:
if show_progress:
status = status_fn(progress)
status.update({
'AvgDuration': '%.2f seconds' % progress.avg_duration
'AvgDuration': '%.2fs' % progress.avg_duration
})
if progress.last_error is not None:
error_text = repr(progress.last_error)
if len(error_text) >= 64:
error_text = error_text[:64] + '...'
status['LastError'] = error_text
status['LastError'] = progress.last_error_str()

if progress.timeit_summary:
status['TimeIt'] = progress.timeit_summary_str()
ProgressBar.update(bar_id, delta=1, status=status)

try:
Expand Down
13 changes: 8 additions & 5 deletions langfun/core/concurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,10 +529,11 @@ def fun(x):

def test_concurrent_map_with_showing_progress(self):
def fun(x):
if x == 2:
raise ValueError('Intentional error.')
time.sleep(x)
return x
with pg.timeit('foo'):
if x == 2:
raise ValueError('Intentional error.')
time.sleep(x)
return x

string_io = io.StringIO()
with contextlib.redirect_stderr(string_io):
Expand All @@ -549,7 +550,9 @@ def fun(x):
(3, pg.MISSING_VALUE),
],
)
self.assertIn('100%', string_io.getvalue())
output = string_io.getvalue()
self.assertIn('100%', output)
self.assertIn('TimeIt=foo (', output)

def test_concurrent_map_with_showing_progress_and_status_fn(self):
def fun(x):
Expand Down
11 changes: 10 additions & 1 deletion langfun/core/eval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,8 +1234,17 @@ def process_output(example, output):
del example, output

def _status(self, progress: lf.concurrent.Progress) -> dict[str, Any]:
status = {'Model': self.lm.model_id}
status.update(self._eval_status(progress))

if progress.last_error is not None:
status['LastError'] = progress.last_error_str()
if progress.timeit_summary:
status['TimeIt'] = progress.timeit_summary_str()
return status

def _eval_status(self, progress: lf.concurrent.Progress) -> dict[str, Any]:
return {
'Model': self.lm.model_id,
'Succeeded': '%s (%d/%d)' % (
self._format_rate(progress.success_rate),
progress.succeeded,
Expand Down
3 changes: 1 addition & 2 deletions langfun/core/eval/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,9 @@ def match(self, answer: Any, groundtruth: Any) -> bool:
"""Matches answer against the groundtruth. Subclasses can override."""
return pg.eq(answer, groundtruth)

def _status(self, progress: lf.concurrent.Progress) -> dict[str, Any]:
def _eval_status(self, progress: lf.concurrent.Progress) -> dict[str, Any]:
del progress
return {
'Model': self.lm.model_id,
'Matches': '%s (%d/%d)' % (
self._format_rate(self.match_rate),
self.num_matches,
Expand Down
3 changes: 1 addition & 2 deletions langfun/core/eval/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,9 @@ def audit_processed(
def score(self, example: Any, output: Any) -> float:
"""Scores the output against its input example."""

def _status(self, progress: lf.concurrent.Progress) -> dict[str, Any]:
def _eval_status(self, progress: lf.concurrent.Progress) -> dict[str, Any]:
del progress
return {
'Model': self.lm.model_id,
'Average Score': {self.avg_score},
'Scored': '%.2f%% (%d/%d)' % (
self.score_rate * 100,
Expand Down
4 changes: 2 additions & 2 deletions langfun/core/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,8 @@ def _parallel_execute_with_currency_control(
inputs: Sequence[Any],
retry_on_errors: Union[
None,
Union[Type[Exception], Tuple[Type[Exception], str]],
Sequence[Union[Type[Exception], Tuple[Type[Exception], str]]],
Union[Type[BaseException], Tuple[Type[BaseException], str]],
Sequence[Union[Type[BaseException], Tuple[Type[BaseException], str]]],
] = RetryableLMError,
) -> Any:
"""Helper method for subclasses for implementing _sample."""
Expand Down
22 changes: 11 additions & 11 deletions langfun/core/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ def sweep(
*,
max_workers: int = 32,
silence_on_errors: Union[
Type[Exception], Tuple[Type[Exception]], None
Type[BaseException], Tuple[Type[BaseException], ...], None
] = None,
ignore_examples_with_errors: bool = True,
**kwargs,
) -> Iterator[
Tuple[
message_lib.Message | Exception, # LM input.
Union[message_lib.Message, Exception, None], # LM output.
message_lib.Message | BaseException, # LM input.
Union[message_lib.Message, BaseException, None], # LM output.
],
]:
"""Sweeps the input/output of this LangFunc concurrently.
Expand Down Expand Up @@ -73,15 +73,15 @@ def random_sample(
*,
max_workers: int = 32,
silence_on_errors: Union[
Type[Exception], Tuple[Type[Exception]], None
Type[BaseException], Tuple[Type[BaseException], ...], None
] = None,
ignore_examples_with_errors: bool = True,
seed: int | None = None,
**kwargs,
) -> Iterator[
Tuple[
message_lib.Message | Exception, # LM input.
Union[message_lib.Message, Exception, None], # LM output.
message_lib.Message | BaseException, # LM input.
Union[message_lib.Message, BaseException, None], # LM output.
],
]:
"""Random samples the input/output of this LangFunc concurrently.
Expand Down Expand Up @@ -121,14 +121,14 @@ def _concurrent_sample(
*,
max_workers: int = 32,
silence_on_errors: Union[
Type[Exception], Tuple[Type[Exception]], None
Type[BaseException], Tuple[Type[BaseException], ...], None
] = None,
ignore_examples_with_errors: bool = True,
**kwargs,
) -> Generator[
Tuple[
message_lib.Message | Exception, # LM input.
Union[message_lib.Message, Exception, None], # LM output.
message_lib.Message | BaseException, # LM input.
Union[message_lib.Message, BaseException, None], # LM output.
],
None,
None, # Sender type and return type.
Expand Down Expand Up @@ -177,6 +177,6 @@ def _call_fn(example):
else:
lm_input, lm_output = error, error
if (not ignore_examples_with_errors
or not (isinstance(lm_input, Exception)
or isinstance(lm_output, Exception))):
or not (isinstance(lm_input, BaseException)
or isinstance(lm_output, BaseException))):
yield lm_input, lm_output

0 comments on commit 1f95b77

Please sign in to comment.