Skip to content

Commit 1fc9577

Browse files
authored
[BugFix] Fix output of SipHash(as_tensor=False) (#2664)
1 parent 7bbd7e3 commit 1fc9577

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

test/test_storage_map.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ def test_sip_hash(self):
4646
hash_b = torch.tensor(hash_module(b))
4747
assert (hash_a == hash_b).all()
4848

49+
def test_sip_hash_nontensor(self):
50+
a = torch.rand((3, 2))
51+
b = a.clone()
52+
hash_module = SipHash(as_tensor=False)
53+
hash_a = hash_module(a)
54+
hash_b = hash_module(b)
55+
assert len(hash_a) == 3
56+
assert hash_a == hash_b
57+
4958
@pytest.mark.parametrize("n_components", [None, 14])
5059
@pytest.mark.parametrize("scale", [0.001, 0.01, 1, 100, 1000])
5160
def test_randomprojection_hash(self, n_components, scale):

torchrl/data/map/hash.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | List[bytes]:
111111
hash_value = x_i.tobytes()
112112
hash_values.append(hash_value)
113113
if not self.as_tensor:
114-
return hash_value
114+
return hash_values
115115
result = torch.tensor([hash(x) for x in hash_values], dtype=torch.int64)
116116
return result
117117

0 commit comments

Comments
 (0)