@@ -43,7 +43,8 @@ Tensor int_nbit_split_embedding_codegen_forward_unweighted_cpu(
43
43
int64_t row_alignment,
44
44
int64_t output_dtype,
45
45
int64_t fp8_exponent_bits,
46
- int64_t fp8_exponent_bias);
46
+ int64_t fp8_exponent_bias,
47
+ bool scale_bias_last);
47
48
48
49
Tensor int_nbit_split_embedding_codegen_forward_weighted_cpu (
49
50
Tensor dev_weights,
@@ -60,7 +61,8 @@ Tensor int_nbit_split_embedding_codegen_forward_weighted_cpu(
60
61
Tensor indice_weights,
61
62
int64_t output_dtype,
62
63
int64_t fp8_exponent_bits,
63
- int64_t fp8_exponent_bias);
64
+ int64_t fp8_exponent_bias,
65
+ bool scale_bias_last);
64
66
65
67
Tensor int_nbit_split_embedding_nobag_codegen_forward_unweighted_cpu (
66
68
Tensor dev_weights,
@@ -75,10 +77,10 @@ Tensor int_nbit_split_embedding_nobag_codegen_forward_unweighted_cpu(
75
77
int64_t row_alignment,
76
78
int64_t output_dtype,
77
79
int64_t fp8_exponent_bits,
78
- int64_t fp8_exponent_bias);
80
+ int64_t fp8_exponent_bias,
81
+ bool scale_bias_last);
79
82
80
- // /@ingroup embedding-cpu
81
- Tensor int_nbit_split_embedding_codegen_lookup_function_cpu (
83
+ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu_impl (
82
84
Tensor dev_weights,
83
85
Tensor uvm_weights, // to match the interface of CUDA op using UVM
84
86
Tensor weights_placements, // to match the interface of CUDA op using UVM
@@ -103,10 +105,12 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu(
103
105
std::optional<int64_t > row_alignment,
104
106
std::optional<int64_t > max_float8_D,
105
107
std::optional<int64_t > fp8_exponent_bits,
106
- std::optional<int64_t > fp8_exponent_bias) {
108
+ std::optional<int64_t > fp8_exponent_bias,
109
+ std::optional<bool > scale_bias_last) {
107
110
if (offsets.scalar_type () != indices.scalar_type ()) {
108
111
offsets = offsets.toType (indices.scalar_type ());
109
112
}
113
+ auto scale_bias_last_val = scale_bias_last ? *scale_bias_last : true ;
110
114
if (static_cast <PoolingMode>(pooling_mode) == PoolingMode::NONE) {
111
115
std::vector<int64_t > max_D_list{
112
116
max_int2_D,
@@ -117,53 +121,110 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu(
117
121
max_float32_D};
118
122
int64_t max_D = *std::max_element (max_D_list.begin (), max_D_list.end ());
119
123
return int_nbit_split_embedding_nobag_codegen_forward_unweighted_cpu (
120
- dev_weights,
121
- uvm_weights,
122
- weights_placements,
123
- weights_offsets,
124
- weights_tys,
124
+ std::move ( dev_weights) ,
125
+ std::move ( uvm_weights) ,
126
+ std::move ( weights_placements) ,
127
+ std::move ( weights_offsets) ,
128
+ std::move ( weights_tys) ,
125
129
max_D,
126
- indices,
127
- offsets,
130
+ std::move ( indices) ,
131
+ std::move ( offsets) ,
128
132
pooling_mode,
129
133
row_alignment ? *row_alignment : 1 ,
130
134
output_dtype,
131
135
fp8_exponent_bits ? *fp8_exponent_bits : -1 ,
132
- fp8_exponent_bias ? *fp8_exponent_bias : -1 );
136
+ fp8_exponent_bias ? *fp8_exponent_bias : -1 ,
137
+ scale_bias_last_val);
133
138
}
134
139
if (!indice_weights || indice_weights->numel () == 0 ) {
135
140
return int_nbit_split_embedding_codegen_forward_unweighted_cpu (
136
- dev_weights,
137
- uvm_weights,
138
- weights_placements,
139
- weights_offsets,
140
- weights_tys,
141
- D_offsets,
141
+ std::move ( dev_weights) ,
142
+ std::move ( uvm_weights) ,
143
+ std::move ( weights_placements) ,
144
+ std::move ( weights_offsets) ,
145
+ std::move ( weights_tys) ,
146
+ std::move ( D_offsets) ,
142
147
total_D,
143
- indices,
144
- offsets,
148
+ std::move ( indices) ,
149
+ std::move ( offsets) ,
145
150
pooling_mode,
146
151
row_alignment ? *row_alignment : 1 ,
147
152
output_dtype,
148
153
fp8_exponent_bits ? *fp8_exponent_bits : -1 ,
149
- fp8_exponent_bias ? *fp8_exponent_bias : -1 );
154
+ fp8_exponent_bias ? *fp8_exponent_bias : -1 ,
155
+ scale_bias_last_val);
150
156
}
151
157
return int_nbit_split_embedding_codegen_forward_weighted_cpu (
152
- dev_weights,
153
- uvm_weights,
154
- weights_placements,
155
- weights_offsets,
156
- weights_tys,
157
- D_offsets,
158
+ std::move ( dev_weights) ,
159
+ std::move ( uvm_weights) ,
160
+ std::move ( weights_placements) ,
161
+ std::move ( weights_offsets) ,
162
+ std::move ( weights_tys) ,
163
+ std::move ( D_offsets) ,
158
164
total_D,
159
- indices,
160
- offsets,
165
+ std::move ( indices) ,
166
+ std::move ( offsets) ,
161
167
pooling_mode,
162
168
row_alignment ? *row_alignment : 1 ,
163
- *indice_weights,
169
+ std::move ( *indice_weights) ,
164
170
output_dtype,
165
171
fp8_exponent_bits ? *fp8_exponent_bits : -1 ,
166
- fp8_exponent_bias ? *fp8_exponent_bias : -1 );
172
+ fp8_exponent_bias ? *fp8_exponent_bias : -1 ,
173
+ scale_bias_last_val);
174
+ }
175
+
176
+ // /@ingroup embedding-cpu
177
+ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu (
178
+ Tensor dev_weights,
179
+ Tensor uvm_weights, // to match the interface of CUDA op using UVM
180
+ Tensor weights_placements, // to match the interface of CUDA op using UVM
181
+ Tensor weights_offsets,
182
+ Tensor weights_tys,
183
+ Tensor D_offsets,
184
+ int64_t total_D,
185
+ int64_t max_int2_D,
186
+ int64_t max_int4_D,
187
+ int64_t max_int8_D,
188
+ int64_t max_float16_D,
189
+ int64_t max_float32_D,
190
+ Tensor indices,
191
+ Tensor offsets,
192
+ int64_t pooling_mode,
193
+ std::optional<Tensor> indice_weights,
194
+ int64_t output_dtype,
195
+ std::optional<Tensor>
196
+ lxu_cache_weights, // Not used, to match cache interface for CUDA op
197
+ std::optional<Tensor>
198
+ lxu_cache_locations, // Not used, to match cache interface for CUDA op
199
+ std::optional<int64_t > row_alignment,
200
+ std::optional<int64_t > max_float8_D,
201
+ std::optional<int64_t > fp8_exponent_bits,
202
+ std::optional<int64_t > fp8_exponent_bias) {
203
+ return int_nbit_split_embedding_codegen_lookup_function_cpu_impl (
204
+ std::move (dev_weights),
205
+ std::move (uvm_weights),
206
+ std::move (weights_placements),
207
+ std::move (weights_offsets),
208
+ std::move (weights_tys),
209
+ std::move (D_offsets),
210
+ total_D,
211
+ max_int2_D,
212
+ max_int4_D,
213
+ max_int8_D,
214
+ max_float16_D,
215
+ max_float32_D,
216
+ std::move (indices),
217
+ std::move (offsets),
218
+ pooling_mode,
219
+ std::move (indice_weights),
220
+ output_dtype,
221
+ std::move (lxu_cache_weights),
222
+ std::move (lxu_cache_locations),
223
+ std::move (row_alignment),
224
+ std::move (max_float8_D),
225
+ std::move (fp8_exponent_bits),
226
+ std::move (fp8_exponent_bias),
227
+ false );
167
228
}
168
229
169
230
// /@ingroup embedding-cpu
0 commit comments