3
3
4
4
import pickle
5
5
from collections .abc import Iterable , Mapping
6
+ from typing import Union
6
7
7
8
import numpy as np
8
9
import torch
23
24
class MultiModalHasher :
24
25
25
26
@classmethod
26
- def serialize_item (cls , obj : object ) -> bytes :
27
+ def serialize_item (cls , obj : object ) -> Union [ bytes , memoryview ] :
27
28
# Simple cases
28
29
if isinstance (obj , str ):
29
30
return obj .encode ("utf-8" )
30
- if isinstance (obj , bytes ):
31
+ if isinstance (obj , ( bytes , memoryview ) ):
31
32
return obj
32
33
if isinstance (obj , (int , float )):
33
34
return np .array (obj ).tobytes ()
@@ -38,12 +39,13 @@ def serialize_item(cls, obj: object) -> bytes:
38
39
if isinstance (obj , torch .Tensor ):
39
40
return cls .item_to_bytes ("tensor" , obj .numpy ())
40
41
if isinstance (obj , np .ndarray ):
41
- return cls .item_to_bytes (
42
- "ndarray" , {
43
- "dtype" : obj .dtype .str ,
44
- "shape" : obj .shape ,
45
- "data" : obj .tobytes (),
46
- })
42
+ # If the array is non-contiguous, we need to copy it first
43
+ arr_data = obj .data if obj .flags .c_contiguous else obj .tobytes ()
44
+ return cls .item_to_bytes ("ndarray" , {
45
+ "dtype" : obj .dtype .str ,
46
+ "shape" : obj .shape ,
47
+ "data" : arr_data ,
48
+ })
47
49
48
50
logger .warning (
49
51
"No serialization method found for %s. "
@@ -64,7 +66,7 @@ def iter_item_to_bytes(
64
66
cls ,
65
67
key : str ,
66
68
obj : object ,
67
- ) -> Iterable [tuple [bytes , bytes ]]:
69
+ ) -> Iterable [tuple [bytes , Union [ bytes , memoryview ] ]]:
68
70
# Recursive cases
69
71
if isinstance (obj , (list , tuple )):
70
72
for i , elem in enumerate (obj ):
@@ -73,7 +75,7 @@ def iter_item_to_bytes(
73
75
for k , v in obj .items ():
74
76
yield from cls .iter_item_to_bytes (f"{ key } .{ k } " , v )
75
77
else :
76
- key_bytes = cls . serialize_item ( key )
78
+ key_bytes = key . encode ( "utf-8" )
77
79
value_bytes = cls .serialize_item (obj )
78
80
yield key_bytes , value_bytes
79
81
0 commit comments