Skip to content

Commit 8567309

Browse files
authored
[Benchmark] Add embedding benchmark (#248)
1 parent ef9be0e commit 8567309

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

benchmark/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
Performance comparison between Helion, torch.compile, Triton, and PyTorch eager is done by leveraging [TritonBench](https://github.com/pytorch-labs/tritonbench).
44

5-
Currently supported kernels for performance comparison are in `benchmark/`.
5+
Currently supported kernels for performance comparison are listed in `KERNEL_MAPPINGS` in `benchmark/run.py`.
66

77
To run the benchmark:
88

benchmark/run.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Performance comparison between Helion, torch.compile, Triton, and PyTorch eager by leveraging TritonBench.
22
3-
Currently supported kernels are in `benchmark/`.
3+
Currently supported kernels are listed in `KERNEL_MAPPINGS` in `benchmark/run.py`.
44
55
Usage:
66
$ python benchmark/run.py [tritonbench args...] --kernel <kernel_name>
@@ -23,6 +23,7 @@
2323
KERNEL_MAPPINGS: dict[str, tuple[str, str]] = {
2424
# <tritonbench_op_name>: (<helion_kernel_module_path>, <helion_kernel_function_name>)
2525
"vector_add": ("examples.add", "add"),
26+
"embedding": ("examples.embedding", "embedding_tritonbench"),
2627
}
2728

2829

examples/embedding.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ def embedding(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
2424
return out.view(*x.size(), embedding_dim)
2525

2626

27+
def embedding_tritonbench(
28+
V: int, D: int, inp: torch.Tensor, shared_weight: torch.Tensor
29+
) -> torch.Tensor:
30+
"""Wrapper for tritonbench that matches its interface."""
31+
return embedding(inp, shared_weight)
32+
33+
2734
def main() -> None:
2835
num_embeddings, embedding_dim = 16, 64
2936
x = torch.randint(0, num_embeddings, [256, 32], device="cuda", dtype=torch.int32)

0 commit comments

Comments
 (0)