Skip to content

Commit

Permalink
Fix logging configuration again
Browse files Browse the repository at this point in the history
* Only use `tqdm.write()` if `tqdm` is active, defer to stderr
* Correct log formatter for TqdmLoggingHandler
* If `rich` is installed and `SD_WEBUI_RICH_LOG` is set, use `rich`'s formatter
  • Loading branch information
akx committed Jan 4, 2024
1 parent f903b4d commit 6fa42e9
Showing 1 changed file with 39 additions and 23 deletions.
62 changes: 39 additions & 23 deletions modules/logging_config.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,57 @@
import os
import logging
import os

try:
from tqdm.auto import tqdm
from tqdm import tqdm


class TqdmLoggingHandler(logging.Handler):
def __init__(self, level=logging.INFO):
super().__init__(level)
def __init__(self, fallback_handler: logging.Handler):
super().__init__()
self.fallback_handler = fallback_handler

def emit(self, record):
try:
msg = self.format(record)
tqdm.write(msg)
self.flush()
# If there are active tqdm progress bars,
# attempt to not interfere with them.
if tqdm._instances:
tqdm.write(self.format(record))
else:
self.fallback_handler.emit(record)
except Exception:
self.handleError(record)
self.fallback_handler.emit(record)

TQDM_IMPORTED = True
except ImportError:
# tqdm does not exist before first launch
# I will import once the UI finishes seting up the enviroment and reloads.
TQDM_IMPORTED = False
TqdmLoggingHandler = None


def setup_logging(loglevel):
if loglevel is None:
loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL")

loghandlers = []
if not loglevel:
return

if logging.root.handlers:
# Already configured, do not interfere
return

if os.environ.get("SD_WEBUI_RICH_LOG"):
from rich.logging import RichHandler
handler = RichHandler()
else:
handler = logging.StreamHandler()

if TqdmLoggingHandler:
handler = TqdmLoggingHandler(handler)

formatter = logging.Formatter(
'%(asctime)s %(levelname)s [%(name)s] %(message)s',
'%Y-%m-%d %H:%M:%S',
)

if TQDM_IMPORTED:
loghandlers.append(TqdmLoggingHandler())
handler.setFormatter(formatter)

if loglevel:
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
logging.basicConfig(
level=log_level,
format='%(asctime)s %(levelname)s [%(name)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=loghandlers
)
log_level = getattr(logging, loglevel.upper(), None) or logging.INFO
logging.root.setLevel(log_level)
logging.root.addHandler(handler)

0 comments on commit 6fa42e9

Please sign in to comment.