From 794d11fa0492097922370bf027e2f2059c48e50d Mon Sep 17 00:00:00 2001 From: Chenyu Zhang Date: Thu, 19 Jun 2025 20:13:31 -0700 Subject: [PATCH 1/3] kvzch inference python operator (#4344) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/1412 Python operator change for kv embedding. https://docs.google.com/document/d/1TNJMnj-PPKWitMgwB8HJIsFT3OwotiEnqgh60fI9P48/edit?tab=t.0#heading=h.o9irumwgl8gj Reviewed By: emlin Differential Revision: D73219651 --- fbgemm_gpu/bench/tbe/tbe_kv_benchmark.py | 181 +++++++++ fbgemm_gpu/fbgemm_gpu/tbe/cache/__init__.py | 1 + .../tbe/cache/kv_embedding_ops_inference.py | 358 ++++++++++++++++++ .../test/tbe/inference/kv_embedding_test.py | 120 ++++++ 4 files changed, 660 insertions(+) create mode 100644 fbgemm_gpu/bench/tbe/tbe_kv_benchmark.py create mode 100644 fbgemm_gpu/fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py create mode 100644 fbgemm_gpu/test/tbe/inference/kv_embedding_test.py diff --git a/fbgemm_gpu/bench/tbe/tbe_kv_benchmark.py b/fbgemm_gpu/bench/tbe/tbe_kv_benchmark.py new file mode 100644 index 0000000000..7f9b5ee3e1 --- /dev/null +++ b/fbgemm_gpu/bench/tbe/tbe_kv_benchmark.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import gc +import logging +import time +from typing import Callable, Dict, Type + +import click +import numpy as np +import psutil +from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.split_table_batched_embeddings_ops_common import EmbeddingLocation +from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( + IntNBitTableBatchedEmbeddingBagsCodegen, +) +from fbgemm_gpu.tbe.bench import benchmark_requests +from fbgemm_gpu.tbe.cache import KVEmbeddingInference +from fbgemm_gpu.tbe.utils import generate_requests, round_up, TBERequest + +OptionCommandType = Callable[..., Callable[..., None]] + +iters: OptionCommandType = click.option( + "--iters", + default=200, + type=int, + help="Number of iterations to benchmark", +) +num_embeddings: OptionCommandType = click.option( + "--num-embeddings", + default=int(1e8), + type=int, + help="Number of embedding to benchmark", +) +dim: OptionCommandType = click.option( + "--dim", default=256, type=int, help="Dimension of embedding to benchmark" +) +num_tables: OptionCommandType = click.option( + "--num-tables", default=4, type=int, help="Number of tables to benchmark" +) +output_dtype: OptionCommandType = click.option( + "--output-dtype", type=SparseType, default=SparseType.FP16 +) +weights_precision: OptionCommandType = click.option( + "--weights-precision", type=SparseType, default=SparseType.INT8 +) +batch_size: OptionCommandType = click.option("--batch-size", default=128) +bag_size: OptionCommandType = click.option("--bag-size", default=1) +mixed_dim: OptionCommandType = click.option("--mixed-dim", is_flag=True, default=False) +tbe_class: OptionCommandType = click.option( + "--tbe-class", type=str, default="KVEmbeddingInference" +) + + +TBE_CLASS_MAP: Dict[str, Type[IntNBitTableBatchedEmbeddingBagsCodegen]] = { + "KVEmbeddingInference": KVEmbeddingInference, + "IntNBitTableBatchedEmbeddingBagsCodegen": IntNBitTableBatchedEmbeddingBagsCodegen, +} + + +@click.group() +def cli() -> None: + pass + + +@cli.command() +@iters +@num_embeddings +@dim +@num_tables +@output_dtype +@weights_precision +@batch_size +@bag_size +@mixed_dim +@tbe_class +def forward_benchmark( + iters: int, + num_embeddings: int, + dim: int, + num_tables: int, + output_dtype: SparseType, + weights_precision: SparseType, + batch_size: int, + bag_size: int, + mixed_dim: bool, + tbe_class: str, +) -> None: + logging.info( + f"Running forward benchmark with {iters} iterations, {num_embeddings} embeddings, {dim} dim, {num_tables} tables, {output_dtype} output dtype, {weights_precision} weights precision, {batch_size} batch" + ) + + stats = [] + + if mixed_dim: + dimentions = [ + round_up(np.random.randint(low=int(0.5 * dim), high=int(1.5 * dim)), 4) + for _ in range(num_tables) + ] + else: + dimentions = [dim] * num_tables + + process = psutil.Process() + + clazz = TBE_CLASS_MAP[tbe_class] + + time.sleep(5) + mem_util_before = process.memory_info().rss / (1024 * 1024) + logging.info(f"Memory util before emb init: {mem_util_before} MB") + tbe = clazz( + [ + ( + "", + num_embeddings, + d, + weights_precision, + EmbeddingLocation.HOST, + ) + for d in dimentions + ], + output_dtype=output_dtype, + device="cpu", + ) + tbe.fill_random_weights() + + gc.collect() + time.sleep(5) + mem_util_after = process.memory_info().rss / (1024 * 1024) + logging.info(f"Memory util after emb fill: {mem_util_after} MB") + logging.info(f"Memory util diff: {mem_util_after - mem_util_before} MB") + + for batch_size in [10240, 20480, 40960]: + requests = generate_requests( + iters, + batch_size, + num_tables, + bag_size, + num_embeddings, + use_cpu=True, + ) + + requests_cpu = [ + TBERequest( + req.indices.int().cpu(), + req.offsets.int().cpu(), + req.per_sample_weights, + ) + for req in requests + ] + + logging.info(f"Running forward benchmark with {len(requests_cpu)} requests") + time_per_iter = benchmark_requests( + requests_cpu, + lambda indices, offsets, per_sample_weights: tbe.forward( + indices.int().cpu(), + offsets.int().cpu(), + per_sample_weights, + ), + num_warmups=10, + ) + logging.info(f"{clazz} CPU Time: {time_per_iter * 1.0e6:.0f}us") + stats.append( + [ + clazz, + num_tables, + batch_size, + f"{time_per_iter * 1.0e6:.0f}us", + f"{mem_util_after - mem_util_before} MB", + ] + ) + for stat in stats: + logging.info(stat) + + +if __name__ == "__main__": + cli() diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/cache/__init__.py b/fbgemm_gpu/fbgemm_gpu/tbe/cache/__init__.py index 4ef4e1bd9d..e262a2b42b 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/cache/__init__.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/cache/__init__.py @@ -7,4 +7,5 @@ # pyre-unsafe +from .kv_embedding_ops_inference import KVEmbeddingInference # noqa: F401 from .split_embeddings_cache_ops import get_unique_indices_v2 # noqa: F401 diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py b/fbgemm_gpu/fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py new file mode 100644 index 0000000000..cbd5fa1d5b --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/tbe/cache/kv_embedding_ops_inference.py @@ -0,0 +1,358 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +# pyre-ignore-all-errors[56] + + +from typing import List, Optional, Tuple, Union + +import torch # usort:skip +from torch import Tensor # usort:skip +from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( + BoundsCheckMode, + CacheAlgorithm, + DEFAULT_SCALE_BIAS_SIZE_IN_BYTES, + EmbeddingLocation, + PoolingMode, + RecordCacheMetrics, +) +from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( + inputs_to_device, + IntNBitTableBatchedEmbeddingBagsCodegen, + random_quant_scaled_tensor, + rounded_row_size_in_bytes, +) +from fbgemm_gpu.utils.loader import load_torch_module + +try: + load_torch_module( + "//deeplearning/fbgemm/fbgemm_gpu:dram_kv_embedding_inference", + ) +except Exception: + pass + + +class KVEmbeddingInference(IntNBitTableBatchedEmbeddingBagsCodegen): + """ + KV Table-batched version of nn.EmbeddingBag(sparse=False) + Inference version, with support for FP32/FP16/FP8/INT8/INT4/INT2 weights + """ + + def __init__( # noqa C901 + self, + embedding_specs: List[ + Tuple[str, int, int, SparseType, EmbeddingLocation] + ], # tuple of (feature_names, rows, dims, SparseType, EmbeddingLocation/placement) + feature_table_map: Optional[List[int]] = None, # [T] + index_remapping: Optional[List[Tensor]] = None, + pooling_mode: PoolingMode = PoolingMode.SUM, + device: Optional[Union[str, int, torch.device]] = None, + bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, + weight_lists: Optional[List[Tuple[Tensor, Optional[Tensor]]]] = None, + pruning_hash_load_factor: float = 0.5, + use_array_for_index_remapping: bool = True, + output_dtype: SparseType = SparseType.FP16, + cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU, + cache_load_factor: float = 0.2, + cache_sets: int = 0, + cache_reserved_memory: float = 0.0, + enforce_hbm: bool = False, # place all weights/momentums in HBM when using cache + record_cache_metrics: Optional[RecordCacheMetrics] = None, + gather_uvm_cache_stats: Optional[bool] = False, + row_alignment: Optional[int] = None, + fp8_exponent_bits: Optional[int] = None, + fp8_exponent_bias: Optional[int] = None, + cache_assoc: int = 32, + scale_bias_size_in_bytes: int = DEFAULT_SCALE_BIAS_SIZE_IN_BYTES, + cacheline_alignment: bool = True, + uvm_host_mapped: bool = False, # True to use cudaHostAlloc; False to use cudaMallocManaged. + reverse_qparam: bool = False, # True to load qparams at end of each row; False to load qparam at begnning of each row. + feature_names_per_table: Optional[List[List[str]]] = None, + indices_dtype: torch.dtype = torch.int32, # Used for construction of the remap_indices tensors. Should match the dtype of the indices passed in the forward() call (INT32 or INT64). + ) -> None: # noqa C901 # tuple of (rows, dims,) + super(KVEmbeddingInference, self).__init__( + embedding_specs=embedding_specs, + feature_table_map=feature_table_map, + index_remapping=index_remapping, + pooling_mode=pooling_mode, + device=device, + bounds_check_mode=bounds_check_mode, + weight_lists=weight_lists, + pruning_hash_load_factor=pruning_hash_load_factor, + use_array_for_index_remapping=use_array_for_index_remapping, + output_dtype=output_dtype, + cache_algorithm=cache_algorithm, + cache_load_factor=cache_load_factor, + cache_sets=cache_sets, + cache_reserved_memory=cache_reserved_memory, + enforce_hbm=enforce_hbm, + record_cache_metrics=record_cache_metrics, + gather_uvm_cache_stats=gather_uvm_cache_stats, + row_alignment=row_alignment, + fp8_exponent_bits=fp8_exponent_bits, + fp8_exponent_bias=fp8_exponent_bias, + cache_assoc=cache_assoc, + scale_bias_size_in_bytes=scale_bias_size_in_bytes, + cacheline_alignment=cacheline_alignment, + uvm_host_mapped=uvm_host_mapped, + reverse_qparam=reverse_qparam, + feature_names_per_table=feature_names_per_table, + indices_dtype=indices_dtype, + ) + self.register_buffer( + "weights_ids", + torch.tensor(0, device=self.current_device, dtype=torch.int64), + ) + + num_shards = 32 + uniform_init_lower: float = -0.01 + uniform_init_upper: float = 0.01 + evict_trigger_mode: int = 0 + # pyre-fixme[4]: Attribute must be annotated. + self.kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( + num_shards, uniform_init_lower, uniform_init_upper, evict_trigger_mode + ) + + self.specs: List[Tuple[int, int, int]] = [ + (rows, dims, sparse_type.as_int()) + for (_, rows, dims, sparse_type, _) in self.embedding_specs + ] + # table shard offset if inference sharding is enabled, otherwise, should be all zeros + self.table_sharding_offset: List[int] = [0] * len(self.embedding_specs) + self.kv_embedding_cache_initialized = False + self.hash_size_cumsum: torch.Tensor = torch.zeros( + 0, + device=self.current_device, + dtype=torch.int64, + ) + self.feature_hash_size_cumsum: torch.Tensor = torch.zeros( + 0, + device=self.current_device, + dtype=torch.int64, + ) + + def construct_hash_size_cumsum(self) -> List[int]: + hash_size_cumsum = [0] + for spec in self.embedding_specs: + rows = spec[1] + hash_size_cumsum.append(hash_size_cumsum[-1] + rows) + return hash_size_cumsum + + def calculate_indices_and_weights_offsets( + self, indices: Tensor, offsets: Tensor + ) -> Tuple[Tensor, Tensor]: + if self.pooling_mode is not PoolingMode.NONE: + T = self.weights_offsets.numel() + else: + T = self.D_offsets.numel() - 1 + B = int((offsets.size(0) - 1) / T) + + total_bytes_added = 0 + new_indices = torch.tensor( + [0] * indices.size(0), device=self.current_device, dtype=indices.dtype + ) + new_weights_offsets = torch.tensor( + [0] * T, device=self.current_device, dtype=self.weights_offsets.dtype + ) + for t in range(T): + new_weights_offsets[t] = total_bytes_added + start, end = int(offsets[t * B]), int(offsets[(t + 1) * B]) + index_size = end - start + new_indices[start:end] = torch.arange(index_size) + table_id = self.feature_table_map[t] + total_bytes_added += index_size * rounded_row_size_in_bytes( + self.embedding_specs[table_id][2], # dim + self.embedding_specs[table_id][3], # weight_ty + self.row_alignment, + self.scale_bias_size_in_bytes, + ) + return new_indices, new_weights_offsets + + def linearize_cache_indices( + self, + indices: torch.Tensor, + offsets: torch.Tensor, + ) -> torch.Tensor: + """ + Linearize cache indices for KV cache. + """ + linearized_indices = torch.zeros( + indices.numel(), + device=indices.device, + dtype=torch.int64, + ) + + T = self.feature_hash_size_cumsum.numel() - 1 + B = int((offsets.size(0) - 1) / T) + + for t in range(T): + start, end = int(offsets[t * B]), int(offsets[(t + 1) * B]) + linearized_indices[start:end] = ( + indices[start:end] + self.feature_hash_size_cumsum[t] + ) + + return linearized_indices + + def forward( + self, + indices: Tensor, + offsets: Tensor, + per_sample_weights: Optional[Tensor] = None, + ) -> Tensor: + assert ( + self.weight_initialized + ), "weight needs to be initialized before forward function" + + indices, offsets, per_sample_weights = inputs_to_device( + indices, offsets, per_sample_weights, self.bounds_check_warning + ) + + lxu_cache_locations = self.lxu_cache_locations_list.pop() + + weights_offsets = self.weights_offsets + weights = self.weights_host if self.host_size > 0 else self.weights_dev + + if self.kv_embedding_cache_initialized: + indices = self.linearize_cache_indices( + indices, + offsets, + ) + + weights = self.kv_embedding_cache.get_embeddings(indices) + + indices, weights_offsets = self.calculate_indices_and_weights_offsets( + indices, offsets + ) + + return torch.ops.fbgemm.int_nbit_split_embedding_codegen_lookup_function( + dev_weights=weights, + uvm_weights=self.weights_uvm, + weights_placements=self.weights_placements, + weights_offsets=weights_offsets, + weights_tys=self.weights_tys, + D_offsets=self.D_offsets, + total_D=self.total_D, + max_int2_D=self.max_int2_D, + max_int4_D=self.max_int4_D, + max_int8_D=self.max_int8_D, + max_float16_D=self.max_float16_D, + max_float32_D=self.max_float32_D, + indices=indices, + offsets=offsets, + pooling_mode=int(self.pooling_mode), + indice_weights=per_sample_weights, + output_dtype=self.output_dtype, + lxu_cache_weights=self.lxu_cache_weights, + lxu_cache_locations=lxu_cache_locations, + row_alignment=self.row_alignment, + max_float8_D=self.max_float8_D, + fp8_exponent_bits=self.fp8_exponent_bits, + fp8_exponent_bias=self.fp8_exponent_bias, + ) + + def fill_random_weights(self) -> None: + """ + Fill the buffer with random weights, table by table + """ + self.initialize_kv_embedding_cache() + for i, (_, num_embeddings, embedding_dim, weight_ty, _) in enumerate( + self.embedding_specs + ): + embedding_dim = rounded_row_size_in_bytes( + embedding_dim, weight_ty, self.row_alignment + ) + indices = torch.range(0, num_embeddings - 1, dtype=torch.int64) + weights = random_quant_scaled_tensor( + shape=torch.Size([num_embeddings, embedding_dim]), + device=self.current_device, + ) + self.embedding_inplace_update_per_table( + i, + indices, + weights, + ) + self.weight_initialized = True + + @torch.jit.export + def init_tbe_config(self, table_sharding_offset: List[int]) -> None: + """ + Initialize the dynamic TBE table configs, e.g. sharded table offsets, etc. + Should be called before loading weights. + """ + self.table_sharding_offset = table_sharding_offset + + @torch.jit.export + def embedding_inplace_update( + self, + update_table_indices: List[int], + update_row_indices: List[List[int]], + update_weights: List[Tensor], + ) -> None: + for i in range(len(update_table_indices)): + self.embedding_inplace_update_per_table( + update_table_indices[i], + torch.tensor( + update_row_indices[i], device=self.current_device, dtype=torch.int64 + ), + update_weights[i], + ) + + @torch.jit.export + def embedding_inplace_update_per_table( + self, + table_id: int, + update_row_indices: Tensor, + update_weights: Tensor, + ) -> None: + assert table_id < len( + self.embedding_specs + ), f"table index {table_id} is out of range {len(self.embedding_specs)}" + # pyre-ignore [29] + table_offset = self.hash_size_cumsum[table_id] + sharding_offset = self.table_sharding_offset[table_id] + + row_size = update_row_indices.numel() + if row_size == 0: + return + + # convert global weight index to fused local weight index + row_indices = update_row_indices + table_offset - sharding_offset + # set weight by id + self.kv_embedding_cache.set_embeddings(row_indices, update_weights) + + @torch.jit.export + def initialize_kv_embedding_cache(self) -> None: + if not self.kv_embedding_cache_initialized: + self.initialize_logical_weights_placements_and_offsets() + + self.row_alignment = ( + 8 if self.use_cpu else self.row_alignment + ) # in order to use mempool implementation for kv embedding it needs to be divisible by 8 + + hash_size_cumsum = self.construct_hash_size_cumsum() + self.hash_size_cumsum = torch.tensor( + hash_size_cumsum, + dtype=torch.int64, + device=self.current_device, + ) + + self.feature_hash_size_cumsum = torch.tensor( + [hash_size_cumsum[t] for t in self.feature_table_map] + + [hash_size_cumsum[-1]], + dtype=torch.int64, + device=self.current_device, + ) + + self.kv_embedding_cache.init( + self.specs, + self.row_alignment, + self.scale_bias_size_in_bytes, + ) + self.kv_embedding_cache_initialized = True diff --git a/fbgemm_gpu/test/tbe/inference/kv_embedding_test.py b/fbgemm_gpu/test/tbe/inference/kv_embedding_test.py new file mode 100644 index 0000000000..d7b3b727b6 --- /dev/null +++ b/fbgemm_gpu/test/tbe/inference/kv_embedding_test.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from unittest import skipIf, TestCase + +import fbgemm_gpu + +import torch +from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.split_table_batched_embeddings_ops_common import EmbeddingLocation +from fbgemm_gpu.split_table_batched_embeddings_ops_inference import ( + IntNBitTableBatchedEmbeddingBagsCodegen, + random_quant_scaled_tensor, +) +from fbgemm_gpu.tbe.cache.kv_embedding_ops_inference import KVEmbeddingInference +from fbgemm_gpu.tbe.utils import generate_requests + +# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. +open_source: bool = getattr(fbgemm_gpu, "open_source", False) + + +@skipIf(open_source, "Not supported in open source yet") +class KVEmbeddingTest(TestCase): + def test_forward(self) -> None: + dim = 256 + num_tables = 4 + num_embeddings = 100 + batch_size = 2 + bag_size = 1 + num_requests = 1 + weights_precision = SparseType.INT8 + output_dtype = SparseType.FP16 + + dimentions = [dim] * num_tables + + nbit_emb_cpu = IntNBitTableBatchedEmbeddingBagsCodegen( + [ + ( + "", + num_embeddings, + d, + weights_precision, + EmbeddingLocation.HOST, + ) + for d in dimentions + ], + output_dtype=output_dtype, + device="cpu", + ) + nbit_emb_cpu.fill_random_weights() + # fill random scale bias + nbit_weights = nbit_emb_cpu.split_embedding_weights() + for dest_weight in nbit_weights: + scale_bias = dest_weight[1] + if scale_bias is not None: + random_quant_scaled_tensor( + shape=scale_bias.shape, + device=nbit_emb_cpu.current_device, + output_tensor=scale_bias, + ) + + kv_emb_cpu = KVEmbeddingInference( + [ + ( + "", + num_embeddings, + d, + weights_precision, + EmbeddingLocation.HOST, + ) + for d in dimentions + ], + output_dtype=output_dtype, + device="cpu", + ) + kv_emb_cpu.initialize_kv_embedding_cache() + + nbit_weights = nbit_emb_cpu.split_embedding_weights(split_scale_shifts=False) + for i, (nbit_weight, _) in enumerate(nbit_weights): + indices = torch.arange(0, nbit_weight.shape[0], dtype=torch.int64) + kv_emb_cpu.embedding_inplace_update_per_table( + i, + indices, + nbit_weight, + ) + kv_emb_cpu.weight_initialized = True + + requests = generate_requests( + num_requests, + batch_size, + num_tables, + bag_size, + num_embeddings, + use_cpu=True, + ) + + for req in requests: + indices = req.indices.int().cpu() + offsets = req.offsets.int().cpu() + + nbit_emb_cpu_output = nbit_emb_cpu.forward( + indices, + offsets, + ) + kv_emb_cpu_output = kv_emb_cpu.forward( + indices, + offsets, + ) + print(f"nbit_emb_cpu_output: {nbit_emb_cpu_output}") + print(f"kv_emb_cpu_output: {kv_emb_cpu_output}") + self.assertTrue( + torch.allclose( + input=nbit_emb_cpu_output, other=kv_emb_cpu_output, equal_nan=True + ) + ) From bbabc3290557d51d68f8cf83d2c8ae8098030fdb Mon Sep 17 00:00:00 2001 From: Chenyu Zhang Date: Thu, 19 Jun 2025 20:13:31 -0700 Subject: [PATCH 2/3] handle inference buck gpu deps (#4358) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/1425 **Context** Add a gpu compiler flags to ignore gpu dependency for cpu buck target. This is to unblock the cpu bolt package build. Reviewed By: emlin Differential Revision: D76228086 --- .../kv_db_table_batched_embeddings.cpp | 10 ++++++++-- .../kv_db_table_batched_embeddings.h | 3 ++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp index 990abfddc5..ec50e28931 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp @@ -11,7 +11,9 @@ #include #include #include "common/time/Time.h" +#ifdef FBGEMM_USE_GPU #include "kv_db_cuda_utils.h" +#endif #include "torch/csrc/autograd/record_function_ops.h" #ifdef FBGEMM_FBCODE #include @@ -411,6 +413,7 @@ void EmbeddingKVDB::get_cuda( const at::Tensor& indices, const at::Tensor& weights, const at::Tensor& count) { +#ifdef FBGEMM_USE_GPU auto rec = torch::autograd::profiler::record_function_enter_new( "## EmbeddingKVDB::get_cuda ##"); check_tensor_type_consistency(indices, weights); @@ -424,6 +427,7 @@ void EmbeddingKVDB::get_cuda( functor, 0)); rec->record.end(); +#endif } void EmbeddingKVDB::set_cuda( @@ -432,6 +436,7 @@ void EmbeddingKVDB::set_cuda( const at::Tensor& count, const int64_t timestep, const bool is_bwd) { +#ifdef FBGEMM_USE_GPU auto rec = torch::autograd::profiler::record_function_enter_new( "## EmbeddingKVDB::set_cuda ##"); check_tensor_type_consistency(indices, weights); @@ -447,6 +452,7 @@ void EmbeddingKVDB::set_cuda( functor, 0)); rec->record.end(); +#endif } void EmbeddingKVDB::stream_cuda( @@ -454,7 +460,7 @@ void EmbeddingKVDB::stream_cuda( const at::Tensor& weights, const at::Tensor& count, bool blocking_tensor_copy) { -#ifdef FBGEMM_FBCODE +#ifdef FBGEMM_USE_GPU auto rec = torch::autograd::profiler::record_function_enter_new( "## EmbeddingKVDB::stream_cuda ##"); check_tensor_type_consistency(indices, weights); @@ -472,7 +478,7 @@ void EmbeddingKVDB::stream_cuda( } void EmbeddingKVDB::stream_sync_cuda() { -#ifdef FBGEMM_FBCODE +#ifdef FBGEMM_USE_GPU auto rec = torch::autograd::profiler::record_function_enter_new( "## EmbeddingKVDB::stream_sync_cuda ##"); // take reference to self to avoid lifetime issues. diff --git a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h index 471d799df6..b6c0f1f7c6 100644 --- a/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h +++ b/fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h @@ -34,9 +34,10 @@ #include #include #include - +#ifdef FBGEMM_USE_GPU #include #include +#endif #include #include "../dram_kv_embedding_cache/feature_evict.h" From f7165d8fa9b01713e10c3b2231fc01d7d94e76ec Mon Sep 17 00:00:00 2001 From: Chenyu Zhang Date: Thu, 19 Jun 2025 20:13:31 -0700 Subject: [PATCH 3/3] kv embedding inference test (#4373) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/1442 Rollback Plan: Differential Revision: D76865305 --- .../tbe/dram_kv/dram_kv_inference_test.py | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py diff --git a/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py b/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py new file mode 100644 index 0000000000..0b3e6d6d4f --- /dev/null +++ b/fbgemm_gpu/test/tbe/dram_kv/dram_kv_inference_test.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest + +import fbgemm_gpu + +import torch +from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.utils.loader import load_torch_module + +# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. +open_source: bool = getattr(fbgemm_gpu, "open_source", False) + +if not open_source: + load_torch_module( + "//deeplearning/fbgemm/fbgemm_gpu:dram_kv_embedding_inference", + ) + + +@unittest.skipIf(open_source, "Not supported in open source yet") +class DramKvInferenceTest(unittest.TestCase): + def test_serialize(self) -> None: + num_shards = 32 + uniform_init_lower: float = -0.01 + uniform_init_upper: float = 0.01 + evict_trigger_mode: int = 1 + + kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( + num_shards, uniform_init_lower, uniform_init_upper, evict_trigger_mode + ) + serialized_result = kv_embedding_cache.serialize() + + self.assertEqual(serialized_result[0][0], num_shards) + self.assertEqual(serialized_result[0][1], evict_trigger_mode) + + self.assertEqual(serialized_result[1][0], uniform_init_lower) + self.assertEqual(serialized_result[1][1], uniform_init_upper) + + def test_serialize_deserialize(self) -> None: + num_shards = 32 + uniform_init_lower: float = -0.01 + uniform_init_upper: float = 0.01 + evict_trigger_mode: int = 1 + + kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( + num_shards, uniform_init_lower, uniform_init_upper, evict_trigger_mode + ) + serialized_result = kv_embedding_cache.serialize() + + kv_embedding_cache_2 = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( + 0, 0.0, 0.0, 0 + ) + kv_embedding_cache_2.deserialize(serialized_result) + + self.assertEqual(str(serialized_result), str(kv_embedding_cache_2.serialize())) + + def test_set_get_embeddings(self) -> None: + num_shards = 32 + uniform_init_lower: float = 0.0 + uniform_init_upper: float = 0.0 + evict_trigger_mode: int = 0 + + kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper( + num_shards, uniform_init_lower, uniform_init_upper, evict_trigger_mode + ) + kv_embedding_cache.init( + [(20, 4, SparseType.INT8.as_int())], + 8, + 4, + ) + + kv_embedding_cache.set_embeddings( + torch.tensor([0, 1, 2, 3], dtype=torch.int64), + torch.tensor( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + dtype=torch.uint8, + ), + ) + + embs = kv_embedding_cache.get_embeddings( + torch.tensor([1, 4, 3, 0, 5, 2], dtype=torch.int64), + ) + assert torch.equal( + embs[:, :4], + torch.tensor( + [ + [5, 6, 7, 8], + [0, 0, 0, 0], + [13, 14, 15, 16], + [1, 2, 3, 4], + [0, 0, 0, 0], + [9, 10, 11, 12], + ], + dtype=torch.uint8, + ), + )