Skip to content

Commit 9db5e09

Browse files
sunfish2010facebook-github-bot
authored andcommitted
silu_mul API Update (#4359)
Summary: Pull Request resolved: #4359 X-link: facebookresearch/FBGEMM#1427 As titled. add optional inputs to torch api Reviewed By: levendlee Differential Revision: D76395658 fbshipit-source-id: 4b31ab23351350fed55bd6bcdbb9a3bb049ce36f
1 parent db28973 commit 9db5e09

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

fbgemm_gpu/experimental/gen_ai/gen_ai/moe/activation.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -135,37 +135,37 @@ def silu_mul_quant(
135135

136136
torch.library.define(
137137
"fbgemm::silu_mul",
138-
"(Tensor x0, Tensor x1) -> Tensor",
138+
"(Tensor x0, Tensor x1, Tensor? valid_token_count=None) -> Tensor",
139139
)
140140

141141

142142
@torch.library.impl(_SILU_MUL_OP_NAME, "Meta")
143-
def silu_mul_meta(x0, x1):
143+
def silu_mul_meta(x0, x1, valid_token_count):
144144
return x0.new_empty(x0.shape)
145145

146146

147147
@torch.library.impl(_SILU_MUL_OP_NAME, "CUDA")
148-
def silu_mul_cuda(x0, x1):
149-
return silu_mul(x0, x1)
148+
def silu_mul_cuda(x0, x1, valid_token_count):
149+
return silu_mul(x0, x1, valid_token_count)
150150

151151

152152
_SILU_MUL_OP_QUANT_NAME = "fbgemm::silu_mul_quant"
153153

154154
torch.library.define(
155155
"fbgemm::silu_mul_quant",
156-
"(Tensor x0, Tensor x1, Tensor? scale_ub) -> Tensor",
156+
"(Tensor x0, Tensor x1, Tensor? scale_ub=None, Tensor? valid_token_count=None) -> Tensor",
157157
)
158158

159159

160160
@torch.library.impl(_SILU_MUL_OP_QUANT_NAME, "Meta")
161-
def silu_mul_quant_meta(x0, x1, scale_ub):
161+
def silu_mul_quant_meta(x0, x1, scale_ub, valid_token_count):
162162
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
163163
return torch.empty(x0.shape, device=x0.device, dtype=pt_dtype)
164164

165165

166166
@torch.library.impl(_SILU_MUL_OP_QUANT_NAME, "CUDA")
167-
def silu_mul_quant_cuda(x0, x1, scale_ub=None):
168-
return silu_mul_quant(x0, x1, scale_ub)
167+
def silu_mul_quant_cuda(x0, x1, scale_ub=None, valid_token_count=None):
168+
return silu_mul_quant(x0, x1, scale_ub, valid_token_count)
169169

170170

171171
# Kernel Implementations

fbgemm_gpu/experimental/gen_ai/gen_ai/moe/gather_scatter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def gather_scale_quant_dense_tokens_cuda(
362362

363363
torch.library.define(
364364
"fbgemm::scatter_add_dense_tokens",
365-
"(Tensor out_tokens, Tensor in_tokens, Tensor token_indices) -> None",
365+
"(Tensor out_tokens, Tensor in_tokens, Tensor token_indices, Tensor? valid_token_count=None) -> None",
366366
)
367367

368368

0 commit comments

Comments
 (0)