Skip to content

Commit 0b21002

Browse files
Junnan Wanfacebook-github-bot
authored andcommitted
Support scale_bias_last on tbe lookup kernel (#4363)
Summary: Pull Request resolved: #4363 X-link: facebookresearch/FBGEMM#1428 Check https://fb.workplace.com/groups/fbgemmusers/permalink/23950680467919409/ for context With scale_bias_last=true, the TBE tensor could be in same shape between publish and inference runtime which makes model loading much easier (no need to process each row). Reviewed By: sryap Differential Revision: D76615824 fbshipit-source-id: 528bc00955156de38f1ef9bc058a9350ad0d75ee
1 parent a70fc5d commit 0b21002

File tree

2 files changed

+108
-45
lines changed

2 files changed

+108
-45
lines changed

fbgemm_gpu/codegen/inference/embedding_forward_quantized_cpu_template.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,12 @@ void pruned_hashmap_insert_{{ wdesc }}_cpu(
9191
continue;
9292
}
9393
const auto capacity = table_end - table_start;
94-
94+
9595
for (const auto b : c10::irange(B)) {
9696
const auto indices_start = offsets_acc[t * B + b];
9797
const auto indices_end = offsets_acc[t * B + b + 1];
9898
const auto L = indices_end - indices_start;
99-
99+
100100
for (const auto l : c10::irange(L)) {
101101
const auto idx = indices_acc[indices_start + l];
102102
const auto dense_idx = dense_indices_acc[indices_start + l];
@@ -109,20 +109,20 @@ void pruned_hashmap_insert_{{ wdesc }}_cpu(
109109
while (true) {
110110
const auto ht_idx = table_start + static_cast<int64_t>(slot);
111111
const auto slot_sparse_idx = hash_table_acc[ht_idx][0];
112-
112+
113113
// Empty slot
114114
if (slot_sparse_idx == -1) {
115115
hash_table_acc[ht_idx][0] = static_cast<hash_t>(idx);
116116
hash_table_acc[ht_idx][1] = static_cast<hash_t>(dense_idx);
117117
break;
118118
}
119-
119+
120120
// Already exists (shouldn't happen in practice)
121121
if (slot_sparse_idx == idx) {
122122
hash_table_acc[ht_idx][1] = static_cast<hash_t>(dense_idx);
123123
break;
124124
}
125-
125+
126126
// Linear probe
127127
slot = (slot + 1) % capacity;
128128
}
@@ -158,7 +158,8 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
158158
{% endif %}
159159
int64_t output_dtype,
160160
int64_t fp8_exponent_bits,
161-
int64_t fp8_exponent_bias
161+
int64_t fp8_exponent_bias,
162+
bool scale_bias_last
162163
) {
163164
TENSOR_ON_CPU(dev_weights);
164165
TENSOR_ON_CPU(uvm_weights);
@@ -273,8 +274,9 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
273274
if (output_is_int8) {
274275
TORCH_CHECK(weight_ty == SparseType::INT8, "int8 output are only supported for int8 weights");
275276
}
277+
const int32_t scale_bias_size = (weight_ty == SparseType::INT8 && scale_bias_last) ? 8 : 4;
276278
// default to 1 byte alignment for CPU TBE
277-
const int32_t D_bytes = nbit::padded_row_size_in_bytes(D, weight_ty, row_alignment);
279+
const int32_t D_bytes = nbit::padded_row_size_in_bytes(D, weight_ty, row_alignment, scale_bias_size);
278280

279281
int tt;
280282
for (tt = t + 1; tt < T && weights_offsets_acc[tt] == weights_offsets_acc[t]; ++tt);
@@ -352,7 +354,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
352354
/*exponent_bias=*/fp8_exponent_bias,
353355
{% endif %}
354356
{% if has_asmjit %}
355-
/*scale_bias_last=*/false,
357+
/*scale_bias_last=*/scale_bias_last,
356358
{% endif %}
357359
{% if use_base %}
358360
/*no_bag=*/nobag_op,
@@ -466,12 +468,12 @@ Tensor pruned_hashmap_lookup_{{ wdesc }}_cpu(
466468
for (const auto l : c10::irange(L)) {
467469
dense_indices_acc[indices_start + l] = indices_acc[indices_start + l];
468470
}
469-
471+
470472
} else {
471473
for (const auto l : c10::irange(L)) {
472474
const auto idx = indices_acc[indices_start + l];
473475
auto slot = pruned_hash_function(static_cast<utdx_t>(idx)) % capacity;
474-
476+
475477
while (true) {
476478
const auto ht_idx = table_start + static_cast<int64_t>(slot);
477479
const auto slot_sparse_idx = hash_table_acc[ht_idx][0];
@@ -486,7 +488,7 @@ Tensor pruned_hashmap_lookup_{{ wdesc }}_cpu(
486488
dense_indices_acc[indices_start + l] = static_cast<index_t>(hash_table_acc[ht_idx][1]);
487489
break;
488490
}
489-
491+
490492
// Linear probe
491493
slot = (slot + 1) % capacity;
492494
}
@@ -496,7 +498,7 @@ Tensor pruned_hashmap_lookup_{{ wdesc }}_cpu(
496498
}
497499
});
498500
});
499-
501+
500502
return dense_indices;
501503
}
502504

fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp

Lines changed: 94 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ Tensor int_nbit_split_embedding_codegen_forward_unweighted_cpu(
4343
int64_t row_alignment,
4444
int64_t output_dtype,
4545
int64_t fp8_exponent_bits,
46-
int64_t fp8_exponent_bias);
46+
int64_t fp8_exponent_bias,
47+
bool scale_bias_last);
4748

4849
Tensor int_nbit_split_embedding_codegen_forward_weighted_cpu(
4950
Tensor dev_weights,
@@ -60,7 +61,8 @@ Tensor int_nbit_split_embedding_codegen_forward_weighted_cpu(
6061
Tensor indice_weights,
6162
int64_t output_dtype,
6263
int64_t fp8_exponent_bits,
63-
int64_t fp8_exponent_bias);
64+
int64_t fp8_exponent_bias,
65+
bool scale_bias_last);
6466

6567
Tensor int_nbit_split_embedding_nobag_codegen_forward_unweighted_cpu(
6668
Tensor dev_weights,
@@ -75,10 +77,10 @@ Tensor int_nbit_split_embedding_nobag_codegen_forward_unweighted_cpu(
7577
int64_t row_alignment,
7678
int64_t output_dtype,
7779
int64_t fp8_exponent_bits,
78-
int64_t fp8_exponent_bias);
80+
int64_t fp8_exponent_bias,
81+
bool scale_bias_last);
7982

80-
///@ingroup embedding-cpu
81-
Tensor int_nbit_split_embedding_codegen_lookup_function_cpu(
83+
Tensor int_nbit_split_embedding_codegen_lookup_function_cpu_impl(
8284
Tensor dev_weights,
8385
Tensor uvm_weights, // to match the interface of CUDA op using UVM
8486
Tensor weights_placements, // to match the interface of CUDA op using UVM
@@ -103,10 +105,12 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu(
103105
std::optional<int64_t> row_alignment,
104106
std::optional<int64_t> max_float8_D,
105107
std::optional<int64_t> fp8_exponent_bits,
106-
std::optional<int64_t> fp8_exponent_bias) {
108+
std::optional<int64_t> fp8_exponent_bias,
109+
std::optional<bool> scale_bias_last) {
107110
if (offsets.scalar_type() != indices.scalar_type()) {
108111
offsets = offsets.toType(indices.scalar_type());
109112
}
113+
auto scale_bias_last_val = scale_bias_last ? *scale_bias_last : true;
110114
if (static_cast<PoolingMode>(pooling_mode) == PoolingMode::NONE) {
111115
std::vector<int64_t> max_D_list{
112116
max_int2_D,
@@ -117,53 +121,110 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu(
117121
max_float32_D};
118122
int64_t max_D = *std::max_element(max_D_list.begin(), max_D_list.end());
119123
return int_nbit_split_embedding_nobag_codegen_forward_unweighted_cpu(
120-
dev_weights,
121-
uvm_weights,
122-
weights_placements,
123-
weights_offsets,
124-
weights_tys,
124+
std::move(dev_weights),
125+
std::move(uvm_weights),
126+
std::move(weights_placements),
127+
std::move(weights_offsets),
128+
std::move(weights_tys),
125129
max_D,
126-
indices,
127-
offsets,
130+
std::move(indices),
131+
std::move(offsets),
128132
pooling_mode,
129133
row_alignment ? *row_alignment : 1,
130134
output_dtype,
131135
fp8_exponent_bits ? *fp8_exponent_bits : -1,
132-
fp8_exponent_bias ? *fp8_exponent_bias : -1);
136+
fp8_exponent_bias ? *fp8_exponent_bias : -1,
137+
scale_bias_last_val);
133138
}
134139
if (!indice_weights || indice_weights->numel() == 0) {
135140
return int_nbit_split_embedding_codegen_forward_unweighted_cpu(
136-
dev_weights,
137-
uvm_weights,
138-
weights_placements,
139-
weights_offsets,
140-
weights_tys,
141-
D_offsets,
141+
std::move(dev_weights),
142+
std::move(uvm_weights),
143+
std::move(weights_placements),
144+
std::move(weights_offsets),
145+
std::move(weights_tys),
146+
std::move(D_offsets),
142147
total_D,
143-
indices,
144-
offsets,
148+
std::move(indices),
149+
std::move(offsets),
145150
pooling_mode,
146151
row_alignment ? *row_alignment : 1,
147152
output_dtype,
148153
fp8_exponent_bits ? *fp8_exponent_bits : -1,
149-
fp8_exponent_bias ? *fp8_exponent_bias : -1);
154+
fp8_exponent_bias ? *fp8_exponent_bias : -1,
155+
scale_bias_last_val);
150156
}
151157
return int_nbit_split_embedding_codegen_forward_weighted_cpu(
152-
dev_weights,
153-
uvm_weights,
154-
weights_placements,
155-
weights_offsets,
156-
weights_tys,
157-
D_offsets,
158+
std::move(dev_weights),
159+
std::move(uvm_weights),
160+
std::move(weights_placements),
161+
std::move(weights_offsets),
162+
std::move(weights_tys),
163+
std::move(D_offsets),
158164
total_D,
159-
indices,
160-
offsets,
165+
std::move(indices),
166+
std::move(offsets),
161167
pooling_mode,
162168
row_alignment ? *row_alignment : 1,
163-
*indice_weights,
169+
std::move(*indice_weights),
164170
output_dtype,
165171
fp8_exponent_bits ? *fp8_exponent_bits : -1,
166-
fp8_exponent_bias ? *fp8_exponent_bias : -1);
172+
fp8_exponent_bias ? *fp8_exponent_bias : -1,
173+
scale_bias_last_val);
174+
}
175+
176+
///@ingroup embedding-cpu
177+
Tensor int_nbit_split_embedding_codegen_lookup_function_cpu(
178+
Tensor dev_weights,
179+
Tensor uvm_weights, // to match the interface of CUDA op using UVM
180+
Tensor weights_placements, // to match the interface of CUDA op using UVM
181+
Tensor weights_offsets,
182+
Tensor weights_tys,
183+
Tensor D_offsets,
184+
int64_t total_D,
185+
int64_t max_int2_D,
186+
int64_t max_int4_D,
187+
int64_t max_int8_D,
188+
int64_t max_float16_D,
189+
int64_t max_float32_D,
190+
Tensor indices,
191+
Tensor offsets,
192+
int64_t pooling_mode,
193+
std::optional<Tensor> indice_weights,
194+
int64_t output_dtype,
195+
std::optional<Tensor>
196+
lxu_cache_weights, // Not used, to match cache interface for CUDA op
197+
std::optional<Tensor>
198+
lxu_cache_locations, // Not used, to match cache interface for CUDA op
199+
std::optional<int64_t> row_alignment,
200+
std::optional<int64_t> max_float8_D,
201+
std::optional<int64_t> fp8_exponent_bits,
202+
std::optional<int64_t> fp8_exponent_bias) {
203+
return int_nbit_split_embedding_codegen_lookup_function_cpu_impl(
204+
std::move(dev_weights),
205+
std::move(uvm_weights),
206+
std::move(weights_placements),
207+
std::move(weights_offsets),
208+
std::move(weights_tys),
209+
std::move(D_offsets),
210+
total_D,
211+
max_int2_D,
212+
max_int4_D,
213+
max_int8_D,
214+
max_float16_D,
215+
max_float32_D,
216+
std::move(indices),
217+
std::move(offsets),
218+
pooling_mode,
219+
std::move(indice_weights),
220+
output_dtype,
221+
std::move(lxu_cache_weights),
222+
std::move(lxu_cache_locations),
223+
std::move(row_alignment),
224+
std::move(max_float8_D),
225+
std::move(fp8_exponent_bits),
226+
std::move(fp8_exponent_bias),
227+
false);
167228
}
168229

169230
///@ingroup embedding-cpu

0 commit comments

Comments
 (0)