Skip to content

Commit

Permalink
Make tqdm and termcolor dependencies optional.
Browse files Browse the repository at this point in the history
Users could use `lf.concurrent.progress_bar = 'console'` or `lf.concurrent.progress_bar = None` to redirect progress to console or not to show progress.

PiperOrigin-RevId: 668166587
  • Loading branch information
daiyip authored and langfun authors committed Aug 27, 2024
1 parent 2ef2cae commit d70c209
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 37 deletions.
169 changes: 148 additions & 21 deletions langfun/core/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,30 @@
# limitations under the License.
"""Utility library for handling concurrency in langfun."""

import abc
import collections
import concurrent.futures
import dataclasses
import io
import random
import sys
import threading
import time
from typing import Any, Callable, Iterable, Iterator, Literal, Sequence, Tuple, Type, Union

from langfun.core import component
from langfun.core import text_formatting
import pyglove as pg
from tqdm import auto as tqdm


progress_bar: Literal['tqdm', 'console', None] = None

try:
from tqdm import auto as tqdm # pylint: disable=g-import-not-at-top
progress_bar = 'tqdm'
except ImportError as e:
progress_bar = 'console'
tqdm = None


def with_context_access(func: Callable[..., Any]) -> Callable[..., Any]:
Expand Down Expand Up @@ -142,7 +155,6 @@ def next_wait_interval(attempt: int) -> float:
attempt = 1
return base_interval() * (2 ** (attempt - 1))

wait_interval = None
wait_intervals = []
errors = []
while True:
Expand Down Expand Up @@ -356,17 +368,17 @@ class Settings:
label: str | None
total: int
color: str | None = None
postfix: dict[str, str] | None = None
status: dict[str, Any] | None = None

@dataclasses.dataclass
class Update:
"""Progress bar update."""
bar_id: int
delta: int
postfix: Union[dict[str, str], str, None] = None
status: Union[dict[str, Any], str, None] = None
color: str | None = None

_progress_bars: dict[int, tqdm.tqdm] = {}
_progress_bars: dict[int, '_ProgressControl'] = {}
_install_requests: list[tuple[int, Settings]] = []
_updates: collections.deque[Update] = collections.deque()
_uninstall_requests: list[int] = []
Expand All @@ -378,11 +390,11 @@ def install(
label: str | None,
total: int,
color: str | None = None,
postfix: dict[str, str] | None = None,
status: dict[str, Any] | None = None,
) -> int:
"""Installs a progress bar and returns a reference id."""
with cls._lock:
settings = ProgressBar.Settings(label, total, color, postfix)
settings = ProgressBar.Settings(label, total, color, status)
bar_id = id(settings)
cls._install_requests.append((bar_id, settings))
return bar_id
Expand All @@ -392,15 +404,17 @@ def update(
cls,
bar_id: int,
delta: int = 0,
postfix: Union[dict[str, str], str, None] = None,
status: Union[dict[str, Any], str, None] = None,
color: str | None = None,
refresh: bool = True,
) -> None:
"""Report the progress for a label."""
if status is not None and not isinstance(status, (str, dict)):
raise ValueError(f'Unsupported status: {status}')
with cls._lock:
cls._updates.append(
ProgressBar.Update(
bar_id=bar_id, delta=delta, postfix=postfix, color=color,
bar_id=bar_id, delta=delta, status=status, color=color,
)
)
if refresh:
Expand All @@ -422,11 +436,11 @@ def refresh(cls) -> None:
# Process install requests.
if cls._install_requests:
for bar_id, settings in cls._install_requests:
cls._progress_bars[bar_id] = tqdm.tqdm(
cls._progress_bars[bar_id] = _progress_control(
total=settings.total,
desc=settings.label,
colour=settings.color,
postfix=settings.postfix)
label=settings.label,
color=settings.color,
status=settings.status)
cls._install_requests.clear()

# Process updates.
Expand All @@ -441,15 +455,11 @@ def refresh(cls) -> None:
if update.delta > 0:
bar.update(update.delta)

if isinstance(update.postfix, str):
bar.set_postfix_str(update.postfix, refresh=False)
elif isinstance(update.postfix, dict):
bar.set_postfix(update.postfix, refresh=False)
elif update.postfix is not None:
raise ValueError(f'Unsupported postfix: {update.postfix}')
if update.status is not None:
bar.set_status(update.status)

if update.color is not None:
bar.colour = update.color
bar.set_color(update.color)
updated_bars.add(bar)

# Refresh each updated bar just once.
Expand Down Expand Up @@ -603,7 +613,7 @@ def update_progress_bar(progress: Progress) -> None:
if len(error_text) >= 64:
error_text = error_text[:64] + '...'
status['LastError'] = error_text
ProgressBar.update(bar_id, delta=1, postfix=status)
ProgressBar.update(bar_id, delta=1, status=status)

try:
if ordered:
Expand Down Expand Up @@ -729,5 +739,122 @@ def executor_from(
raise ValueError(f'Unsupported value: {maybe_executor}.')


class _ProgressControl(pg.Object):
"""Abstract progress control."""
# Disable symbolic comparison so the hash is based on object address.
use_symbolic_comparison = False

total: int
label: str | None
color: str | None
status: str | dict[str, Any] | None

def set_color(self, color: str | None):
with pg.notify_on_change(False):
self.rebind(color=color)

def set_status(self, status: str | dict[str, Any] | None):
with pg.notify_on_change(False):
self.rebind(status=status)

@abc.abstractmethod
def update(self, delta):
"""Update progress."""

@abc.abstractmethod
def refresh(self) -> None:
"""Refresh progress bar."""


class _TqdmProgressControl(_ProgressControl):
"""Tqdm-based progress control."""

def _on_bound(self):
super()._on_bound()
assert tqdm is not None
self._tqdm = tqdm.tqdm(
total=self.total,
desc=self.label,
colour=self.color,
postfix=self.status,
)

def update(self, delta: int) -> None:
self._tqdm.update(delta)

def refresh(self):
self._tqdm.set_description(self.label, refresh=False)
if isinstance(self.status, str):
self._tqdm.set_postfix_str(self.status, refresh=False)
else:
self._tqdm.set_postfix(self.status, refresh=False)
self._tqdm.colour = self.color
self._tqdm.refresh()


class _ConsoleProgressControl(_ProgressControl):
"""Simple progress control by printing the status to the console."""

def _on_bound(self):
super()._on_bound()
self._progress = 0

def update(self, delta: int) -> None:
self._progress += delta

def refresh(self):
s = io.StringIO()
if self.label is not None:
s.write(text_formatting.colored(self.label, 'red', styles=['bold']))
s.write(': ')
s.write(
text_formatting.colored(
'%d%% (%d/%d)' %
(
self._progress * 100 // self.total,
self._progress,
self.total,
),
color=self.color or 'green'
)
)
if self.status is not None:
status = repr(self.status) if isinstance(
self.status, dict) else self.status
s.write(f' : {status}')
sys.stderr.write(s.getvalue() + '\n')


class _NoopProgressControl(_ProgressControl):
"""No-op progress control."""

def update(self, delta: int) -> None:
pass

def refresh(self) -> None:
pass


def _progress_control(
total: int,
label: str | None,
color: str | None,
status: str | dict[str, Any] | None,
) -> _ProgressControl:
"""Creates a process control."""
if progress_bar == 'tqdm':
if not tqdm:
raise RuntimeError(
'Please install package "tqdm" to use `tqdm` progress bar.'
)
return _TqdmProgressControl(total, label, color, status)
elif progress_bar == 'console':
return _ConsoleProgressControl(total, label, color, status)
elif progress_bar is None:
return _NoopProgressControl(total, label, color, status)
else:
raise ValueError(f'Unsupported progress bar type: {progress_bar}')


# The global executor pool based on resource IDs.
_executor_pool = ExecutorPool()
66 changes: 58 additions & 8 deletions langfun/core/concurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,57 @@ def fun2(unused_x):
self.assertIs(p.last_error, job2.error)


class ProgressControlTest(unittest.TestCase):

def test_noop(self):
concurrent.progress_bar = None
ctrl = concurrent._progress_control(100, 'noop', 'blue', None)
self.assertIsInstance(ctrl, concurrent._NoopProgressControl)
string_io = io.StringIO()
with contextlib.redirect_stderr(string_io):
ctrl.update(1)
ctrl.refresh()
self.assertEqual(string_io.getvalue(), '')
concurrent.progress_bar = 'tqdm'

def test_console(self):
concurrent.progress_bar = 'console'
ctrl = concurrent._progress_control(100, 'foo', 'blue', None)
self.assertIsInstance(ctrl, concurrent._ConsoleProgressControl)
string_io = io.StringIO()
with contextlib.redirect_stderr(string_io):
ctrl.set_status('bar')
ctrl.update(10)
ctrl.refresh()
self.assertEqual(
string_io.getvalue(),
'\x1b[1m\x1b[31mfoo\x1b[0m: \x1b[34m10% (10/100)\x1b[0m : bar\n'
)
concurrent.progress_bar = 'tqdm'

def test_tqdm(self):
concurrent.progress_bar = 'tqdm'
string_io = io.StringIO()
with contextlib.redirect_stderr(string_io):
ctrl = concurrent._progress_control(100, 'foo', 'blue', None)
self.assertIsInstance(ctrl, concurrent._TqdmProgressControl)
ctrl.update(10)
ctrl.refresh()
self.assertIn('10/100', string_io.getvalue())

tqdm = concurrent.tqdm
concurrent.tqdm = None
with self.assertRaisesRegex(RuntimeError, 'install package "tqdm"'):
_ = concurrent._progress_control(100, 'foo', 'blue', None)
concurrent.tqdm = tqdm

def test_unsupported(self):
concurrent.progress_bar = 'unknown'
with self.assertRaisesRegex(ValueError, 'Unsupported progress bar type'):
_ = concurrent._progress_control(100, 'foo', 'blue', None)
concurrent.progress_bar = 'tqdm'


class ProgressBarTest(unittest.TestCase):

def test_multithread_support(self):
Expand All @@ -241,26 +292,25 @@ def test_multithread_support(self):
bar_id = concurrent.ProgressBar.install(None, 5)
def fun(x):
del x
concurrent.ProgressBar.update(bar_id, 1, postfix=None)
concurrent.ProgressBar.update(bar_id, 1, status=None)

for _ in concurrent.concurrent_execute(fun, range(5)):
concurrent.ProgressBar.refresh()
concurrent.ProgressBar.uninstall(bar_id)
output_str = string_io.getvalue()
print(output_str)
self.assertIn('100%', output_str)
self.assertIn('5/5', output_str)

def test_report(self):
string_io = io.StringIO()
with contextlib.redirect_stderr(string_io):
bar_id = concurrent.ProgressBar.install(None, 4)
concurrent.ProgressBar.update(bar_id, 1, postfix=None)
concurrent.ProgressBar.update(bar_id, 1, postfix='hello')
concurrent.ProgressBar.update(bar_id, color='lightgreen')
concurrent.ProgressBar.update(bar_id, 2, postfix=dict(x=1))
with self.assertRaisesRegex(ValueError, 'Unsupported postfix'):
concurrent.ProgressBar.update(bar_id, 0, postfix=1)
concurrent.ProgressBar.update(bar_id, 1, status=None)
concurrent.ProgressBar.update(bar_id, 1, status='hello')
concurrent.ProgressBar.update(bar_id, color='green')
concurrent.ProgressBar.update(bar_id, 2, status=dict(x=1))
with self.assertRaisesRegex(ValueError, 'Unsupported status'):
concurrent.ProgressBar.update(bar_id, 0, status=1)
concurrent.ProgressBar.uninstall(bar_id)
self.assertIn('1/4', string_io.getvalue())
self.assertIn('2/4', string_io.getvalue())
Expand Down
Loading

0 comments on commit d70c209

Please sign in to comment.