Skip to content

Commit a5c455b

Browse files
authored
Support IntEnum in converter (#74)
1 parent d91593d commit a5c455b

File tree

2 files changed

+52
-15
lines changed

2 files changed

+52
-15
lines changed

temporalio/converter.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from abc import ABC, abstractmethod
1010
from dataclasses import dataclass
1111
from datetime import datetime
12+
from enum import IntEnum
1213
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Type
1314

1415
import dacite
@@ -386,6 +387,7 @@ class JSONPlainPayloadConverter(EncodingPayloadConverter):
386387
_encoder: Optional[Type[json.JSONEncoder]]
387388
_decoder: Optional[Type[json.JSONDecoder]]
388389
_encoding: str
390+
_dacite_config: dacite.Config
389391

390392
def __init__(
391393
self,
@@ -405,6 +407,7 @@ def __init__(
405407
self._encoder = encoder
406408
self._decoder = decoder
407409
self._encoding = encoding
410+
self._dacite_config = dacite.Config(cast=[IntEnum])
408411

409412
@property
410413
def encoding(self) -> str:
@@ -431,6 +434,11 @@ def from_payload(
431434
"""See base class."""
432435
try:
433436
obj = json.loads(payload.data, cls=self._decoder)
437+
438+
# If the object is an int and the type hint is an IntEnum, convert
439+
if isinstance(obj, int) and type_hint and issubclass(type_hint, IntEnum):
440+
obj = type_hint(obj)
441+
434442
# If the object is a dict and the type hint is present for a data
435443
# class, we instantiate the data class with the value
436444
if (
@@ -439,7 +447,8 @@ def from_payload(
439447
and dataclasses.is_dataclass(type_hint)
440448
):
441449
# We have to use dacite here to handle nested dataclasses
442-
obj = dacite.from_dict(type_hint, obj)
450+
obj = dacite.from_dict(type_hint, obj, self._dacite_config)
451+
443452
return obj
444453
except json.JSONDecodeError as err:
445454
raise RuntimeError("Failed parsing") from err

tests/test_converter.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dataclasses import dataclass
22
from datetime import datetime
3+
from enum import Enum, IntEnum
34

45
import pytest
56

@@ -8,7 +9,26 @@
89
from temporalio.api.common.v1 import Payload as AnotherNameForPayload
910

1011

11-
async def test_default():
12+
class NonSerializableClass:
13+
pass
14+
15+
16+
class NonSerializableEnum(Enum):
17+
FOO = "foo"
18+
19+
20+
class SerializableEnum(IntEnum):
21+
FOO = 1
22+
23+
24+
@dataclass
25+
class MyDataClass:
26+
foo: str
27+
bar: int
28+
baz: SerializableEnum
29+
30+
31+
async def test_converter_default():
1232
async def assert_payload(
1333
input,
1434
expected_encoding,
@@ -33,6 +53,7 @@ async def assert_payload(
3353
assert len(actual_inputs) == 1
3454
if expected_decoded_input is None:
3555
expected_decoded_input = input
56+
assert type(actual_inputs[0]) is type(expected_decoded_input)
3657
assert actual_inputs[0] == expected_decoded_input
3758
return payloads[0]
3859

@@ -58,31 +79,38 @@ async def assert_payload(
5879

5980
# Unknown type
6081
with pytest.raises(TypeError) as excinfo:
61-
62-
class NonSerializableClass:
63-
pass
64-
6582
await assert_payload(NonSerializableClass(), None, None)
6683
assert "not JSON serializable" in str(excinfo.value)
6784

68-
@dataclass
69-
class MyDataClass:
70-
foo: str
71-
bar: int
85+
# Bad enum type. We do not allow non-int enums due to ambiguity in
86+
# rebuilding and other confusion.
87+
with pytest.raises(TypeError) as excinfo:
88+
await assert_payload(NonSerializableEnum.FOO, None, None)
89+
assert "not JSON serializable" in str(excinfo.value)
90+
91+
# Good enum no type hint
92+
await assert_payload(
93+
SerializableEnum.FOO, "json/plain", "1", expected_decoded_input=1
94+
)
95+
96+
# Good enum type hint
97+
await assert_payload(
98+
SerializableEnum.FOO, "json/plain", "1", type_hint=SerializableEnum
99+
)
72100

73101
# Data class without type hint is just dict
74102
await assert_payload(
75-
MyDataClass(foo="somestr", bar=123),
103+
MyDataClass(foo="somestr", bar=123, baz=SerializableEnum.FOO),
76104
"json/plain",
77-
'{"bar":123,"foo":"somestr"}',
78-
expected_decoded_input={"foo": "somestr", "bar": 123},
105+
'{"bar":123,"baz":1,"foo":"somestr"}',
106+
expected_decoded_input={"foo": "somestr", "bar": 123, "baz": 1},
79107
)
80108

81109
# Data class with type hint reconstructs the class
82110
await assert_payload(
83-
MyDataClass(foo="somestr", bar=123),
111+
MyDataClass(foo="somestr", bar=123, baz=SerializableEnum.FOO),
84112
"json/plain",
85-
'{"bar":123,"foo":"somestr"}',
113+
'{"bar":123,"baz":1,"foo":"somestr"}',
86114
type_hint=MyDataClass,
87115
)
88116

0 commit comments

Comments
 (0)