37
37
38
38
class AttrTestOp (CustomOp ):
39
39
def get_nodeattr_types (self ):
40
- return {"tensor_attr" : ("t" , True , np .asarray ([]))}
40
+ my_attrs = {"tensor_attr" : ("t" , True , np .asarray ([])), "strings_attr" : ("strings" , True , ["" ])}
41
+ return my_attrs
41
42
42
43
def make_shape_compatible_op (self , model ):
43
44
param_tensor = self .get_nodeattr ("tensor_attr" )
@@ -70,6 +71,7 @@ def test_attr():
70
71
strarr = np .array2string (w , separator = ", " )
71
72
w_str = strarr .replace ("[" , "{" ).replace ("]" , "}" ).replace (" " , "" )
72
73
tensor_attr_str = f"int8{ wshp_str } { w_str } "
74
+ strings_attr = ["a" , "bc" , "def" ]
73
75
74
76
input = f"""
75
77
<
@@ -86,9 +88,17 @@ def test_attr():
86
88
model = oprs .parse_model (input )
87
89
model = ModelWrapper (model )
88
90
inst = getCustomOp (model .graph .node [0 ])
91
+
89
92
w_prod = inst .get_nodeattr ("tensor_attr" )
90
93
assert (w_prod == w ).all ()
91
94
w = w - 1
92
95
inst .set_nodeattr ("tensor_attr" , w )
93
96
w_prod = inst .get_nodeattr ("tensor_attr" )
94
97
assert (w_prod == w ).all ()
98
+
99
+ inst .set_nodeattr ("strings_attr" , strings_attr )
100
+ strings_attr_prod = inst .get_nodeattr ("strings_attr" )
101
+ assert strings_attr_prod == strings_attr
102
+ strings_attr_prod [0 ] = "test"
103
+ inst .set_nodeattr ("strings_attr" , strings_attr_prod )
104
+ assert inst .get_nodeattr ("strings_attr" ) == ["test" ] + strings_attr [1 :]
0 commit comments