You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Summary:
Pull Request resolved: #4421
X-link: facebookresearch/FBGEMM#1491
**TLDR;**
Fix int8 nobag in TBE inference CUDA kernel such that
- output shape is {total_L, D + kINT8QparamsBytes}
- kINT8QparamsBytes = 4
**Detail**
For nobag int8, the output shape should be `{total_L, D + kINT8QparamsBytes}`, since `total_L` dimension already includes `T`. `T *` was unintentionally added in D36018114.
`kINT8QparamsBytes` is 4 in CPU, since a half is used. However, 8 is used in CUDA.
This diff removes `T*` from the output shape and change `kINT8QparamsBytes` to be 4 for CUDA kernel implementation to match CPU and production.
There has been no issue because both our int8 nobag CUDA kernels are not currently used in production.
----
Note that this is currently used meta function is [fbgemm_int_nbit_split_embedding_codegen_lookup_function_meta](https://www.internalfb.com/code/fbsource/[d4f61c30f747f0a8c2e6d806904bc8ef3ee5ea42]/fbcode/caffe2/torch/fb/model_transform/splitting/split_dispatcher.py?lines=231%2C423), which has different logic for int8 and nobag cases.
The discrepancy has not been an issue because:
- Nobag
- split_dispatcher: D = average D
- FBGEMM: D = max(max_D of each dtype)
-> The embedding dimensions are the same, so average D = max D.
- Int8 Pooled
- split_dispatcher: [B, total_D] here
- FBGEMM: [B, total_D + T * 8]
-> This is not being used in prod
This will be a problem if embedding dimensions are mixed, or int8 pooled is going to be used.
Reviewed By: q10
Differential Revision: D76488339
fbshipit-source-id: ae8ca9dcb9db01eec8aa25504d1a01202c7cd466
0 commit comments