Skip to content

Commit 8325430

Browse files
spcypptfacebook-github-bot
authored andcommitted
Fix int_nbit int8 nobag CUDA kernel (#4421)
Summary: Pull Request resolved: #4421 X-link: facebookresearch/FBGEMM#1491 **TLDR;** Fix int8 nobag in TBE inference CUDA kernel such that - output shape is {total_L, D + kINT8QparamsBytes} - kINT8QparamsBytes = 4 **Detail** For nobag int8, the output shape should be `{total_L, D + kINT8QparamsBytes}`, since `total_L` dimension already includes `T`. `T *` was unintentionally added in D36018114. `kINT8QparamsBytes` is 4 in CPU, since a half is used. However, 8 is used in CUDA. This diff removes `T*` from the output shape and change `kINT8QparamsBytes` to be 4 for CUDA kernel implementation to match CPU and production. There has been no issue because both our int8 nobag CUDA kernels are not currently used in production. ---- Note that this is currently used meta function is [fbgemm_int_nbit_split_embedding_codegen_lookup_function_meta](https://www.internalfb.com/code/fbsource/[d4f61c30f747f0a8c2e6d806904bc8ef3ee5ea42]/fbcode/caffe2/torch/fb/model_transform/splitting/split_dispatcher.py?lines=231%2C423), which has different logic for int8 and nobag cases. The discrepancy has not been an issue because: - Nobag - split_dispatcher: D = average D - FBGEMM: D = max(max_D of each dtype) -> The embedding dimensions are the same, so average D = max D. - Int8 Pooled - split_dispatcher: [B, total_D] here - FBGEMM: [B, total_D + T * 8] -> This is not being used in prod This will be a problem if embedding dimensions are mixed, or int8 pooled is going to be used. Reviewed By: q10 Differential Revision: D76488339 fbshipit-source-id: ae8ca9dcb9db01eec8aa25504d1a01202c7cd466
1 parent b2808c9 commit 8325430

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,12 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no
130130

131131
// Construct output tensor
132132
Tensor output;
133-
const int kINT8QparamsBytes = 8;
134133

135134
SparseType o_dtype = static_cast<SparseType>(output_dtype);
136135
TORCH_CHECK(o_dtype == SparseType::FP32 || o_dtype == SparseType::FP16 || o_dtype == SparseType::BF16 || o_dtype == SparseType::INT8);
137136

138137
{%- if not nobag %}
139-
138+
const int kINT8QparamsBytes = 8;
140139
int64_t total_adjusted_D = total_D;
141140
if (o_dtype == SparseType::INT8) {
142141
total_adjusted_D += T * kINT8QparamsBytes;
@@ -149,10 +148,11 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no
149148
}
150149

151150
{%- else %}
152-
151+
// TODO: Change to use half to match CPU/Meta implementation
152+
const int kINT8QparamsBytes = 8; // using float for scale and bias
153153
int64_t adjusted_D = D;
154154
if (o_dtype == SparseType::INT8) {
155-
adjusted_D += T * kINT8QparamsBytes;
155+
adjusted_D += kINT8QparamsBytes;
156156
}
157157

158158
if (total_L == 0) {

0 commit comments

Comments
 (0)