Skip to content

optimized inference for L=1 case on AMD GPU #4382

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions fbgemm_gpu/bench/tbe/tbe_inference_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def nbit_cpu( # noqa C901
emb = IntNBitTableBatchedEmbeddingBagsCodegen(
[("", E, d, weights_precision, EmbeddingLocation.HOST) for d in Ds],
device="cpu",
Ls=L,
index_remapping=[torch.arange(E) for _ in Ds] if index_remapping else None,
output_dtype=output_dtype,
pooling_mode=pooling_mode,
Expand Down Expand Up @@ -403,6 +404,7 @@ def nbit_device( # noqa C901
emb = IntNBitTableBatchedEmbeddingBagsCodegen(
[("", E, d, weights_precision, managed_option) for d in Ds],
bounds_check_mode=BoundsCheckMode(bounds_check_mode),
Ls=L,
index_remapping=index_remapping,
pruning_hash_load_factor=pruning_hash_load_factor,
use_array_for_index_remapping=use_array_for_index_remapping,
Expand Down Expand Up @@ -791,6 +793,7 @@ def nbit_device_with_spec( # noqa C901
emb = IntNBitTableBatchedEmbeddingBagsCodegen(
[("", e, d, weights_precision, managed_option) for d, e in zip(Ds, Es)],
device="cpu" if use_cpu else None,
Ls=Ls,
bounds_check_mode=BoundsCheckMode(bounds_check_mode),
index_remapping=index_remapping,
pruning_hash_load_factor=pruning_hash_load_factor,
Expand Down Expand Up @@ -856,7 +859,7 @@ def nbit_device_with_spec( # noqa C901
# don't use zipf if e isn't large enough compared to bag_size.
alpha=alpha if (e / bag_size) > 2.0 else 1.0,
# need many more samples for zipf if bag_size is very small.
zipf_oversample_ratio=3 if bag_size > 5 else 10,
zipf_oversample_ratio=10 if bag_size > 5 else 20,
weighted=weighted,
use_cpu=use_cpu,
)
Expand Down Expand Up @@ -919,14 +922,12 @@ def nbit_device_with_spec( # noqa C901

# free up memory
del requests
result_msg = (
f"Iteration {i}: "
f"{weights_precision} Forward, B: {B}, "
f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
f"BW: {cpu_copies * float(read_write_bytes) / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950
f"Time: {time_per_iter * 1.0e6:.0f}us, "
f"Memory Usage For Pruning: {mem_for_pruning / 1.0e9:.0f} GB"
)
result_msg = f"Iteration {i}: "
f"{weights_precision} Forward, B: {B}, "
f"E: {E}, T: {T}, D: {D}, L: {L}, W: {weighted}, "
f"BW: {cpu_copies * float(read_write_bytes) / time_per_iter / 1.0e9: .2f} GB/s, " # noqa: B950
f"Time: {time_per_iter * 1.0e6:.0f}us, "
f"Memory Usage For Pruning: {mem_for_pruning / 1.0e9:.0f} GB"

if use_cpu and cpu_copies > 1:
result_msg += f", Parallel Copies: {cpu_copies}"
Expand Down Expand Up @@ -1111,6 +1112,7 @@ def nbit_uvm(
)
for d in Ds[:T_uvm]
],
Ls=L,
output_dtype=output_dtype,
cache_load_factor=cache_load_factor,
cache_algorithm=cache_alg,
Expand All @@ -1133,6 +1135,7 @@ def nbit_uvm(
)
for d in Ds[T_uvm:]
],
Ls=L,
output_dtype=output_dtype,
fp8_exponent_bits=fp8_exponent_bits,
fp8_exponent_bias=fp8_exponent_bias,
Expand All @@ -1155,6 +1158,7 @@ def nbit_uvm(
[managed_type] * T_uvm + [EmbeddingLocation.DEVICE] * T_gpu,
)
],
Ls=L,
output_dtype=output_dtype,
cache_load_factor=cache_load_factor,
cache_algorithm=cache_alg,
Expand Down Expand Up @@ -1456,6 +1460,7 @@ def bench_uvm_cls(
)
for d in Ds[:T]
],
Ls=L,
output_dtype=output_dtype,
cache_load_factor=cache_load_factor,
cache_algorithm=cache_alg,
Expand Down Expand Up @@ -1637,6 +1642,7 @@ def nbit_cache( # noqa C901
)
for d in Ds
],
Ls=L,
output_dtype=output_dtype,
enforce_hbm=enforce_hbm,
fp8_exponent_bits=fp8_exponent_bits,
Expand All @@ -1661,6 +1667,7 @@ def nbit_cache( # noqa C901
record_cache_metrics=RecordCacheMetrics(
record_cache_miss_counter, record_tablewise_cache_miss
),
Ls=L,
gather_uvm_cache_stats=gather_uvm_cache_stats,
cache_load_factor=cache_load_factor,
cache_algorithm=cache_alg,
Expand Down
83 changes: 70 additions & 13 deletions fbgemm_gpu/codegen/inference/embedding_forward_quantized_host.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <ATen/ATen.h>
#include <ATen/TypeDefault.h>
Expand Down Expand Up @@ -200,6 +200,12 @@ Tensor int_nbit_split_embedding_codegen_forward_unweighted_cuda(
int64_t max_int8_D,
int64_t max_float16_D,
int64_t max_float32_D,
int64_t INT2_max_ls,
int64_t INT4_max_ls,
int64_t INT8_max_ls,
int64_t FP8_max_ls,
int64_t FP16_max_ls,
int64_t FP32_max_ls,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Expand All @@ -224,6 +230,12 @@ Tensor int_nbit_split_embedding_codegen_forward_weighted_cuda(
int64_t max_int8_D,
int64_t max_float16_D,
int64_t max_float32_D,
int64_t INT2_max_ls,
int64_t INT4_max_ls,
int64_t INT8_max_ls,
int64_t FP8_max_ls,
int64_t FP16_max_ls,
int64_t FP32_max_ls,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Expand All @@ -248,6 +260,12 @@ Tensor int_nbit_split_embedding_nobag_codegen_forward_unweighted_cuda(
int64_t max_int8_D,
int64_t max_float16_D,
int64_t max_float32_D,
int64_t INT2_max_ls,
int64_t INT4_max_ls,
int64_t INT8_max_ls,
int64_t FP8_max_ls,
int64_t FP16_max_ls,
int64_t FP32_max_ls,
Tensor indices,
Tensor offsets,
int64_t row_alignment,
Expand All @@ -272,6 +290,12 @@ Tensor int_nbit_split_embedding_codegen_lookup_function(
int64_t max_int8_D,
int64_t max_float16_D,
int64_t max_float32_D,
int64_t INT2_max_ls,
int64_t INT4_max_ls,
int64_t INT8_max_ls,
int64_t FP8_max_ls,
int64_t FP16_max_ls,
int64_t FP32_max_ls,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Expand Down Expand Up @@ -308,6 +332,12 @@ Tensor int_nbit_split_embedding_codegen_lookup_function(
max_int8_D,
max_float16_D,
max_float32_D,
INT2_max_ls,
INT4_max_ls,
INT8_max_ls,
FP8_max_ls,
FP16_max_ls,
FP32_max_ls,
indices.to(at::kInt),
offsets.to(at::kInt),
row_alignment ? *row_alignment : 16,
Expand All @@ -316,7 +346,8 @@ Tensor int_nbit_split_embedding_codegen_lookup_function(
lxu_cache_locations.value_or(at::empty({0}, at::kInt)),
max_float8_D ? *max_float8_D : 0,
fp8_exponent_bits ? *fp8_exponent_bits : -1,
fp8_exponent_bias ? *fp8_exponent_bias : -1);
fp8_exponent_bias ? *fp8_exponent_bias : -1
);
}
if (!indice_weights || indice_weights->numel() == 0) {
return int_nbit_split_embedding_codegen_forward_unweighted_cuda(
Expand All @@ -332,6 +363,12 @@ Tensor int_nbit_split_embedding_codegen_lookup_function(
max_int8_D,
max_float16_D,
max_float32_D,
INT2_max_ls,
INT4_max_ls,
INT8_max_ls,
FP8_max_ls,
FP16_max_ls,
FP32_max_ls,
indices,
offsets,
pooling_mode,
Expand All @@ -341,7 +378,8 @@ Tensor int_nbit_split_embedding_codegen_lookup_function(
lxu_cache_locations.value_or(at::empty({0}, at::kInt)),
max_float8_D ? *max_float8_D : 0,
fp8_exponent_bits ? *fp8_exponent_bits : -1,
fp8_exponent_bias ? *fp8_exponent_bias : -1);
fp8_exponent_bias ? *fp8_exponent_bias : -1
);
}
// Force casting indice_weights to float (doing this in the backend to avoid
// JIT issue)
Expand All @@ -359,6 +397,12 @@ Tensor int_nbit_split_embedding_codegen_lookup_function(
max_int8_D,
max_float16_D,
max_float32_D,
INT2_max_ls,
INT4_max_ls,
INT8_max_ls,
FP8_max_ls,
FP16_max_ls,
FP32_max_ls,
indices,
offsets,
pooling_mode,
Expand All @@ -369,15 +413,15 @@ Tensor int_nbit_split_embedding_codegen_lookup_function(
lxu_cache_locations.value_or(at::empty({0}, at::kInt)),
max_float8_D ? *max_float8_D : 0,
fp8_exponent_bits ? *fp8_exponent_bits : -1,
fp8_exponent_bias ? *fp8_exponent_bias : -1);
fp8_exponent_bias ? *fp8_exponent_bias : -1
);
}

///@ingroup embedding-cuda
/// Simlar to int_nbit_split_embedding_codegen_lookup_function, but it does
/// UVM_CACHING lookup.
Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
// First args should be the same to those of
// int_nbit_split_embedding_codegen_lookup_function.

Tensor dev_weights,
Tensor uvm_weights,
Tensor weights_placements,
Expand All @@ -390,6 +434,12 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
int64_t max_int8_D,
int64_t max_float16_D,
int64_t max_float32_D,
int64_t INT2_max_ls,
int64_t INT4_max_ls,
int64_t INT8_max_ls,
int64_t FP8_max_ls,
int64_t FP16_max_ls,
int64_t FP32_max_ls,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
Expand Down Expand Up @@ -547,6 +597,12 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
max_int8_D,
max_float16_D,
max_float32_D,
INT2_max_ls,
INT4_max_ls,
INT8_max_ls,
FP8_max_ls,
FP16_max_ls,
FP32_max_ls,
indices,
offsets,
pooling_mode,
Expand All @@ -557,7 +613,8 @@ Tensor int_nbit_split_embedding_uvm_caching_codegen_lookup_function(
row_alignment,
max_float8_D,
fp8_exponent_bits,
fp8_exponent_bias);
fp8_exponent_bias
);
}

///@ingroup embedding-cuda
Expand All @@ -583,4 +640,4 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
int_nbit_split_embedding_uvm_caching_codegen_lookup_function);
DISPATCH_TO_CUDA("pruned_hashmap_lookup", pruned_hashmap_lookup_cuda);
DISPATCH_TO_CUDA("pruned_array_lookup", pruned_array_lookup_cuda);
}
}
Loading
Loading