Skip to content

Add TBE data configuration reporter to TBE forward (v3) #4455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions fbgemm_gpu/bench/tbe/tbe_training_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@
TBEBenchmarkingConfigLoader,
TBEDataConfigLoader,
)
from fbgemm_gpu.tbe.bench.tbe_data_config_bench_helper import (
generate_embedding_dims,
generate_feature_requires_grad,
generate_requests,
)
from fbgemm_gpu.tbe.ssd import SSDTableBatchedEmbeddingBags
from fbgemm_gpu.tbe.utils import get_device
from torch.profiler import profile
Expand Down Expand Up @@ -102,13 +107,13 @@ def device( # noqa C901

# Generate feature_requires_grad
feature_requires_grad = (
tbeconfig.generate_feature_requires_grad(weighted_num_requires_grad)
generate_feature_requires_grad(tbeconfig, weighted_num_requires_grad)
if weighted_num_requires_grad
else None
)

# Generate embedding dims
effective_D, Ds = tbeconfig.generate_embedding_dims()
effective_D, Ds = generate_embedding_dims(tbeconfig)

# Determine the optimizer
optimizer = OptimType.EXACT_ROWWISE_ADAGRAD if row_wise else OptimType.EXACT_ADAGRAD
Expand Down Expand Up @@ -212,7 +217,7 @@ def device( # noqa C901
f"Accessed weights per batch: {tbeconfig.batch_params.B * sum(Ds) * tbeconfig.pooling_params.L * param_size_multiplier / 1.0e9: .2f} GB"
)

requests = tbeconfig.generate_requests(benchconfig.num_requests)
requests = generate_requests(tbeconfig, benchconfig.num_requests)

# pyre-ignore[53]
def _kineto_trace_handler(p: profile, phase: str) -> None:
Expand Down
3 changes: 3 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/config/feature_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def foo():
# Enable bounds_check_indices_v2
BOUNDS_CHECK_INDICES_V2 = auto()

# Enable TBE input parameters extraction
TBE_REPORT_INPUT_PARAMS = auto()

def is_enabled(self) -> bool:
return FeatureGate.is_enabled(self)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
generate_vbe_metadata,
is_torchdynamo_compiling,
)
from fbgemm_gpu.tbe.stats import TBEBenchmarkParamsReporter
from fbgemm_gpu.tbe_input_multiplexer import (
TBEInfo,
TBEInputInfo,
Expand Down Expand Up @@ -1441,6 +1442,11 @@ def __init__( # noqa C901
self._debug_print_input_stats_factory()
)

# Get a reporter function pointer
self._report_input_params: Callable[..., None] = (
self.__report_input_params_factory()
)

if optimizer == OptimType.EXACT_SGD and self.use_writeback_bwd_prehook:
# Register writeback hook for Exact_SGD optimizer
self.log(
Expand Down Expand Up @@ -1953,6 +1959,19 @@ def forward( # noqa: C901
# Print input stats if enable (for debugging purpose only)
self._debug_print_input_stats(indices, offsets, per_sample_weights)

# Extract and Write input stats if enable
if self._report_input_params is not None:
self._report_input_params(
feature_rows=self.rows_per_table,
feature_dims=self.feature_dims,
iteration=self.iter_cpu.item() if hasattr(self, "iter_cpu") else 0,
indices=indices,
offsets=offsets,
op_id=self.uuid,
per_sample_weights=per_sample_weights,
batch_size_per_feature_per_rank=batch_size_per_feature_per_rank,
)

if not is_torchdynamo_compiling():
# Mutations of nn.Module attr forces dynamo restart of Analysis which increases compilation time

Expand Down Expand Up @@ -3804,6 +3823,25 @@ def _debug_print_input_stats_factory_null(
return _debug_print_input_stats_factory_impl
return _debug_print_input_stats_factory_null

@torch.jit.ignore
def __report_input_params_factory(
self,
) -> Optional[Callable[..., None]]:
"""
This function returns a function pointer based on the environment variable `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL`.

If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is set to a value greater than 0, it returns a function pointer that:
- Reports input parameters (TBEDataConfig).
- Writes the output as a JSON file.

If `FBGEMM_REPORT_INPUT_PARAMS_INTERVAL` is not set or is set to 0, it returns a dummy function pointer that performs no action.
"""

if self._feature_is_enabled(FeatureGateName.TBE_REPORT_INPUT_PARAMS):
reporter = TBEBenchmarkParamsReporter.create()
return reporter.report_stats
return None


class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
"""
Expand Down
184 changes: 2 additions & 182 deletions fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,11 @@

import dataclasses
import json
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Optional

import numpy as np
import torch

from fbgemm_gpu.tbe.utils.common import get_device, round_up
from fbgemm_gpu.tbe.utils.requests import (
generate_batch_sizes_from_stats,
generate_pooling_factors_from_stats,
get_table_batched_offsets_from_dense,
maybe_to_dtype,
TBERequest,
)
from fbgemm_gpu.tbe.utils.common import get_device

from .tbe_data_config_param_models import BatchParams, IndicesParams, PoolingParams

Expand Down Expand Up @@ -104,175 +96,3 @@ def variable_L(self) -> bool:
def _new_weights(self, size: int) -> Optional[torch.Tensor]:
# Per-sample weights will always be FP32
return None if not self.weighted else torch.randn(size, device=get_device())

def _generate_batch_sizes(self) -> Tuple[List[int], Optional[List[List[int]]]]:
if self.variable_B():
assert (
self.batch_params.vbe_num_ranks is not None
), "vbe_num_ranks must be set for varaible batch size generation"
return generate_batch_sizes_from_stats(
self.batch_params.B,
self.T,
# pyre-ignore [6]
self.batch_params.sigma_B,
self.batch_params.vbe_num_ranks,
# pyre-ignore [6]
self.batch_params.vbe_distribution,
)

else:
return ([self.batch_params.B] * self.T, None)

def _generate_pooling_info(self, iters: int, Bs: List[int]) -> torch.Tensor:
if self.variable_L():
# Generate L from stats
_, L_offsets = generate_pooling_factors_from_stats(
iters,
Bs,
self.pooling_params.L,
# pyre-ignore [6]
self.pooling_params.sigma_L,
# pyre-ignore [6]
self.pooling_params.length_distribution,
)

else:
Ls = [self.pooling_params.L] * (sum(Bs) * iters)
L_offsets = torch.tensor([0] + Ls, dtype=torch.long).cumsum(0)

return L_offsets

def _generate_indices(
self,
iters: int,
Bs: List[int],
L_offsets: torch.Tensor,
) -> torch.Tensor:
total_B = sum(Bs)
L_offsets_list = L_offsets.tolist()
indices_list = []
for it in range(iters):
# L_offsets is defined over the entire set of batches for a single iteration
start_offset = L_offsets_list[it * total_B]
end_offset = L_offsets_list[(it + 1) * total_B]

indices_list.append(
torch.ops.fbgemm.tbe_generate_indices_from_distribution(
self.indices_params.heavy_hitters,
self.indices_params.zipf_q,
self.indices_params.zipf_s,
# max_index = dimensions of the embedding table
self.E,
# num_indices = number of indices to generate
end_offset - start_offset,
)
)

return torch.cat(indices_list)

def _build_requests_jagged(
self,
iters: int,
Bs: List[int],
Bs_feature_rank: Optional[List[List[int]]],
L_offsets: torch.Tensor,
all_indices: torch.Tensor,
) -> List[TBERequest]:
total_B = sum(Bs)
all_indices = all_indices.flatten()
requests = []
for it in range(iters):
start_offset = L_offsets[it * total_B]
it_L_offsets = torch.concat(
[
torch.zeros(1, dtype=L_offsets.dtype, device=L_offsets.device),
L_offsets[it * total_B + 1 : (it + 1) * total_B + 1] - start_offset,
]
)
requests.append(
TBERequest(
maybe_to_dtype(
all_indices[start_offset : L_offsets[(it + 1) * total_B]],
self.indices_params.index_dtype,
),
maybe_to_dtype(
it_L_offsets.to(get_device()), self.indices_params.offset_dtype
),
self._new_weights(int(it_L_offsets[-1].item())),
Bs_feature_rank if self.variable_B() else None,
)
)
return requests

def _build_requests_dense(
self, iters: int, all_indices: torch.Tensor
) -> List[TBERequest]:
# NOTE: We're using existing code from requests.py to build the
# requests, and since the existing code requires 2D view of all_indices,
# the existing all_indices must be reshaped
all_indices = all_indices.reshape(iters, -1)

requests = []
for it in range(iters):
indices, offsets = get_table_batched_offsets_from_dense(
all_indices[it].view(
self.T, self.batch_params.B, self.pooling_params.L
),
use_cpu=self.use_cpu,
)
requests.append(
TBERequest(
maybe_to_dtype(indices, self.indices_params.index_dtype),
maybe_to_dtype(offsets, self.indices_params.offset_dtype),
self._new_weights(
self.T * self.batch_params.B * self.pooling_params.L
),
)
)
return requests

def generate_requests(
self,
iters: int = 1,
) -> List[TBERequest]:
# Generate batch sizes
Bs, Bs_feature_rank = self._generate_batch_sizes()

# Generate pooling info
L_offsets = self._generate_pooling_info(iters, Bs)

# Generate indices
all_indices = self._generate_indices(iters, Bs, L_offsets)

# Build TBE requests
if self.variable_B() or self.variable_L():
return self._build_requests_jagged(
iters, Bs, Bs_feature_rank, L_offsets, all_indices
)
else:
return self._build_requests_dense(iters, all_indices)

def generate_embedding_dims(self) -> Tuple[int, List[int]]:
if self.mixed_dim:
Ds = [
round_up(
np.random.randint(low=int(0.5 * self.D), high=int(1.5 * self.D)), 4
)
for _ in range(self.T)
]
return (int(np.average(Ds)), Ds)
else:
return (self.D, [self.D] * self.T)

def generate_feature_requires_grad(self, size: int) -> torch.Tensor:
assert size <= self.T, "size of feature_requires_grad must be less than T"
weighted_requires_grad_tables = np.random.choice(
self.T, replace=False, size=(size,)
).tolist()
return (
torch.tensor(
[1 if t in weighted_requires_grad_tables else 0 for t in range(self.T)]
)
.to(get_device())
.int()
)
Loading
Loading