Skip to content

Commit cfdc548

Browse files
authored
Add StrEnum conversion support (#177)
Fixes #176
1 parent 656b77b commit cfdc548

File tree

3 files changed

+36
-3
lines changed

3 files changed

+36
-3
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ The default data converter supports converting multiple types including:
188188
* Iterables including ones JSON dump may not support by default, e.g. `set`
189189
* Any class with a `dict()` method and a static `parse_obj()` method, e.g.
190190
[Pydantic models](https://pydantic-docs.helpmanual.io/usage/models)
191-
* [IntEnum](https://docs.python.org/3/library/enum.html) based enumerates
191+
* [IntEnum, StrEnum](https://docs.python.org/3/library/enum.html) based enumerates
192192

193193
For converting from JSON, the workflow/activity type hint is taken into account to convert to the proper type. Care has
194194
been taken to support all common typings including `Optional`, `Union`, all forms of iterables and mappings, `NewType`,

temporalio/converter.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import dataclasses
88
import inspect
99
import json
10+
import sys
1011
from abc import ABC, abstractmethod
1112
from dataclasses import dataclass
1213
from datetime import datetime
@@ -33,6 +34,10 @@
3334
import temporalio.api.common.v1
3435
import temporalio.common
3536

37+
# StrEnum is available in 3.11+
38+
if sys.version_info >= (3, 11):
39+
from enum import StrEnum
40+
3641

3742
class PayloadConverter(ABC):
3843
"""Base payload converter to/from multiple payloads/values."""
@@ -874,6 +879,15 @@ def value_to_type(hint: Type, value: Any) -> Any:
874879
)
875880
return hint(value)
876881

882+
# StrEnum, available in 3.11+
883+
if sys.version_info >= (3, 11):
884+
if inspect.isclass(hint) and issubclass(hint, StrEnum):
885+
if not isinstance(value, str):
886+
raise TypeError(
887+
f"Cannot convert to enum {hint}, value not a string, value is {type(value)}"
888+
)
889+
return hint(value)
890+
877891
# Iterable. We intentionally put this last as it catches several others.
878892
if inspect.isclass(origin) and issubclass(origin, collections.abc.Iterable):
879893
if not isinstance(value, collections.abc.Iterable):

tests/test_converter.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
import temporalio.converter
3434
from temporalio.api.common.v1 import Payload as AnotherNameForPayload
3535

36+
# StrEnum is available in 3.11+
37+
if sys.version_info >= (3, 11):
38+
from enum import StrEnum
39+
3640

3741
class NonSerializableClass:
3842
pass
@@ -46,6 +50,12 @@ class SerializableEnum(IntEnum):
4650
FOO = 1
4751

4852

53+
if sys.version_info >= (3, 11):
54+
55+
class SerializableStrEnum(StrEnum):
56+
FOO = "foo"
57+
58+
4959
@dataclass
5060
class MyDataClass:
5161
foo: str
@@ -107,8 +117,8 @@ async def assert_payload(
107117
await assert_payload(NonSerializableClass(), None, None)
108118
assert "not JSON serializable" in str(excinfo.value)
109119

110-
# Bad enum type. We do not allow non-int enums due to ambiguity in
111-
# rebuilding and other confusion.
120+
# Bad enum type. We do not allow non-int or non-str enums due to ambiguity
121+
# in rebuilding and other confusion.
112122
with pytest.raises(TypeError) as excinfo:
113123
await assert_payload(NonSerializableEnum.FOO, None, None)
114124
assert "not JSON serializable" in str(excinfo.value)
@@ -295,6 +305,15 @@ def fail(hint: Type, value: Any) -> None:
295305
ok(SerializableEnum, SerializableEnum.FOO)
296306
ok(List[SerializableEnum], [SerializableEnum.FOO, SerializableEnum.FOO])
297307

308+
# StrEnum is available in 3.11+
309+
if sys.version_info >= (3, 11):
310+
# StrEnum
311+
ok(SerializableStrEnum, SerializableStrEnum.FOO)
312+
ok(
313+
List[SerializableStrEnum],
314+
[SerializableStrEnum.FOO, SerializableStrEnum.FOO],
315+
)
316+
298317
# 3.10+ checks
299318
if sys.version_info >= (3, 10):
300319
ok(list[int], [1, 2])

0 commit comments

Comments
 (0)