Skip to content

Commit 2ce3f43

Browse files
committed
Improved internal representation of Bit class
1 parent 2d1b754 commit 2ce3f43

File tree

2 files changed

+32
-25
lines changed

2 files changed

+32
-25
lines changed

pgvector/bit.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,51 +5,58 @@
55

66
class Bit:
77
def __init__(self, value):
8-
if isinstance(value, str):
9-
self._value = self.from_text(value)._value
10-
elif isinstance(value, bytes):
11-
self._value = np.unpackbits(np.frombuffer(value, dtype=np.uint8)).astype(bool)
8+
if isinstance(value, bytes):
9+
self._len = 8 * len(value)
10+
self._data = value
1211
else:
13-
value = np.asarray(value)
12+
if isinstance(value, str):
13+
value = [v != '0' for v in value]
14+
else:
15+
value = np.asarray(value)
1416

15-
if value.dtype != np.bool:
16-
warn('expected elements to be boolean', stacklevel=2)
17-
value = value.astype(bool)
17+
if value.dtype != np.bool:
18+
warn('expected elements to be boolean', stacklevel=2)
19+
value = value.astype(bool)
1820

19-
if value.ndim != 1:
20-
raise ValueError('expected ndim to be 1')
21+
if value.ndim != 1:
22+
raise ValueError('expected ndim to be 1')
2123

22-
self._value = value
24+
self._len = len(value)
25+
self._data = np.packbits(value).tobytes()
2326

2427
def __repr__(self):
2528
return f'Bit({self.to_text()})'
2629

2730
def __eq__(self, other):
2831
if isinstance(other, self.__class__):
29-
return np.array_equal(self.to_numpy(), other.to_numpy())
32+
return self._len == other._len and self._data == other._data
3033
return False
3134

3235
def to_list(self):
33-
return self._value.tolist()
36+
return self.to_numpy().tolist()
3437

3538
def to_numpy(self):
36-
return self._value
39+
return np.unpackbits(np.frombuffer(self._data, dtype=np.uint8), count=self._len).astype(bool)
3740

3841
def to_text(self):
39-
return ''.join(self._value.astype(np.uint8).astype(str))
42+
return ''.join(format(v, '08b') for v in self._data)[:self._len]
4043

4144
def to_binary(self):
42-
return pack('>i', len(self._value)) + np.packbits(self._value).tobytes()
45+
return pack('>i', self._len) + self._data
4346

4447
@classmethod
4548
def from_text(cls, value):
46-
return cls(np.asarray([v != '0' for v in value], dtype=bool))
49+
return cls(str(value))
4750

4851
@classmethod
4952
def from_binary(cls, value):
50-
count = unpack_from('>i', value)[0]
51-
buf = np.frombuffer(value, dtype=np.uint8, offset=4)
52-
return cls(np.unpackbits(buf, count=count).astype(bool))
53+
if not isinstance(value, bytes):
54+
raise ValueError('expected bytes')
55+
56+
bit = cls.__new__(cls)
57+
bit._len = unpack_from('>i', value)[0]
58+
bit._data = value[4:]
59+
return bit
5360

5461
@classmethod
5562
def _to_db(cls, value):

tests/test_bit.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ def test_bytes(self):
2525
assert Bit(b'\xff\x00').to_list() == [True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False]
2626
assert Bit(b'\xfe\x07').to_list() == [True, True, True, True, True, True, True, False, False, False, False, False, False, True, True, True]
2727

28+
def test_ndarray(self):
29+
arr = np.array([True, False, True])
30+
assert Bit(arr).to_list() == [True, False, True]
31+
assert np.array_equal(Bit(arr).to_numpy(), arr)
32+
2833
def test_ndarray_uint8(self):
2934
arr = np.array([254, 7, 0], dtype=np.uint8)
3035
with pytest.warns(UserWarning, match='expected elements to be boolean'):
@@ -35,11 +40,6 @@ def test_ndarray_uint16(self):
3540
with pytest.warns(UserWarning, match='expected elements to be boolean'):
3641
assert Bit(arr).to_text() == '110'
3742

38-
def test_ndarray_same_object(self):
39-
arr = np.array([True, False, True])
40-
assert Bit(arr).to_list() == [True, False, True]
41-
assert Bit(arr).to_numpy() is arr
42-
4343
def test_ndim_two(self):
4444
with pytest.raises(ValueError) as error:
4545
Bit([[True, False], [True, False]])

0 commit comments

Comments
 (0)