From e4fd0fe7a221019da2f6de71f07138579b1f55a5 Mon Sep 17 00:00:00 2001 From: Gantaphon Chalumporn Date: Fri, 11 Jul 2025 10:35:01 -0700 Subject: [PATCH] Add TBE data configuration reporter to TBE forward (v3) (#4455) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/1516 Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/4455 Re-land attempt of D75462895 # Add TBE data configuration reporter to TBE forward call. The reporter reports TBE data configuration at the `SplitTableBatchedEmbeddingBagsCodegen` ***forward*** call. The output is a `TBEDataConfig` object, which is written to a JSON file(s). The configuration of its environment variables and an example of its usage is described below. ## Just Knobs for enablement - fbgemm_gpu/features:TBE_REPORT_INPUT_PARAMS is added for enablement of the reporter (https://www.internalfb.com/intern/justknobs/?name=fbgemm_gpu%2Ffeatures) - Default is set to `False`, enable this flag to enable reporter. - To enable it locally use: ``` jk canary set fbgemm_gpu/features:TBE_REPORT_INPUT_PARAMS --on --ttl 600 ``` ## Environment Variables --------------------- The Reporter relies on several environment variables to control its behavior. Below is a description of each variable: - **FBGEMM_REPORT_INPUT_PARAMS_INTERVAL**: - **Description**: Determines the interval at which reports are generated. This is specified in terms of the number of iterations. - **Example Value**: `1` (report every iteration) - **FBGEMM_REPORT_INPUT_PARAMS_ITER_START**: - ***Description**: Specifies the start of the iteration range to capture reports. Default 0. - ***Example Value**: `0` (start reporting from the first iteration) - **FBGEMM_REPORT_INPUT_PARAMS_ITER_END**: - ***Description**: Specifies the end of the iteration range to capture reports. Use `-1` to report until the last iteration. Default -1. - ***Example Value**: `-1` (report until the last iteration) - **FBGEMM_REPORT_INPUT_PARAMS_BUCKET**: * **Description**: Specifies the name of the Manifold bucket where the report data will be saved. * **Example Value**: `tlparse_reports` - **FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX**: - **Description**: Defines the path prefix where the report files will be stored. Path will be created if not exist. - **Example Value**: `tree/tests/` ## Example Usage ------------- Below is an example command demonstrating how to use the FBGEMM Reporter with specific environment variable settings: ``` FBGEMM_REPORT_INPUT_PARAMS_INTERVAL=2 FBGEMM_REPORT_INPUT_PARAMS_ITER_START=3 FBGEMM_REPORT_INPUT_PARAMS_BUCKET=tlparse_reports FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX=tree/tests/ buck2 run mode/opt //deeplearning/fbgemm/fbgemm_gpu/bench:split_table_batched_embeddings -- device --iters 2 ``` **Explanation** The above setting will report `iter 3` and `iter 5` * **FBGEMM_REPORT_INPUT_PARAMS_INTERVAL=2**: The reporter will generate a report every 2 iterations. * **FBGEMM_REPORT_INPUT_PARAMS_ITER_START=0**: The reporter will start generating reports from the first iteration. * **FBGEMM_REPORT_INPUT_PARAMS_ITER_END=-1 (Default)**: The reporter will continue to generate reports until the last iteration interval. * **FBGEMM_REPORT_INPUT_PARAMS_BUCKET=tlparse_reports**: The reports will be saved in the `tlparse_reports` bucket. * **FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX=tree/tests/**: The reports will be stored with the path prefix `tree/tests/`. For Manifold make sure all folders within the path exist. **Note on Benchmark example** Note that with the `--iters 2` option, the benchmark will execute 6 forward calls (2 iterations plus 1 warmup) for the forward benchmark and another 3 calls (2 iterations plus 1 warmup) for the backward benchmark. Iteration starts from 0. --- --- ## Other includes changes in this Diff: - Updates build dependency of tbe_data_config* files - Remove `shutil` and `numpy.random` lib as it cause uncompatiblity error. - Add non-OSS test, writing extracted config data json file to Manifold Differential Revision: D76992907 --- .../bench/tbe/tbe_training_benchmark.py | 11 +- fbgemm_gpu/fbgemm_gpu/config/feature_list.py | 3 + ...t_table_batched_embeddings_ops_training.py | 38 +++ .../fbgemm_gpu/tbe/bench/tbe_data_config.py | 184 +-------------- .../tbe/bench/tbe_data_config_bench_helper.py | 223 ++++++++++++++++++ .../tbe/bench/tbe_data_config_loader.py | 8 +- .../tbe/stats/bench_params_reporter.py | 199 ++++++++++++---- fbgemm_gpu/fbgemm_gpu/utils/filestore.py | 8 +- .../include/fbgemm_gpu/config/feature_gates.h | 3 +- .../stats/tbe_bench_params_reporter_test.py | 166 +++++++++++-- fbgemm_gpu/test/utils/filestore_test.py | 4 + 11 files changed, 596 insertions(+), 251 deletions(-) create mode 100644 fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py diff --git a/fbgemm_gpu/bench/tbe/tbe_training_benchmark.py b/fbgemm_gpu/bench/tbe/tbe_training_benchmark.py index bc47a6eed4..6b9413c8d6 100644 --- a/fbgemm_gpu/bench/tbe/tbe_training_benchmark.py +++ b/fbgemm_gpu/bench/tbe/tbe_training_benchmark.py @@ -33,6 +33,11 @@ TBEBenchmarkingConfigLoader, TBEDataConfigLoader, ) +from fbgemm_gpu.tbe.bench.tbe_data_config_bench_helper import ( + generate_embedding_dims, + generate_feature_requires_grad, + generate_requests, +) from fbgemm_gpu.tbe.ssd import SSDTableBatchedEmbeddingBags from fbgemm_gpu.tbe.utils import get_device from torch.profiler import profile @@ -102,13 +107,13 @@ def device( # noqa C901 # Generate feature_requires_grad feature_requires_grad = ( - tbeconfig.generate_feature_requires_grad(weighted_num_requires_grad) + generate_feature_requires_grad(tbeconfig, weighted_num_requires_grad) if weighted_num_requires_grad else None ) # Generate embedding dims - effective_D, Ds = tbeconfig.generate_embedding_dims() + effective_D, Ds = generate_embedding_dims(tbeconfig) # Determine the optimizer optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD @@ -212,7 +217,7 @@ def device( # noqa C901 f"Accessed weights per batch: {tbeconfig.batch_params.B * sum(Ds) * tbeconfig.pooling_params.L * param_size_multiplier / 1.0e9: .2f} GB" ) - requests = tbeconfig.generate_requests(benchconfig.num_requests) + requests = generate_requests(tbeconfig, benchconfig.num_requests) # pyre-ignore[53] def _kineto_trace_handler(p: profile, phase: str) -> None: diff --git a/fbgemm_gpu/fbgemm_gpu/config/feature_list.py b/fbgemm_gpu/fbgemm_gpu/config/feature_list.py index 14b2bbe2a9..8f8fd6f495 100644 --- a/fbgemm_gpu/fbgemm_gpu/config/feature_list.py +++ b/fbgemm_gpu/fbgemm_gpu/config/feature_list.py @@ -60,6 +60,9 @@ def foo(): # Enable bounds_check_indices_v2 BOUNDS_CHECK_INDICES_V2 = auto() + # Enable TBE input parameters extraction + TBE_REPORT_INPUT_PARAMS = auto() + def is_enabled(self) -> bool: return FeatureGate.is_enabled(self) diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py index fe8fad0af1..a43371eecb 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py @@ -51,6 +51,7 @@ generate_vbe_metadata, is_torchdynamo_compiling, ) +from fbgemm_gpu.tbe.stats import TBEBenchmarkParamsReporter from fbgemm_gpu.tbe_input_multiplexer import ( TBEInfo, TBEInputInfo, @@ -1441,6 +1442,11 @@ def __init__( # noqa C901 self._debug_print_input_stats_factory() ) + # Get a reporter function pointer + self._report_input_params: Callable[..., None] = ( + self.__report_input_params_factory() + ) + if optimizer == OptimType.EXACT_SGD and self.use_writeback_bwd_prehook: # Register writeback hook for Exact_SGD optimizer self.log( @@ -1953,6 +1959,19 @@ def forward( # noqa: C901 # Print input stats if enable (for debugging purpose only) self._debug_print_input_stats(indices, offsets, per_sample_weights) + # Extract and Write input stats if enable + if self._report_input_params is not None: + self._report_input_params( + feature_rows=self.rows_per_table, + feature_dims=self.feature_dims, + iteration=self.iter_cpu.item() if hasattr(self, "iter_cpu") else 0, + indices=indices, + offsets=offsets, + op_id=self.uuid, + per_sample_weights=per_sample_weights, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, + ) + if not is_torchdynamo_compiling(): # Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time @@ -3804,6 +3823,25 @@ def _debug_print_input_stats_factory_null( return _debug_print_input_stats_factory_impl return _debug_print_input_stats_factory_null + @torch.jit.ignore + def __report_input_params_factory( + self, + ) -> Optional[Callable[..., None]]: + """ + This function returns a function pointer based on the environment variable `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL`. + + If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is set to a value greater than 0, it returns a function pointer that: + - Reports input parameters (TBEDataConfig). + - Writes the output as a JSON file. + + If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is not set or is set to 0, it returns a dummy function pointer that performs no action. + """ + + if self._feature_is_enabled(FeatureGateName.TBE_REPORT_INPUT_PARAMS): + reporter = TBEBenchmarkParamsReporter.create() + return reporter.report_stats + return None + class DenseTableBatchedEmbeddingBagsCodegen(nn.Module): """ diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py index 58a0d19baa..156b903256 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py @@ -9,19 +9,11 @@ import dataclasses import json -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, Optional -import numpy as np import torch -from fbgemm_gpu.tbe.utils.common import get_device, round_up -from fbgemm_gpu.tbe.utils.requests import ( - generate_batch_sizes_from_stats, - generate_pooling_factors_from_stats, - get_table_batched_offsets_from_dense, - maybe_to_dtype, - TBERequest, -) +from fbgemm_gpu.tbe.utils.common import get_device from .tbe_data_config_param_models import BatchParams, IndicesParams, PoolingParams @@ -104,175 +96,3 @@ def variable_L(self) -> bool: def _new_weights(self, size: int) -> Optional[torch.Tensor]: # Per-sample weights will always be FP32 return None if not self.weighted else torch.randn(size, device=get_device()) - - def _generate_batch_sizes(self) -> Tuple[List[int], Optional[List[List[int]]]]: - if self.variable_B(): - assert ( - self.batch_params.vbe_num_ranks is not None - ), "vbe_num_ranks must be set for varaible batch size generation" - return generate_batch_sizes_from_stats( - self.batch_params.B, - self.T, - # pyre-ignore [6] - self.batch_params.sigma_B, - self.batch_params.vbe_num_ranks, - # pyre-ignore [6] - self.batch_params.vbe_distribution, - ) - - else: - return ([self.batch_params.B] * self.T, None) - - def _generate_pooling_info(self, iters: int, Bs: List[int]) -> torch.Tensor: - if self.variable_L(): - # Generate L from stats - _, L_offsets = generate_pooling_factors_from_stats( - iters, - Bs, - self.pooling_params.L, - # pyre-ignore [6] - self.pooling_params.sigma_L, - # pyre-ignore [6] - self.pooling_params.length_distribution, - ) - - else: - Ls = [self.pooling_params.L] * (sum(Bs) * iters) - L_offsets = torch.tensor([0] + Ls, dtype=torch.long).cumsum(0) - - return L_offsets - - def _generate_indices( - self, - iters: int, - Bs: List[int], - L_offsets: torch.Tensor, - ) -> torch.Tensor: - total_B = sum(Bs) - L_offsets_list = L_offsets.tolist() - indices_list = [] - for it in range(iters): - # L_offsets is defined over the entire set of batches for a single iteration - start_offset = L_offsets_list[it * total_B] - end_offset = L_offsets_list[(it + 1) * total_B] - - indices_list.append( - torch.ops.fbgemm.tbe_generate_indices_from_distribution( - self.indices_params.heavy_hitters, - self.indices_params.zipf_q, - self.indices_params.zipf_s, - # max_index = dimensions of the embedding table - self.E, - # num_indices = number of indices to generate - end_offset - start_offset, - ) - ) - - return torch.cat(indices_list) - - def _build_requests_jagged( - self, - iters: int, - Bs: List[int], - Bs_feature_rank: Optional[List[List[int]]], - L_offsets: torch.Tensor, - all_indices: torch.Tensor, - ) -> List[TBERequest]: - total_B = sum(Bs) - all_indices = all_indices.flatten() - requests = [] - for it in range(iters): - start_offset = L_offsets[it * total_B] - it_L_offsets = torch.concat( - [ - torch.zeros(1, dtype=L_offsets.dtype, device=L_offsets.device), - L_offsets[it * total_B + 1 : (it + 1) * total_B + 1] - start_offset, - ] - ) - requests.append( - TBERequest( - maybe_to_dtype( - all_indices[start_offset : L_offsets[(it + 1) * total_B]], - self.indices_params.index_dtype, - ), - maybe_to_dtype( - it_L_offsets.to(get_device()), self.indices_params.offset_dtype - ), - self._new_weights(int(it_L_offsets[-1].item())), - Bs_feature_rank if self.variable_B() else None, - ) - ) - return requests - - def _build_requests_dense( - self, iters: int, all_indices: torch.Tensor - ) -> List[TBERequest]: - # NOTE: We're using existing code from requests.py to build the - # requests, and since the existing code requires 2D view of all_indices, - # the existing all_indices must be reshaped - all_indices = all_indices.reshape(iters, -1) - - requests = [] - for it in range(iters): - indices, offsets = get_table_batched_offsets_from_dense( - all_indices[it].view( - self.T, self.batch_params.B, self.pooling_params.L - ), - use_cpu=self.use_cpu, - ) - requests.append( - TBERequest( - maybe_to_dtype(indices, self.indices_params.index_dtype), - maybe_to_dtype(offsets, self.indices_params.offset_dtype), - self._new_weights( - self.T * self.batch_params.B * self.pooling_params.L - ), - ) - ) - return requests - - def generate_requests( - self, - iters: int = 1, - ) -> List[TBERequest]: - # Generate batch sizes - Bs, Bs_feature_rank = self._generate_batch_sizes() - - # Generate pooling info - L_offsets = self._generate_pooling_info(iters, Bs) - - # Generate indices - all_indices = self._generate_indices(iters, Bs, L_offsets) - - # Build TBE requests - if self.variable_B() or self.variable_L(): - return self._build_requests_jagged( - iters, Bs, Bs_feature_rank, L_offsets, all_indices - ) - else: - return self._build_requests_dense(iters, all_indices) - - def generate_embedding_dims(self) -> Tuple[int, List[int]]: - if self.mixed_dim: - Ds = [ - round_up( - np.random.randint(low=int(0.5 * self.D), high=int(1.5 * self.D)), 4 - ) - for _ in range(self.T) - ] - return (int(np.average(Ds)), Ds) - else: - return (self.D, [self.D] * self.T) - - def generate_feature_requires_grad(self, size: int) -> torch.Tensor: - assert size <= self.T, "size of feature_requires_grad must be less than T" - weighted_requires_grad_tables = np.random.choice( - self.T, replace=False, size=(size,) - ).tolist() - return ( - torch.tensor( - [1 if t in weighted_requires_grad_tables else 0 for t in range(self.T)] - ) - .to(get_device()) - .int() - ) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py new file mode 100644 index 0000000000..51ea6bbc5e --- /dev/null +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_bench_helper.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import List, Optional, Tuple + +import torch + +from fbgemm_gpu.tbe.bench.tbe_data_config import TBEDataConfig +from fbgemm_gpu.tbe.utils.common import get_device, round_up + +from fbgemm_gpu.tbe.utils.requests import ( + generate_batch_sizes_from_stats, + generate_pooling_factors_from_stats, + get_table_batched_offsets_from_dense, + maybe_to_dtype, + TBERequest, +) + + +def _generate_batch_sizes( + tbe_data_config: TBEDataConfig, +) -> Tuple[List[int], Optional[List[List[int]]]]: + if tbe_data_config.variable_B(): + assert ( + tbe_data_config.batch_params.vbe_num_ranks is not None + ), "vbe_num_ranks must be set for varaible batch size generation" + return generate_batch_sizes_from_stats( + tbe_data_config.batch_params.B, + tbe_data_config.T, + # pyre-ignore [6] + tbe_data_config.batch_params.sigma_B, + tbe_data_config.batch_params.vbe_num_ranks, + # pyre-ignore [6] + tbe_data_config.batch_params.vbe_distribution, + ) + + else: + return ([tbe_data_config.batch_params.B] * tbe_data_config.T, None) + + +def _generate_pooling_info( + tbe_data_config: TBEDataConfig, iters: int, Bs: List[int] +) -> torch.Tensor: + if tbe_data_config.variable_L(): + # Generate L from stats + _, L_offsets = generate_pooling_factors_from_stats( + iters, + Bs, + tbe_data_config.pooling_params.L, + # pyre-ignore [6] + tbe_data_config.pooling_params.sigma_L, + # pyre-ignore [6] + tbe_data_config.pooling_params.length_distribution, + ) + else: + Ls = [tbe_data_config.pooling_params.L] * (sum(Bs) * iters) + L_offsets = torch.tensor([0] + Ls, dtype=torch.long).cumsum(0) + + return L_offsets + + +def _generate_indices( + tbe_data_config: TBEDataConfig, + iters: int, + Bs: List[int], + L_offsets: torch.Tensor, +) -> torch.Tensor: + total_B = sum(Bs) + L_offsets_list = L_offsets.tolist() + indices_list = [] + for it in range(iters): + # L_offsets is defined over the entire set of batches for a single iteration + start_offset = L_offsets_list[it * total_B] + end_offset = L_offsets_list[(it + 1) * total_B] + + indices_list.append( + torch.ops.fbgemm.tbe_generate_indices_from_distribution( + tbe_data_config.indices_params.heavy_hitters, + tbe_data_config.indices_params.zipf_q, + tbe_data_config.indices_params.zipf_s, + # max_index = dimensions of the embedding table + tbe_data_config.E, + # num_indices = number of indices to generate + end_offset - start_offset, + ) + ) + + return torch.cat(indices_list) + + +def _build_requests_jagged( + tbe_data_config: TBEDataConfig, + iters: int, + Bs: List[int], + Bs_feature_rank: Optional[List[List[int]]], + L_offsets: torch.Tensor, + all_indices: torch.Tensor, +) -> List[TBERequest]: + total_B = sum(Bs) + all_indices = all_indices.flatten() + requests = [] + for it in range(iters): + start_offset = L_offsets[it * total_B] + it_L_offsets = torch.concat( + [ + torch.zeros(1, dtype=L_offsets.dtype, device=L_offsets.device), + L_offsets[it * total_B + 1 : (it + 1) * total_B + 1] - start_offset, + ] + ) + requests.append( + TBERequest( + maybe_to_dtype( + all_indices[start_offset : L_offsets[(it + 1) * total_B]], + tbe_data_config.indices_params.index_dtype, + ), + maybe_to_dtype( + it_L_offsets.to(get_device()), + tbe_data_config.indices_params.offset_dtype, + ), + tbe_data_config._new_weights(int(it_L_offsets[-1].item())), + Bs_feature_rank if tbe_data_config.variable_B() else None, + ) + ) + return requests + + +def _build_requests_dense( + tbe_data_config: TBEDataConfig, iters: int, all_indices: torch.Tensor +) -> List[TBERequest]: + # NOTE: We're using existing code from requests.py to build the + # requests, and since the existing code requires 2D view of all_indices, + # the existing all_indices must be reshaped + all_indices = all_indices.reshape(iters, -1) + + requests = [] + for it in range(iters): + indices, offsets = get_table_batched_offsets_from_dense( + all_indices[it].view( + tbe_data_config.T, + tbe_data_config.batch_params.B, + tbe_data_config.pooling_params.L, + ), + use_cpu=tbe_data_config.use_cpu, + ) + requests.append( + TBERequest( + maybe_to_dtype(indices, tbe_data_config.indices_params.index_dtype), + maybe_to_dtype(offsets, tbe_data_config.indices_params.offset_dtype), + tbe_data_config._new_weights( + tbe_data_config.T + * tbe_data_config.batch_params.B + * tbe_data_config.pooling_params.L + ), + ) + ) + return requests + + +def generate_requests( + tbe_data_config: TBEDataConfig, + iters: int = 1, +) -> List[TBERequest]: + # Generate batch sizes + Bs, Bs_feature_rank = _generate_batch_sizes(tbe_data_config) + + # Generate pooling info + L_offsets = _generate_pooling_info(tbe_data_config, iters, Bs) + + # Generate indices + all_indices = _generate_indices(tbe_data_config, iters, Bs, L_offsets) + + # Build TBE requests + if tbe_data_config.variable_B() or tbe_data_config.variable_L(): + return _build_requests_jagged( + tbe_data_config, iters, Bs, Bs_feature_rank, L_offsets, all_indices + ) + else: + return _build_requests_dense(tbe_data_config, iters, all_indices) + + +def generate_embedding_dims(tbe_data_config: TBEDataConfig) -> Tuple[int, List[int]]: + if tbe_data_config.mixed_dim: + Ds = [ + round_up( + int( + torch.randint( + low=int(0.5 * tbe_data_config.D), + high=int(1.5 * tbe_data_config.D), + size=(1,), + ).item() + ), + 4, + ) + for _ in range(tbe_data_config.T) + ] + return (sum(Ds) // len(Ds), Ds) + else: + return (tbe_data_config.D, [tbe_data_config.D] * tbe_data_config.T) + + +def generate_feature_requires_grad( + tbe_data_config: TBEDataConfig, size: int +) -> torch.Tensor: + assert ( + size <= tbe_data_config.T + ), "size of feature_requires_grad must be less than T" + weighted_requires_grad_tables = torch.randperm(tbe_data_config.T)[:size].tolist() + return ( + torch.tensor( + [ + 1 if t in weighted_requires_grad_tables else 0 + for t in range(tbe_data_config.T) + ] + ) + .to(get_device()) + .int() + ) diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_loader.py b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_loader.py index 52bde55308..5149104bcc 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_loader.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config_loader.py @@ -11,8 +11,12 @@ import torch import yaml -from .tbe_data_config import TBEDataConfig -from .tbe_data_config_param_models import BatchParams, IndicesParams, PoolingParams +from fbgemm_gpu.tbe.bench.tbe_data_config import ( + BatchParams, + IndicesParams, + PoolingParams, + TBEDataConfig, +) class TBEDataConfigLoader: diff --git a/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py b/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py index 794b38a20e..bc2a52a465 100644 --- a/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py +++ b/fbgemm_gpu/fbgemm_gpu/tbe/stats/bench_params_reporter.py @@ -8,31 +8,36 @@ # pyre-strict import io +import json import logging import os from typing import List, Optional import fbgemm_gpu # noqa F401 -import numpy as np # usort:skip import torch # usort:skip -from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( - SplitTableBatchedEmbeddingBagsCodegen, -) -from fbgemm_gpu.tbe.bench import ( +from fbgemm_gpu.tbe.bench.tbe_data_config import ( BatchParams, IndicesParams, PoolingParams, TBEDataConfig, ) -# pyre-ignore[16] -open_source: bool = getattr(fbgemm_gpu, "open_source", False) +open_source: bool = False +try: + # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. + if getattr(fbgemm_gpu, "open_source", False): + open_source = True +except Exception: + pass if open_source: from fbgemm_gpu.utils import FileStore else: - from fbgemm_gpu.fb.utils import FileStore + try: + from fbgemm_gpu.fb.utils.manifold_wrapper import FileStore + except ImportError: + from fbgemm_gpu.utils import FileStore class TBEBenchmarkParamsReporter: @@ -43,7 +48,8 @@ class TBEBenchmarkParamsReporter: def __init__( self, report_interval: int, - report_once: bool = False, + report_iter_start: int = 0, + report_iter_end: int = -1, bucket: Optional[str] = None, path_prefix: Optional[str] = None, ) -> None: @@ -52,13 +58,30 @@ def __init__( Args: report_interval (int): The interval at which reports are generated. - report_once (bool, optional): If True, reporting occurs only once. Defaults to False. + report_iter_start (int): The start of the iteration range to capture. Defaults to 0. + report_iter_end (int): The end of the iteration range to capture. Defaults to -1 (last iteration). bucket (Optional[str], optional): The storage bucket for reports. Defaults to None. path_prefix (Optional[str], optional): The path prefix for report storage. Defaults to None. """ + assert report_interval > 0, "report_interval must be greater than 0" + assert ( + report_iter_start >= 0 + ), "report_iter_start must be greater than or equal to 0" + assert ( + report_iter_end >= -1 + ), "report_iter_end must be greater than or equal to -1" + assert ( + report_iter_end == -1 or report_iter_start <= report_iter_end + ), "report_iter_start must be less than or equal to report_iter_end" + self.report_interval = report_interval - self.report_once = report_once - self.has_reported = False + self.report_iter_start = report_iter_start + self.report_iter_end = report_iter_end + + if path_prefix is not None and path_prefix.endswith("/"): + path_prefix = path_prefix[:-1] + + self.path_prefix = path_prefix default_bucket = "/tmp" if open_source else "tlparse_reports" bucket = ( @@ -68,22 +91,65 @@ def __init__( ) self.filestore = FileStore(bucket) + if self.path_prefix is not None and not self.filestore.exists(self.path_prefix): + self.filestore.create_directory(self.path_prefix) + self.logger: logging.Logger = logging.getLogger(__name__) self.logger.setLevel(logging.INFO) + @classmethod + def create(cls) -> "TBEBenchmarkParamsReporter": + """ + This method returns an instance of TBEBenchmarkParamsReporter based on environment variables. + + If the `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` environment variable is set to a value greater than 0, it creates an instance that: + - Reports input parameters (TBEDataConfig). + - Writes the output as a JSON file. + + Additionally, the following environment variables are considered: + - `FBGEMM_REPORT_INPUT_PARAMS_ITER_START`: Specifies the start of the iteration range to capture. + - `FBGEMM_REPORT_INPUT_PARAMS_ITER_END`: Specifies the end of the iteration range to capture. + - `FBGEMM_REPORT_INPUT_PARAMS_BUCKET`: Specifies the bucket for reporting. + - `FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX`: Specifies the path prefix for reporting. + + Returns: + TBEBenchmarkParamsReporter: An instance configured based on the environment variables. + """ + report_interval = int( + os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_INTERVAL", "1") + ) + report_iter_start = int( + os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_ITER_START", "0") + ) + report_iter_end = int( + os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_ITER_END", "-1") + ) + bucket = os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_BUCKET", "") + path_prefix = os.environ.get("FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX", "") + + return cls( + report_interval=report_interval, + report_iter_start=report_iter_start, + report_iter_end=report_iter_end, + bucket=bucket, + path_prefix=path_prefix, + ) + def extract_params( self, - embedding_op: SplitTableBatchedEmbeddingBagsCodegen, + feature_rows: torch.Tensor, + feature_dims: torch.Tensor, indices: torch.Tensor, offsets: torch.Tensor, per_sample_weights: Optional[torch.Tensor] = None, batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, ) -> TBEDataConfig: """ - Extracts parameters from the embedding operation, input indices and offsets to create a TBEDataConfig. + Extracts parameters from the embedding operation, input indices, and offsets to create a TBEDataConfig. Args: - embedding_op (SplitTableBatchedEmbeddingBagsCodegen): The embedding operation. + feature_rows (torch.Tensor): Number of rows in each feature. + feature_dims (torch.Tensor): Number of dimensions in each feature. indices (torch.Tensor): The input indices tensor. offsets (torch.Tensor): The input offsets tensor. per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None. @@ -92,24 +158,33 @@ def extract_params( Returns: TBEDataConfig: The configuration data for TBE benchmarking. """ + + Es = feature_rows.tolist() + Ds = feature_dims.tolist() + + assert len(Es) == len( + Ds + ), "feature_rows and feature_dims must have the same length" + # Transfer indices back to CPU for EEG analysis indices_cpu = indices.cpu() - # Extract embedding table specs - embedding_specs = [ - embedding_op.embedding_specs[t] for t in embedding_op.feature_table_map - ] - rowcounts = [embedding_spec[0] for embedding_spec in embedding_specs] - dims = [embedding_spec[1] for embedding_spec in embedding_specs] - # Set T to be the number of features we are looking at - T = len(embedding_op.feature_table_map) + T = len(Ds) # Set E to be the mean of the rowcounts to avoid biasing - E = rowcounts[0] if len(set(rowcounts)) == 1 else np.ceil((np.mean(rowcounts))) + E = ( + Es[0] + if len(set(Es)) == 1 + else torch.ceil(torch.mean(torch.tensor(feature_rows))) + ) # Set mixed_dim to be True if there are multiple dims - mixed_dim = len(set(dims)) > 1 + mixed_dim = len(set(Ds)) > 1 # Set D to be the mean of the dims to avoid biasing - D = dims[0] if not mixed_dim else np.ceil((np.mean(dims))) + D = ( + Ds[0] + if not mixed_dim + else torch.ceil(torch.mean(torch.tensor(feature_dims))) + ) # Compute indices distribution parameters heavy_hitters, q, s, _, _ = torch.ops.fbgemm.tbe_estimate_indices_distribution( @@ -123,8 +198,18 @@ def extract_params( batch_params = BatchParams( B=((offsets.numel() - 1) // T), sigma_B=( - np.ceil( - np.std([b for bs in batch_size_per_feature_per_rank for b in bs]) + int( + torch.ceil( + torch.std( + torch.tensor( + [ + b + for bs in batch_size_per_feature_per_rank + for b in bs + ] + ) + ) + ) ) if batch_size_per_feature_per_rank else None @@ -138,11 +223,19 @@ def extract_params( ) # Compute pooling parameters - bag_sizes = (offsets[1:] - offsets[:-1]).tolist() + bag_sizes = offsets[1:] - offsets[:-1] mixed_bag_sizes = len(set(bag_sizes)) > 1 pooling_params = PoolingParams( - L=np.ceil(np.mean(bag_sizes)) if mixed_bag_sizes else bag_sizes[0], - sigma_L=(np.ceil(np.std(bag_sizes)) if mixed_bag_sizes else None), + L=( + int(torch.ceil(torch.mean(bag_sizes.float()))) + if mixed_bag_sizes + else int(bag_sizes[0]) + ), + sigma_L=( + int(torch.ceil(torch.std(bag_sizes.float()))) + if mixed_bag_sizes + else None + ), length_distribution=("normal" if mixed_bag_sizes else None), ) @@ -160,34 +253,58 @@ def extract_params( def report_stats( self, - embedding_op: SplitTableBatchedEmbeddingBagsCodegen, + feature_rows: torch.Tensor, + feature_dims: torch.Tensor, + iteration: int, indices: torch.Tensor, offsets: torch.Tensor, + op_id: str = "", per_sample_weights: Optional[torch.Tensor] = None, batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, ) -> None: """ - Reports the configuration of the embedding operation and input data then writes the TBE configuration to the filestore. + Reports the configuration of the embedding operation and input data, then writes the TBE configuration to the filestore. Args: - embedding_op (SplitTableBatchedEmbeddingBagsCodegen): The embedding operation. + feature_rows (torch.Tensor): Number of rows in each feature. + feature_dims (torch.Tensor): Number of dimensions in each feature. + iteration (int): The current iteration number. indices (torch.Tensor): The input indices tensor. offsets (torch.Tensor): The input offsets tensor. + op_id (str, optional): The operation identifier. Defaults to an empty string. per_sample_weights (Optional[torch.Tensor], optional): Weights for each sample. Defaults to None. batch_size_per_feature_per_rank (Optional[List[List[int]]], optional): Batch sizes per feature per rank. Defaults to None. """ - if embedding_op.iter.item() % self.report_interval == 0 and ( - not self.report_once or (self.report_once and not self.has_reported) + if ( + (iteration - self.report_iter_start) % self.report_interval == 0 + and (iteration >= self.report_iter_start) + and (self.report_iter_end == -1 or iteration <= self.report_iter_end) ): # Extract TBE config config = self.extract_params( - embedding_op, indices, offsets, per_sample_weights + feature_rows=feature_rows, + feature_dims=feature_dims, + indices=indices, + offsets=offsets, + per_sample_weights=per_sample_weights, + batch_size_per_feature_per_rank=batch_size_per_feature_per_rank, ) + config.json() + + # Ad-hoc fix for adding Es and Ds to JSON output + # TODO: Remove this once we moved Es and Ds to be part of TBEDataConfig + adhoc_config = config.dict() + adhoc_config["Es"] = feature_rows.tolist() + adhoc_config["Ds"] = feature_dims.tolist() + if batch_size_per_feature_per_rank: + adhoc_config["Bs"] = [ + sum(batch_size_per_feature_per_rank[f]) + for f in range(len(adhoc_config["Es"])) + ] + # Write the TBE config to FileStore self.filestore.write( - f"tbe-{embedding_op.uuid}-config-estimation-{embedding_op.iter.item()}.json", - io.BytesIO(config.json(format=True).encode()), + f"{self.path_prefix}/tbe-{op_id}-config-estimation-{iteration}.json", + io.BytesIO(json.dumps(adhoc_config, indent=2).encode()), ) - - self.has_reported = True diff --git a/fbgemm_gpu/fbgemm_gpu/utils/filestore.py b/fbgemm_gpu/fbgemm_gpu/utils/filestore.py index d5639c5639..3bd2878d95 100644 --- a/fbgemm_gpu/fbgemm_gpu/utils/filestore.py +++ b/fbgemm_gpu/fbgemm_gpu/utils/filestore.py @@ -11,7 +11,6 @@ import io import logging import os -import shutil from dataclasses import dataclass from pathlib import Path from typing import BinaryIO, Union @@ -76,7 +75,12 @@ def write( elif isinstance(raw_input, Path): if not os.path.exists(raw_input): raise FileNotFoundError(f"File {raw_input} does not exist") - shutil.copyfile(raw_input, filepath) + # Open the source file and destination file, and copy the contents + with open(raw_input, "rb") as src_file, open( + filepath, "wb" + ) as dst_file: + while chunk := src_file.read(4096): # Read 4 KB at a time + dst_file.write(chunk) elif isinstance(raw_input, io.BytesIO) or isinstance(raw_input, BinaryIO): with open(filepath, "wb") as file: diff --git a/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h b/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h index 11c4d55763..9018e68603 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h +++ b/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h @@ -61,7 +61,8 @@ namespace fbgemm_gpu::config { X(TBE_ANNOTATE_KINETO_TRACE) \ X(TBE_ROCM_INFERENCE_PACKED_BAGS) \ X(TBE_ROCM_HIP_BACKWARD_KERNEL) \ - X(BOUNDS_CHECK_INDICES_V2) + X(BOUNDS_CHECK_INDICES_V2) \ + X(TBE_REPORT_INPUT_PARAMS) // X(EXAMPLE_FEATURE_FLAG) /// @ingroup fbgemm-gpu-config diff --git a/fbgemm_gpu/test/tbe/stats/tbe_bench_params_reporter_test.py b/fbgemm_gpu/test/tbe/stats/tbe_bench_params_reporter_test.py index 112a174091..4474ad4adb 100644 --- a/fbgemm_gpu/test/tbe/stats/tbe_bench_params_reporter_test.py +++ b/fbgemm_gpu/test/tbe/stats/tbe_bench_params_reporter_test.py @@ -8,13 +8,20 @@ # pyre-strict import unittest +from typing import Optional +from unittest.mock import patch + +import fbgemm_gpu import hypothesis.strategies as st import torch -from fbgemm_gpu.split_table_batched_embeddings_ops_common import EmbeddingLocation -from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( +from fbgemm_gpu.config import FeatureGateName +from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( ComputeDevice, + EmbeddingLocation, +) +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( SplitTableBatchedEmbeddingBagsCodegen, ) from fbgemm_gpu.tbe.bench import ( @@ -23,10 +30,25 @@ PoolingParams, TBEDataConfig, ) + +from fbgemm_gpu.tbe.bench.tbe_data_config_bench_helper import ( + generate_embedding_dims, + generate_requests, +) + from fbgemm_gpu.tbe.stats import TBEBenchmarkParamsReporter from fbgemm_gpu.tbe.utils import get_device from hypothesis import given, settings +from .. import common # noqa E402 +from ..common import running_in_oss + +try: + # pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`. + open_source: bool = getattr(fbgemm_gpu, "open_source", False) +except Exception: + open_source: bool = False + class TestTBEBenchmarkParamsReporter(unittest.TestCase): # pyre-ignore[56] @@ -74,23 +96,25 @@ def test_report_stats( ) # Generate the embedding dimension list - _, Ds = tbeconfig.generate_embedding_dims() + _, Ds = generate_embedding_dims(tbeconfig) + + embedding_specs = [ + ( + tbeconfig.E, + D, + embedding_location, + ( + ComputeDevice.CUDA + if torch.cuda.is_available() + else ComputeDevice.CPU + ), + ) + for D in Ds + ] # Generate the embedding operation embedding_op = SplitTableBatchedEmbeddingBagsCodegen( - [ - ( - tbeconfig.E, - D, - embedding_location, - ( - ComputeDevice.CUDA - if torch.cuda.is_available() - else ComputeDevice.CPU - ), - ) - for D in Ds - ], + embedding_specs, embedding_table_index_type=tbeconfig.indices_params.index_dtype or torch.int64, embedding_table_offset_type=tbeconfig.indices_params.offset_dtype @@ -103,11 +127,12 @@ def test_report_stats( reporter = TBEBenchmarkParamsReporter(report_interval=1) # Generate indices and offsets - request = tbeconfig.generate_requests(1)[0] + request = generate_requests(tbeconfig, 1)[0] - # Call the report_stats method + # Call the extract_params method extracted_config = reporter.extract_params( - embedding_op=embedding_op, + feature_rows=embedding_op.rows_per_table, + feature_dims=embedding_op.feature_dims, indices=request.indices, offsets=request.offsets, ) @@ -125,4 +150,105 @@ def test_report_stats( and extracted_config.indices_params.offset_dtype == tbeconfig.indices_params.offset_dtype ), "Extracted config does not match the original TBEDataConfig" - # Attempt to reconstruct TBEDataConfig from extracted_json_config + + # pyre-ignore[56] + @given( + T=st.integers(1, 10), + E=st.integers(100, 10000), + D=st.sampled_from([32, 64, 128, 256]), + L=st.integers(1, 10), + B=st.integers(20, 100), + ) + @settings(max_examples=1, deadline=None) + @unittest.skipIf(*running_in_oss) + def test_report_fb_files( + self, + T: int, + E: int, + D: int, + L: int, + B: int, + ) -> None: + """ + Test writing extrcted TBEDataConfig to FB FileStore + """ + from fbgemm_gpu.fb.utils.manifold_wrapper import FileStore + + # Initialize the reporter + bucket = "tlparse_reports" + path_prefix = "tree/unit_tests/" + + # Generate a TBEDataConfig + tbeconfig = TBEDataConfig( + T=T, + E=E, + D=D, + mixed_dim=False, + weighted=False, + batch_params=BatchParams(B=B), + indices_params=IndicesParams( + heavy_hitters=torch.tensor([]), + zipf_q=0.1, + zipf_s=0.1, + index_dtype=torch.int64, + offset_dtype=torch.int64, + ), + pooling_params=PoolingParams(L=L), + use_cpu=not torch.cuda.is_available(), + ) + + embedding_location = ( + EmbeddingLocation.DEVICE + if torch.cuda.is_available() + else EmbeddingLocation.HOST + ) + + # Generate the embedding dimension list + _, Ds = generate_embedding_dims(tbeconfig) + + with patch( + "torch.ops.fbgemm.check_feature_gate_key" + ) as mock_check_feature_gate_key: + # Mock the return value for TBE_REPORT_INPUT_PARAMS + def side_effect(feature_name: str) -> Optional[bool]: + if feature_name == FeatureGateName.TBE_REPORT_INPUT_PARAMS.name: + return True + + mock_check_feature_gate_key.side_effect = side_effect + + # Generate the embedding operation + embedding_op = SplitTableBatchedEmbeddingBagsCodegen( + [ + ( + tbeconfig.E, + D, + embedding_location, + ( + ComputeDevice.CUDA + if torch.cuda.is_available() + else ComputeDevice.CPU + ), + ) + for D in Ds + ], + ) + + embedding_op = embedding_op.to(get_device()) + + # Generate indices and offsets + request = generate_requests(tbeconfig, 1)[0] + + # Execute the embedding operation with reporting flag enable + embedding_op.forward(request.indices, request.offsets) + + # Check if the file was written to Manifold + store = FileStore(bucket) + path = f"{path_prefix}tbe-{embedding_op.uuid}-config-estimation-{embedding_op.iter_cpu.item()}.json" + assert store.exists(path), f"{path} not exists" + + # Clenaup, delete the file + store.remove(path) + + +if __name__ == "__main__": + unittest.main() diff --git a/fbgemm_gpu/test/utils/filestore_test.py b/fbgemm_gpu/test/utils/filestore_test.py index 3c2485338c..b30e5100a0 100644 --- a/fbgemm_gpu/test/utils/filestore_test.py +++ b/fbgemm_gpu/test/utils/filestore_test.py @@ -209,3 +209,7 @@ def test_filestore_fb_directory(self) -> None: from fbgemm_gpu.fb.utils.manifold_wrapper import FileStore self._test_filestore_directory(FileStore("tlparse_reports"), "tree/unit_tests") + + +if __name__ == "__main__": + unittest.main()