Skip to content

Commit dba8936

Browse files
FindHaofacebook-github-bot
authored andcommitted
Refactor logging configuration and clean up test cases (#42)
Summary: - Removed unused log parsing functions from `test_tritonparse.py` to streamline the test code. - Added calls to `triton.knobs.compilation.listener` and `tritonparse.structured_logging.clear_logging_config()` to ensure proper cleanup and reset of logging configurations after tests. - Improved test reliability by ensuring that logging settings do not persist between test runs. These changes enhance the maintainability of the test suite and ensure that logging configurations are correctly managed during testing. Pull Request resolved: #42 Reviewed By: Sibylau Differential Revision: D78772087 Pulled By: FindHao fbshipit-source-id: b44b8093d5f032f21db673766a3a2dc125cef910
1 parent 06fa342 commit dba8936

File tree

2 files changed

+45
-45
lines changed

2 files changed

+45
-45
lines changed

tests/test_tritonparse.py

Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def compile_listener(
164164
torch.cuda.synchronize()
165165
assert "python_source" in trace_data
166166
assert "file_path" in trace_data["python_source"]
167+
triton.knobs.compilation.listener = None
167168

168169
@unittest.skipUnless(torch.cuda.is_available(), "CUDA not available")
169170
def test_whole_workflow(self):
@@ -222,45 +223,6 @@ def run_test_kernel(x):
222223
)
223224
print(f"Found {len(log_files)} log files in {temp_dir_logs}: {log_files}")
224225

225-
def parse_log_line(line: str, line_num: int) -> dict | None:
226-
"""Parse a single log line and extract event data"""
227-
try:
228-
return json.loads(line.strip())
229-
except json.JSONDecodeError as e:
230-
print(f" Line {line_num}: JSON decode error - {e}")
231-
return None
232-
233-
def process_event_data(
234-
event_data: dict, line_num: int, event_counts: dict
235-
) -> None:
236-
"""Process event data and update counts"""
237-
try:
238-
event_type = event_data.get("event_type")
239-
if event_type is None:
240-
return
241-
242-
if event_type in event_counts:
243-
event_counts[event_type] += 1
244-
print(
245-
f" Line {line_num}: event_type = '{event_type}' (count: {event_counts[event_type]})"
246-
)
247-
else:
248-
print(
249-
f" Line {line_num}: event_type = '{event_type}' (not tracked)"
250-
)
251-
except (KeyError, TypeError) as e:
252-
print(f" Line {line_num}: Data structure error - {e}")
253-
254-
def count_events_in_file(file_path: str, event_counts: dict) -> None:
255-
"""Count events in a single log file"""
256-
print(f"Checking event types in: {os.path.basename(file_path)}")
257-
258-
with open(file_path, "r") as f:
259-
for line_num, line in enumerate(f, 1):
260-
event_data = parse_log_line(line, line_num)
261-
if event_data:
262-
process_event_data(event_data, line_num, event_counts)
263-
264226
def check_event_type_counts_in_logs(log_dir: str) -> dict:
265227
"""Count 'launch' and unique 'compilation' events in all log files"""
266228
event_counts = {"launch": 0}
@@ -326,6 +288,7 @@ def check_event_type_counts_in_logs(log_dir: str) -> dict:
326288
# Clean up
327289
shutil.rmtree(temp_dir)
328290
print("✓ Cleaned up temporary directory")
291+
tritonparse.structured_logging.clear_logging_config()
329292

330293
@unittest.skipUnless(torch.cuda.is_available(), "CUDA not available")
331294
def test_complex_kernels(self):
@@ -628,6 +591,7 @@ def fused_op(a, b, c, scale_factor: float, activation: str):
628591
# Clean up
629592
shutil.rmtree(temp_dir)
630593
print("✓ Cleaned up temporary directory")
594+
tritonparse.structured_logging.clear_logging_config()
631595

632596

633597
if __name__ == "__main__":

tritonparse/structured_logging.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -941,12 +941,12 @@ def __call__(self, metadata):
941941
def maybe_enable_trace_launch():
942942
global _trace_launch_enabled
943943
if TRITON_TRACE_LAUNCH and not _trace_launch_enabled:
944-
import triton
944+
from triton import knobs
945945

946946
launch_hook = LaunchHookImpl()
947947
jit_hook = JITHookImpl()
948-
triton.knobs.runtime.jit_post_compile_hook = jit_hook
949-
triton.knobs.runtime.launch_enter_hook = launch_hook
948+
knobs.runtime.jit_post_compile_hook = jit_hook
949+
knobs.runtime.launch_enter_hook = launch_hook
950950

951951
_trace_launch_enabled = True
952952

@@ -964,7 +964,7 @@ def init_basic(trace_folder: Optional[str] = None):
964964
maybe_enable_debug_logging()
965965
if triton_trace_folder is not None and trace_folder is not None:
966966
log.info(
967-
"Conflict settings: TRITON_TRACE is already set to %s, we will use provided trace_folder(%s) instead.",
967+
"Conflict settings: triton_trace_folder is already set to %s, we will use provided trace_folder(%s) instead.",
968968
triton_trace_folder,
969969
trace_folder,
970970
)
@@ -995,7 +995,43 @@ def init(trace_folder: Optional[str] = None, enable_trace_launch: bool = False):
995995
global TRITON_TRACE_LAUNCH
996996
if enable_trace_launch:
997997
TRITON_TRACE_LAUNCH = True
998-
import triton
999998

1000999
init_basic(trace_folder)
1001-
triton.knobs.compilation.listener = maybe_trace_triton
1000+
from triton import knobs
1001+
1002+
knobs.compilation.listener = maybe_trace_triton
1003+
1004+
1005+
def clear_logging_config():
1006+
"""
1007+
Clear all configurations made by init() and init_basic().
1008+
1009+
This function resets the logging handlers, global state variables,
1010+
and Triton knobs to their default states, effectively disabling
1011+
the custom tracing.
1012+
1013+
WARNING: This function is not supposed to be called unless you are sure
1014+
you want to clear the logging config.
1015+
"""
1016+
global TRITON_TRACE_HANDLER, triton_trace_folder, _KERNEL_ALLOWLIST_PATTERNS
1017+
global _trace_launch_enabled
1018+
1019+
# 1. Clean up the log handler
1020+
if TRITON_TRACE_HANDLER is not None:
1021+
if TRITON_TRACE_HANDLER in triton_trace_log.handlers:
1022+
triton_trace_log.removeHandler(TRITON_TRACE_HANDLER)
1023+
TRITON_TRACE_HANDLER.close()
1024+
TRITON_TRACE_HANDLER = None
1025+
1026+
# 2. Reset global state variables
1027+
triton_trace_folder = None
1028+
_KERNEL_ALLOWLIST_PATTERNS = None
1029+
_trace_launch_enabled = False
1030+
1031+
# 3. Reset Triton knobs
1032+
# Check if triton was actually imported and used
1033+
from triton import knobs
1034+
1035+
knobs.compilation.listener = None
1036+
knobs.runtime.jit_post_compile_hook = None
1037+
knobs.runtime.launch_enter_hook = None

0 commit comments

Comments
 (0)