@@ -135,37 +135,37 @@ def silu_mul_quant(
135
135
136
136
torch .library .define (
137
137
"fbgemm::silu_mul" ,
138
- "(Tensor x0, Tensor x1) -> Tensor" ,
138
+ "(Tensor x0, Tensor x1, Tensor? valid_token_count=None ) -> Tensor" ,
139
139
)
140
140
141
141
142
142
@torch .library .impl (_SILU_MUL_OP_NAME , "Meta" )
143
- def silu_mul_meta (x0 , x1 ):
143
+ def silu_mul_meta (x0 , x1 , valid_token_count ):
144
144
return x0 .new_empty (x0 .shape )
145
145
146
146
147
147
@torch .library .impl (_SILU_MUL_OP_NAME , "CUDA" )
148
- def silu_mul_cuda (x0 , x1 ):
149
- return silu_mul (x0 , x1 )
148
+ def silu_mul_cuda (x0 , x1 , valid_token_count ):
149
+ return silu_mul (x0 , x1 , valid_token_count )
150
150
151
151
152
152
_SILU_MUL_OP_QUANT_NAME = "fbgemm::silu_mul_quant"
153
153
154
154
torch .library .define (
155
155
"fbgemm::silu_mul_quant" ,
156
- "(Tensor x0, Tensor x1, Tensor? scale_ub) -> Tensor" ,
156
+ "(Tensor x0, Tensor x1, Tensor? scale_ub=None, Tensor? valid_token_count=None ) -> Tensor" ,
157
157
)
158
158
159
159
160
160
@torch .library .impl (_SILU_MUL_OP_QUANT_NAME , "Meta" )
161
- def silu_mul_quant_meta (x0 , x1 , scale_ub ):
161
+ def silu_mul_quant_meta (x0 , x1 , scale_ub , valid_token_count ):
162
162
pt_dtype , tl_dtype , max_fp8 , eps = get_fp8_constants ()
163
163
return torch .empty (x0 .shape , device = x0 .device , dtype = pt_dtype )
164
164
165
165
166
166
@torch .library .impl (_SILU_MUL_OP_QUANT_NAME , "CUDA" )
167
- def silu_mul_quant_cuda (x0 , x1 , scale_ub = None ):
168
- return silu_mul_quant (x0 , x1 , scale_ub )
167
+ def silu_mul_quant_cuda (x0 , x1 , scale_ub = None , valid_token_count = None ):
168
+ return silu_mul_quant (x0 , x1 , scale_ub , valid_token_count )
169
169
170
170
171
171
# Kernel Implementations
0 commit comments