Skip to content

Commit 51965ab

Browse files
authored
Merge pull request #121 from makoeppel/resolve_datatype/add_typeCheck_for_input
add type check for input name of resolve_datatype()
2 parents 364f995 + 7e747d7 commit 51965ab

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

src/qonnx/core/datatype.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,9 @@ def get_canonical_name(self):
376376

377377

378378
def resolve_datatype(name):
379+
if not isinstance(name, str):
380+
raise TypeError(f"Input 'name' must be of type 'str', but got type '{type(name).__name__}'")
381+
379382
_special_types = {
380383
"BINARY": IntType(1, False),
381384
"BIPOLAR": BipolarType(),

tests/core/test_datatypes.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
import numpy as np
3030

31-
from qonnx.core.datatype import DataType
31+
from qonnx.core.datatype import DataType, resolve_datatype
3232

3333

3434
def test_datatypes():
@@ -97,3 +97,52 @@ def test_smallest_possible():
9797
assert DataType.get_smallest_possible(-1) == DataType["BIPOLAR"]
9898
assert DataType.get_smallest_possible(-3) == DataType["INT3"]
9999
assert DataType.get_smallest_possible(-3.2) == DataType["FLOAT32"]
100+
101+
102+
def test_resolve_datatype():
103+
assert resolve_datatype("BIPOLAR")
104+
assert resolve_datatype("BINARY")
105+
assert resolve_datatype("TERNARY")
106+
assert resolve_datatype("UINT2")
107+
assert resolve_datatype("UINT3")
108+
assert resolve_datatype("UINT4")
109+
assert resolve_datatype("UINT8")
110+
assert resolve_datatype("UINT16")
111+
assert resolve_datatype("UINT32")
112+
assert resolve_datatype("INT2")
113+
assert resolve_datatype("INT3")
114+
assert resolve_datatype("INT4")
115+
assert resolve_datatype("INT8")
116+
assert resolve_datatype("INT16")
117+
assert resolve_datatype("INT32")
118+
assert resolve_datatype("FLOAT32")
119+
120+
121+
def test_input_type_error():
122+
def test_resolve_datatype(input):
123+
# test with invalid input to check if the TypeError works
124+
try:
125+
resolve_datatype(input) # This should raise a TypeError
126+
except TypeError:
127+
pass
128+
else:
129+
assert False, "Test with invalid input failed: No TypeError was raised."
130+
131+
test_resolve_datatype(123)
132+
test_resolve_datatype(1.23)
133+
test_resolve_datatype(DataType["BIPOLAR"])
134+
test_resolve_datatype(DataType["BINARY"])
135+
test_resolve_datatype(DataType["TERNARY"])
136+
test_resolve_datatype(DataType["UINT2"])
137+
test_resolve_datatype(DataType["UINT3"])
138+
test_resolve_datatype(DataType["UINT4"])
139+
test_resolve_datatype(DataType["UINT8"])
140+
test_resolve_datatype(DataType["UINT16"])
141+
test_resolve_datatype(DataType["UINT32"])
142+
test_resolve_datatype(DataType["INT2"])
143+
test_resolve_datatype(DataType["INT3"])
144+
test_resolve_datatype(DataType["INT4"])
145+
test_resolve_datatype(DataType["INT8"])
146+
test_resolve_datatype(DataType["INT16"])
147+
test_resolve_datatype(DataType["INT32"])
148+
test_resolve_datatype(DataType["FLOAT32"])

0 commit comments

Comments
 (0)