Skip to content

Commit 935d8b5

Browse files
sryapfacebook-github-bot
authored andcommitted
Fixed illegal memory access in [split|dense]_embedding_nobag_codegen_forward_unweighted_small_kernel (#1610)
Summary: Pull Request resolved: #1610 `[split|dense]_embedding_nobag_codegen_forward_unweighted_small_kernel` always accesses the LXU cache even when the cache is not available. This diff moves cache accesses under a runtime conditional to ensure that the access is valid. Reviewed By: jspark1105, mjanderson09 Differential Revision: D43581917 fbshipit-source-id: ab302092ffeb27b9a7fd18205c3856abc1dd8713
1 parent 45a861a commit 935d8b5

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

fbgemm_gpu/codegen/embedding_forward_split_template.cu

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,11 +106,7 @@ __global__ void {{ "dense" if dense else "split" }}_embedding_nobag_codegen_forw
106106
{% endif %}
107107

108108
{% if not dense %}
109-
auto weight_row_cache = WeightRow<emb_t, cache_t, cache_t>(
110-
const_cast<emb_t*>(&weights[idx_j * D_emb]),
111-
const_cast<cache_t*>(&lxu_cache_weights[cache_idx_j][0]),
112-
D,
113-
nullptr);
109+
114110
// assume cache is fp16/fp32 which doesn't require qparams
115111
float2 qparams_cache = make_float2(0.0f, 0.0f);
116112

@@ -128,6 +124,11 @@ __global__ void {{ "dense" if dense else "split" }}_embedding_nobag_codegen_forw
128124
if (d < D) {
129125
{% if not dense %}
130126
if (placement == PlacementType::MANAGED_CACHING && cache_idx_j != kCacheLocationMissing) {
127+
auto weight_row_cache = WeightRow<emb_t, cache_t, cache_t>(
128+
const_cast<emb_t*>(&weights[idx_j * D_emb]),
129+
const_cast<cache_t*>(&lxu_cache_weights[cache_idx_j][0]),
130+
D,
131+
nullptr);
131132
Vec4T<cache_t> weight = weight_row_cache.load(d, qparams_cache);
132133
weight.store(&output[output_j][d]);
133134
} else {

0 commit comments

Comments
 (0)