Skip to content

Commit bd7faf5

Browse files
Chenyu Zhangfacebook-github-bot
authored andcommitted
kv embedding inference test
Differential Revision: D76865305
1 parent f2da3ab commit bd7faf5

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from unittest import TestCase
10+
11+
import fbgemm_gpu
12+
import torch
13+
from fbgemm_gpu.split_embedding_configs import SparseType
14+
from fbgemm_gpu.utils.loader import load_torch_module
15+
16+
# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
17+
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
18+
19+
if not open_source:
20+
load_torch_module(
21+
"//deeplearning/fbgemm/fbgemm_gpu:dram_kv_embedding_inference",
22+
)
23+
24+
25+
class DramKvInferenceTest(TestCase):
26+
def test_serialize(self) -> None:
27+
num_shards = 32
28+
uniform_init_lower: float = -0.01
29+
uniform_init_upper: float = 0.01
30+
evict_trigger_mode: int = 1
31+
32+
kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
33+
num_shards, uniform_init_lower, uniform_init_upper, evict_trigger_mode
34+
)
35+
serialized_result = kv_embedding_cache.serialize()
36+
37+
self.assertEqual(serialized_result[0][0], num_shards)
38+
self.assertEqual(serialized_result[0][1], evict_trigger_mode)
39+
40+
self.assertEqual(serialized_result[1][0], uniform_init_lower)
41+
self.assertEqual(serialized_result[1][1], uniform_init_upper)
42+
43+
def test_serialize_deserialize(self) -> None:
44+
num_shards = 32
45+
uniform_init_lower: float = -0.01
46+
uniform_init_upper: float = 0.01
47+
evict_trigger_mode: int = 1
48+
49+
kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
50+
num_shards, uniform_init_lower, uniform_init_upper, evict_trigger_mode
51+
)
52+
serialized_result = kv_embedding_cache.serialize()
53+
54+
kv_embedding_cache_2 = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
55+
0, 0.0, 0.0, 0
56+
)
57+
kv_embedding_cache_2.deserialize(serialized_result)
58+
59+
self.assertEqual(str(serialized_result), str(kv_embedding_cache_2.serialize()))
60+
61+
def test_set_get_embeddings(self) -> None:
62+
num_shards = 32
63+
uniform_init_lower: float = 0.0
64+
uniform_init_upper: float = 0.0
65+
evict_trigger_mode: int = 0
66+
67+
kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
68+
num_shards, uniform_init_lower, uniform_init_upper, evict_trigger_mode
69+
)
70+
kv_embedding_cache.init(
71+
[(20, 4, SparseType.INT8.as_int())],
72+
8,
73+
4,
74+
)
75+
76+
kv_embedding_cache.set_embeddings(
77+
torch.tensor([0, 1, 2, 3], dtype=torch.int64),
78+
torch.tensor(
79+
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
80+
dtype=torch.uint8,
81+
),
82+
)
83+
84+
embs = kv_embedding_cache.get_embeddings(
85+
torch.tensor([1, 4, 3, 0, 5, 2], dtype=torch.int64),
86+
)
87+
assert torch.equal(
88+
embs[:, :4],
89+
torch.tensor(
90+
[
91+
[5, 6, 7, 8],
92+
[0, 0, 0, 0],
93+
[13, 14, 15, 16],
94+
[1, 2, 3, 4],
95+
[0, 0, 0, 0],
96+
[9, 10, 11, 12],
97+
],
98+
dtype=torch.uint8,
99+
),
100+
)

0 commit comments

Comments
 (0)