Skip to content

Commit 866a2b1

Browse files
gchalumpfacebook-github-bot
authored andcommitted
Add TBE data configuration reporter to TBE forward (v3) (#4455)
Summary: X-link: facebookresearch/FBGEMM#1516 Re-land attempt of D75462895 # Add TBE data configuration reporter to TBE forward call. The reporter reports TBE data configuration at the `SplitTableBatchedEmbeddingBagsCodegen` ***forward*** call. The output is a `TBEDataConfig` object, which is written to a JSON file(s). The configuration of its environment variables and an example of its usage is described below. ## Just Knobs for enablement - fbgemm_gpu/features:TBE_REPORT_INPUT_PARAMS is added for enablement of the reporter (https://www.internalfb.com/intern/justknobs/?name=fbgemm_gpu%2Ffeatures) - Default is set to `False`, enable this flag to enable reporter. - To enable it locally use: ``` jk canary set fbgemm_gpu/features:TBE_REPORT_INPUT_PARAMS --on --ttl 600 ``` ## Environment Variables --------------------- The Reporter relies on several environment variables to control its behavior. Below is a description of each variable: - **FBGEMM_REPORT_INPUT_PARAMS_INTERVAL**: - **Description**: Determines the interval at which reports are generated. This is specified in terms of the number of iterations. - **Example Value**: `1` (report every iteration) - **FBGEMM_REPORT_INPUT_PARAMS_ITER_START**: - ***Description**: Specifies the start of the iteration range to capture reports. Default 0. - ***Example Value**: `0` (start reporting from the first iteration) - **FBGEMM_REPORT_INPUT_PARAMS_ITER_END**: - ***Description**: Specifies the end of the iteration range to capture reports. Use `-1` to report until the last iteration. Default -1. - ***Example Value**: `-1` (report until the last iteration) - **FBGEMM_REPORT_INPUT_PARAMS_BUCKET**: * **Description**: Specifies the name of the Manifold bucket where the report data will be saved. * **Example Value**: `tlparse_reports` - **FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX**: - **Description**: Defines the path prefix where the report files will be stored. Path will be created if not exist. - **Example Value**: `tree/tests/` ## Example Usage ------------- Below is an example command demonstrating how to use the FBGEMM Reporter with specific environment variable settings: ``` FBGEMM_REPORT_INPUT_PARAMS_INTERVAL=2 FBGEMM_REPORT_INPUT_PARAMS_ITER_START=3 FBGEMM_REPORT_INPUT_PARAMS_BUCKET=tlparse_reports FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX=tree/tests/ buck2 run mode/opt //deeplearning/fbgemm/fbgemm_gpu/bench:split_table_batched_embeddings -- device --iters 2 ``` **Explanation** The above setting will report `iter 3` and `iter 5` * **FBGEMM_REPORT_INPUT_PARAMS_INTERVAL=2**: The reporter will generate a report every 2 iterations. * **FBGEMM_REPORT_INPUT_PARAMS_ITER_START=0**: The reporter will start generating reports from the first iteration. * **FBGEMM_REPORT_INPUT_PARAMS_ITER_END=-1 (Default)**: The reporter will continue to generate reports until the last iteration interval. * **FBGEMM_REPORT_INPUT_PARAMS_BUCKET=tlparse_reports**: The reports will be saved in the `tlparse_reports` bucket. * **FBGEMM_REPORT_INPUT_PARAMS_PATH_PREFIX=tree/tests/**: The reports will be stored with the path prefix `tree/tests/`. For Manifold make sure all folders within the path exist. **Note on Benchmark example** Note that with the `--iters 2` option, the benchmark will execute 6 forward calls (2 iterations plus 1 warmup) for the forward benchmark and another 3 calls (2 iterations plus 1 warmup) for the backward benchmark. Iteration starts from 0. --- --- ## Other includes changes in this Diff: - Updates build dependency of tbe_data_config* files - Remove `shutil` and `numpy.random` lib as it cause uncompatiblity error. - Add non-OSS test, writing extracted config data json file to Manifold Differential Revision: D76992907
1 parent b60e109 commit 866a2b1

File tree

11 files changed

+607
-251
lines changed

11 files changed

+607
-251
lines changed

fbgemm_gpu/bench/tbe/tbe_training_benchmark.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@
3333
TBEBenchmarkingConfigLoader,
3434
TBEDataConfigLoader,
3535
)
36+
from fbgemm_gpu.tbe.bench.tbe_data_config_bench_helper import (
37+
generate_embedding_dims,
38+
generate_feature_requires_grad,
39+
generate_requests,
40+
)
3641
from fbgemm_gpu.tbe.ssd import SSDTableBatchedEmbeddingBags
3742
from fbgemm_gpu.tbe.utils import get_device
3843
from torch.profiler import profile
@@ -102,13 +107,13 @@ def device( # noqa C901
102107

103108
# Generate feature_requires_grad
104109
feature_requires_grad = (
105-
tbeconfig.generate_feature_requires_grad(weighted_num_requires_grad)
110+
generate_feature_requires_grad(tbeconfig, weighted_num_requires_grad)
106111
if weighted_num_requires_grad
107112
else None
108113
)
109114

110115
# Generate embedding dims
111-
effective_D, Ds = tbeconfig.generate_embedding_dims()
116+
effective_D, Ds = generate_embedding_dims(tbeconfig)
112117

113118
# Determine the optimizer
114119
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD
@@ -212,7 +217,7 @@ def device( # noqa C901
212217
f"Accessed weights per batch: {tbeconfig.batch_params.B * sum(Ds) * tbeconfig.pooling_params.L * param_size_multiplier / 1.0e9: .2f} GB"
213218
)
214219

215-
requests = tbeconfig.generate_requests(benchconfig.num_requests)
220+
requests = generate_requests(tbeconfig, benchconfig.num_requests)
216221

217222
# pyre-ignore[53]
218223
def _kineto_trace_handler(p: profile, phase: str) -> None:

fbgemm_gpu/fbgemm_gpu/config/feature_list.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def foo():
6060
# Enable bounds_check_indices_v2
6161
BOUNDS_CHECK_INDICES_V2 = auto()
6262

63+
# Enable TBE input parameters extraction
64+
TBE_REPORT_INPUT_PARAMS = auto()
65+
6366
def is_enabled(self) -> bool:
6467
return FeatureGate.is_enabled(self)
6568

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
generate_vbe_metadata,
5252
is_torchdynamo_compiling,
5353
)
54+
from fbgemm_gpu.tbe.stats import TBEBenchmarkParamsReporter
5455
from fbgemm_gpu.tbe_input_multiplexer import (
5556
TBEInfo,
5657
TBEInputInfo,
@@ -1441,6 +1442,11 @@ def __init__( # noqa C901
14411442
self._debug_print_input_stats_factory()
14421443
)
14431444

1445+
# Get a reporter function pointer
1446+
self._report_input_params: Callable[..., None] = (
1447+
self.__report_input_params_factory()
1448+
)
1449+
14441450
if optimizer == OptimType.EXACT_SGD and self.use_writeback_bwd_prehook:
14451451
# Register writeback hook for Exact_SGD optimizer
14461452
self.log(
@@ -1953,6 +1959,19 @@ def forward( # noqa: C901
19531959
# Print input stats if enable (for debugging purpose only)
19541960
self._debug_print_input_stats(indices, offsets, per_sample_weights)
19551961

1962+
# Extract and Write input stats if enable
1963+
if self._feature_is_enabled(FeatureGateName.TBE_REPORT_INPUT_PARAMS):
1964+
self._report_input_params(
1965+
feature_rows=self.rows_per_table,
1966+
feature_dims=self.feature_dims,
1967+
iteration=self.iter_cpu if hasattr(self, "iter_cpu") else 0,
1968+
indices=indices,
1969+
offsets=offsets,
1970+
op_id=self.uuid,
1971+
per_sample_weights=per_sample_weights,
1972+
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
1973+
)
1974+
19561975
if not is_torchdynamo_compiling():
19571976
# Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time
19581977

@@ -3804,6 +3823,39 @@ def _debug_print_input_stats_factory_null(
38043823
return _debug_print_input_stats_factory_impl
38053824
return _debug_print_input_stats_factory_null
38063825

3826+
@torch.jit.ignore
3827+
def __report_input_params_factory(
3828+
self,
3829+
) -> Callable[..., None]:
3830+
"""
3831+
This function returns a function pointer based on the environment variable `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL`.
3832+
3833+
If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is set to a value greater than 0, it returns a function pointer that:
3834+
- Reports input parameters (TBEDataConfig).
3835+
- Writes the output as a JSON file.
3836+
3837+
If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is not set or is set to 0, it returns a dummy function pointer that performs no action.
3838+
"""
3839+
3840+
@torch.jit.ignore
3841+
def __report_input_params_factory_null(
3842+
feature_rows: Tensor,
3843+
feature_dims: Tensor,
3844+
iteration: int,
3845+
indices: Tensor,
3846+
offsets: Tensor,
3847+
op_id: Optional[str] = None,
3848+
per_sample_weights: Optional[Tensor] = None,
3849+
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
3850+
) -> None:
3851+
pass
3852+
3853+
if self._feature_is_enabled(FeatureGateName.TBE_REPORT_INPUT_PARAMS):
3854+
3855+
reporter = TBEBenchmarkParamsReporter.create()
3856+
return reporter.report_stats
3857+
return __report_input_params_factory_null
3858+
38073859

38083860
class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
38093861
"""

fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py

Lines changed: 2 additions & 182 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,11 @@
99

1010
import dataclasses
1111
import json
12-
from typing import Any, Dict, List, Optional, Tuple
12+
from typing import Any, Dict, Optional
1313

14-
import numpy as np
1514
import torch
1615

17-
from fbgemm_gpu.tbe.utils.common import get_device, round_up
18-
from fbgemm_gpu.tbe.utils.requests import (
19-
generate_batch_sizes_from_stats,
20-
generate_pooling_factors_from_stats,
21-
get_table_batched_offsets_from_dense,
22-
maybe_to_dtype,
23-
TBERequest,
24-
)
16+
from fbgemm_gpu.tbe.utils.common import get_device
2517

2618
from .tbe_data_config_param_models import BatchParams, IndicesParams, PoolingParams
2719

@@ -104,175 +96,3 @@ def variable_L(self) -> bool:
10496
def _new_weights(self, size: int) -> Optional[torch.Tensor]:
10597
# Per-sample weights will always be FP32
10698
return None if not self.weighted else torch.randn(size, device=get_device())
107-
108-
def _generate_batch_sizes(self) -> Tuple[List[int], Optional[List[List[int]]]]:
109-
if self.variable_B():
110-
assert (
111-
self.batch_params.vbe_num_ranks is not None
112-
), "vbe_num_ranks must be set for varaible batch size generation"
113-
return generate_batch_sizes_from_stats(
114-
self.batch_params.B,
115-
self.T,
116-
# pyre-ignore [6]
117-
self.batch_params.sigma_B,
118-
self.batch_params.vbe_num_ranks,
119-
# pyre-ignore [6]
120-
self.batch_params.vbe_distribution,
121-
)
122-
123-
else:
124-
return ([self.batch_params.B] * self.T, None)
125-
126-
def _generate_pooling_info(self, iters: int, Bs: List[int]) -> torch.Tensor:
127-
if self.variable_L():
128-
# Generate L from stats
129-
_, L_offsets = generate_pooling_factors_from_stats(
130-
iters,
131-
Bs,
132-
self.pooling_params.L,
133-
# pyre-ignore [6]
134-
self.pooling_params.sigma_L,
135-
# pyre-ignore [6]
136-
self.pooling_params.length_distribution,
137-
)
138-
139-
else:
140-
Ls = [self.pooling_params.L] * (sum(Bs) * iters)
141-
L_offsets = torch.tensor([0] + Ls, dtype=torch.long).cumsum(0)
142-
143-
return L_offsets
144-
145-
def _generate_indices(
146-
self,
147-
iters: int,
148-
Bs: List[int],
149-
L_offsets: torch.Tensor,
150-
) -> torch.Tensor:
151-
total_B = sum(Bs)
152-
L_offsets_list = L_offsets.tolist()
153-
indices_list = []
154-
for it in range(iters):
155-
# L_offsets is defined over the entire set of batches for a single iteration
156-
start_offset = L_offsets_list[it * total_B]
157-
end_offset = L_offsets_list[(it + 1) * total_B]
158-
159-
indices_list.append(
160-
torch.ops.fbgemm.tbe_generate_indices_from_distribution(
161-
self.indices_params.heavy_hitters,
162-
self.indices_params.zipf_q,
163-
self.indices_params.zipf_s,
164-
# max_index = dimensions of the embedding table
165-
self.E,
166-
# num_indices = number of indices to generate
167-
end_offset - start_offset,
168-
)
169-
)
170-
171-
return torch.cat(indices_list)
172-
173-
def _build_requests_jagged(
174-
self,
175-
iters: int,
176-
Bs: List[int],
177-
Bs_feature_rank: Optional[List[List[int]]],
178-
L_offsets: torch.Tensor,
179-
all_indices: torch.Tensor,
180-
) -> List[TBERequest]:
181-
total_B = sum(Bs)
182-
all_indices = all_indices.flatten()
183-
requests = []
184-
for it in range(iters):
185-
start_offset = L_offsets[it * total_B]
186-
it_L_offsets = torch.concat(
187-
[
188-
torch.zeros(1, dtype=L_offsets.dtype, device=L_offsets.device),
189-
L_offsets[it * total_B + 1 : (it + 1) * total_B + 1] - start_offset,
190-
]
191-
)
192-
requests.append(
193-
TBERequest(
194-
maybe_to_dtype(
195-
all_indices[start_offset : L_offsets[(it + 1) * total_B]],
196-
self.indices_params.index_dtype,
197-
),
198-
maybe_to_dtype(
199-
it_L_offsets.to(get_device()), self.indices_params.offset_dtype
200-
),
201-
self._new_weights(int(it_L_offsets[-1].item())),
202-
Bs_feature_rank if self.variable_B() else None,
203-
)
204-
)
205-
return requests
206-
207-
def _build_requests_dense(
208-
self, iters: int, all_indices: torch.Tensor
209-
) -> List[TBERequest]:
210-
# NOTE: We're using existing code from requests.py to build the
211-
# requests, and since the existing code requires 2D view of all_indices,
212-
# the existing all_indices must be reshaped
213-
all_indices = all_indices.reshape(iters, -1)
214-
215-
requests = []
216-
for it in range(iters):
217-
indices, offsets = get_table_batched_offsets_from_dense(
218-
all_indices[it].view(
219-
self.T, self.batch_params.B, self.pooling_params.L
220-
),
221-
use_cpu=self.use_cpu,
222-
)
223-
requests.append(
224-
TBERequest(
225-
maybe_to_dtype(indices, self.indices_params.index_dtype),
226-
maybe_to_dtype(offsets, self.indices_params.offset_dtype),
227-
self._new_weights(
228-
self.T * self.batch_params.B * self.pooling_params.L
229-
),
230-
)
231-
)
232-
return requests
233-
234-
def generate_requests(
235-
self,
236-
iters: int = 1,
237-
) -> List[TBERequest]:
238-
# Generate batch sizes
239-
Bs, Bs_feature_rank = self._generate_batch_sizes()
240-
241-
# Generate pooling info
242-
L_offsets = self._generate_pooling_info(iters, Bs)
243-
244-
# Generate indices
245-
all_indices = self._generate_indices(iters, Bs, L_offsets)
246-
247-
# Build TBE requests
248-
if self.variable_B() or self.variable_L():
249-
return self._build_requests_jagged(
250-
iters, Bs, Bs_feature_rank, L_offsets, all_indices
251-
)
252-
else:
253-
return self._build_requests_dense(iters, all_indices)
254-
255-
def generate_embedding_dims(self) -> Tuple[int, List[int]]:
256-
if self.mixed_dim:
257-
Ds = [
258-
round_up(
259-
np.random.randint(low=int(0.5 * self.D), high=int(1.5 * self.D)), 4
260-
)
261-
for _ in range(self.T)
262-
]
263-
return (int(np.average(Ds)), Ds)
264-
else:
265-
return (self.D, [self.D] * self.T)
266-
267-
def generate_feature_requires_grad(self, size: int) -> torch.Tensor:
268-
assert size <= self.T, "size of feature_requires_grad must be less than T"
269-
weighted_requires_grad_tables = np.random.choice(
270-
self.T, replace=False, size=(size,)
271-
).tolist()
272-
return (
273-
torch.tensor(
274-
[1 if t in weighted_requires_grad_tables else 0 for t in range(self.T)]
275-
)
276-
.to(get_device())
277-
.int()
278-
)

0 commit comments

Comments
 (0)