From 91a4559f1b0169b2d8f9c1158f624b547952d194 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Thu, 30 Jan 2025 12:17:13 -0600 Subject: [PATCH 01/10] fixed find_max_Ls function to return int type --- ..._table_batched_embeddings_ops_inference.py | 40 +++++++++++++++++-- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 23e3397d76..2586889b3e 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -55,7 +55,13 @@ import fbgemm_gpu # noqa - +def find_max_ls(ty: SparseType, weights_tys:List[SparseType], offsets: Tensor ): + # bag_sizes = None + for type_ in weights_tys: + if type_ == ty or type_.value == ty.value: + bag_sizes = offsets[1:] - offsets[:-1] + return bag_sizes.max().item() + return 0 def rounded_row_size_in_bytes( dim: int, weight_ty: SparseType, @@ -469,6 +475,8 @@ def max_ty_D(ty: SparseType) -> int: ], default=0, ) + + self.max_int2_D: int = max_ty_D(SparseType.INT2) self.max_int4_D: int = max_ty_D(SparseType.INT4) @@ -476,7 +484,6 @@ def max_ty_D(ty: SparseType) -> int: self.max_float8_D: int = max_ty_D(SparseType.FP8) self.max_float16_D: int = max_ty_D(SparseType.FP16) self.max_float32_D: int = max_ty_D(SparseType.FP32) - self.register_buffer( "D_offsets", torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32), @@ -932,6 +939,27 @@ def _forward_impl( indices, offsets, per_sample_weights = inputs_to_device( indices, offsets, per_sample_weights, self.bounds_check_warning.device ) + # bag_sizes = offsets[1:] - offsets[:-1] + # max_Ls = bag_sizes.max() + max_ls_tys = [] + weights_tys: List[SparseType] = [e[3] for e in self.embedding_specs] + + type_list = [SparseType.INT2, SparseType.INT4, SparseType.INT8, SparseType.FP8, SparseType.FP16, SparseType.FP32] + INT2_max_ls = find_max_ls(SparseType.INT2, weights_tys, offsets) + print(INT2_max_ls) + INT4_max_ls = find_max_ls(SparseType.INT4, weights_tys, offsets) + print(INT4_max_ls) + INT8_max_ls = find_max_ls(SparseType.INT8, weights_tys, offsets) + print(INT8_max_ls) + FP8_max_ls = find_max_ls(SparseType.FP8, weights_tys, offsets) + print(FP8_max_ls) + FP16_max_ls = find_max_ls(SparseType.FP16, weights_tys, offsets) + print(FP16_max_ls) + FP32_max_ls = find_max_ls(SparseType.FP32, weights_tys, offsets) + print(FP32_max_ls) + + + # First bound check: check if the indices/offsets are within the boundary # of the original embedding rows before pruning. @@ -1009,6 +1037,12 @@ def _forward_impl( max_int8_D=self.max_int8_D, max_float16_D=self.max_float16_D, max_float32_D=self.max_float32_D, + INT2_max_ls=INT2_max_ls, + INT4_max_ls=INT4_max_ls, + INT8_max_ls=INT8_max_ls, + FP8_max_ls = FP8_max_ls, + FP16_max_ls=FP16_max_ls, + FP32_max_ls=FP32_max_ls, indices=indices, offsets=offsets, pooling_mode=int(self.pooling_mode), @@ -1019,7 +1053,7 @@ def _forward_impl( row_alignment=self.row_alignment, max_float8_D=self.max_float8_D, fp8_exponent_bits=self.fp8_exponent_bits, - fp8_exponent_bias=self.fp8_exponent_bias, + fp8_exponent_bias=self.fp8_exponent_bias ) def forward( From 139b529bbd545ec7c9dc8ae01e352bda9ff63294 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Thu, 30 Jan 2025 12:18:24 -0600 Subject: [PATCH 02/10] removed prints on max_ls vars --- .../split_table_batched_embeddings_ops_inference.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 2586889b3e..16105638c1 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -939,24 +939,15 @@ def _forward_impl( indices, offsets, per_sample_weights = inputs_to_device( indices, offsets, per_sample_weights, self.bounds_check_warning.device ) - # bag_sizes = offsets[1:] - offsets[:-1] - # max_Ls = bag_sizes.max() - max_ls_tys = [] weights_tys: List[SparseType] = [e[3] for e in self.embedding_specs] - type_list = [SparseType.INT2, SparseType.INT4, SparseType.INT8, SparseType.FP8, SparseType.FP16, SparseType.FP32] INT2_max_ls = find_max_ls(SparseType.INT2, weights_tys, offsets) - print(INT2_max_ls) INT4_max_ls = find_max_ls(SparseType.INT4, weights_tys, offsets) - print(INT4_max_ls) INT8_max_ls = find_max_ls(SparseType.INT8, weights_tys, offsets) - print(INT8_max_ls) FP8_max_ls = find_max_ls(SparseType.FP8, weights_tys, offsets) - print(FP8_max_ls) FP16_max_ls = find_max_ls(SparseType.FP16, weights_tys, offsets) - print(FP16_max_ls) FP32_max_ls = find_max_ls(SparseType.FP32, weights_tys, offsets) - print(FP32_max_ls) + From cdf3977c6db4f01f8688cb053f35ead4e5e337b2 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Fri, 20 Jun 2025 17:26:26 +0000 Subject: [PATCH 03/10] optimized inference for L=1 case --- .../bench/tbe/tbe_inference_benchmark.py | 27 +- .../embedding_forward_quantized_host.cpp | 83 +- .../embedding_forward_quantized_host_cpu.cpp | 49 +- ...ward_quantized_split_nbit_host_template.cu | 98 ++- ...rd_quantized_split_nbit_kernel_template.cu | 831 ++++++++++++------ ..._table_batched_embeddings_ops_inference.py | 64 +- fbgemm_gpu/test/tbe/inference/common.py | 1 + .../test/tbe/inference/nbit_forward_test.py | 102 +-- 8 files changed, 774 insertions(+), 481 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/tbe_inference_benchmark.py b/fbgemm_gpu/bench/tbe/tbe_inference_benchmark.py index b1e59a495e..38e99d85eb 100644 --- a/fbgemm_gpu/bench/tbe/tbe_inference_benchmark.py +++ b/fbgemm_gpu/bench/tbe/tbe_inference_benchmark.py @@ -189,6 +189,7 @@ def nbit_cpu( # noqa C901 emb = IntNBitTableBatchedEmbeddingBagsCodegen( [("", E, d, weights_precision, EmbeddingLocation.HOST) for d in Ds], device="cpu", + Ls=L, index_remapping=[torch.arange(E) for _ in Ds] if index_remapping else None, output_dtype=output_dtype, pooling_mode=pooling_mode, @@ -403,6 +404,7 @@ def nbit_device( # noqa C901 emb = IntNBitTableBatchedEmbeddingBagsCodegen( [("", E, d, weights_precision, managed_option) for d in Ds], bounds_check_mode=BoundsCheckMode(bounds_check_mode), + Ls=L, index_remapping=index_remapping, pruning_hash_load_factor=pruning_hash_load_factor, use_array_for_index_remapping=use_array_for_index_remapping, @@ -791,6 +793,7 @@ def nbit_device_with_spec( # noqa C901 emb = IntNBitTableBatchedEmbeddingBagsCodegen( [("", e, d, weights_precision, managed_option) for d, e in zip(Ds, Es)], device="cpu" if use_cpu else None, + Ls=Ls, bounds_check_mode=BoundsCheckMode(bounds_check_mode), index_remapping=index_remapping, pruning_hash_load_factor=pruning_hash_load_factor, @@ -856,7 +859,7 @@ def nbit_device_with_spec( # noqa C901 # don't use zipf if e isn't large enough compared to bag_size. alpha=alpha if (e / bag_size) > 2.0 else 1.0, # need many more samples for zipf if bag_size is very small. - zipf_oversample_ratio=3 if bag_size > 5 else 10, + zipf_oversample_ratio=10 if bag_size > 5 else 20, weighted=weighted, use_cpu=use_cpu, ) @@ -919,14 +922,12 @@ def nbit_device_with_spec( # noqa C901 # free up memory del requests - result_msg = ( - f"Iteration {i}: " - f"{weights_precision} Forward, B: {B}, " - f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, " - f"BW: {cpu_copies * float(read_write_bytes) / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 - f"Time: {time_per_iter * 1.0e6:.0f}us, " - f"Memory Usage For Pruning: {mem_for_pruning / 1.0e9:.0f} GB" - ) + result_msg = f"Iteration {i}: " + f"{weights_precision} Forward, B: {B}, " + f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, " + f"BW: {cpu_copies * float(read_write_bytes) / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950 + f"Time: {time_per_iter * 1.0e6:.0f}us, " + f"Memory Usage For Pruning: {mem_for_pruning / 1.0e9:.0f} GB" if use_cpu and cpu_copies > 1: result_msg += f", Parallel Copies: {cpu_copies}" @@ -1111,6 +1112,7 @@ def nbit_uvm( ) for d in Ds[:T_uvm] ], + Ls=L, output_dtype=output_dtype, cache_load_factor=cache_load_factor, cache_algorithm=cache_alg, @@ -1133,6 +1135,7 @@ def nbit_uvm( ) for d in Ds[T_uvm:] ], + Ls=L, output_dtype=output_dtype, fp8_exponent_bits=fp8_exponent_bits, fp8_exponent_bias=fp8_exponent_bias, @@ -1155,6 +1158,7 @@ def nbit_uvm( [managed_type] * T_uvm + [EmbeddingLocation.DEVICE] * T_gpu, ) ], + Ls=L, output_dtype=output_dtype, cache_load_factor=cache_load_factor, cache_algorithm=cache_alg, @@ -1456,6 +1460,7 @@ def bench_uvm_cls( ) for d in Ds[:T] ], + Ls=L, output_dtype=output_dtype, cache_load_factor=cache_load_factor, cache_algorithm=cache_alg, @@ -1637,6 +1642,7 @@ def nbit_cache( # noqa C901 ) for d in Ds ], + Ls=L, output_dtype=output_dtype, enforce_hbm=enforce_hbm, fp8_exponent_bits=fp8_exponent_bits, @@ -1661,6 +1667,7 @@ def nbit_cache( # noqa C901 record_cache_metrics=RecordCacheMetrics( record_cache_miss_counter, record_tablewise_cache_miss ), + Ls=L, gather_uvm_cache_stats=gather_uvm_cache_stats, cache_load_factor=cache_load_factor, cache_algorithm=cache_alg, @@ -1832,4 +1839,4 @@ def nbit_cache( # noqa C901 if __name__ == "__main__": - cli() + cli() \ No newline at end of file diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp index 5fcc3a0176..988a5508df 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp @@ -1,10 +1,10 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include #include @@ -200,6 +200,12 @@ Tensor int_nbit_split_embedding_codegen_forward_unweighted_cuda( int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, + int64_t INT2_max_ls, + int64_t INT4_max_ls, + int64_t INT8_max_ls, + int64_t FP8_max_ls, + int64_t FP16_max_ls, + int64_t FP32_max_ls, Tensor indices, Tensor offsets, int64_t pooling_mode, @@ -224,6 +230,12 @@ Tensor int_nbit_split_embedding_codegen_forward_weighted_cuda( int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, + int64_t INT2_max_ls, + int64_t INT4_max_ls, + int64_t INT8_max_ls, + int64_t FP8_max_ls, + int64_t FP16_max_ls, + int64_t FP32_max_ls, Tensor indices, Tensor offsets, int64_t pooling_mode, @@ -248,6 +260,12 @@ Tensor int_nbit_split_embedding_nobag_codegen_forward_unweighted_cuda( int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, + int64_t INT2_max_ls, + int64_t INT4_max_ls, + int64_t INT8_max_ls, + int64_t FP8_max_ls, + int64_t FP16_max_ls, + int64_t FP32_max_ls, Tensor indices, Tensor offsets, int64_t row_alignment, @@ -272,6 +290,12 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, + int64_t INT2_max_ls, + int64_t INT4_max_ls, + int64_t INT8_max_ls, + int64_t FP8_max_ls, + int64_t FP16_max_ls, + int64_t FP32_max_ls, Tensor indices, Tensor offsets, int64_t pooling_mode, @@ -308,6 +332,12 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( max_int8_D, max_float16_D, max_float32_D, + INT2_max_ls, + INT4_max_ls, + INT8_max_ls, + FP8_max_ls, + FP16_max_ls, + FP32_max_ls, indices.to(at::kInt), offsets.to(at::kInt), row_alignment ? *row_alignment : 16, @@ -316,7 +346,8 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( lxu_cache_locations.value_or(at::empty({0}, at::kInt)), max_float8_D ? *max_float8_D : 0, fp8_exponent_bits ? *fp8_exponent_bits : -1, - fp8_exponent_bias ? *fp8_exponent_bias : -1); + fp8_exponent_bias ? *fp8_exponent_bias : -1 + ); } if (!indice_weights || indice_weights->numel() == 0) { return int_nbit_split_embedding_codegen_forward_unweighted_cuda( @@ -332,6 +363,12 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( max_int8_D, max_float16_D, max_float32_D, + INT2_max_ls, + INT4_max_ls, + INT8_max_ls, + FP8_max_ls, + FP16_max_ls, + FP32_max_ls, indices, offsets, pooling_mode, @@ -341,7 +378,8 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( lxu_cache_locations.value_or(at::empty({0}, at::kInt)), max_float8_D ? *max_float8_D : 0, fp8_exponent_bits ? *fp8_exponent_bits : -1, - fp8_exponent_bias ? *fp8_exponent_bias : -1); + fp8_exponent_bias ? *fp8_exponent_bias : -1 + ); } // Force casting indice_weights to float (doing this in the backend to avoid // JIT issue) @@ -359,6 +397,12 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( max_int8_D, max_float16_D, max_float32_D, + INT2_max_ls, + INT4_max_ls, + INT8_max_ls, + FP8_max_ls, + FP16_max_ls, + FP32_max_ls, indices, offsets, pooling_mode, @@ -369,15 +413,15 @@ Tensor int_nbit_split_embedding_codegen_lookup_function( lxu_cache_locations.value_or(at::empty({0}, at::kInt)), max_float8_D ? *max_float8_D : 0, fp8_exponent_bits ? *fp8_exponent_bits : -1, - fp8_exponent_bias ? *fp8_exponent_bias : -1); + fp8_exponent_bias ? *fp8_exponent_bias : -1 + ); } ///@ingroup embedding-cuda -/// Simlar to int_nbit_split_embedding_codegen_lookup_function, but it does /// UVM_CACHING lookup. Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function( // First args should be the same to those of - // int_nbit_split_embedding_codegen_lookup_function. + Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, @@ -390,6 +434,12 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function( int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, + int64_t INT2_max_ls, + int64_t INT4_max_ls, + int64_t INT8_max_ls, + int64_t FP8_max_ls, + int64_t FP16_max_ls, + int64_t FP32_max_ls, Tensor indices, Tensor offsets, int64_t pooling_mode, @@ -547,6 +597,12 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function( max_int8_D, max_float16_D, max_float32_D, + INT2_max_ls, + INT4_max_ls, + INT8_max_ls, + FP8_max_ls, + FP16_max_ls, + FP32_max_ls, indices, offsets, pooling_mode, @@ -557,7 +613,8 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function( row_alignment, max_float8_D, fp8_exponent_bits, - fp8_exponent_bias); + fp8_exponent_bias + ); } ///@ingroup embedding-cuda @@ -583,4 +640,4 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { int_nbit_split_embedding_uvm_caching_codegen_lookup_function); DISPATCH_TO_CUDA("pruned_hashmap_lookup", pruned_hashmap_lookup_cuda); DISPATCH_TO_CUDA("pruned_array_lookup", pruned_array_lookup_cuda); -} +} \ No newline at end of file diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp index 2d2008abea..cd72d8251d 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_host_cpu.cpp @@ -1,10 +1,10 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ +* Copyright (c) Meta Platforms, Inc. and affiliates. +* All rights reserved. +* +* This source code is licensed under the BSD-style license found in the +* LICENSE file in the root directory of this source tree. +*/ #include #include @@ -91,6 +91,12 @@ Tensor int_nbit_split_embedding_codegen_lookup_function_cpu( int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, + int64_t INT2_max_ls, + int64_t INT4_max_ls, + int64_t INT8_max_ls, + int64_t FP8_max_ls, + int64_t FP16_max_ls, + int64_t FP32_max_ls, Tensor indices, Tensor offsets, int64_t pooling_mode, @@ -180,6 +186,12 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function_cpu( int64_t max_int8_D, int64_t max_float16_D, int64_t max_float32_D, + int64_t INT2_max_ls, + int64_t INT4_max_ls, + int64_t INT8_max_ls, + int64_t FP8_max_ls, + int64_t FP16_max_ls, + int64_t FP32_max_ls, Tensor indices, Tensor offsets, int64_t pooling_mode, @@ -212,6 +224,12 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function_cpu( max_int8_D, max_float16_D, max_float32_D, + INT2_max_ls, + INT4_max_ls, + INT8_max_ls, + FP8_max_ls, + FP16_max_ls, + FP32_max_ls, indices, offsets, pooling_mode, @@ -222,7 +240,8 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function_cpu( row_alignment, max_float8_D, fp8_exponent_bits, - fp8_exponent_bias); + fp8_exponent_bias + ); } ///@ingroup embedding-cpu @@ -254,14 +273,14 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py"); #endif m.def( - "int_nbit_split_embedding_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment = None, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1) -> Tensor", + "int_nbit_split_embedding_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D ,int INT2_max_ls, int INT4_max_ls, int INT8_max_ls, int FP8_max_ls, int FP16_max_ls, int FP32_max_ls, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment = None, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1 ) -> Tensor", {PT2_COMPLIANT_TAG}); DISPATCH_TO_CPU( "int_nbit_split_embedding_codegen_lookup_function", int_nbit_split_embedding_codegen_lookup_function_cpu); m.def( - "int_nbit_split_embedding_uvm_caching_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights=None, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment=-1, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1, Tensor? cache_hash_size_cumsum=None, int? total_cache_hash_size=-1, Tensor? cache_index_table_map=None, Tensor? lxu_cache_state=None, Tensor? lxu_state=None) -> Tensor"); + "int_nbit_split_embedding_uvm_caching_codegen_lookup_function(Tensor dev_weights, Tensor uvm_weights, Tensor weights_placements, Tensor weights_offsets, Tensor weights_tys, Tensor D_offsets, SymInt total_D, int max_int2_D, int max_int4_D, int max_int8_D, int max_float16_D, int max_float32_D ,int INT2_max_ls, int INT4_max_ls, int INT8_max_ls, int FP8_max_ls, int FP16_max_ls, int FP32_max_ls, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights=None, int output_dtype=1, Tensor? lxu_cache_weights=None, Tensor? lxu_cache_locations=None, int? row_alignment=-1, int? max_float8_D=0, int? fp8_exponent_bits=-1, int? fp8_exponent_bias=-1, Tensor? cache_hash_size_cumsum=None, int? total_cache_hash_size=-1, Tensor? cache_index_table_map=None, Tensor? lxu_cache_state=None, Tensor? lxu_state=None) -> Tensor"); DISPATCH_TO_CPU( "int_nbit_split_embedding_uvm_caching_codegen_lookup_function", int_nbit_split_embedding_uvm_caching_codegen_lookup_function_cpu); @@ -287,7 +306,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { } class PrunedMapCPU : public torch::jit::CustomClassHolder { - public: +public: PrunedMapCPU() {} explicit PrunedMapCPU(std::string serialized) { torch::serialize::InputArchive archive; @@ -409,7 +428,7 @@ class PrunedMapCPU : public torch::jit::CustomClassHolder { return dense_indices; } - private: +private: #ifdef FBCODE_CAFFE2 std::vector> maps_; #else @@ -433,7 +452,7 @@ static auto PrunedMapCPURegistry = }); class AtomicCounter : public torch::jit::CustomClassHolder { - public: +public: AtomicCounter() { counter_ = 0; } @@ -470,7 +489,7 @@ class AtomicCounter : public torch::jit::CustomClassHolder { return oss.str(); } - private: +private: std::atomic counter_{0}; }; @@ -570,7 +589,7 @@ struct TensorQueue : torch::CustomClassHolder { std::make_tuple("queue", queue_vec)); } - private: +private: std::deque queue_; std::mutex mutex_; Tensor init_tensor_; @@ -594,4 +613,4 @@ static auto TensorQueueRegistry = [](c10::Dict data) -> c10::intrusive_ptr { return c10::make_intrusive(std::move(data)); - }); + }); \ No newline at end of file diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu index d9dc720c30..130cc52061 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu @@ -9,8 +9,7 @@ // clang-format off {%- set wdesc = "weighted" if weighted else "unweighted" %} #include "fbgemm_gpu/embedding_forward_template_helpers.cuh" -#include "fbgemm_gpu/utils/kernel_launcher.cuh" -#include "fbgemm_gpu/utils/tensor_accessor_builder.h" +#include "fbgemm_gpu/utils/tensor_accessor.h" #include "fbgemm_gpu/config/feature_gates.h" using namespace fbgemm_gpu; @@ -25,7 +24,7 @@ namespace nbit { same generated source file. */ {%- for emb_weight_type in ["FP32", "FP16", "FP8", "INT8", "INT4", "INT2"] %} -template +template __launch_bounds__(WarpsPerBlock * kWarpSize) __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L( const pta::PackedTensorAccessor64 dev_weights, @@ -53,6 +52,7 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no const int fp8_exponent_bias, {%- endif %} const int32_t num_packed_bags, + const int32_t num_packed_bags_L, pta::PackedTensorAccessor32 output, // [B][total_D], const pta::PackedTensorAccessor64 lxu_cache_weights, const pta::PackedTensorAccessor32 lxu_cache_locations @@ -64,46 +64,52 @@ __global__ void {{ type_map[emb_weight_type].enum_name }}_split_embedding{{ "_no {%- macro define_kernel_invocation(emb_weight_type) %} {%- set func_name = "nbit::" + emb_weight_type + "_split_embedding" + ("_nobag" if nobag else "") + "_codegen_forward_" + wdesc + "_kernel_small_L" %} + #ifdef FBGEMM_GPU_MEMCHECK + const auto func_name_{{ emb_weight_type }} = "{{ func_name }}_{{ emb_weight_type }}"; + #endif + #ifdef X #undef X #endif - #define X(DeviceOnly, PackedMode, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ - FBGEMM_LAUNCH_KERNEL( \ - ({{ func_name }}), \ + // Define {{ emb_weight_type }} kernel invocation macro + #define X(DeviceOnly, PackedMode, PackedModeL, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ + {{ func_name }}<<< \ nbit::div_round_up(T * nbit::div_round_up(B, num_packed_bags * OutputRowsPerThread), kWarpsPerBlock), \ dim3(kWarpSize, kWarpsPerBlock), \ 0, \ - at::cuda::getCurrentCUDAStream(), \ - PTA_B(dev_weights, uint8_t, 1, 64), \ - PTA_B(uvm_weights, uint8_t, 1, 64), \ - PTA_B(weights_placements, int32_t, 1, 32), \ - PTA_B(weights_offsets, int64_t, 1, 32), \ - PTA_B(weights_tys, uint8_t, 1, 32), \ + at::cuda::getCurrentCUDAStream()>>>( \ + MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, dev_weights, uint8_t, 1, 64), \ + MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, uvm_weights, uint8_t, 1, 64), \ + MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, weights_placements, int32_t, 1, 32), \ + MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, weights_offsets, int64_t, 1, 32), \ + MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, weights_tys, uint8_t, 1, 32), \ {%- if not nobag %} - PTA_B(D_offsets, int32_t, 1, 32), \ + MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, D_offsets, int32_t, 1, 32), \ {%- else %} D, \ {%- endif %} FixedDivisor(div_round_up(B, num_packed_bags * OutputRowsPerThread)), \ - PTA_B(indices, index_t, 1, 32), \ - PTA_B(offsets, index_t, 1, 32), \ + MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, indices, index_t, 1, 32), \ + MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, offsets, index_t, 1, 32), \ {%- if not nobag %} pooling_mode, \ {%- endif %} row_alignment, \ {%- if weighted %} - PTA_B(indice_weights, float, 1, 32), \ + MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, indice_weights, float, 1, 32), \ {%- endif %} {%- if emb_weight_type == "FP8" %} fp8_exponent_bits, \ fp8_exponent_bias, \ {%- endif %} num_packed_bags, \ - PTA_B(output, output_t, 2, 32), \ - PTA_B(lxu_cache_weights, uint8_t, 2, 64), \ - PTA_B(lxu_cache_locations, int32_t, 1, 32) \ - ); + num_packed_bags_L, \ + MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, output, output_t, 2, 32), \ + MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, lxu_cache_weights, uint8_t, 2, 64), \ + MAKE_PTA_WITH_NAME(func_name_{{ emb_weight_type }}, lxu_cache_locations, int32_t, 1, 32) \ + ); \ + C10_CUDA_KERNEL_LAUNCH_CHECK(); \ {%- endmacro %} {%- macro construct_and_return_output_tensor() %} @@ -186,6 +192,12 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ const int64_t max_int8_D, const int64_t max_float16_D, const int64_t max_float32_D, + const int64_t INT2_max_ls, + const int64_t INT4_max_ls, + const int64_t INT8_max_ls, + const int64_t FP8_max_ls, + const int64_t FP16_max_ls, + const int64_t FP32_max_ls, Tensor indices, Tensor offsets, {%- if not nobag %} @@ -229,13 +241,16 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ const static bool use_rocm_packed_bag_mode = kIsRocm && fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_INFERENCE_PACKED_BAGS); /* * Helper macro for run-time packed mode dispatch. Computes maximum number of bags - * (num_packed_bags) that fits into NumUint4LoadsPerRow given embeddings' type and - * size. num_packed_bags is to be used for additional bags indexing + * (num_packed_bags) that fits into NumUint4LoadsPerRow given embeddings' type and + * size. num_packed_bags is to be used for additional bags indexing * * Current support range: ROCm and output_t != uint8_t and sparse_type != FP32 */ #define PACKED_MODE_SWITCH(dev_only, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ int32_t num_packed_bags = 1; \ + int32_t num_packed_bags_D = 1; \ + int32_t num_packed_bags_L = 1; \ + const int64_t max_L = max_Ls; \ {%-if is_rocm and not nobag %} const static bool use_packed_bag_mode = fbgemm_gpu::config::is_feature_enabled( \ fbgemm_gpu::config::FeatureGateName::TBE_ROCM_INFERENCE_PACKED_BAGS); \ @@ -243,14 +258,19 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ /* The actual maximum number of uint4 reads per row w.r.t. row size, type and alignment */ \ const int32_t num_uint4_loads_per_row = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), sizeof(uint4)); \ constexpr int32_t NumUint4LoadsPerRow = MaxNum128BRows * 128 / sizeof(uint4); \ + constexpr int32_t max_indices_per_warp = kWarpSize / NumUint4LoadsPerRow; \ + num_packed_bags_L = max_indices_per_warp > max_L && !std::is_same_v && sparse_type != SparseType::FP32? max_indices_per_warp / max_L : 1; \ + num_packed_bags_D = NumUint4LoadsPerRow > num_uint4_loads_per_row && !std::is_same_v && sparse_type != SparseType::FP32 ? NumUint4LoadsPerRow / num_uint4_loads_per_row : 1; \ /* Number of bags that might be fitted to shared memory. */ \ - num_packed_bags = NumUint4LoadsPerRow > num_uint4_loads_per_row && !std::is_same_v && sparse_type != SparseType::FP32 ? NumUint4LoadsPerRow / num_uint4_loads_per_row : 1; \ + num_packed_bags = max_L>1 ? num_packed_bags_D : num_packed_bags_L * num_packed_bags_D; \ } \ {%- endif %} - if (num_packed_bags > 1) { \ - X(dev_only, true, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ + if (num_packed_bags > 1 && max_L>1) { \ + X(dev_only, true, false, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ + } else if (num_packed_bags > 1 && max_L<=1) { \ + X(dev_only, true, true, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ } else { \ - X(dev_only, false, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ + X(dev_only, false, false, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ }; #define Y(...) \ @@ -270,6 +290,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ DISPATCH_OUTPUT_TYPES(output.scalar_type(), "int2_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_int2_D > 0) { const auto max_D = max_int2_D; + const auto max_Ls = INT2_max_ls; constexpr auto sparse_type = SparseType::INT2; auto max_int2_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), 128); TORCH_CHECK(max_int2_128b_rows <= 8); @@ -299,6 +320,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ DISPATCH_OUTPUT_TYPES(output.scalar_type(), "int4_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_int4_D > 0) { const auto max_D = max_int4_D; + const auto max_Ls = INT4_max_ls; constexpr auto sparse_type = SparseType::INT4; auto max_int4_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), 128); TORCH_CHECK(max_int4_128b_rows <= 16); @@ -327,9 +349,9 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ if (max_int4_128b_rows > 8) { if(use_rocm_packed_bag_mode) { Y(1, 1, 8, 16); - } else { + } else { Y(1, 4, 8, 16); - } + } } } })); @@ -345,6 +367,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ DISPATCH_OUTPUT_TYPES(output.scalar_type(), "int8_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_int8_D > 0) { const auto max_D = max_int8_D; + const auto max_Ls = INT8_max_ls; constexpr auto sparse_type = SparseType::INT8; auto max_int8_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), 128); TORCH_CHECK(max_int8_128b_rows <= 32); @@ -386,7 +409,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ Y(1, 1, 16, 32); } else { Y(1, 2, 16, 32); - } + } } } })); @@ -402,6 +425,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ DISPATCH_OUTPUT_TYPES(output.scalar_type(), "fp8_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_float8_D > 0) { const auto max_D = max_float8_D; + const auto max_Ls = FP8_max_ls; constexpr auto sparse_type = SparseType::FP8; auto max_fp8_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), 128); TORCH_CHECK(max_fp8_128b_rows <= 32); @@ -437,6 +461,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ DISPATCH_OUTPUT_TYPES(output.scalar_type(), "fp16_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_float16_D > 0) { const auto max_D = max_float16_D; + const auto max_Ls = FP16_max_ls; constexpr auto sparse_type = SparseType::FP16; auto max_fp16_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), 128); TORCH_CHECK(max_fp16_128b_rows <= 64); @@ -472,6 +497,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ DISPATCH_OUTPUT_TYPES(output.scalar_type(), "fp32_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_kernel", ([&] { if (max_float32_D > 0) { const auto max_D = max_float32_D; + const auto max_Ls = FP32_max_ls; constexpr auto sparse_type = SparseType::FP32; auto max_fp32_128b_rows = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), 128); TORCH_CHECK(max_fp32_128b_rows <= 64); // 128 doesn't fit in 48KB SM, so FP32 TBE supports a smaller dimension than others @@ -511,6 +537,12 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ const int64_t max_int8_D, const int64_t max_float16_D, const int64_t max_float32_D, + const int64_t INT2_max_ls, + const int64_t INT4_max_ls, + const int64_t INT8_max_ls, + const int64_t FP8_max_ls, + const int64_t FP16_max_ls, + const int64_t FP32_max_ls, Tensor indices, Tensor offsets, {%- if not nobag %} @@ -572,6 +604,12 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ max_int8_D, max_float16_D, max_float32_D, + INT2_max_ls, + INT4_max_ls, + INT8_max_ls, + FP8_max_ls, + FP16_max_ls, + FP32_max_ls, indices, offsets, {%- if not nobag %} @@ -592,4 +630,4 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ return output; } - // clang-format on + // clang-format on \ No newline at end of file diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu index c17fe9fb0f..42ee5ef2ea 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_kernel_template.cu @@ -9,7 +9,7 @@ // clang-format off {% set wdesc = "weighted" if weighted else "unweighted" %} #include "fbgemm_gpu/embedding_forward_template_helpers.cuh" -#include "fbgemm_gpu/utils/tensor_accessor_builder.h" +#include "fbgemm_gpu/utils/tensor_accessor.h" using namespace fbgemm_gpu; using Tensor = at::Tensor; @@ -17,7 +17,7 @@ using Tensor = at::Tensor; namespace nbit { // TODO: increase code sharing (templates for accumulator_ty, accumulation, outputs per thread, etc?) -template +template __launch_bounds__(WarpsPerBlock * kWarpSize) __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{{ wdesc }}_kernel_small_L( const pta::PackedTensorAccessor64 dev_weights, @@ -46,40 +46,40 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no {% endif %} // The number of bags that one warp/wave is able to process in one go. (NumUint4LoadsPerRow / uint4_loads_per_row) const int32_t num_packed_bags, + const int32_t num_packed_bags_L, pta::PackedTensorAccessor32 output, // [B][total_D], const pta::PackedTensorAccessor64 lxu_cache_weights, const pta::PackedTensorAccessor32 lxu_cache_locations ) { - const int32_t T = weights_offsets.size(0); - {% if not nobag %} - const bool mean_pooling = static_cast(pooling_mode) == PoolingMode::MEAN; - const int32_t B = output.size(0); - {% else %} - const int32_t B = (offsets.size(0) - 1) / T; - {% endif %} - const auto bb_t = blockIdx.x * blockDim.y + threadIdx.y; - if (bb_t >= fd_B.D() * T) { - return; - } - static_assert( - std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v, - "output_t can only be float or half or bytes now" - ); - - int32_t t; - int32_t bb; - fd_B.DivMod(bb_t, &t, &bb); - - {% if not nobag %} - const int32_t D_start = D_offsets[t]; - const int32_t D_end = D_offsets[t + 1]; - const int32_t D = D_end - D_start; - {% endif %} - SparseType weight_ty = static_cast(weights_tys[t]); - if (weight_ty != SparseType::{{ emb_weight_type.enum_name }}) { + const int32_t T = weights_offsets.size(0); + {% if not nobag %} + const bool mean_pooling = static_cast(pooling_mode) == PoolingMode::MEAN; + const int32_t B = output.size(0); + {% else %} + const int32_t B = (offsets.size(0) - 1) / T; + {% endif %} + const auto bb_t = blockIdx.x * blockDim.y + threadIdx.y; + if (bb_t >= fd_B.D() * T) { return; - } - + } + static_assert( + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v, + "output_t can only be float or half or bytes now" + ); + + int32_t t; + int32_t bb; + fd_B.DivMod(bb_t, &t, &bb); + + {% if not nobag %} + const int32_t D_start = D_offsets[t]; + const int32_t D_end = D_offsets[t + 1]; + const int32_t D = D_end - D_start; + {% endif %} + SparseType weight_ty = static_cast(weights_tys[t]); + if (weight_ty != SparseType::{{ emb_weight_type.enum_name }}) { + return; + } // default to 16 byte alignment for GPU TBE const int32_t D_bytes = padded_row_size_in_bytes(D, weight_ty, row_alignment); @@ -100,7 +100,7 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no constexpr uint32_t NumUint4LoadsPerRow = MaxNum128BRows * 128 / sizeof(uint4); const uint32_t uint4_loads_per_row = div_round_up(D_bytes, sizeof(uint4)); - + const int32_t bag_size_offset = num_packed_bags_L > 1 ? kWarpSize/(num_packed_bags_L * NumUint4LoadsPerRow) : 1; // Index of packed bag during load stage in current warp/wave. Should fit into NumUint4LoadsPerRow (3rd) shared // memory buffer's dimension w.r.t. the actual size of the row in the bag. const uint32_t packed_bag_load_idx = PackedMode ? (threadIdx.x % NumUint4LoadsPerRow) / uint4_loads_per_row : 0; @@ -109,7 +109,10 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no // Packed bag accumulation index in current warp/wave. Each thread/lane process 1 uint instead of // 4 uints during load stage, so the index should be recomputed accordingly. const int32_t packed_bag_acc_idx = PackedMode ? (threadIdx.x / uints_per_row) % num_packed_bags : 0; + const uint32_t packed_bag_idx_L = num_packed_bags_L > 1 ? (threadIdx.x / NumUint4LoadsPerRow) / bag_size_offset : 0; + const uint32_t packed_bag_idx = (packed_bag_idx_L * num_packed_bags) + packed_bag_load_idx; + // const int32_t bag_d = kWarpSize/num_packed_bags_L; // num_packed_bags_L can be {1, 2, 4, 8} for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { // In case of PackedMode, b should be offseted with num_packed_bags and indexed with packed_bag_load_idx // to take into account reduced grid size in host kernel call and that the warp/wave may contain several @@ -134,315 +137,568 @@ __global__ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if no {% if not nobag %} VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> accumulators[OutputRowsPerThread][AccumulateStoreRequests]; - {% endif %} - - for (uint32_t L_start = 0; L_start < max_Ls; L_start += InputRowsInFlight) { - uint32_t input_rows_in_flight = min(static_cast(InputRowsInFlight), max_Ls - L_start); - - typedef uint4 AllBuffers[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][NumUint4LoadsPerRow]; - __shared__ AllBuffers buffers; - {% if weighted %} - // In case of PackedMode, overallocate indice weights buffer to store additional per-row weights for - // packed bags. - typedef float AllIndiceWeights[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][PackedMode ? NumUint4LoadsPerRow : 1]; - __shared__ AllIndiceWeights buffers_indice_weights; - {% endif %} - - for (uint32_t load_idx = threadIdx.x; load_idx < input_rows_in_flight * NumUint4LoadsPerRow; load_idx += kWarpSize) { - uint32_t row_load_idx = load_idx % NumUint4LoadsPerRow; - if constexpr (PackedMode) { - // The actual row index in packed bag w.r.t. the required uint4 loads. - row_load_idx %= uint4_loads_per_row; - } - uint32_t input_row_idx = (load_idx / NumUint4LoadsPerRow); - // In case of PackedMode, packed_bag_load_idx already takes into account uint4_loads_per_row, - // so only the packed_bag index should be evaluated against total number of packed bags. - bool load_idx_valid = PackedMode ? packed_bag_load_idx < num_packed_bags : row_load_idx < uint4_loads_per_row; - {%- if is_rocm %} - constexpr uint32_t kMaxRowUnroll = 4; - constexpr uint32_t kRowUnroll = OutputRowsPerThread < kMaxRowUnroll ? OutputRowsPerThread : kMaxRowUnroll; - - #pragma unroll - for (uint32_t outer_i = 0; outer_i < OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; outer_i += kRowUnroll) { - uint4 row_data_v[kRowUnroll]; - const uint4* row_v[kRowUnroll]; - int32_t idx_v[kRowUnroll]; - int32_t cache_idx_v[kRowUnroll]; + {% endif %} + typedef uint4 AllBuffers[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][NumUint4LoadsPerRow]; + __shared__ AllBuffers buffers; + {% if weighted %} + // In case of PackedMode, overallocate indice weights buffer to store additional per-row weights for + // packed bags. + typedef float AllIndiceWeights[WarpsPerBlock][OutputRowsPerThread][InputRowsInFlight][PackedMode ? NumUint4LoadsPerRow : 1]; + __shared__ AllIndiceWeights buffers_indice_weights; + {% endif %} + {%- if is_rocm %} + constexpr uint32_t kMaxRowUnroll = 4; + constexpr uint32_t kRowUnroll = OutputRowsPerThread < kMaxRowUnroll ? OutputRowsPerThread : kMaxRowUnroll; + {% endif %} + if constexpr (PackedModeL){ + for (uint32_t L_start = 0; L_start < max_Ls; L_start += InputRowsInFlight) { + uint32_t input_rows_in_flight = min(static_cast(InputRowsInFlight), max_Ls - L_start); + for (uint32_t load_idx = threadIdx.x; load_idx < input_rows_in_flight * num_packed_bags_L * NumUint4LoadsPerRow; load_idx += kWarpSize) { + uint32_t row_load_idx = load_idx % NumUint4LoadsPerRow % uint4_loads_per_row; + uint32_t input_row_idx = num_packed_bags_L>1? (load_idx / NumUint4LoadsPerRow) % bag_size_offset: (load_idx / NumUint4LoadsPerRow); + bool load_idx_valid = packed_bag_load_idx < num_packed_bags && packed_bag_idx_L < num_packed_bags_L; + {%- if is_rocm %} #pragma unroll - for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { - uint32_t i = outer_i + inner_i; - bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; - bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); - idx_v[inner_i] = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; - cache_idx_v[inner_i] = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; + for (uint32_t outer_i = 0; outer_i < OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; outer_i += kRowUnroll) { + uint4 row_data_v[kRowUnroll]; + const uint4* row_v[kRowUnroll]; + int32_t idx_v[kRowUnroll]; + int32_t cache_idx_v[kRowUnroll]; + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { + uint32_t i = outer_i + inner_i; + bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; + bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); + idx_v[inner_i] = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; + cache_idx_v[inner_i] = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; + } + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { + uint32_t i = outer_i + inner_i; + bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; + bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); + valid = valid && (idx_v[inner_i] != -1); + if (!DeviceOnly && cache_valid && cache_idx_v[inner_i] != kCacheLocationMissing) { + row_v[inner_i] = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx_v[inner_i])][0]); + } else + if (valid) { + row_v[inner_i] = reinterpret_cast(&weights[static_cast(idx_v[inner_i]) * D_bytes]); + } else { + row_v[inner_i] = reinterpret_cast(&weights[0]); + } + } + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { + uint32_t i = outer_i + inner_i; + row_data_v[inner_i] = row_v[inner_i][row_load_idx]; + } + uint4 zeros = {0, 0, 0, 0}; + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { + uint32_t i = outer_i + inner_i; + bool valid = load_idx_valid && (L_start + input_row_idx < Ls[i]) && (idx_v[inner_i] != -1); + uint4 data = valid ? row_data_v[inner_i] : zeros; + buffers[warp_idx][i][input_row_idx + bag_size_offset *packed_bag_idx_L][row_load_idx + uint4_loads_per_row * packed_bag_load_idx] = data; + {% if weighted %} + buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; + {% endif %} + } } - - + {%- endif %} + + {%- if is_rocm %} + if constexpr (OutputRowsPerThread % kRowUnroll) + { #pragma unroll - for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { - uint32_t i = outer_i + inner_i; + for (uint32_t i = OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; i < OutputRowsPerThread; ++i) { + {%- else %} + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + {%- endif %} bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); - valid = valid && (idx_v[inner_i] != -1); - if (!DeviceOnly && cache_valid && cache_idx_v[inner_i] != kCacheLocationMissing) { - row_v[inner_i] = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx_v[inner_i])][0]); - } else - if (valid) { - row_v[inner_i] = reinterpret_cast(&weights[static_cast(idx_v[inner_i]) * D_bytes]); + int32_t idx = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; + int32_t cache_idx = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; + valid = valid && (idx != -1); + const uint4* row; + if (!DeviceOnly && cache_valid && cache_idx != kCacheLocationMissing) { + row = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx)][0]); + } else if (valid) { + row = reinterpret_cast(&weights[static_cast(idx) * D_bytes]); + } else { + row = reinterpret_cast(&weights[0]); + } + cp_async_zfill_cg(&buffers[warp_idx][i][input_row_idx + bag_size_offset * packed_bag_idx_L][row_load_idx + uint4_loads_per_row * packed_bag_load_idx] , &row[row_load_idx], valid); + {% if weighted %} + buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] = valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; + {% endif %} + } + {%- if is_rocm %} + } // constexpr if (OutputRowsPerThread % kRowUnroll) + {%- endif %} + } + // equivalent to fence + wait. + cp_async_wait<0>(); + syncwarp(); + const int32_t packed_bag_load_idx = (threadIdx.x / uints_per_row) % num_packed_bags; + input_rows_in_flight = shfl_sync(input_rows_in_flight, packed_bag_load_idx * uint4_loads_per_row); + constexpr int32_t max_indices_per_warp = kWarpSize / (MaxNum128BRows * 128 / sizeof(uint4)); + int32_t Ls_shfl[kWarpSize]; + for(uint32_t k = 0; k < num_packed_bags_L ; ++k){ + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + Ls_shfl[k*OutputRowsPerThread+i] = shfl_sync(Ls[i], k * bag_size_offset * NumUint4LoadsPerRow + packed_bag_load_idx * uint4_loads_per_row); + } + } + for(uint32_t k = 0; k < num_packed_bags_L ; ++k){ + for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { + + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + bool valid = L_start + input_row_idx < Ls_shfl[k*OutputRowsPerThread+i]; + if (!valid) { + continue; + } + const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx + bag_size_offset *k][0]); + // scale and bias are at the beginning of each row. + // rationale: have scale/shift at start since these get loaded first + // and then broadcasted around so it might speed up the first cache miss. + {% if emb_weight_type.primitive_type == "INT" %} + half2 shift_scale = reinterpret_cast(row)[(packed_bag_load_idx * uints_per_row)]; + {% endif %} + + {% if weighted %} + float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx][PackedMode ? packed_bag_acc_idx : 0]; + {% endif %} + + using scalar_t = {{ emb_weight_type.cpp_type_name }}; + + {% if not nobag %} + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + {% if weighted %} + accumulators[i][j].fma(v, {% if emb_weight_type.primitive_type == "INT" %} shift_scale, {% elif emb_weight_type.enum_name == "FP8" %} exponent_bits, exponent_bias, {% endif %} row_weight); + {% else %} + accumulators[i][j].add(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + + {% endif %} + } + + {% else %} + const int32_t output_j = indices_starts[i] + L_start + input_row_idx; + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + // Read the uint8/4/2 values: note that first 4 Bytes will be ditched later: + // We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to + // the scale/shift handling). + // Reason: to avoid divergence the first thread in the warp computes garbage. + const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + acc.store(&output[output_j][output_d], num_valid_outputs); + } + } + } else if constexpr (std::is_same_v) { + // INT8: + // apply per feature row-wise int8 + auto thread_local_min = std::numeric_limits::max(); + auto thread_local_max = std::numeric_limits::lowest(); + float2 qparams; + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + if (output_d >= 0 && output_d < D) { + thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(acc.acc)); + thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(acc.acc)); + } + } + qparams = warp_find_qparams(thread_local_min, thread_local_max); + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + acc.store(&output[output_j][output_d], qparams, num_valid_outputs); + } + } + if (threadIdx.x == 0) { + store_qparams_to_row(&output[output_j][D], qparams); + } + } + {% endif %} + } + } + {% if not nobag %} + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + const int32_t num_stores_with_padding_per_row = 4 * uint4_loads_per_row; + const int32_t packed_bag_load_idx = threadIdx.x / num_stores_with_padding_per_row; + uint32_t b = min(static_cast(bb * num_packed_bags * OutputRowsPerThread + i * num_packed_bags + k*num_packed_bags + packed_bag_load_idx), static_cast(B - 1)); + const float inv_L = (mean_pooling &&Ls_shfl[k*OutputRowsPerThread+i] != 0) ? static_cast(1.0) / Ls_shfl[k*OutputRowsPerThread+i] : static_cast(1.0); + + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding \ + - packed_bag_load_idx * kOutputsPerThread * num_stores_with_padding_per_row; + accumulators[i][j].mul(inv_L); + + if (output_d >= 0 && output_d < D && packed_bag_load_idx < num_packed_bags) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + accumulators[i][j].store(&output[b][D_start + output_d], num_valid_outputs); + } + + } + + } else if constexpr (std::is_same_v) { + // INT8: + // apply per feature row-wise int8 + float thread_local_min = std::numeric_limits::max(); + float thread_local_max = std::numeric_limits::lowest(); + float2 qparams; + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + accumulators[i][j].mul(inv_L); + if (output_d >= 0 && output_d < D) { + thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(accumulators[i][j].acc)); + thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(accumulators[i][j].acc)); + } + } + + qparams = warp_find_qparams(thread_local_min, thread_local_max); + const int output_D_start = D_start + t * 8; + const int output_D_end = output_D_start + D; + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + accumulators[i][j].store(&output[b][output_D_start + output_d], qparams, num_valid_outputs); + } + } + if (threadIdx.x == 0) { + store_qparams_to_row(&output[b][output_D_end], qparams); + } } else { - row_v[inner_i] = reinterpret_cast(&weights[0]); + // INT4: not implemented yet } } + + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + accumulators[i][j].mul(0.0); // Use a dedicated clear method + } + } + + + {% endif %} + } + } + } + else{ + for (uint32_t L_start = 0; L_start < max_Ls; L_start += InputRowsInFlight) { + uint32_t input_rows_in_flight = min(static_cast(InputRowsInFlight), max_Ls - L_start); + + for (uint32_t load_idx = threadIdx.x; load_idx < input_rows_in_flight * NumUint4LoadsPerRow; load_idx += kWarpSize) { + uint32_t row_load_idx = load_idx % NumUint4LoadsPerRow; + if constexpr (PackedMode) { + // The actual row index in packed bag w.r.t. the required uint4 loads. + row_load_idx %= uint4_loads_per_row; + } + uint32_t input_row_idx = (load_idx / NumUint4LoadsPerRow); + // In case of PackedMode, packed_bag_load_idx already takes into account uint4_loads_per_row, + // so only the packed_bag index should be evaluated against total number of packed bags. + bool load_idx_valid = PackedMode ? packed_bag_load_idx < num_packed_bags : row_load_idx < uint4_loads_per_row; + {%- if is_rocm %} #pragma unroll - for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { - uint32_t i = outer_i + inner_i; - row_data_v[inner_i] = row_v[inner_i][row_load_idx]; + for (uint32_t outer_i = 0; outer_i < OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; outer_i += kRowUnroll) { + uint4 row_data_v[kRowUnroll]; + const uint4* row_v[kRowUnroll]; + int32_t idx_v[kRowUnroll]; + int32_t cache_idx_v[kRowUnroll]; + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { + uint32_t i = outer_i + inner_i; + bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; + bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); + idx_v[inner_i] = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; + cache_idx_v[inner_i] = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; + } + + + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; ++inner_i) { + uint32_t i = outer_i + inner_i; + bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; + bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); + valid = valid && (idx_v[inner_i] != -1); + if (!DeviceOnly && cache_valid && cache_idx_v[inner_i] != kCacheLocationMissing) { + row_v[inner_i] = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx_v[inner_i])][0]); + } else + if (valid) { + row_v[inner_i] = reinterpret_cast(&weights[static_cast(idx_v[inner_i]) * D_bytes]); + } else { + row_v[inner_i] = reinterpret_cast(&weights[0]); + } + } + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { + uint32_t i = outer_i + inner_i; + row_data_v[inner_i] = row_v[inner_i][row_load_idx]; + } + uint4 zeros = {0, 0, 0, 0}; + #pragma unroll + for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { + uint32_t i = outer_i + inner_i; + bool valid = load_idx_valid && (L_start + input_row_idx < Ls[i]) && (idx_v[inner_i] != -1); + uint4 data = valid ? row_data_v[inner_i] : zeros; + if constexpr (PackedMode) { + // Store row data with uint4_loads_per_row offset + buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_load_idx] = data; + } else { + buffers[warp_idx][i][input_row_idx][row_load_idx] = data; + } + {% if weighted %} + if (valid && row_load_idx == 0) { + // Use only one thread to load the index weight to prevent a race + // condition when writing to the shared memory + buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] = indice_weights[indices_starts[i] + L_start + input_row_idx]; + } + {% endif %} + } } - uint4 zeros = {0, 0, 0, 0}; + {%- endif %} + + {%- if is_rocm %} + if constexpr (OutputRowsPerThread % kRowUnroll) + { #pragma unroll - for (uint32_t inner_i = 0; inner_i < kRowUnroll; inner_i++) { - uint32_t i = outer_i + inner_i; - bool valid = load_idx_valid && (L_start + input_row_idx < Ls[i]) && (idx_v[inner_i] != -1); - uint4 data = valid ? row_data_v[inner_i] : zeros; + for (uint32_t i = OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; i < OutputRowsPerThread; ++i) { + {%- else %} + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + {%- endif %} + bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; + bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); + int32_t idx = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; + int32_t cache_idx = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; + valid = valid && (idx != -1); + const uint4* row; + if (!DeviceOnly && cache_valid && cache_idx != kCacheLocationMissing) { + row = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx)][0]); + } else if (valid) { + row = reinterpret_cast(&weights[static_cast(idx) * D_bytes]); + } else { + row = reinterpret_cast(&weights[0]); + } if constexpr (PackedMode) { - // Store row data with uint4_loads_per_row offset - buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_load_idx] = data; + // Load valid packed row data w.r.t. packed_bag offset + cp_async_zfill_cg(&buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_load_idx], &row[row_load_idx], valid); } else { - buffers[warp_idx][i][input_row_idx][row_load_idx] = data; + cp_async_zfill_cg(&buffers[warp_idx][i][input_row_idx][row_load_idx], &row[row_load_idx], valid); } {% if weighted %} - if (row_load_idx == 0) { + if (valid && row_load_idx == 0) { // Use only one thread to load the index weight to prevent a race // condition when writing to the shared memory - buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] = - valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; + buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] = indice_weights[indices_starts[i] + L_start + input_row_idx]; } {% endif %} } + {%- if is_rocm %} + } // constexpr if (OutputRowsPerThread % kRowUnroll) + {%- endif %} } - {%- endif %} - - {%- if is_rocm %} - if constexpr (OutputRowsPerThread % kRowUnroll) - { - #pragma unroll - for (uint32_t i = OutputRowsPerThread - OutputRowsPerThread % kRowUnroll; i < OutputRowsPerThread; ++i) { - {%- else %} - #pragma unroll OutputRowsPerThread - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - {%- endif %} - bool valid = load_idx_valid && L_start + input_row_idx < Ls[i]; - bool cache_valid = !DeviceOnly && (placement == PlacementType::MANAGED_CACHING && valid); - int32_t idx = valid ? indices_[indices_starts[i] + L_start + input_row_idx] : -1; - int32_t cache_idx = (!DeviceOnly && cache_valid) ? lxu_cache_locations[indices_starts[i] + L_start + input_row_idx] : -1; - valid = valid && (idx != -1); - const uint4* row; - if (!DeviceOnly && cache_valid && cache_idx != kCacheLocationMissing) { - row = reinterpret_cast(&lxu_cache_weights[static_cast(cache_idx)][0]); - } else if (valid) { - row = reinterpret_cast(&weights[static_cast(idx) * D_bytes]); - } else { - row = reinterpret_cast(&weights[0]); - } - if constexpr (PackedMode) { - // Load valid packed row data w.r.t. packed_bag offset - cp_async_zfill_cg(&buffers[warp_idx][i][input_row_idx][row_load_idx + uint4_loads_per_row * packed_bag_load_idx], &row[row_load_idx], valid); - } else { - cp_async_zfill_cg(&buffers[warp_idx][i][input_row_idx][row_load_idx], &row[row_load_idx], valid); - } - {% if weighted %} - if (row_load_idx == 0) { - // Use only one thread to load the index weight to prevent a race - // condition when writing to the shared memory - buffers_indice_weights[warp_idx][i][input_row_idx][packed_bag_load_idx] = - valid ? indice_weights[indices_starts[i] + L_start + input_row_idx] : 0.0; - } - {% endif %} - } - {%- if is_rocm %} - } // constexpr if (OutputRowsPerThread % kRowUnroll) - {%- endif %} - } - // equivalent to fence + wait. - cp_async_wait<0>(); - syncwarp(); - - if constexpr (PackedMode) { - // Since in PackedMode one warp/wave may contain different bags with different sizes, - // the permutation should be done after switching from uint4 processing during load stage - // to uint processing during accumulate and store. - input_rows_in_flight = shfl_sync(input_rows_in_flight, packed_bag_acc_idx * uint4_loads_per_row); - - #pragma unroll OutputRowsPerThread - for(uint32_t i = 0; i < OutputRowsPerThread; ++i) - { - Ls[i] = shfl_sync(Ls[i], packed_bag_acc_idx * uint4_loads_per_row); - } - } + // equivalent to fence + wait. + cp_async_wait<0>(); + syncwarp(); - for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { - #pragma unroll OutputRowsPerThread - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - bool valid = L_start + input_row_idx < Ls[i]; - if (!valid) { - continue; + if constexpr (PackedMode) { + // Since in PackedMode one warp/wave may contain different bags with different sizes, + // the permutation should be done after switching from uint4 processing during load stage + // to uint processing during accumulate and store. + input_rows_in_flight = shfl_sync(input_rows_in_flight, packed_bag_acc_idx * uint4_loads_per_row); + + #pragma unroll OutputRowsPerThread + for(uint32_t i = 0; i < OutputRowsPerThread; ++i) + { + Ls[i] = shfl_sync(Ls[i], packed_bag_acc_idx * uint4_loads_per_row); } - const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx][0]); - // scale and bias are at the beginning of each row. - // rationale: have scale/shift at start since these get loaded first - // and then broadcasted around so it might speed up the first cache miss. - {% if emb_weight_type.primitive_type == "INT" %} - // In PackedMode, row pointer may contain several rows from different bags, so each thread/lane should - // read the certain shift_scale related to the row in the packed_bag. - half2 shift_scale = reinterpret_cast(row)[PackedMode ? packed_bag_acc_idx * uints_per_row : 0]; - {% endif %} - - {% if weighted %} - float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx][PackedMode ? packed_bag_acc_idx : 0]; - {% endif %} - - using scalar_t = {{ emb_weight_type.cpp_type_name }}; + } + + for (uint32_t input_row_idx = 0; input_row_idx < input_rows_in_flight; ++input_row_idx) { + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + bool valid = L_start + input_row_idx < Ls[i]; + if (!valid) { + continue; + } + const uint32_t* row = reinterpret_cast(&buffers[warp_idx][i][input_row_idx][0]); + // scale and bias are at the beginning of each row. + // rationale: have scale/shift at start since these get loaded first + // and then broadcasted around so it might speed up the first cache miss. + {% if emb_weight_type.primitive_type == "INT" %} + // In PackedMode, row pointer may contain several rows from different bags, so each thread/lane should + // read the certain shift_scale related to the row in the packed_bag. + half2 shift_scale = reinterpret_cast(row)[PackedMode ? packed_bag_acc_idx * uints_per_row : 0]; + {% endif %} - {% if not nobag %} - #pragma unroll AccumulateStoreRequests - for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { - scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; {% if weighted %} - accumulators[i][j].fma(v, {% if emb_weight_type.primitive_type == "INT" %} shift_scale, {% elif emb_weight_type.enum_name == "FP8" %} exponent_bits, exponent_bias, {% endif %} row_weight); - {% else %} - accumulators[i][j].add(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + float row_weight = buffers_indice_weights[warp_idx][i][input_row_idx][PackedMode ? packed_bag_acc_idx : 0]; {% endif %} - } - {% else %} - const int32_t output_j = indices_starts[i] + L_start + input_row_idx; - if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + + using scalar_t = {{ emb_weight_type.cpp_type_name }}; + + {% if not nobag %} #pragma unroll AccumulateStoreRequests for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { - // Read the uint8/4/2 values: note that first 4 Bytes will be ditched later: - // We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to - // the scale/shift handling). - // Reason: to avoid divergence the first thread in the warp computes garbage. - const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - if (output_d >= 0 && output_d < D) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); - VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); - acc.store(&output[output_j][output_d], num_valid_outputs); - } + {% if weighted %} + accumulators[i][j].fma(v, {% if emb_weight_type.primitive_type == "INT" %} shift_scale, {% elif emb_weight_type.enum_name == "FP8" %} exponent_bits, exponent_bias, {% endif %} row_weight); + {% else %} + accumulators[i][j].add(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + {% endif %} } - } else if constexpr (std::is_same_v) { - // INT8: - // apply per feature row-wise int8 - auto thread_local_min = std::numeric_limits::max(); - auto thread_local_max = std::numeric_limits::lowest(); - float2 qparams; - #pragma unroll AccumulateStoreRequests - for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { - auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); - if (output_d >= 0 && output_d < D) { - thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(acc.acc)); - thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(acc.acc)); + {% else %} + const int32_t output_j = indices_starts[i] + L_start + input_row_idx; + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + // Read the uint8/4/2 values: note that first 4 Bytes will be ditched later: + // We shift back by 4/8/16 elements to remove the first 4 Bytes (which is garbage due to + // the scale/shift handling). + // Reason: to avoid divergence the first thread in the warp computes garbage. + const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + acc.store(&output[output_j][output_d], num_valid_outputs); + } } - } - qparams = warp_find_qparams(thread_local_min, thread_local_max); - #pragma unroll AccumulateStoreRequests - for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { - const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; - if (output_d >= 0 && output_d < D) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + } else if constexpr (std::is_same_v) { + // INT8: + // apply per feature row-wise int8 + auto thread_local_min = std::numeric_limits::max(); + auto thread_local_max = std::numeric_limits::lowest(); + float2 qparams; + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); - acc.store(&output[output_j][output_d], qparams, num_valid_outputs); + if (output_d >= 0 && output_d < D) { + thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(acc.acc)); + thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(acc.acc)); + } + } + qparams = warp_find_qparams(thread_local_min, thread_local_max); + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + scalar_t v = reinterpret_cast(row)[kWarpSize * j + threadIdx.x]; + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + VecNT<{{ (32 // emb_weight_type.bit_width) }}, PrimitiveType::{{ emb_weight_type.primitive_type }}> acc(v{% if emb_weight_type.primitive_type == "INT" %}, shift_scale {% elif emb_weight_type.enum_name == "FP8" %}, exponent_bits, exponent_bias {% endif %}); + acc.store(&output[output_j][output_d], qparams, num_valid_outputs); + } + } + if (threadIdx.x == 0) { + store_qparams_to_row(&output[output_j][D], qparams); } } - if (threadIdx.x == 0) { - store_qparams_to_row(&output[output_j][D], qparams); - } + {% endif %} } - {% endif %} } } - } - {% if not nobag %} - // In case of PackedMode, computes the packed bag index during store stage w.r.t. - // the real number of uints in the rows. - const auto packed_bag_store_idx = PackedMode ? threadIdx.x / uints_per_row : 0; + {% if not nobag %} + // In case of PackedMode, computes the packed bag index during store stage w.r.t. + // the real number of uints in the rows. + const int32_t packed_bag_store_idx = PackedMode ? threadIdx.x / uints_per_row : 0; - #pragma unroll OutputRowsPerThread - for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { - const uint32_t b = min(static_cast(bb * num_packed_bags * OutputRowsPerThread + i * num_packed_bags + packed_bag_store_idx), static_cast(B - 1)); - const float inv_L = (mean_pooling && Ls[i] != 0) ? static_cast(1.0) / Ls[i] : static_cast(1.0); + #pragma unroll OutputRowsPerThread + for (uint32_t i = 0; i < OutputRowsPerThread; ++i) { + const uint32_t b = min(static_cast(bb * num_packed_bags * OutputRowsPerThread + i * num_packed_bags + packed_bag_store_idx), static_cast(B - 1)); + const float inv_L = (mean_pooling && Ls[i] != 0) ? static_cast(1.0) / Ls[i] : static_cast(1.0); - if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { - #pragma unroll AccumulateStoreRequests - for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { - auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - if constexpr (PackedMode) { - // Offset global output_d index with the size of outputs per bag w.r.t. current - // packed bag index - output_d -= packed_bag_store_idx * kOutputsPerThread * uints_per_row; - } - accumulators[i][j].mul(inv_L); + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + if constexpr (PackedMode) { + // Offset global output_d index with the size of outputs per bag w.r.t. current + // packed bag index + output_d -= packed_bag_store_idx * kOutputsPerThread * uints_per_row; + } + accumulators[i][j].mul(inv_L); - if constexpr (PackedMode) { - // Take into account the packed bag index overflow - if (output_d >= 0 && output_d < D && packed_bag_store_idx < num_packed_bags) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); - accumulators[i][j].store(&output[b][D_start + output_d], num_valid_outputs); + if constexpr (PackedMode) { + // Take into account the packed bag index overflow + if (output_d >= 0 && output_d < D && packed_bag_store_idx < num_packed_bags) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + accumulators[i][j].store(&output[b][D_start + output_d], num_valid_outputs); + } + } else { + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + accumulators[i][j].store(&output[b][D_start + output_d], num_valid_outputs); + } } - } else { + + } + } else if constexpr (std::is_same_v) { + // INT8: + // apply per feature row-wise int8 + float thread_local_min = std::numeric_limits::max(); + float thread_local_max = std::numeric_limits::lowest(); + float2 qparams; + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + accumulators[i][j].mul(inv_L); if (output_d >= 0 && output_d < D) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); - accumulators[i][j].store(&output[b][D_start + output_d], num_valid_outputs); + thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(accumulators[i][j].acc)); + thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(accumulators[i][j].acc)); } } - } - } else if constexpr (std::is_same_v) { - // INT8: - // apply per feature row-wise int8 - float thread_local_min = std::numeric_limits::max(); - float thread_local_max = std::numeric_limits::lowest(); - float2 qparams; - #pragma unroll AccumulateStoreRequests - for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { - auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - accumulators[i][j].mul(inv_L); - if (output_d >= 0 && output_d < D) { - thread_local_max = max(thread_local_max, float{{ (32 // emb_weight_type.bit_width) }}_max(accumulators[i][j].acc)); - thread_local_min = min(thread_local_min, float{{ (32 // emb_weight_type.bit_width) }}_min(accumulators[i][j].acc)); + qparams = warp_find_qparams(thread_local_min, thread_local_max); + const int output_D_start = D_start + t * 8; + const int output_D_end = output_D_start + D; + #pragma unroll AccumulateStoreRequests + for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { + const int32_t output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; + if (output_d >= 0 && output_d < D) { + const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); + accumulators[i][j].store(&output[b][output_D_start + output_d], qparams, num_valid_outputs); + } } - } - - qparams = warp_find_qparams(thread_local_min, thread_local_max); - const int output_D_start = D_start + t * 8; - const int output_D_end = output_D_start + D; - #pragma unroll AccumulateStoreRequests - for (uint32_t j = 0; j < AccumulateStoreRequests; ++j) { - const auto output_d = kWarpSize * j * kOutputsPerThread + threadIdx.x * kOutputsPerThread - D_padding; - if (output_d >= 0 && output_d < D) { - const int num_valid_outputs = min(static_cast(D - output_d), static_cast({{ (32 // emb_weight_type.bit_width) }})); - accumulators[i][j].store(&output[b][output_D_start + output_d], qparams, num_valid_outputs); + if (threadIdx.x == 0) { + store_qparams_to_row(&output[b][output_D_end], qparams); } + } else { + // INT4: not implemented yet } - if (threadIdx.x == 0) { - store_qparams_to_row(&output[b][output_D_end], qparams); - } - } else { - // INT4: not implemented yet } - } - {% endif %} + {% endif %} + } } // kWarpsPerBlock is defined in embedding_forward_quantized_split_nbit_host_template.cu {% set warps_per_block = '4' %} {% for packed_mode in ['true', 'false'] %} +{% for packed_mode_L in ['true', 'false'] %} {% for device_only in ['true', 'false'] %} {% for output_type in ['at::Half', 'at::BFloat16', 'float', 'uint8_t'] %} {% for index_type in ['int32_t', 'int64_t'] %} @@ -464,7 +720,8 @@ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if nobag else "" {{ params.min_128b_rows }}, {{ params.max_128b_rows }}, {{ device_only }}, - {{ packed_mode }} > ( + {{ packed_mode }}, + {{ packed_mode_L }} > ( const pta::PackedTensorAccessor64 dev_weights, const pta::PackedTensorAccessor64 uvm_weights, const pta::PackedTensorAccessor32 weights_placements, @@ -490,6 +747,7 @@ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if nobag else "" const int exponent_bias, {% endif %} const int32_t num_packed_bags, + const int32_t num_packed_bags_L, pta::PackedTensorAccessor32<{{ output_type }}, 2, at::RestrictPtrTraits> output, // [B][total_D], const pta::PackedTensorAccessor64 lxu_cache_weights, const pta::PackedTensorAccessor32 lxu_cache_locations @@ -504,7 +762,8 @@ void {{ emb_weight_type.enum_name }}_split_embedding{{ "_nobag" if nobag else "" {% endfor %} // for output_type in [True, False] {% endfor %} // device_only in [True, False] {% endfor %} // packed_bags in ['true', 'false'] +{% endfor %} // packed_bags in ['true', 'false'] } - // clang-format on + // clang-format on \ No newline at end of file diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index c360cbd091..3c2e5836e8 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -28,7 +28,6 @@ DEFAULT_SCALE_BIAS_SIZE_IN_BYTES, EmbeddingLocation, EmbeddingSpecInfo, - get_bounds_check_version_for_platform, get_new_embedding_location, MAX_PREFETCH_DEPTH, PoolingMode, @@ -58,13 +57,7 @@ import fbgemm_gpu # noqa -def find_max_ls(ty: SparseType, weights_tys:List[SparseType], offsets: Tensor ): - # bag_sizes = None - for type_ in weights_tys: - if type_ == ty or type_.value == ty.value: - bag_sizes = offsets[1:] - offsets[:-1] - return bag_sizes.max().item() - return 0 + def rounded_row_size_in_bytes( dim: int, weight_ty: SparseType, @@ -358,6 +351,7 @@ def __init__( # noqa C901 feature_table_map: Optional[List[int]] = None, # [T] index_remapping: Optional[List[Tensor]] = None, pooling_mode: PoolingMode = PoolingMode.SUM, + Ls=None, device: Optional[Union[str, int, torch.device]] = None, bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, weight_lists: Optional[List[Tuple[Tensor, Optional[Tensor]]]] = None, @@ -487,8 +481,25 @@ def max_ty_D(ty: SparseType) -> int: ], default=0, ) - + def find_max_ls(ty: SparseType, weights_tys: List[SparseType], Ls) -> int: + if isinstance(Ls, list): + return ( + 0 + if not any(t.value == ty.value for t in weights_tys) + else int(max(Ls)) + ) + else: + return ( + 0 if not any(t.value == ty.value for t in weights_tys) else int(Ls) + ) + + self.INT2_max_ls = find_max_ls(SparseType.INT2, weights_tys, Ls) + self.INT4_max_ls = find_max_ls(SparseType.INT4, weights_tys, Ls) + self.INT8_max_ls = find_max_ls(SparseType.INT8, weights_tys, Ls) + self.FP8_max_ls = find_max_ls(SparseType.FP8, weights_tys, Ls) + self.FP16_max_ls = find_max_ls(SparseType.FP16, weights_tys, Ls) + self.FP32_max_ls = find_max_ls(SparseType.FP32, weights_tys, Ls) self.max_int2_D: int = max_ty_D(SparseType.INT2) self.max_int4_D: int = max_ty_D(SparseType.INT4) @@ -496,6 +507,7 @@ def max_ty_D(ty: SparseType) -> int: self.max_float8_D: int = max_ty_D(SparseType.FP8) self.max_float16_D: int = max_ty_D(SparseType.FP16) self.max_float32_D: int = max_ty_D(SparseType.FP32) + self.register_buffer( "D_offsets", torch.tensor(D_offsets, device=self.current_device, dtype=torch.int32), @@ -643,8 +655,6 @@ def max_ty_D(ty: SparseType) -> int: self.fp8_exponent_bits = -1 self.fp8_exponent_bias = -1 - self.bounds_check_version: int = get_bounds_check_version_for_platform() - @torch.jit.ignore def log(self, msg: str) -> None: """ @@ -967,18 +977,6 @@ def _forward_impl( indices, offsets, per_sample_weights = inputs_to_device( indices, offsets, per_sample_weights, self.bounds_check_warning ) - weights_tys: List[SparseType] = [e[3] for e in self.embedding_specs] - type_list = [SparseType.INT2, SparseType.INT4, SparseType.INT8, SparseType.FP8, SparseType.FP16, SparseType.FP32] - INT2_max_ls = find_max_ls(SparseType.INT2, weights_tys, offsets) - INT4_max_ls = find_max_ls(SparseType.INT4, weights_tys, offsets) - INT8_max_ls = find_max_ls(SparseType.INT8, weights_tys, offsets) - FP8_max_ls = find_max_ls(SparseType.FP8, weights_tys, offsets) - FP16_max_ls = find_max_ls(SparseType.FP16, weights_tys, offsets) - FP32_max_ls = find_max_ls(SparseType.FP32, weights_tys, offsets) - - - - # First bound check: check if the indices/offsets are within the boundary # of the original embedding rows before pruning. @@ -997,7 +995,6 @@ def _forward_impl( self.bounds_check_mode_int, self.bounds_check_warning, per_sample_weights, - bounds_check_version=self.bounds_check_version, ) # Index remapping changes input indices, and some of them becomes -1 (prunned rows). @@ -1040,7 +1037,6 @@ def _forward_impl( self.bounds_check_mode_int, self.bounds_check_warning, per_sample_weights, - bounds_check_version=self.bounds_check_version, ) # Note: CPU and CUDA ops use the same interface to facilitate JIT IR # generation for CUDA/CPU. For CPU op, we don't need weights_uvm and @@ -1058,12 +1054,12 @@ def _forward_impl( max_int8_D=self.max_int8_D, max_float16_D=self.max_float16_D, max_float32_D=self.max_float32_D, - INT2_max_ls=INT2_max_ls, - INT4_max_ls=INT4_max_ls, - INT8_max_ls=INT8_max_ls, - FP8_max_ls = FP8_max_ls, - FP16_max_ls=FP16_max_ls, - FP32_max_ls=FP32_max_ls, + INT2_max_ls=self.INT2_max_ls, + INT4_max_ls=self.INT4_max_ls, + INT8_max_ls=self.INT8_max_ls, + FP8_max_ls=self.FP8_max_ls, + FP16_max_ls=self.FP16_max_ls, + FP32_max_ls=self.FP32_max_ls, indices=indices, offsets=offsets, pooling_mode=int(self.pooling_mode), @@ -1074,7 +1070,7 @@ def _forward_impl( row_alignment=self.row_alignment, max_float8_D=self.max_float8_D, fp8_exponent_bits=self.fp8_exponent_bits, - fp8_exponent_bias=self.fp8_exponent_bias + fp8_exponent_bias=self.fp8_exponent_bias, ) def forward( @@ -1546,7 +1542,6 @@ def move_to_device_with_cache( for i, weight in enumerate(weights): weights[i] = ( weight[0].to(device), - # pyre-fixme[16]: Undefined attribute: `Optional` has no attribute `to`. weight[1].to(device) if weight[1] is not None else None, ) ( @@ -1816,7 +1811,6 @@ def assign_embedding_weights( dest_weight[0].copy_(input_weight[0]) if input_weight[1] is not None: assert dest_weight[1] is not None - # pyre-fixme[16]: Undefined attribute: `Optional` has no attribute `copy_`. dest_weight[1].copy_(input_weight[1]) else: assert dest_weight[1] is None @@ -2064,4 +2058,4 @@ def embedding_inplace_update_internal( row_alignment=self.row_alignment, lxu_cache_weights=self.lxu_cache_weights, lxu_cache_locations=lxu_cache_locations, - ) + ) \ No newline at end of file diff --git a/fbgemm_gpu/test/tbe/inference/common.py b/fbgemm_gpu/test/tbe/inference/common.py index 8c04441e47..146a3a0b0a 100644 --- a/fbgemm_gpu/test/tbe/inference/common.py +++ b/fbgemm_gpu/test/tbe/inference/common.py @@ -240,6 +240,7 @@ def execute_nbit_forward_( # noqa C901 ) for (E, D, M, W_TY) in zip(Es, Ds, managed, weights_ty_list) ], + Ls=L, pooling_mode=pooling_mode, index_remapping=index_remappings_array if B != 0 else None, device="cpu" if use_cpu else torch.cuda.current_device(), diff --git a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py index 1a1cde1753..b7ca592ece 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py @@ -48,8 +48,8 @@ VERBOSITY: Verbosity = Verbosity.verbose - - +print(open_source) +print(*gpu_unavailable) # pyre-ignore additional_decorators: Dict[str, List[Callable]] = { "test_faketensor__test_nbit_forward_uvm_cache": [ @@ -115,9 +115,6 @@ "Operator outputs int4 tensors which do not support opcheck tests" ), ], - "test_faketensor__test_nbit_forward_fused_pooled_emb_quant_nan_weighted": [ - unittest.skip("Operator not implemented for fake tensors"), - ], } @@ -177,6 +174,7 @@ def execute_nbit_forward_fused_pooled_emb_quant_( ) for (E, D, M, W_TY) in zip(Es, Ds, managed, weights_ty_list) ], + Ls=L, output_dtype=output_dtype, device=torch.cuda.current_device(), ) @@ -211,6 +209,7 @@ def execute_nbit_forward_fused_pooled_emb_quant_( ) for (E, D, M, W_TY) in zip(Es, Ds, managed, weights_ty_list) ], + Ls=L, output_dtype=SparseType.FP32, device=torch.cuda.current_device(), ) @@ -357,92 +356,6 @@ def test_nbit_forward_fused_pooled_emb_quant_against_ref( **kwargs, ) - @unittest.skipIf(*gpu_unavailable) - def test_nbit_forward_fused_pooled_emb_quant_nan_weighted(self) -> None: - # Hash size - E = 10 - # Embedding dimensoin - D = 160 - # Pooling factor - L = 64 - - # Use TBE training op as a reference - op_ref = SplitTableBatchedEmbeddingBagsCodegen( - [ - (E, D, EmbeddingLocation.DEVICE, ComputeDevice.CUDA), - ], - weights_precision=SparseType.FP32, - output_dtype=SparseType.FP32, - device=torch.cuda.current_device(), - ) - - # Instantiate TBE inference - op = IntNBitTableBatchedEmbeddingBagsCodegen( - embedding_specs=[ - ( - "", - E, - D, - SparseType.INT4, - EmbeddingLocation.DEVICE, - ), - ], - output_dtype=SparseType.FP16, - ) - - # Initialize weights_ref with 1.0 - weights_ref = op_ref.split_embedding_weights() - weights_ref[0].fill_(1.0) - - # Copy weights_ref to weights - op.initialize_weights() - weights = op.split_embedding_weights() - quant_weights, quant_scale_shift = quantize_embs( - weights_ref[0], SparseType.INT4 - ) - weights[0][0].copy_(quant_weights) - weights[0][1].copy_(quant_scale_shift) - - # Generate inputs - indices = torch.as_tensor( - [0] * L, device=torch.cuda.current_device(), dtype=torch.int - ) - offsets = torch.as_tensor( - [0, L], device=torch.cuda.current_device(), dtype=torch.int - ) - per_sample_weights = torch.arange( - L, device=torch.cuda.current_device(), dtype=torch.float - ) - - # Set a bunch of indices to -1 to simulate pruning. - pruned_indices = indices.clone().detach() - prune_select = torch.arange(pruned_indices.numel()) % 8 == 0 - pruned_indices[prune_select] = -1 - - # Pre-prune per_sample_weights for reference - pruned_per_sample_weights = per_sample_weights.clone().detach() - pruned_per_sample_weights[prune_select] = 0.0 - - # Run reference - output_ref = op_ref( - indices=indices, - offsets=offsets, - per_sample_weights=pruned_per_sample_weights, - ) - - # Initialize shared memory to NaNs. - torch.ops.fbgemm.initialize_nan_shared_mem(torch.cuda.current_device()) - - # Run test - output = op( - indices=pruned_indices, - offsets=offsets, - per_sample_weights=per_sample_weights, - ) - - # Expect the outputs to be bit-wise equivalent - assert torch.equal(output_ref, output) - @unittest.skipIf(*gpu_unavailable) @given( T=st.integers(min_value=1, max_value=10), @@ -788,6 +701,7 @@ def test_nbit_forward_uvm_cache( ) for (E, D) in zip(Es, Ds) ], + Ls=L, index_remapping=index_remapping, use_array_for_index_remapping=use_array_for_index_remapping, pruning_hash_load_factor=pruning_hash_load_factor, @@ -795,6 +709,7 @@ def test_nbit_forward_uvm_cache( cc_ref.fill_random_weights() cc = IntNBitTableBatchedEmbeddingBagsCodegen( [("", E, D, weights_ty, M) for (E, D, M) in zip(Es, Ds, managed)], + Ls=L, cache_algorithm=cache_algorithm, cache_assoc=associativity, index_remapping=index_remapping, @@ -868,6 +783,7 @@ def test_nbit_forward_cpu_seq_int8( ) for H in T_H ], + Ls=L, pooling_mode=pooling_mode, device="cpu", output_dtype=nbit_weights_ty, @@ -1014,6 +930,7 @@ def test_nbit_forward_cpu_gpu_dequantize_parity( ) for H in T_H ], + Ls=L, pooling_mode=pooling_mode, device="cpu", output_dtype=nbit_weights_ty, @@ -1031,6 +948,7 @@ def test_nbit_forward_cpu_gpu_dequantize_parity( ) for H in T_H ], + Ls=L, pooling_mode=pooling_mode, device="cpu", output_dtype=output_dtype, @@ -1121,4 +1039,4 @@ def test_nbit_forward_cpu_gpu_dequantize_parity( if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file From f3654d6794d9394dd3708340afdb91d5d0cc14a1 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Tue, 24 Jun 2025 02:38:38 +0000 Subject: [PATCH 04/10] fixed packed_bags logic on max_L and added new condition on find_max_ls --- ...g_forward_quantized_split_nbit_host_template.cu | 14 ++++++++------ ...split_table_batched_embeddings_ops_inference.py | 3 +++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu index 130cc52061..827cf141a7 100644 --- a/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu +++ b/fbgemm_gpu/codegen/inference/embedding_forward_quantized_split_nbit_host_template.cu @@ -259,16 +259,18 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{ const int32_t num_uint4_loads_per_row = nbit::div_round_up(nbit::padded_row_size_in_bytes(max_D, sparse_type, row_alignment), sizeof(uint4)); \ constexpr int32_t NumUint4LoadsPerRow = MaxNum128BRows * 128 / sizeof(uint4); \ constexpr int32_t max_indices_per_warp = kWarpSize / NumUint4LoadsPerRow; \ - num_packed_bags_L = max_indices_per_warp > max_L && !std::is_same_v && sparse_type != SparseType::FP32? max_indices_per_warp / max_L : 1; \ + num_packed_bags_L = max_L > 0 && max_indices_per_warp > max_L && !std::is_same_v && sparse_type != SparseType::FP32? max_indices_per_warp / max_L : 1; \ num_packed_bags_D = NumUint4LoadsPerRow > num_uint4_loads_per_row && !std::is_same_v && sparse_type != SparseType::FP32 ? NumUint4LoadsPerRow / num_uint4_loads_per_row : 1; \ /* Number of bags that might be fitted to shared memory. */ \ - num_packed_bags = max_L>1 ? num_packed_bags_D : num_packed_bags_L * num_packed_bags_D; \ + num_packed_bags = max_L==1 ? num_packed_bags_L * num_packed_bags_D : num_packed_bags_D; \ } \ {%- endif %} - if (num_packed_bags > 1 && max_L>1) { \ - X(dev_only, true, false, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ - } else if (num_packed_bags > 1 && max_L<=1) { \ - X(dev_only, true, true, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ + if (num_packed_bags > 1) { \ + if (max_L==1){ \ + X(dev_only, true, true, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ + } else{ \ + X(dev_only, true, false, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ + } \ } else { \ X(dev_only, false, false, OutputRowsPerThread, InputRowsInFlight, MinNum128BRows, MaxNum128BRows) \ }; diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 3c2e5836e8..34ccf4df71 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -489,6 +489,9 @@ def find_max_ls(ty: SparseType, weights_tys: List[SparseType], Ls) -> int: if not any(t.value == ty.value for t in weights_tys) else int(max(Ls)) ) + elif Ls == None: + return 0 + else: return ( 0 if not any(t.value == ty.value for t in weights_tys) else int(Ls) From 43114ad54edbe4cb177d96c403eef15719091560 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Tue, 24 Jun 2025 02:45:50 +0000 Subject: [PATCH 05/10] adapted linter on None condition --- .../fbgemm_gpu/split_table_batched_embeddings_ops_inference.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 34ccf4df71..1a4a995468 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -489,9 +489,8 @@ def find_max_ls(ty: SparseType, weights_tys: List[SparseType], Ls) -> int: if not any(t.value == ty.value for t in weights_tys) else int(max(Ls)) ) - elif Ls == None: + elif Ls is None: return 0 - else: return ( 0 if not any(t.value == ty.value for t in weights_tys) else int(Ls) From 30095f79a86abb35af561f392e7da364a49c049b Mon Sep 17 00:00:00 2001 From: kudomcho Date: Tue, 24 Jun 2025 18:34:04 +0000 Subject: [PATCH 06/10] added args of max_ls on codegen_lookup_func --- .../inference/nbit_split_embeddings_test.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py index 381b2ba1e1..ca83ac3e31 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py @@ -27,12 +27,12 @@ from fbgemm_gpu.tbe.utils import generate_requests, round_up from hypothesis import assume, given, HealthCheck, settings, Verbosity -from .. import common # noqa E402 -from ..common import MAX_EXAMPLES, MAX_EXAMPLES_LONG_RUNNING, open_source +import common_prev # noqa E402 +from common_prev import MAX_EXAMPLES, MAX_EXAMPLES_LONG_RUNNING, open_source if open_source: # pyre-ignore[21] - from test_utils import gpu_unavailable, optests, TEST_WITH_ROCM + from test_utils_ import gpu_unavailable, optests, TEST_WITH_ROCM else: from fbgemm_gpu.test.test_utils import gpu_unavailable, optests, TEST_WITH_ROCM @@ -196,6 +196,7 @@ def test_nbit_split_embedding_weights_with_scale_and_bias( indices_dtype=st.sampled_from([torch.int, torch.int64]), ) @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) + def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function( self, weights_ty: SparseType, @@ -288,6 +289,12 @@ def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function( max_int8_D=cc_ref.max_int8_D, max_float16_D=cc_ref.max_float16_D, max_float32_D=cc_ref.max_float32_D, + INT2_max_ls=cc_ref.INT2_max_ls, + INT4_max_ls=cc_ref.INT4_max_ls, + INT8_max_ls=cc_ref.INT8_max_ls, + FP8_max_ls=cc_ref.FP8_max_ls, + FP16_max_ls=cc_ref.FP16_max_ls, + FP32_max_ls=cc_ref.FP32_max_ls, indices=indices, offsets=offsets, pooling_mode=int(cc_ref.pooling_mode), @@ -362,6 +369,12 @@ def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function( max_int8_D=cc_ref.max_int8_D, max_float16_D=cc_ref.max_float16_D, max_float32_D=cc_ref.max_float32_D, + INT2_max_ls=cc_ref.INT2_max_ls, + INT4_max_ls=cc_ref.INT4_max_ls, + INT8_max_ls=cc_ref.INT8_max_ls, + FP8_max_ls=cc_ref.FP8_max_ls, + FP16_max_ls=cc_ref.FP16_max_ls, + FP32_max_ls=cc_ref.FP32_max_ls, indices=indices, offsets=offsets, pooling_mode=int(cc_ref.pooling_mode), From b93ebacad1e72a812eaaabeeacea15c488204b9d Mon Sep 17 00:00:00 2001 From: kudomcho Date: Tue, 24 Jun 2025 18:43:39 +0000 Subject: [PATCH 07/10] fixed errors on flake8 --- fbgemm_gpu/bench/tbe/tbe_inference_benchmark.py | 2 +- .../fbgemm_gpu/split_table_batched_embeddings_ops_inference.py | 2 +- fbgemm_gpu/test/tbe/inference/nbit_forward_test.py | 2 +- fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py | 1 - 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/bench/tbe/tbe_inference_benchmark.py b/fbgemm_gpu/bench/tbe/tbe_inference_benchmark.py index 38e99d85eb..a02e43cae0 100644 --- a/fbgemm_gpu/bench/tbe/tbe_inference_benchmark.py +++ b/fbgemm_gpu/bench/tbe/tbe_inference_benchmark.py @@ -1839,4 +1839,4 @@ def nbit_cache( # noqa C901 if __name__ == "__main__": - cli() \ No newline at end of file + cli() diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py index 1a4a995468..7af4f41da8 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_inference.py @@ -2060,4 +2060,4 @@ def embedding_inplace_update_internal( row_alignment=self.row_alignment, lxu_cache_weights=self.lxu_cache_weights, lxu_cache_locations=lxu_cache_locations, - ) \ No newline at end of file + ) diff --git a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py index b7ca592ece..060fb798c0 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_forward_test.py @@ -1039,4 +1039,4 @@ def test_nbit_forward_cpu_gpu_dequantize_parity( if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py index ca83ac3e31..4e1bdf867f 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py @@ -196,7 +196,6 @@ def test_nbit_split_embedding_weights_with_scale_and_bias( indices_dtype=st.sampled_from([torch.int, torch.int64]), ) @settings(verbosity=VERBOSITY, max_examples=MAX_EXAMPLES, deadline=None) - def test_int_nbit_split_embedding_uvm_caching_codegen_lookup_function( self, weights_ty: SparseType, From c8896466047150f167461a7a7d2e4fe45a7b4f22 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Tue, 24 Jun 2025 19:07:13 +0000 Subject: [PATCH 08/10] fixed the relative path on common on test scripts --- fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py index 4e1bdf867f..238ccefbab 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py @@ -27,12 +27,12 @@ from fbgemm_gpu.tbe.utils import generate_requests, round_up from hypothesis import assume, given, HealthCheck, settings, Verbosity -import common_prev # noqa E402 -from common_prev import MAX_EXAMPLES, MAX_EXAMPLES_LONG_RUNNING, open_source +from .. import common # noqa E402 +from ..common import open_source import MAX_EXAMPLES, MAX_EXAMPLES_LONG_RUNNING, open_source if open_source: # pyre-ignore[21] - from test_utils_ import gpu_unavailable, optests, TEST_WITH_ROCM + from test_utils import gpu_unavailable, optests, TEST_WITH_ROCM else: from fbgemm_gpu.test.test_utils import gpu_unavailable, optests, TEST_WITH_ROCM From 24123669efff40415cfa306a6a4bf69474412c23 Mon Sep 17 00:00:00 2001 From: kudomcho Date: Tue, 24 Jun 2025 19:12:55 +0000 Subject: [PATCH 09/10] formatted ufmt and fixed flake8 on nbit_split_embeddings_test.py --- fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py index 238ccefbab..9312e106a5 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py @@ -14,6 +14,7 @@ from typing import Callable, Dict, List import hypothesis.strategies as st +import MAX_EXAMPLES, MAX_EXAMPLES_LONG_RUNNING, open_source import numpy as np import torch from fbgemm_gpu.split_embedding_configs import SparseType @@ -28,7 +29,6 @@ from hypothesis import assume, given, HealthCheck, settings, Verbosity from .. import common # noqa E402 -from ..common import open_source import MAX_EXAMPLES, MAX_EXAMPLES_LONG_RUNNING, open_source if open_source: # pyre-ignore[21] From 65c77dafdae137de73872c0d71bad025377292ed Mon Sep 17 00:00:00 2001 From: kudomcho Date: Tue, 24 Jun 2025 19:17:36 +0000 Subject: [PATCH 10/10] formatted ufmt and fixed flake8 on nbit_split_embeddings_test.py --- fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py index 9312e106a5..f9aaac6243 100644 --- a/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py +++ b/fbgemm_gpu/test/tbe/inference/nbit_split_embeddings_test.py @@ -14,7 +14,6 @@ from typing import Callable, Dict, List import hypothesis.strategies as st -import MAX_EXAMPLES, MAX_EXAMPLES_LONG_RUNNING, open_source import numpy as np import torch from fbgemm_gpu.split_embedding_configs import SparseType @@ -29,6 +28,7 @@ from hypothesis import assume, given, HealthCheck, settings, Verbosity from .. import common # noqa E402 +from ..common import MAX_EXAMPLES, MAX_EXAMPLES_LONG_RUNNING, open_source if open_source: # pyre-ignore[21]