Skip to content

Commit 0432615

Browse files
sunfish2010facebook-github-bot
authored andcommitted
silu_mul_quant fix (#4395)
Summary: Pull Request resolved: #4395 X-link: facebookresearch/FBGEMM#1466 Avoid division by 0 when T == 0 Reviewed By: jianyuh Differential Revision: D77236510 fbshipit-source-id: f8943125b358443bea9b05e136875f5c93822b26
1 parent fe9946e commit 0432615

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

fbgemm_gpu/experimental/gen_ai/gen_ai/moe/activation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,12 @@ def silu_mul_quant(
104104

105105
out = torch.empty((T, D), device="cuda", dtype=pt_dtype)
106106
out_inv_scale = torch.empty((T,), device="cuda", dtype=torch.float32)
107+
if T == 0:
108+
return out, out_inv_scale
107109

108110
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
109111
BLOCK_T = triton.cdiv(T, NUM_SMS)
112+
110113
NUM_CTAS = triton.cdiv(T, BLOCK_T)
111114

112115
grid = (NUM_CTAS,)

fbgemm_gpu/experimental/gen_ai/test/moe/activation_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class ActivationTests(unittest.TestCase):
3636
"""Test activation kernels."""
3737

3838
@given(
39-
T=st.sampled_from([1, 128, 2048, 4096, 16384]),
39+
T=st.sampled_from([0, 1, 128, 2048, 4096, 16384]),
4040
D=st.sampled_from([5120, 7168]),
4141
contiguous=st.sampled_from([True, False]),
4242
partial=st.sampled_from([True, False]),
@@ -94,7 +94,7 @@ def ref_fn() -> torch.Tensor:
9494
"Skip when H100 is not available",
9595
)
9696
@given(
97-
T=st.sampled_from([1, 128, 2048, 4096, 16384]),
97+
T=st.sampled_from([0, 1, 128, 2048, 4096, 16384]),
9898
D=st.sampled_from([5120, 7168]),
9999
scale_ub=st.sampled_from([None, 1200.00]),
100100
contiguous=st.sampled_from([True, False]),

0 commit comments

Comments
 (0)