|
28 | 28 |
|
29 | 29 | import numpy as np
|
30 | 30 |
|
31 |
| -from qonnx.core.datatype import DataType |
| 31 | +from qonnx.core.datatype import DataType, resolve_datatype |
32 | 32 |
|
33 | 33 |
|
34 | 34 | def test_datatypes():
|
@@ -97,3 +97,52 @@ def test_smallest_possible():
|
97 | 97 | assert DataType.get_smallest_possible(-1) == DataType["BIPOLAR"]
|
98 | 98 | assert DataType.get_smallest_possible(-3) == DataType["INT3"]
|
99 | 99 | assert DataType.get_smallest_possible(-3.2) == DataType["FLOAT32"]
|
| 100 | + |
| 101 | + |
| 102 | +def test_resolve_datatype(): |
| 103 | + assert resolve_datatype("BIPOLAR") |
| 104 | + assert resolve_datatype("BINARY") |
| 105 | + assert resolve_datatype("TERNARY") |
| 106 | + assert resolve_datatype("UINT2") |
| 107 | + assert resolve_datatype("UINT3") |
| 108 | + assert resolve_datatype("UINT4") |
| 109 | + assert resolve_datatype("UINT8") |
| 110 | + assert resolve_datatype("UINT16") |
| 111 | + assert resolve_datatype("UINT32") |
| 112 | + assert resolve_datatype("INT2") |
| 113 | + assert resolve_datatype("INT3") |
| 114 | + assert resolve_datatype("INT4") |
| 115 | + assert resolve_datatype("INT8") |
| 116 | + assert resolve_datatype("INT16") |
| 117 | + assert resolve_datatype("INT32") |
| 118 | + assert resolve_datatype("FLOAT32") |
| 119 | + |
| 120 | + |
| 121 | +def test_input_type_error(): |
| 122 | + def test_resolve_datatype(input): |
| 123 | + # test with invalid input to check if the TypeError works |
| 124 | + try: |
| 125 | + resolve_datatype(input) # This should raise a TypeError |
| 126 | + except TypeError: |
| 127 | + pass |
| 128 | + else: |
| 129 | + assert False, "Test with invalid input failed: No TypeError was raised." |
| 130 | + |
| 131 | + test_resolve_datatype(123) |
| 132 | + test_resolve_datatype(1.23) |
| 133 | + test_resolve_datatype(DataType["BIPOLAR"]) |
| 134 | + test_resolve_datatype(DataType["BINARY"]) |
| 135 | + test_resolve_datatype(DataType["TERNARY"]) |
| 136 | + test_resolve_datatype(DataType["UINT2"]) |
| 137 | + test_resolve_datatype(DataType["UINT3"]) |
| 138 | + test_resolve_datatype(DataType["UINT4"]) |
| 139 | + test_resolve_datatype(DataType["UINT8"]) |
| 140 | + test_resolve_datatype(DataType["UINT16"]) |
| 141 | + test_resolve_datatype(DataType["UINT32"]) |
| 142 | + test_resolve_datatype(DataType["INT2"]) |
| 143 | + test_resolve_datatype(DataType["INT3"]) |
| 144 | + test_resolve_datatype(DataType["INT4"]) |
| 145 | + test_resolve_datatype(DataType["INT8"]) |
| 146 | + test_resolve_datatype(DataType["INT16"]) |
| 147 | + test_resolve_datatype(DataType["INT32"]) |
| 148 | + test_resolve_datatype(DataType["FLOAT32"]) |
0 commit comments