|
20 | 20 | from copy import copy
|
21 | 21 | from functools import wraps
|
22 | 22 | from importlib import import_module
|
| 23 | +from textwrap import indent |
23 | 24 | from typing import Any, Callable, cast, TypeVar
|
24 | 25 |
|
25 | 26 | import numpy as np
|
@@ -52,25 +53,37 @@ def strtobool(val: Any) -> bool:
|
52 | 53 | LOGGING_LEVEL = os.environ.get("RL_LOGGING_LEVEL", "INFO")
|
53 | 54 | logger = logging.getLogger("torchrl")
|
54 | 55 | logger.setLevel(getattr(logging, LOGGING_LEVEL))
|
55 |
| -# Disable propagation to the root logger |
56 | 56 | logger.propagate = False
|
57 |
| -# Remove all attached handlers |
| 57 | +# Clear existing handlers |
58 | 58 | while logger.hasHandlers():
|
59 | 59 | logger.removeHandler(logger.handlers[0])
|
60 | 60 | stream_handlers = {
|
61 | 61 | "stdout": sys.stdout,
|
62 | 62 | "stderr": sys.stderr,
|
63 | 63 | }
|
64 | 64 | TORCHRL_CONSOLE_STREAM = os.getenv("TORCHRL_CONSOLE_STREAM")
|
65 |
| -if TORCHRL_CONSOLE_STREAM: |
66 |
| - stream_handler = stream_handlers[TORCHRL_CONSOLE_STREAM] |
67 |
| -else: |
68 |
| - stream_handler = None |
69 |
| -console_handler = logging.StreamHandler(stream=stream_handler) |
70 |
| - |
71 |
| -console_handler.setLevel(logging.INFO) |
72 |
| -formatter = logging.Formatter("%(asctime)s [%(name)s][%(levelname)s] %(message)s") |
73 |
| -console_handler.setFormatter(formatter) |
| 65 | +stream_handler = stream_handlers.get(TORCHRL_CONSOLE_STREAM, sys.stdout) |
| 66 | + |
| 67 | + |
| 68 | +# Create colored handler |
| 69 | +class _CustomFormatter(logging.Formatter): |
| 70 | + def format(self, record): |
| 71 | + # Format the initial part in green |
| 72 | + green_format = "\033[92m%(asctime)s [%(name)s][%(levelname)s]\033[0m" |
| 73 | + # Format the message part |
| 74 | + message_format = "%(message)s" |
| 75 | + # End marker in green |
| 76 | + end_marker = "\033[92m [END]\033[0m" |
| 77 | + # Combine all parts |
| 78 | + formatted_message = logging.Formatter( |
| 79 | + green_format + indent(message_format, " " * 4) + end_marker |
| 80 | + ).format(record) |
| 81 | + |
| 82 | + return formatted_message |
| 83 | + |
| 84 | + |
| 85 | +console_handler = logging.StreamHandler(stream_handler) |
| 86 | +console_handler.setFormatter(_CustomFormatter()) |
74 | 87 | logger.addHandler(console_handler)
|
75 | 88 |
|
76 | 89 | VERBOSE = strtobool(os.environ.get("VERBOSE", str(logger.isEnabledFor(logging.DEBUG))))
|
|
0 commit comments