Skip to content

Commit c4c16f7

Browse files
authored
Merge pull request #128 from mdanilow/feature/strings_attr
Feature/strings attr
2 parents 326a525 + 654bf15 commit c4c16f7

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

src/qonnx/custom_op/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def get_nodeattr(self, name):
7474
if dtype == "s":
7575
# decode string attributes
7676
ret = ret.decode("utf-8")
77+
elif dtype == "strings":
78+
ret = [x.decode("utf-8") for x in ret]
7779
elif dtype == "t":
7880
# use numpy helper to convert TensorProto -> np array
7981
ret = np_helper.to_array(ret)
@@ -123,13 +125,15 @@ def set_nodeattr(self, name, value):
123125
# encode string attributes
124126
value = value.encode("utf-8")
125127
attr.__setattr__(dtype, value)
128+
elif dtype == "strings":
129+
attr.strings[:] = [x.encode("utf-8") for x in value]
126130
elif dtype == "floats": # list of floats
127131
attr.floats[:] = value
128132
elif dtype == "ints": # list of integers
129133
attr.ints[:] = value
130134
elif dtype == "t": # single tensor
131135
attr.t.CopyFrom(value)
132-
elif dtype in ["strings", "tensors", "graphs", "sparse_tensors"]:
136+
elif dtype in ["tensors", "graphs", "sparse_tensors"]:
133137
# untested / unsupported attribute types
134138
# add testcases & appropriate getters before enabling
135139
raise Exception("Attribute type %s not yet supported" % dtype)

tests/custom_op/test_attr.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737

3838
class AttrTestOp(CustomOp):
3939
def get_nodeattr_types(self):
40-
return {"tensor_attr": ("t", True, np.asarray([]))}
40+
my_attrs = {"tensor_attr": ("t", True, np.asarray([])), "strings_attr": ("strings", True, [""])}
41+
return my_attrs
4142

4243
def make_shape_compatible_op(self, model):
4344
param_tensor = self.get_nodeattr("tensor_attr")
@@ -70,6 +71,7 @@ def test_attr():
7071
strarr = np.array2string(w, separator=", ")
7172
w_str = strarr.replace("[", "{").replace("]", "}").replace(" ", "")
7273
tensor_attr_str = f"int8{wshp_str} {w_str}"
74+
strings_attr = ["a", "bc", "def"]
7375

7476
input = f"""
7577
<
@@ -86,9 +88,17 @@ def test_attr():
8688
model = oprs.parse_model(input)
8789
model = ModelWrapper(model)
8890
inst = getCustomOp(model.graph.node[0])
91+
8992
w_prod = inst.get_nodeattr("tensor_attr")
9093
assert (w_prod == w).all()
9194
w = w - 1
9295
inst.set_nodeattr("tensor_attr", w)
9396
w_prod = inst.get_nodeattr("tensor_attr")
9497
assert (w_prod == w).all()
98+
99+
inst.set_nodeattr("strings_attr", strings_attr)
100+
strings_attr_prod = inst.get_nodeattr("strings_attr")
101+
assert strings_attr_prod == strings_attr
102+
strings_attr_prod[0] = "test"
103+
inst.set_nodeattr("strings_attr", strings_attr_prod)
104+
assert inst.get_nodeattr("strings_attr") == ["test"] + strings_attr[1:]

0 commit comments

Comments
 (0)