Skip to content

Commit c63f643

Browse files
Chenyu Zhangfacebook-github-bot
authored andcommitted
kvzch inference python operator (#4344)
Summary: X-link: facebookresearch/FBGEMM#1412 Python operator change for kv embedding. https://docs.google.com/document/d/1TNJMnj-PPKWitMgwB8HJIsFT3OwotiEnqgh60fI9P48/edit?tab=t.0#heading=h.o9irumwgl8gj Reviewed By: emlin Differential Revision: D73219651
1 parent c572afc commit c63f643

File tree

4 files changed

+654
-0
lines changed

4 files changed

+654
-0
lines changed
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import gc
10+
import logging
11+
import time
12+
from typing import Callable, Dict, Type
13+
14+
import click
15+
import numpy as np
16+
import psutil
17+
from fbgemm_gpu.split_embedding_configs import SparseType
18+
from fbgemm_gpu.split_table_batched_embeddings_ops_common import EmbeddingLocation
19+
from fbgemm_gpu.split_table_batched_embeddings_ops_inference import (
20+
IntNBitTableBatchedEmbeddingBagsCodegen,
21+
)
22+
from fbgemm_gpu.tbe.bench import benchmark_requests
23+
from fbgemm_gpu.tbe.cache import KVEmbeddingInference
24+
from fbgemm_gpu.tbe.utils import generate_requests, round_up, TBERequest
25+
26+
OptionCommandType = Callable[..., Callable[..., None]]
27+
28+
iters: OptionCommandType = click.option(
29+
"--iters",
30+
default=200,
31+
type=int,
32+
help="Number of iterations to benchmark",
33+
)
34+
num_embeddings: OptionCommandType = click.option(
35+
"--num-embeddings",
36+
default=int(1e8),
37+
type=int,
38+
help="Number of embedding to benchmark",
39+
)
40+
dim: OptionCommandType = click.option(
41+
"--dim", default=256, type=int, help="Dimension of embedding to benchmark"
42+
)
43+
num_tables: OptionCommandType = click.option(
44+
"--num-tables", default=4, type=int, help="Number of tables to benchmark"
45+
)
46+
output_dtype: OptionCommandType = click.option(
47+
"--output-dtype", type=SparseType, default=SparseType.FP16
48+
)
49+
weights_precision: OptionCommandType = click.option(
50+
"--weights-precision", type=SparseType, default=SparseType.INT8
51+
)
52+
batch_size: OptionCommandType = click.option("--batch-size", default=128)
53+
bag_size: OptionCommandType = click.option("--bag-size", default=1)
54+
mixed_dim: OptionCommandType = click.option("--mixed-dim", is_flag=True, default=False)
55+
tbe_class: OptionCommandType = click.option(
56+
"--tbe-class", type=str, default="KVEmbeddingInference"
57+
)
58+
59+
60+
TBE_CLASS_MAP: Dict[str, Type[IntNBitTableBatchedEmbeddingBagsCodegen]] = {
61+
"KVEmbeddingInference": KVEmbeddingInference,
62+
"IntNBitTableBatchedEmbeddingBagsCodegen": IntNBitTableBatchedEmbeddingBagsCodegen,
63+
}
64+
65+
66+
@click.group()
67+
def cli() -> None:
68+
pass
69+
70+
71+
@cli.command()
72+
@iters
73+
@num_embeddings
74+
@dim
75+
@num_tables
76+
@output_dtype
77+
@weights_precision
78+
@batch_size
79+
@bag_size
80+
@mixed_dim
81+
@tbe_class
82+
def forward_benchmark(
83+
iters: int,
84+
num_embeddings: int,
85+
dim: int,
86+
num_tables: int,
87+
output_dtype: SparseType,
88+
weights_precision: SparseType,
89+
batch_size: int,
90+
bag_size: int,
91+
mixed_dim: bool,
92+
tbe_class: str,
93+
) -> None:
94+
logging.info(
95+
f"Running forward benchmark with {iters} iterations, {num_embeddings} embeddings, {dim} dim, {num_tables} tables, {output_dtype} output dtype, {weights_precision} weights precision, {batch_size} batch"
96+
)
97+
98+
stats = []
99+
100+
if mixed_dim:
101+
dimentions = [
102+
round_up(np.random.randint(low=int(0.5 * dim), high=int(1.5 * dim)), 4)
103+
for _ in range(num_tables)
104+
]
105+
else:
106+
dimentions = [dim] * num_tables
107+
108+
process = psutil.Process()
109+
110+
clazz = TBE_CLASS_MAP[tbe_class]
111+
112+
time.sleep(5)
113+
mem_util_before = process.memory_info().rss / (1024 * 1024)
114+
logging.info(f"Memory util before emb init: {mem_util_before} MB")
115+
tbe = clazz(
116+
[
117+
(
118+
"",
119+
num_embeddings,
120+
d,
121+
weights_precision,
122+
EmbeddingLocation.HOST,
123+
)
124+
for d in dimentions
125+
],
126+
output_dtype=output_dtype,
127+
device="cpu",
128+
)
129+
tbe.fill_random_weights()
130+
131+
gc.collect()
132+
time.sleep(5)
133+
mem_util_after = process.memory_info().rss / (1024 * 1024)
134+
logging.info(f"Memory util after emb fill: {mem_util_after} MB")
135+
logging.info(f"Memory util diff: {mem_util_after - mem_util_before} MB")
136+
137+
for batch_size in [10240, 20480, 40960]:
138+
requests = generate_requests(
139+
iters,
140+
batch_size,
141+
num_tables,
142+
bag_size,
143+
num_embeddings,
144+
use_cpu=True,
145+
)
146+
147+
requests_cpu = [
148+
TBERequest(
149+
req.indices.int().cpu(),
150+
req.offsets.int().cpu(),
151+
req.per_sample_weights,
152+
)
153+
for req in requests
154+
]
155+
156+
logging.info(f"Running forward benchmark with {len(requests_cpu)} requests")
157+
time_per_iter = benchmark_requests(
158+
requests_cpu,
159+
lambda indices, offsets, per_sample_weights: tbe.forward(
160+
indices.int().cpu(),
161+
offsets.int().cpu(),
162+
per_sample_weights,
163+
),
164+
num_warmups=10,
165+
)
166+
logging.info(f"{clazz} CPU Time: {time_per_iter * 1.0e6:.0f}us")
167+
stats.append(
168+
[
169+
clazz,
170+
num_tables,
171+
batch_size,
172+
f"{time_per_iter * 1.0e6:.0f}us",
173+
f"{mem_util_after - mem_util_before} MB",
174+
]
175+
)
176+
for stat in stats:
177+
logging.info(stat)
178+
179+
180+
if __name__ == "__main__":
181+
cli()

fbgemm_gpu/fbgemm_gpu/tbe/cache/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@
77

88
# pyre-unsafe
99

10+
from .kv_embedding_ops_inference import KVEmbeddingInference # noqa: F401
1011
from .split_embeddings_cache_ops import get_unique_indices_v2 # noqa: F401

0 commit comments

Comments
 (0)