Skip to content

Commit f7165d8

Browse files
Chenyu Zhangfacebook-github-bot
authored andcommitted
kv embedding inference test (#4373)
Summary: X-link: facebookresearch/FBGEMM#1442 Rollback Plan: Differential Revision: D76865305
1 parent bbabc32 commit f7165d8

File tree

1 file changed

+102
-0
lines changed

1 file changed

+102
-0
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
import unittest
10+
11+
import fbgemm_gpu
12+
13+
import torch
14+
from fbgemm_gpu.split_embedding_configs import SparseType
15+
from fbgemm_gpu.utils.loader import load_torch_module
16+
17+
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
18+
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
19+
20+
if not open_source:
21+
load_torch_module(
22+
"//deeplearning/fbgemm/fbgemm_gpu:dram_kv_embedding_inference",
23+
)
24+
25+
26+
@unittest.skipIf(open_source, "Not supported in open source yet")
27+
class DramKvInferenceTest(unittest.TestCase):
28+
def test_serialize(self) -> None:
29+
num_shards = 32
30+
uniform_init_lower: float = -0.01
31+
uniform_init_upper: float = 0.01
32+
evict_trigger_mode: int = 1
33+
34+
kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
35+
num_shards, uniform_init_lower, uniform_init_upper, evict_trigger_mode
36+
)
37+
serialized_result = kv_embedding_cache.serialize()
38+
39+
self.assertEqual(serialized_result[0][0], num_shards)
40+
self.assertEqual(serialized_result[0][1], evict_trigger_mode)
41+
42+
self.assertEqual(serialized_result[1][0], uniform_init_lower)
43+
self.assertEqual(serialized_result[1][1], uniform_init_upper)
44+
45+
def test_serialize_deserialize(self) -> None:
46+
num_shards = 32
47+
uniform_init_lower: float = -0.01
48+
uniform_init_upper: float = 0.01
49+
evict_trigger_mode: int = 1
50+
51+
kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
52+
num_shards, uniform_init_lower, uniform_init_upper, evict_trigger_mode
53+
)
54+
serialized_result = kv_embedding_cache.serialize()
55+
56+
kv_embedding_cache_2 = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
57+
0, 0.0, 0.0, 0
58+
)
59+
kv_embedding_cache_2.deserialize(serialized_result)
60+
61+
self.assertEqual(str(serialized_result), str(kv_embedding_cache_2.serialize()))
62+
63+
def test_set_get_embeddings(self) -> None:
64+
num_shards = 32
65+
uniform_init_lower: float = 0.0
66+
uniform_init_upper: float = 0.0
67+
evict_trigger_mode: int = 0
68+
69+
kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
70+
num_shards, uniform_init_lower, uniform_init_upper, evict_trigger_mode
71+
)
72+
kv_embedding_cache.init(
73+
[(20, 4, SparseType.INT8.as_int())],
74+
8,
75+
4,
76+
)
77+
78+
kv_embedding_cache.set_embeddings(
79+
torch.tensor([0, 1, 2, 3], dtype=torch.int64),
80+
torch.tensor(
81+
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
82+
dtype=torch.uint8,
83+
),
84+
)
85+
86+
embs = kv_embedding_cache.get_embeddings(
87+
torch.tensor([1, 4, 3, 0, 5, 2], dtype=torch.int64),
88+
)
89+
assert torch.equal(
90+
embs[:, :4],
91+
torch.tensor(
92+
[
93+
[5, 6, 7, 8],
94+
[0, 0, 0, 0],
95+
[13, 14, 15, 16],
96+
[1, 2, 3, 4],
97+
[0, 0, 0, 0],
98+
[9, 10, 11, 12],
99+
],
100+
dtype=torch.uint8,
101+
),
102+
)

0 commit comments

Comments
 (0)