Skip to content

Commit 1b9d2c1

Browse files
author
Vincent Moens
committed
[Feature] Colored logger
ghstack-source-id: 5f5827b Pull-Request-resolved: #2967
1 parent 50ecb15 commit 1b9d2c1

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

torchrl/_utils.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from copy import copy
2121
from functools import wraps
2222
from importlib import import_module
23+
from textwrap import indent
2324
from typing import Any, Callable, cast, TypeVar
2425

2526
import numpy as np
@@ -52,25 +53,37 @@ def strtobool(val: Any) -> bool:
5253
LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO")
5354
logger = logging.getLogger("torchrl")
5455
logger.setLevel(getattr(logging, LOGGING_LEVEL))
55-
# Disable propagation to the root logger
5656
logger.propagate = False
57-
# Remove all attached handlers
57+
# Clear existing handlers
5858
while logger.hasHandlers():
5959
logger.removeHandler(logger.handlers[0])
6060
stream_handlers = {
6161
"stdout": sys.stdout,
6262
"stderr": sys.stderr,
6363
}
6464
TORCHRL_CONSOLE_STREAM = os.getenv("TORCHRL_CONSOLE_STREAM")
65-
if TORCHRL_CONSOLE_STREAM:
66-
stream_handler = stream_handlers[TORCHRL_CONSOLE_STREAM]
67-
else:
68-
stream_handler = None
69-
console_handler = logging.StreamHandler(stream=stream_handler)
70-
71-
console_handler.setLevel(logging.INFO)
72-
formatter = logging.Formatter("%(asctime)s [%(name)s][%(levelname)s] %(message)s")
73-
console_handler.setFormatter(formatter)
65+
stream_handler = stream_handlers.get(TORCHRL_CONSOLE_STREAM, sys.stdout)
66+
67+
68+
# Create colored handler
69+
class _CustomFormatter(logging.Formatter):
70+
def format(self, record):
71+
# Format the initial part in green
72+
green_format = "\033[92m%(asctime)s [%(name)s][%(levelname)s]\033[0m"
73+
# Format the message part
74+
message_format = "%(message)s"
75+
# End marker in green
76+
end_marker = "\033[92m [END]\033[0m"
77+
# Combine all parts
78+
formatted_message = logging.Formatter(
79+
green_format + indent(message_format, " " * 4) + end_marker
80+
).format(record)
81+
82+
return formatted_message
83+
84+
85+
console_handler = logging.StreamHandler(stream_handler)
86+
console_handler.setFormatter(_CustomFormatter())
7487
logger.addHandler(console_handler)
7588

7689
VERBOSE = strtobool(os.environ.get("VERBOSE", str(logger.isEnabledFor(logging.DEBUG))))

0 commit comments

Comments
 (0)