Skip to content

Commit 4aaa389

Browse files
sryapfacebook-github-bot
authored andcommitted
Fix illegal memory access when num indices = 0 in TBE inference (#1613)
Summary: Pull Request resolved: #1613 The TBE inference kernel (`*_split_embedding_codegen_forward*_kernel_small_L`) may illegally access memory if the number of indices is zero. This update ensures that the TBE inference operator returns a zero-tensor when the number of indices is zero, preventing the aforementioned problem. Reviewed By: jianyuh Differential Revision: D43588092 fbshipit-source-id: 95eb6b96c7b9cac8518072c03c6d382edfc47436
1 parent 9b4952f commit 4aaa389

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

fbgemm_gpu/codegen/embedding_forward_quantized_split_template.cu

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -667,17 +667,27 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
667667
if (o_dtype == SparseType::INT8) {
668668
total_adjusted_D += T * kINT8QparamsBytes;
669669
}
670-
output = at::empty({B, total_adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)));
670+
if (indices.numel() == 0) {
671+
output = at::zeros({B, total_adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)));
672+
}
673+
else {
674+
output = at::empty({B, total_adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)));
675+
}
671676
{% else %}
672677
int64_t adjusted_D = D;
673678
if (o_dtype == SparseType::INT8) {
674679
adjusted_D += T * kINT8QparamsBytes;
675680
}
676-
output = at::empty({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)));
681+
if (total_L == 0) {
682+
output = at::zeros({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)));
683+
}
684+
else {
685+
output = at::empty({total_L, adjusted_D}, dev_weights.options().dtype(getScalarType(o_dtype)));
686+
}
677687

678688
{% endif %}
679689

680-
if (B == 0) {
690+
if (B == 0 || indices.numel() == 0) {
681691
return output;
682692
}
683693

0 commit comments

Comments
 (0)