Skip to content

Commit

Permalink
lf.concurrent to use pg.object_utils.TimeIt.StatusSummary for pro…
Browse files Browse the repository at this point in the history
…gress reporting.

PiperOrigin-RevId: 691501085
  • Loading branch information
daiyip authored and langfun authors committed Oct 30, 2024
1 parent 1d1ef24 commit 3f6039d
Showing 1 changed file with 9 additions and 38 deletions.
47 changes: 9 additions & 38 deletions langfun/core/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ class Job:
func: Callable[[Any], Any]
arg: Any
result: Any = pg.MISSING_VALUE
error: BaseException | None = None
timeit: pg.object_utils.TimeIt = dataclasses.field(
default_factory=lambda: pg.object_utils.TimeIt('job')
)
Expand All @@ -269,56 +270,33 @@ def elapse(self) -> float:
"""Returns the running time in seconds since the job get started."""
return self.timeit.elapse

@property
def error(self) -> BaseException | None:
"""Returns the error if the job failed."""
return self.timeit.error

def __call__(self) -> Any:
try:
with self.timeit:
self.result = self.func(self.arg)
return self.result
except BaseException as e: # pylint: disable=broad-exception-caught
self.error = e
return e

def mark_canceled(self, error: BaseException) -> None:
"""Marks the job as canceled."""
self.timeit.end(error)
self.error = error


@dataclasses.dataclass
class Progress:
"""Concurrent processing progress."""
total: int

@dataclasses.dataclass
class TimeItSummary:
"""Execution details for each `pg.timeit`."""

num_started: int = 0
num_ended: int = 0
num_failed: int = 0
avg_duration: float = 0.0

def aggregate(self, status: pg.object_utils.TimeIt.Status):
self.avg_duration = (
(self.avg_duration * self.num_started + status.elapse)
/ (self.num_started + 1)
)
self.num_started += 1
if status.has_ended:
self.num_ended += 1
if status.has_error:
self.num_failed += 1

_succeeded: int = 0
_failed: int = 0
_last_error: BaseException | None = None
_total_duration: float = 0.0
_job: Job | None = None
_timeit_summary: dict[str, TimeItSummary] = dataclasses.field(
default_factory=dict
_timeit_summary: pg.object_utils.TimeIt.StatusSummary = dataclasses.field(
default_factory=pg.object_utils.TimeIt.StatusSummary
)

@property
Expand Down Expand Up @@ -368,7 +346,7 @@ def avg_duration(self) -> float:
return self._total_duration / self.completed

@property
def timeit_summary(self) -> dict[str, TimeItSummary]:
def timeit_summary(self) -> pg.object_utils.TimeIt.StatusSummary:
"""Returns the aggregated summary for each `pg.timeit`."""
return self._timeit_summary

Expand All @@ -377,8 +355,8 @@ def timeit_summary_str(self) -> str | None:
return None
return ', '.join([
'%s (%.2fs, %d/%d)' % (
k, v.avg_duration, v.num_ended, v.num_started
) for k, v in self.timeit_summary.items()
k.lstrip('job.'), v.avg_duration, v.num_ended, v.num_started
) for k, v in self.timeit_summary.breakdown.items() if k != 'job'
])

def last_error_str(self) -> str | None:
Expand All @@ -398,14 +376,7 @@ def update(self, job: Job) -> None:
self._failed += 1
self._last_error = job.error
self._total_duration += job.elapse
self.merge_timeit_summary(job)

def merge_timeit_summary(self, job: Job):
for child in job.timeit.children:
for name, status in child.status().items():
if name not in self._timeit_summary:
self._timeit_summary[name] = Progress.TimeItSummary()
self._timeit_summary[name].aggregate(status)
self._timeit_summary.aggregate(job.timeit.status())


class ProgressBar:
Expand Down

0 comments on commit 3f6039d

Please sign in to comment.