Skip to content

Commit

Permalink
Add show_progress: bool to robust_map.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 689799990
  • Loading branch information
Edward2 Team authored and edward-bot committed Oct 25, 2024
1 parent 10f63b7 commit b33fb98
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions edward2/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def robust_map(
max_workers: int | None = ...,
raise_error: Literal[False] = ...,
retry_exception_types: list[type[Exception]] | None = ...,
show_progress: bool = ...,
) -> list[U | V]:
...

Expand All @@ -54,6 +55,7 @@ def robust_map(
max_workers: int | None = ...,
raise_error: Literal[True] = ...,
retry_exception_types: list[type[Exception]] | None = ...,
show_progress: bool = ...,
) -> list[U]:
...

Expand All @@ -69,6 +71,7 @@ def robust_map(
max_workers: int | None = ...,
raise_error: bool = ...,
retry_exception_types: list[type[Exception]] | None = ...,
show_progress: bool = ...,
) -> list[U | V]:
...

Expand All @@ -84,6 +87,7 @@ def robust_map(
max_workers: int | None = ...,
raise_error: bool = ...,
progress_desc: str = ...,
show_progress: bool = ...,
) -> list[U | V]:
...

Expand All @@ -100,6 +104,7 @@ def robust_map(
raise_error: bool = False,
retry_exception_types: list[type[Exception]] | None = None,
progress_desc: str = 'robust_map',
show_progress: bool = True,
) -> list[U | V]:
"""Maps a function to inputs using a threadpool.
Expand All @@ -126,6 +131,7 @@ def robust_map(
retry_exception_types: Exception types to retry on. Defaults to retrying
only on grpc's RPC exceptions.
progress_desc: A string to display in the progress bar.
show_progress: Whether to show the progress bar.
Returns:
A list of items each of type U. They are the outputs of `fn` applied to
Expand Down Expand Up @@ -162,7 +168,12 @@ def robust_map(
num_existing = len(index_to_output)
num_inputs = len(inputs)
logging.info('Found %s/%s existing examples.', num_existing, num_inputs)
progress_bar = tqdm.tqdm(total=num_inputs - num_existing, desc=progress_desc)
if show_progress:
progress_bar = tqdm.tqdm(
total=num_inputs - num_existing, desc=progress_desc
)
else:
progress_bar = None
indices = [i for i in range(num_inputs) if i not in index_to_output.keys()]
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers
Expand All @@ -175,7 +186,8 @@ def robust_map(
try:
output = future.result()
index_to_output[index] = output
progress_bar.update(1)
if progress_bar:
progress_bar.update(1)
except tenacity.RetryError as e:
if raise_error:
logging.exception('Item %s exceeded max retries.', index)
Expand All @@ -189,6 +201,7 @@ def robust_map(
e,
)
index_to_output[index] = error_output
progress_bar.update(1)
if progress_bar:
progress_bar.update(1)
outputs = [index_to_output[i] for i in range(num_inputs)]
return outputs

0 comments on commit b33fb98

Please sign in to comment.