Skip to content

Commit 4d66d79

Browse files
Chenyu Zhangfacebook-github-bot
authored andcommitted
kv embedding inference test (#4373)
Summary: X-link: facebookresearch/FBGEMM#1442 Rollback Plan: Differential Revision: D76865305
1 parent 2ae43a4 commit 4d66d79

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from unittest import TestCase
10+
11+
import torch
12+
from fbgemm_gpu.split_embedding_configs import SparseType
13+
from fbgemm_gpu.utils.loader import load_torch_module
14+
15+
load_torch_module(
16+
"//deeplearning/fbgemm/fbgemm_gpu:dram_kv_embedding_inference",
17+
)
18+
19+
20+
class DramKvInferenceTest(TestCase):
21+
def test_serialize(self) -> None:
22+
num_shards = 32
23+
uniform_init_lower: float = -0.01
24+
uniform_init_upper: float = 0.01
25+
evict_trigger_mode: int = 1
26+
27+
kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
28+
num_shards, uniform_init_lower, uniform_init_upper, evict_trigger_mode
29+
)
30+
serialized_result = kv_embedding_cache.serialize()
31+
32+
self.assertEqual(serialized_result[0][0], num_shards)
33+
self.assertEqual(serialized_result[0][1], evict_trigger_mode)
34+
35+
self.assertEqual(serialized_result[1][0], uniform_init_lower)
36+
self.assertEqual(serialized_result[1][1], uniform_init_upper)
37+
38+
def test_serialize_deserialize(self) -> None:
39+
num_shards = 32
40+
uniform_init_lower: float = -0.01
41+
uniform_init_upper: float = 0.01
42+
evict_trigger_mode: int = 1
43+
44+
kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
45+
num_shards, uniform_init_lower, uniform_init_upper, evict_trigger_mode
46+
)
47+
serialized_result = kv_embedding_cache.serialize()
48+
49+
kv_embedding_cache_2 = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
50+
0, 0.0, 0.0, 0
51+
)
52+
kv_embedding_cache_2.deserialize(serialized_result)
53+
54+
self.assertEqual(str(serialized_result), str(kv_embedding_cache_2.serialize()))
55+
56+
def test_set_get_embeddings(self) -> None:
57+
num_shards = 32
58+
uniform_init_lower: float = 0.0
59+
uniform_init_upper: float = 0.0
60+
evict_trigger_mode: int = 0
61+
62+
kv_embedding_cache = torch.classes.fbgemm.DramKVEmbeddingInferenceWrapper(
63+
num_shards, uniform_init_lower, uniform_init_upper, evict_trigger_mode
64+
)
65+
kv_embedding_cache.init(
66+
[(20, 4, SparseType.INT8.as_int())],
67+
8,
68+
4,
69+
)
70+
71+
kv_embedding_cache.set_embeddings(
72+
torch.tensor([0, 1, 2, 3], dtype=torch.int64),
73+
torch.tensor(
74+
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]],
75+
dtype=torch.uint8,
76+
),
77+
)
78+
79+
embs = kv_embedding_cache.get_embeddings(
80+
torch.tensor([1, 4, 3, 0, 5, 2], dtype=torch.int64),
81+
)
82+
assert torch.equal(
83+
embs[:, :4],
84+
torch.tensor(
85+
[
86+
[5, 6, 7, 8],
87+
[0, 0, 0, 0],
88+
[13, 14, 15, 16],
89+
[1, 2, 3, 4],
90+
[0, 0, 0, 0],
91+
[9, 10, 11, 12],
92+
],
93+
dtype=torch.uint8,
94+
),
95+
)

0 commit comments

Comments
 (0)