Skip to content

Commit 127420d

Browse files
generatedunixname89002005232357facebook-github-bot
authored andcommitted
Revert D75462895: Multisect successfully blamed "D75462895: [fbgemm_gpu] Add TBE data configuration reporter to TBE forward (v2)" for one test failure (#4381)
Summary: X-link: facebookresearch/FBGEMM#1451 Pull Request resolved: #4381 This diff reverts D75462895 Depends on D76973522 D75462895: [fbgemm_gpu] Add TBE data configuration reporter to TBE forward (v2) by gchalump causes the following test failure: Tests affected: - [cogwheel:cogwheel_aps_ads_icvr_test_on_mast_fm#test_aps_ads_icvr_mast_fm_in_trainer_publish](https://www.internalfb.com/intern/test/281475084377136/) Here's the Multisect link: https://www.internalfb.com/multisect/29905580 Here are the tasks that are relevant to this breakage: T226571912: Some build rules, one test, one sandcastle job unhealthy for aps_ads_release The backout may land if someone accepts it. If this diff has been generated in error, you can Commandeer and Abandon it. Depends on D75462895 Reviewed By: gchalump Differential Revision: D76973532 fbshipit-source-id: 9ee731e2f336a404b0e985f1c9dfbe535e2d32ae
1 parent 360479b commit 127420d

File tree

12 files changed

+264
-614
lines changed

12 files changed

+264
-614
lines changed

fbgemm_gpu/bench/tbe/tbe_training_benchmark.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,6 @@
3333
TBEBenchmarkingConfigLoader,
3434
TBEDataConfigLoader,
3535
)
36-
from fbgemm_gpu.tbe.bench.tbe_data_config_bench_helper import (
37-
generate_embedding_dims,
38-
generate_feature_requires_grad,
39-
generate_requests,
40-
)
4136
from fbgemm_gpu.tbe.ssd import SSDTableBatchedEmbeddingBags
4237
from fbgemm_gpu.tbe.utils import get_device
4338
from torch.profiler import profile
@@ -107,13 +102,13 @@ def device( # noqa C901
107102

108103
# Generate feature_requires_grad
109104
feature_requires_grad = (
110-
generate_feature_requires_grad(tbeconfig, weighted_num_requires_grad)
105+
tbeconfig.generate_feature_requires_grad(weighted_num_requires_grad)
111106
if weighted_num_requires_grad
112107
else None
113108
)
114109

115110
# Generate embedding dims
116-
effective_D, Ds = generate_embedding_dims(tbeconfig)
111+
effective_D, Ds = tbeconfig.generate_embedding_dims()
117112

118113
# Determine the optimizer
119114
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD
@@ -217,7 +212,7 @@ def device( # noqa C901
217212
f"Accessed weights per batch: {tbeconfig.batch_params.B * sum(Ds) * tbeconfig.pooling_params.L * param_size_multiplier / 1.0e9: .2f} GB"
218213
)
219214

220-
requests = generate_requests(tbeconfig, benchconfig.num_requests)
215+
requests = tbeconfig.generate_requests(benchconfig.num_requests)
221216

222217
# pyre-ignore[53]
223218
def _kineto_trace_handler(p: profile, phase: str) -> None:

fbgemm_gpu/fbgemm_gpu/config/feature_list.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,6 @@ def foo():
6363
# Disable FP8 quantization vectorization
6464
DISABLE_FP8_QUANT_VECTORIZATION = auto()
6565

66-
# Enable TBE input parameters extraction
67-
TBE_REPORT_INPUT_PARAMS = auto()
68-
6966
def is_enabled(self) -> bool:
7067
return FeatureGate.is_enabled(self)
7168

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 0 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
generate_vbe_metadata,
5252
is_torchdynamo_compiling,
5353
)
54-
from fbgemm_gpu.tbe.stats import TBEBenchmarkParamsReporter
5554
from fbgemm_gpu.tbe_input_multiplexer import (
5655
TBEInfo,
5756
TBEInputInfo,
@@ -1442,11 +1441,6 @@ def __init__( # noqa C901
14421441
self._debug_print_input_stats_factory()
14431442
)
14441443

1445-
# Get a reporter function pointer
1446-
self._report_input_params: Callable[..., None] = (
1447-
self.__report_input_params_factory()
1448-
)
1449-
14501444
if optimizer == OptimType.EXACT_SGD and self.use_writeback_bwd_prehook:
14511445
# Register writeback hook for Exact_SGD optimizer
14521446
self.log(
@@ -1959,18 +1953,6 @@ def forward( # noqa: C901
19591953
# Print input stats if enable (for debugging purpose only)
19601954
self._debug_print_input_stats(indices, offsets, per_sample_weights)
19611955

1962-
# Extract and Write input stats if enable
1963-
self._report_input_params(
1964-
feature_rows=self.rows_per_table,
1965-
feature_dims=self.feature_dims,
1966-
iteration=self.iter.item() if hasattr(self, "iter") else 0,
1967-
indices=indices,
1968-
offsets=offsets,
1969-
op_id=self.uuid,
1970-
per_sample_weights=per_sample_weights,
1971-
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
1972-
)
1973-
19741956
if not is_torchdynamo_compiling():
19751957
# Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time
19761958

@@ -3822,39 +3804,6 @@ def _debug_print_input_stats_factory_null(
38223804
return _debug_print_input_stats_factory_impl
38233805
return _debug_print_input_stats_factory_null
38243806

3825-
@torch.jit.ignore
3826-
def __report_input_params_factory(
3827-
self,
3828-
) -> Callable[..., None]:
3829-
"""
3830-
This function returns a function pointer based on the environment variable `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL`.
3831-
3832-
If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is set to a value greater than 0, it returns a function pointer that:
3833-
- Reports input parameters (TBEDataConfig).
3834-
- Writes the output as a JSON file.
3835-
3836-
If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is not set or is set to 0, it returns a dummy function pointer that performs no action.
3837-
"""
3838-
3839-
@torch.jit.ignore
3840-
def __report_input_params_factory_null(
3841-
feature_rows: Tensor,
3842-
feature_dims: Tensor,
3843-
iteration: int,
3844-
indices: Tensor,
3845-
offsets: Tensor,
3846-
op_id: Optional[str] = None,
3847-
per_sample_weights: Optional[Tensor] = None,
3848-
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
3849-
) -> None:
3850-
pass
3851-
3852-
if self._feature_is_enabled(FeatureGateName.TBE_REPORT_INPUT_PARAMS):
3853-
3854-
reporter = TBEBenchmarkParamsReporter.create()
3855-
return reporter.report_stats
3856-
return __report_input_params_factory_null
3857-
38583807

38593808
class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
38603809
"""

fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py

Lines changed: 182 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,19 @@
99

1010
import dataclasses
1111
import json
12-
from typing import Any, Dict, Optional
12+
from typing import Any, Dict, List, Optional, Tuple
1313

14+
import numpy as np
1415
import torch
1516

16-
from fbgemm_gpu.tbe.utils.common import get_device
17+
from fbgemm_gpu.tbe.utils.common import get_device, round_up
18+
from fbgemm_gpu.tbe.utils.requests import (
19+
generate_batch_sizes_from_stats,
20+
generate_pooling_factors_from_stats,
21+
get_table_batched_offsets_from_dense,
22+
maybe_to_dtype,
23+
TBERequest,
24+
)
1725

1826
from .tbe_data_config_param_models import BatchParams, IndicesParams, PoolingParams
1927

@@ -96,3 +104,175 @@ def variable_L(self) -> bool:
96104
def _new_weights(self, size: int) -> Optional[torch.Tensor]:
97105
# Per-sample weights will always be FP32
98106
return None if not self.weighted else torch.randn(size, device=get_device())
107+
108+
def _generate_batch_sizes(self) -> Tuple[List[int], Optional[List[List[int]]]]:
109+
if self.variable_B():
110+
assert (
111+
self.batch_params.vbe_num_ranks is not None
112+
), "vbe_num_ranks must be set for varaible batch size generation"
113+
return generate_batch_sizes_from_stats(
114+
self.batch_params.B,
115+
self.T,
116+
# pyre-ignore [6]
117+
self.batch_params.sigma_B,
118+
self.batch_params.vbe_num_ranks,
119+
# pyre-ignore [6]
120+
self.batch_params.vbe_distribution,
121+
)
122+
123+
else:
124+
return ([self.batch_params.B] * self.T, None)
125+
126+
def _generate_pooling_info(self, iters: int, Bs: List[int]) -> torch.Tensor:
127+
if self.variable_L():
128+
# Generate L from stats
129+
_, L_offsets = generate_pooling_factors_from_stats(
130+
iters,
131+
Bs,
132+
self.pooling_params.L,
133+
# pyre-ignore [6]
134+
self.pooling_params.sigma_L,
135+
# pyre-ignore [6]
136+
self.pooling_params.length_distribution,
137+
)
138+
139+
else:
140+
Ls = [self.pooling_params.L] * (sum(Bs) * iters)
141+
L_offsets = torch.tensor([0] + Ls, dtype=torch.long).cumsum(0)
142+
143+
return L_offsets
144+
145+
def _generate_indices(
146+
self,
147+
iters: int,
148+
Bs: List[int],
149+
L_offsets: torch.Tensor,
150+
) -> torch.Tensor:
151+
total_B = sum(Bs)
152+
L_offsets_list = L_offsets.tolist()
153+
indices_list = []
154+
for it in range(iters):
155+
# L_offsets is defined over the entire set of batches for a single iteration
156+
start_offset = L_offsets_list[it * total_B]
157+
end_offset = L_offsets_list[(it + 1) * total_B]
158+
159+
indices_list.append(
160+
torch.ops.fbgemm.tbe_generate_indices_from_distribution(
161+
self.indices_params.heavy_hitters,
162+
self.indices_params.zipf_q,
163+
self.indices_params.zipf_s,
164+
# max_index = dimensions of the embedding table
165+
self.E,
166+
# num_indices = number of indices to generate
167+
end_offset - start_offset,
168+
)
169+
)
170+
171+
return torch.cat(indices_list)
172+
173+
def _build_requests_jagged(
174+
self,
175+
iters: int,
176+
Bs: List[int],
177+
Bs_feature_rank: Optional[List[List[int]]],
178+
L_offsets: torch.Tensor,
179+
all_indices: torch.Tensor,
180+
) -> List[TBERequest]:
181+
total_B = sum(Bs)
182+
all_indices = all_indices.flatten()
183+
requests = []
184+
for it in range(iters):
185+
start_offset = L_offsets[it * total_B]
186+
it_L_offsets = torch.concat(
187+
[
188+
torch.zeros(1, dtype=L_offsets.dtype, device=L_offsets.device),
189+
L_offsets[it * total_B + 1 : (it + 1) * total_B + 1] - start_offset,
190+
]
191+
)
192+
requests.append(
193+
TBERequest(
194+
maybe_to_dtype(
195+
all_indices[start_offset : L_offsets[(it + 1) * total_B]],
196+
self.indices_params.index_dtype,
197+
),
198+
maybe_to_dtype(
199+
it_L_offsets.to(get_device()), self.indices_params.offset_dtype
200+
),
201+
self._new_weights(int(it_L_offsets[-1].item())),
202+
Bs_feature_rank if self.variable_B() else None,
203+
)
204+
)
205+
return requests
206+
207+
def _build_requests_dense(
208+
self, iters: int, all_indices: torch.Tensor
209+
) -> List[TBERequest]:
210+
# NOTE: We're using existing code from requests.py to build the
211+
# requests, and since the existing code requires 2D view of all_indices,
212+
# the existing all_indices must be reshaped
213+
all_indices = all_indices.reshape(iters, -1)
214+
215+
requests = []
216+
for it in range(iters):
217+
indices, offsets = get_table_batched_offsets_from_dense(
218+
all_indices[it].view(
219+
self.T, self.batch_params.B, self.pooling_params.L
220+
),
221+
use_cpu=self.use_cpu,
222+
)
223+
requests.append(
224+
TBERequest(
225+
maybe_to_dtype(indices, self.indices_params.index_dtype),
226+
maybe_to_dtype(offsets, self.indices_params.offset_dtype),
227+
self._new_weights(
228+
self.T * self.batch_params.B * self.pooling_params.L
229+
),
230+
)
231+
)
232+
return requests
233+
234+
def generate_requests(
235+
self,
236+
iters: int = 1,
237+
) -> List[TBERequest]:
238+
# Generate batch sizes
239+
Bs, Bs_feature_rank = self._generate_batch_sizes()
240+
241+
# Generate pooling info
242+
L_offsets = self._generate_pooling_info(iters, Bs)
243+
244+
# Generate indices
245+
all_indices = self._generate_indices(iters, Bs, L_offsets)
246+
247+
# Build TBE requests
248+
if self.variable_B() or self.variable_L():
249+
return self._build_requests_jagged(
250+
iters, Bs, Bs_feature_rank, L_offsets, all_indices
251+
)
252+
else:
253+
return self._build_requests_dense(iters, all_indices)
254+
255+
def generate_embedding_dims(self) -> Tuple[int, List[int]]:
256+
if self.mixed_dim:
257+
Ds = [
258+
round_up(
259+
np.random.randint(low=int(0.5 * self.D), high=int(1.5 * self.D)), 4
260+
)
261+
for _ in range(self.T)
262+
]
263+
return (int(np.average(Ds)), Ds)
264+
else:
265+
return (self.D, [self.D] * self.T)
266+
267+
def generate_feature_requires_grad(self, size: int) -> torch.Tensor:
268+
assert size <= self.T, "size of feature_requires_grad must be less than T"
269+
weighted_requires_grad_tables = np.random.choice(
270+
self.T, replace=False, size=(size,)
271+
).tolist()
272+
return (
273+
torch.tensor(
274+
[1 if t in weighted_requires_grad_tables else 0 for t in range(self.T)]
275+
)
276+
.to(get_device())
277+
.int()
278+
)

0 commit comments

Comments
 (0)