Skip to content

Commit a87f472

Browse files
FindHaofacebook-github-bot
authored andcommitted
Add Complex Kernel Test (#38)
Summary: This pull request introduces a new, comprehensive test case (`tests/test_complex_kernels.py`) to validate the trace parsing logic, especially for the recently added `launch_diff` feature. ### Motivation The existing tests were insufficient to fully validate the parser's ability to handle more realistic scenarios involving multiple, distinct kernels and varied launch parameters within a single run. This test was created to provide a robust, end-to-end validation for the kernel grouping and launch diffing functionalities. ### What the Test Does 1. **Defines Two Kernels**: * `matmul_kernel`: An autotuned matrix multiplication kernel designed to test the parser's handling of autotuner-generated configurations. * `fused_op_kernel`: A simpler element-wise kernel used to test basic launch parameter variations. 2. **Multiple, Varied Launches**: * The `matmul` kernel is launched three times with different input tensor shapes. * The `fused_op` kernel is launched four times with different scalar arguments (`scale_factor`) and compile-time constants (`activation`). 3. **End-to-End Parsing**: * After executing all kernel launches and generating a trace log, the test script calls `tritonparse.utils.unified_parse`. * This ensures that the entire trace, containing multiple compilations (due to autotuning) and multiple launch events, can be processed successfully into the structured, diff-summarized format. ### How it Improves Testing * Provides a realistic test case that mimics a complex model with multiple custom kernels. * Specifically validates that launch events are correctly grouped with their corresponding compilation events, even when interleaved. * Acts as a strong regression test for the `launch_diff` generation, ensuring that parameter variations are correctly identified and summarized. Pull Request resolved: #38 Test Plan: ```bash % python -m unittest tests.test_tritonparse.TestTritonparseCUDA.test_complex_kernels -v test_complex_kernels (tests.test_tritonparse.TestTritonparseCUDA.test_complex_kernels) A more complex test case involving two distinct Triton kernels, one of which uses autotuning. ... Temporary directory: /tmp/tmp35ksprrc --- Testing Matmul Kernel (3 launches) --- WARNING:tritonparse.structured_logging:fn JitFunctionInfo(module='tests.test_tritonparse', name='TestTritonparseCUDA.test_complex_kernels.<locals>.matmul_kernel', jit_function=JITFunction(tests.test_tritonparse:TestTritonparseCUDA.test_complex_kernels.<locals>.matmul_kernel)) launch_metadata is not None: <function add_launch_metadata at 0x7f9a504e05e0>. It will be overridden by tritonparse. WARNING:tritonparse.structured_logging:fn JitFunctionInfo(module='tests.test_tritonparse', name='TestTritonparseCUDA.test_complex_kernels.<locals>.matmul_kernel', jit_function=JITFunction(tests.test_tritonparse:TestTritonparseCUDA.test_complex_kernels.<locals>.matmul_kernel)) launch_metadata is not None: <function add_launch_metadata at 0x7f9a504e05e0>. It will be overridden by tritonparse. Matmul Launch 1 (16x16 @ 16x16) done. Matmul Launch 2 (32x16 @ 16x32) done. Matmul Launch 3 (16x32 @ 32x16) done. --- Testing Fused Op Kernel (4 launches) --- Fused Op Launch 1: scale=1.0, activation=None Fused Op Launch 2: scale=2.5, activation=None Fused Op Launch 3: scale=1.0, activation='relu' WARNING:tritonparse.structured_logging:fn JitFunctionInfo(module='tests.test_tritonparse', name='TestTritonparseCUDA.test_complex_kernels.<locals>.fused_op_kernel', jit_function=JITFunction(tests.test_tritonparse:TestTritonparseCUDA.test_complex_kernels.<locals>.fused_op_kernel)) launch_metadata is not None: <function add_launch_metadata at 0x7f9a504e05e0>. It will be overridden by tritonparse. Fused Op Launch 4: scale=1.0, activation='relu', different size All kernels executed. tritonparse log file list: /tmp/tmpowi_t2et/log_file_list.json INFO:tritonparse:Copying parsed logs from /tmp/tmpowi_t2et to /tmp/tmp35ksprrc/parsed_output_complex ✓ Generated 1 log files ✓ Generated 2 parsed files ✓ Found 1 .json files and 1 .ndjson.gz files Checking launch_diff events in dedicated_log_triton_trace_yhao__mapped.ndjson.gz Line 591: Found launch_diff event (count: 1) Line 1315: Found launch_diff event (count: 2) Line 2033: Found launch_diff event (count: 3) Line 2037: Found launch_diff event (count: 4) Line 2041: Found launch_diff event (count: 5) ✓ Total launch_diff events found: 5 ✓ Verified 5 launch_diff events in parsed output ✓ Cleaned up temporary directory ok ---------------------------------------------------------------------- Ran 1 test in 8.305s OK ``` Reviewed By: davidberard98 Differential Revision: D78667083 Pulled By: FindHao fbshipit-source-id: 8c3b38793fedf1cf40658fca3c3548988b389b6a
1 parent 5076f8c commit a87f472

File tree

1 file changed

+302
-0
lines changed

1 file changed

+302
-0
lines changed

tests/test_tritonparse.py

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,308 @@ def check_event_type_counts_in_logs(log_dir: str) -> dict:
327327
shutil.rmtree(temp_dir)
328328
print("✓ Cleaned up temporary directory")
329329

330+
@unittest.skipUnless(torch.cuda.is_available(), "CUDA not available")
331+
def test_complex_kernels(self):
332+
"""
333+
A more complex test case involving two distinct Triton kernels, one of which uses autotuning.
334+
This test is designed to validate the launch_diff functionality with multiple, varied launches.
335+
"""
336+
337+
# Kernel 1: Autotuned Matmul (simplified configs for small scale)
338+
@triton.autotune(
339+
configs=[
340+
triton.Config(
341+
{
342+
"BLOCK_SIZE_M": 16,
343+
"BLOCK_SIZE_N": 16,
344+
"BLOCK_SIZE_K": 16,
345+
"GROUP_SIZE_M": 1,
346+
},
347+
num_stages=1,
348+
num_warps=1,
349+
),
350+
triton.Config(
351+
{
352+
"BLOCK_SIZE_M": 32,
353+
"BLOCK_SIZE_N": 16,
354+
"BLOCK_SIZE_K": 16,
355+
"GROUP_SIZE_M": 1,
356+
},
357+
num_stages=1,
358+
num_warps=1,
359+
),
360+
triton.Config(
361+
{
362+
"BLOCK_SIZE_M": 16,
363+
"BLOCK_SIZE_N": 32,
364+
"BLOCK_SIZE_K": 16,
365+
"GROUP_SIZE_M": 1,
366+
},
367+
num_stages=1,
368+
num_warps=1,
369+
),
370+
],
371+
key=["M", "N", "K"],
372+
)
373+
@triton.jit
374+
def matmul_kernel(
375+
a,
376+
b,
377+
c,
378+
M,
379+
N,
380+
K,
381+
stride_am,
382+
stride_ak,
383+
stride_bk,
384+
stride_bn,
385+
stride_cm,
386+
stride_cn,
387+
BLOCK_SIZE_M: tl.constexpr,
388+
BLOCK_SIZE_N: tl.constexpr,
389+
BLOCK_SIZE_K: tl.constexpr,
390+
GROUP_SIZE_M: tl.constexpr,
391+
ACTIVATION: tl.constexpr,
392+
):
393+
pid = tl.program_id(axis=0)
394+
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
395+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
396+
num_pid_in_group = GROUP_SIZE_M * num_pid_n
397+
group_id = pid // num_pid_in_group
398+
first_pid_m = group_id * GROUP_SIZE_M
399+
group_size = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
400+
pid_m = first_pid_m + (pid % group_size)
401+
pid_n = (pid % num_pid_in_group) // group_size
402+
403+
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
404+
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
405+
offs_k = tl.arange(0, BLOCK_SIZE_K)
406+
a_ptrs = a + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
407+
b_ptrs = b + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
408+
409+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
410+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
411+
a_block = tl.load(
412+
a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0
413+
)
414+
b_block = tl.load(
415+
b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0
416+
)
417+
accumulator += tl.dot(a_block, b_block)
418+
a_ptrs += BLOCK_SIZE_K * stride_ak
419+
b_ptrs += BLOCK_SIZE_K * stride_bk
420+
c_block = accumulator.to(tl.float16)
421+
422+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
423+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
424+
c_ptrs = c + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
425+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
426+
tl.store(c_ptrs, c_block, mask=c_mask)
427+
428+
def matmul(a, b):
429+
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
430+
M, K = a.shape
431+
K, N = b.shape
432+
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
433+
434+
def grid(META):
435+
return (
436+
triton.cdiv(M, META["BLOCK_SIZE_M"])
437+
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
438+
)
439+
440+
matmul_kernel[grid](
441+
a,
442+
b,
443+
c,
444+
M,
445+
N,
446+
K,
447+
a.stride(0),
448+
a.stride(1),
449+
b.stride(0),
450+
b.stride(1),
451+
c.stride(0),
452+
c.stride(1),
453+
ACTIVATION=None,
454+
)
455+
return c
456+
457+
# Kernel 2: Fused element-wise operation
458+
@triton.jit
459+
def fused_op_kernel(
460+
a_ptr,
461+
b_ptr,
462+
c_ptr,
463+
output_ptr,
464+
n_elements,
465+
scale_factor: float,
466+
ACTIVATION: tl.constexpr,
467+
BLOCK_SIZE: tl.constexpr,
468+
):
469+
pid = tl.program_id(axis=0)
470+
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
471+
mask = offsets < n_elements
472+
473+
a = tl.load(a_ptr + offsets, mask=mask)
474+
b = tl.load(b_ptr + offsets, mask=mask)
475+
c = tl.load(c_ptr + offsets, mask=mask)
476+
477+
result = a * b * scale_factor + c
478+
if ACTIVATION == "relu":
479+
result = tl.where(result > 0, result, 0.0)
480+
481+
tl.store(output_ptr + offsets, result, mask=mask)
482+
483+
def fused_op(a, b, c, scale_factor: float, activation: str):
484+
n_elements = a.numel()
485+
output = torch.empty_like(a)
486+
BLOCK_SIZE = 8 # Reduced from 1024 for small scale testing
487+
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
488+
fused_op_kernel[grid](
489+
a,
490+
b,
491+
c,
492+
output,
493+
n_elements,
494+
scale_factor,
495+
ACTIVATION=activation,
496+
BLOCK_SIZE=BLOCK_SIZE,
497+
)
498+
return output
499+
500+
# Set up test environment
501+
temp_dir = tempfile.mkdtemp()
502+
log_path = os.path.join(temp_dir, "logs_complex")
503+
parsed_output_path = os.path.join(temp_dir, "parsed_output_complex")
504+
os.makedirs(log_path, exist_ok=True)
505+
os.makedirs(parsed_output_path, exist_ok=True)
506+
print(f"Temporary directory: {temp_dir}")
507+
508+
# Initialize logging
509+
tritonparse.structured_logging.init(log_path, enable_trace_launch=True)
510+
511+
try:
512+
# Main test function logic
513+
torch.manual_seed(0)
514+
515+
# --- Matmul Launches (3 times with different configs) ---
516+
print("--- Testing Matmul Kernel (3 launches) ---")
517+
# Launch 1
518+
a1 = torch.randn((16, 16), device="cuda", dtype=torch.float16)
519+
b1 = torch.randn((16, 16), device="cuda", dtype=torch.float16)
520+
c1 = matmul(a1, b1)
521+
c1.sum() # Synchronize
522+
print("Matmul Launch 1 (16x16 @ 16x16) done.")
523+
524+
# Launch 2
525+
a2 = torch.randn((32, 16), device="cuda", dtype=torch.float16)
526+
b2 = torch.randn((16, 32), device="cuda", dtype=torch.float16)
527+
c2 = matmul(a2, b2)
528+
c2.sum() # Synchronize
529+
print("Matmul Launch 2 (32x16 @ 16x32) done.")
530+
531+
# Launch 3
532+
a3 = torch.randn((16, 32), device="cuda", dtype=torch.float16)
533+
b3 = torch.randn((32, 16), device="cuda", dtype=torch.float16)
534+
c3 = matmul(a3, b3)
535+
c3.sum() # Synchronize
536+
print("Matmul Launch 3 (16x32 @ 32x16) done.")
537+
538+
# --- Fused Op Launches (4 times with different parameters) ---
539+
print("\n--- Testing Fused Op Kernel (4 launches) ---")
540+
x = torch.randn((8,), device="cuda", dtype=torch.float32)
541+
y = torch.randn((8,), device="cuda", dtype=torch.float32)
542+
z = torch.randn((8,), device="cuda", dtype=torch.float32)
543+
544+
# Launch 1
545+
print("Fused Op Launch 1: scale=1.0, activation=None")
546+
out1 = fused_op(x, y, z, scale_factor=1.0, activation="none")
547+
out1.sum() # Synchronize
548+
549+
# Launch 2
550+
print("Fused Op Launch 2: scale=2.5, activation=None")
551+
out2 = fused_op(x, y, z, scale_factor=2.5, activation="none")
552+
out2.sum() # Synchronize
553+
554+
# Launch 3
555+
print("Fused Op Launch 3: scale=1.0, activation='relu'")
556+
out3 = fused_op(x, y, z, scale_factor=1.0, activation="relu")
557+
out3.sum() # Synchronize
558+
559+
# Launch 4 (different size)
560+
print("Fused Op Launch 4: scale=1.0, activation='relu', different size")
561+
x_large = torch.randn((6,), device="cuda", dtype=torch.float32)
562+
y_large = torch.randn((6,), device="cuda", dtype=torch.float32)
563+
z_large = torch.randn((6,), device="cuda", dtype=torch.float32)
564+
out4 = fused_op(
565+
x_large, y_large, z_large, scale_factor=1.0, activation="relu"
566+
)
567+
out4.sum() # Synchronize
568+
print("All kernels executed.")
569+
570+
# Use unified_parse to process the generated logs
571+
tritonparse.utils.unified_parse(
572+
source=log_path, out=parsed_output_path, overwrite=True
573+
)
574+
575+
# Verify that logs and parsed output were generated
576+
log_files = os.listdir(log_path)
577+
assert len(log_files) > 0, f"No log files found in {log_path}"
578+
print(f"✓ Generated {len(log_files)} log files")
579+
580+
parsed_files = os.listdir(parsed_output_path)
581+
assert (
582+
len(parsed_files) > 0
583+
), f"No parsed files found in {parsed_output_path}"
584+
print(f"✓ Generated {len(parsed_files)} parsed files")
585+
586+
# Verify we have both json and ndjson.gz files
587+
json_files = [f for f in parsed_files if f.endswith(".json")]
588+
ndjson_gz_files = [f for f in parsed_files if f.endswith(".ndjson.gz")]
589+
590+
assert len(json_files) > 0, f"No .json files found in {parsed_output_path}"
591+
assert (
592+
len(ndjson_gz_files) > 0
593+
), f"No .ndjson.gz files found in {parsed_output_path}"
594+
print(
595+
f"✓ Found {len(json_files)} .json files and {len(ndjson_gz_files)} .ndjson.gz files"
596+
)
597+
598+
# Unzip and check launch_diff events in the .ndjson.gz file
599+
import gzip
600+
601+
for ndjson_gz_file in ndjson_gz_files:
602+
ndjson_gz_path = os.path.join(parsed_output_path, ndjson_gz_file)
603+
launch_diff_count = 0
604+
605+
print(f"Checking launch_diff events in {ndjson_gz_file}")
606+
with gzip.open(ndjson_gz_path, "rt", encoding="utf-8") as f:
607+
for line_num, line in enumerate(f, 1):
608+
try:
609+
event_data = json.loads(line.strip())
610+
event_type = event_data.get("event_type")
611+
if event_type == "launch_diff":
612+
launch_diff_count += 1
613+
print(
614+
f" Line {line_num}: Found launch_diff event (count: {launch_diff_count})"
615+
)
616+
except json.JSONDecodeError as e:
617+
print(f" Line {line_num}: JSON decode error - {e}")
618+
except Exception as e:
619+
print(f" Line {line_num}: Error processing line - {e}")
620+
621+
print(f"✓ Total launch_diff events found: {launch_diff_count}")
622+
assert (
623+
launch_diff_count == 5
624+
), f"Expected 5 launch_diff events, found {launch_diff_count}"
625+
print("✓ Verified 5 launch_diff events in parsed output")
626+
627+
finally:
628+
# Clean up
629+
shutil.rmtree(temp_dir)
630+
print("✓ Cleaned up temporary directory")
631+
330632

331633
if __name__ == "__main__":
332634
unittest.main()

0 commit comments

Comments
 (0)