Skip to content

Commit b8a80e0

Browse files
authored
Type refactoring and activity class/method support (#69)
1 parent 7456f44 commit b8a80e0

File tree

9 files changed

+1628
-314
lines changed

9 files changed

+1628
-314
lines changed

README.md

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,12 @@ Some things to note about the above code:
379379
capabilities are needed.
380380
* Local activities work very similarly except the functions are `workflow.start_local_activity()` and
381381
`workflow.execute_local_activity()`
382+
* Activities can be methods of a class. Invokers should use `workflow.start_activity_method()`,
383+
`workflow.execute_activity_method()`, `workflow.start_local_activity_method()`, and
384+
`workflow.execute_local_activity_method()` instead.
385+
* Activities can callable classes (i.e. that define `__call__`). Invokers should use `workflow.start_activity_class()`,
386+
`workflow.execute_activity_class()`, `workflow.start_local_activity_class()`, and
387+
`workflow.execute_local_activity_class()` instead.
382388

383389
#### Invoking Child Workflows
384390

@@ -465,7 +471,7 @@ While running in a workflow, in addition to features documented elsewhere, the f
465471

466472
#### Definition
467473

468-
Activities are functions decorated with `@activity.defn` like so:
474+
Activities are decorated with `@activity.defn` like so:
469475

470476
```python
471477
from temporalio import activity
@@ -482,6 +488,10 @@ Some things to note about activity definitions:
482488
* Long running activities should regularly heartbeat and handle cancellation
483489
* Activities can only have positional arguments. Best practice is to only take a single argument that is an
484490
object/dataclass of fields that can be added to as needed.
491+
* Activities can be defined on methods instead of top-level functions. This allows the instance to carry state that an
492+
activity may need (e.g. a DB connection). The instance method should be what is registered with the worker.
493+
* Activities can also be defined on callable classes (i.e. classes with `__call__`). An instance of the class should be
494+
what is registered with the worker.
485495

486496
#### Types of Activities
487497

@@ -721,6 +731,6 @@ poe test
721731
* We use [Black](https://github.com/psf/black) for formatting, so that takes precedence
722732
* In tests and example code, can import individual classes/functions to make it more readable. Can also do this for
723733
rarely in library code for some Python common items (e.g. `dataclass` or `partial`), but not allowed to do this for
724-
any `temporalio` packages or any classes/functions that aren't clear when unqualified.
734+
any `temporalio` packages (except `temporalio.types`) or any classes/functions that aren't clear when unqualified.
725735
* We allow relative imports for private packages
726736
* We allow `@staticmethod`

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ html-output = "build/apidocs"
136136
intersphinx = ["https://docs.python.org/3/objects.inv", "https://googleapis.dev/python/protobuf/latest/objects.inv"]
137137
privacy = [
138138
"PRIVATE:temporalio.bridge",
139+
"PRIVATE:temporalio.types",
139140
"HIDDEN:temporalio.worker.activity",
140141
"HIDDEN:temporalio.worker.interceptor",
141142
"HIDDEN:temporalio.worker.worker",

temporalio/activity.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,28 +27,27 @@
2727
NoReturn,
2828
Optional,
2929
Tuple,
30-
TypeVar,
3130
overload,
3231
)
3332

3433
import temporalio.api.common.v1
3534
import temporalio.common
3635
import temporalio.exceptions
3736

38-
ActivityFunc = TypeVar("ActivityFunc", bound=Callable[..., Any])
37+
from .types import CallableType
3938

4039

4140
@overload
42-
def defn(fn: ActivityFunc) -> ActivityFunc:
41+
def defn(fn: CallableType) -> CallableType:
4342
...
4443

4544

4645
@overload
47-
def defn(*, name: str) -> Callable[[ActivityFunc], ActivityFunc]:
46+
def defn(*, name: str) -> Callable[[CallableType], CallableType]:
4847
...
4948

5049

51-
def defn(fn: Optional[ActivityFunc] = None, *, name: Optional[str] = None):
50+
def defn(fn: Optional[CallableType] = None, *, name: Optional[str] = None):
5251
"""Decorator for activity functions.
5352
5453
Activities can be async or non-async.
@@ -58,7 +57,7 @@ def defn(fn: Optional[ActivityFunc] = None, *, name: Optional[str] = None):
5857
name: Name to use for the activity. Defaults to function ``__name__``.
5958
"""
6059

61-
def with_name(name: str, fn: ActivityFunc) -> ActivityFunc:
60+
def with_name(name: str, fn: CallableType) -> CallableType:
6261
# This performs validation
6362
_Definition._apply_to_callable(fn, name)
6463
return fn
@@ -371,16 +370,19 @@ def _apply_to_callable(fn: Callable, activity_name: str) -> None:
371370
raise ValueError("Function already contains activity definition")
372371
elif not callable(fn):
373372
raise TypeError("Activity is not callable")
374-
elif not fn.__code__:
375-
raise TypeError("Activity callable missing __code__")
376-
elif fn.__code__.co_kwonlyargcount:
377-
raise TypeError("Activity cannot have keyword-only arguments")
373+
# We do not allow keyword only arguments in activities
374+
sig = inspect.signature(fn)
375+
for param in sig.parameters.values():
376+
if param.kind == inspect.Parameter.KEYWORD_ONLY:
377+
raise TypeError("Activity cannot have keyword-only arguments")
378378
setattr(
379379
fn,
380380
"__temporal_activity_definition",
381381
_Definition(
382382
name=activity_name,
383383
fn=fn,
384-
is_async=inspect.iscoroutinefunction(fn),
384+
# iscoroutinefunction does not return true for async __call__
385+
# TODO(cretz): Why can't MyPy handle this?
386+
is_async=inspect.iscoroutinefunction(fn) or inspect.iscoroutinefunction(fn.__call__), # type: ignore
385387
),
386388
)

temporalio/client.py

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,17 @@
3737
import temporalio.workflow_service
3838
from temporalio.workflow_service import RetryConfig, RPCError, RPCStatusCode, TLSConfig
3939

40-
LocalParamType = TypeVar("LocalParamType")
41-
LocalReturnType = TypeVar("LocalReturnType")
42-
WorkflowClass = TypeVar("WorkflowClass")
43-
WorkflowReturnType = TypeVar("WorkflowReturnType")
44-
MultiParamSpec = ParamSpec("MultiParamSpec")
40+
from .types import (
41+
LocalReturnType,
42+
MethodAsyncNoParam,
43+
MethodAsyncSingleParam,
44+
MethodSyncOrAsyncNoParam,
45+
MethodSyncOrAsyncSingleParam,
46+
MultiParamSpec,
47+
ParamType,
48+
ReturnType,
49+
SelfType,
50+
)
4551

4652

4753
class Client:
@@ -198,7 +204,7 @@ def data_converter(self) -> temporalio.converter.DataConverter:
198204
@overload
199205
async def start_workflow(
200206
self,
201-
workflow: Callable[[WorkflowClass], Awaitable[WorkflowReturnType]],
207+
workflow: MethodAsyncNoParam[SelfType, ReturnType],
202208
*,
203209
id: str,
204210
task_queue: str,
@@ -213,17 +219,15 @@ async def start_workflow(
213219
header: Optional[Mapping[str, Any]] = None,
214220
start_signal: Optional[str] = None,
215221
start_signal_args: Iterable[Any] = [],
216-
) -> WorkflowHandle[WorkflowClass, WorkflowReturnType]:
222+
) -> WorkflowHandle[SelfType, ReturnType]:
217223
...
218224

219225
# Overload for single-param workflow
220226
@overload
221227
async def start_workflow(
222228
self,
223-
workflow: Callable[
224-
[WorkflowClass, LocalParamType], Awaitable[WorkflowReturnType]
225-
],
226-
arg: LocalParamType,
229+
workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType],
230+
arg: ParamType,
227231
*,
228232
id: str,
229233
task_queue: str,
@@ -238,15 +242,15 @@ async def start_workflow(
238242
header: Optional[Mapping[str, Any]] = None,
239243
start_signal: Optional[str] = None,
240244
start_signal_args: Iterable[Any] = [],
241-
) -> WorkflowHandle[WorkflowClass, WorkflowReturnType]:
245+
) -> WorkflowHandle[SelfType, ReturnType]:
242246
...
243247

244248
# Overload for multi-param workflow
245249
@overload
246250
async def start_workflow(
247251
self,
248252
workflow: Callable[
249-
Concatenate[WorkflowClass, MultiParamSpec], Awaitable[WorkflowReturnType]
253+
Concatenate[SelfType, MultiParamSpec], Awaitable[ReturnType]
250254
],
251255
*,
252256
args: Iterable[Any],
@@ -263,7 +267,7 @@ async def start_workflow(
263267
header: Optional[Mapping[str, Any]] = None,
264268
start_signal: Optional[str] = None,
265269
start_signal_args: Iterable[Any] = [],
266-
) -> WorkflowHandle[WorkflowClass, WorkflowReturnType]:
270+
) -> WorkflowHandle[SelfType, ReturnType]:
267271
...
268272

269273
# Overload for string-name workflow
@@ -377,7 +381,7 @@ async def start_workflow(
377381
@overload
378382
async def execute_workflow(
379383
self,
380-
workflow: Callable[[WorkflowClass], Awaitable[WorkflowReturnType]],
384+
workflow: MethodAsyncNoParam[SelfType, ReturnType],
381385
*,
382386
id: str,
383387
task_queue: str,
@@ -392,17 +396,15 @@ async def execute_workflow(
392396
header: Optional[Mapping[str, Any]] = None,
393397
start_signal: Optional[str] = None,
394398
start_signal_args: Iterable[Any] = [],
395-
) -> WorkflowReturnType:
399+
) -> ReturnType:
396400
...
397401

398402
# Overload for single-param workflow
399403
@overload
400404
async def execute_workflow(
401405
self,
402-
workflow: Callable[
403-
[WorkflowClass, LocalParamType], Awaitable[WorkflowReturnType]
404-
],
405-
arg: LocalParamType,
406+
workflow: MethodAsyncSingleParam[SelfType, ParamType, ReturnType],
407+
arg: ParamType,
406408
*,
407409
id: str,
408410
task_queue: str,
@@ -417,15 +419,15 @@ async def execute_workflow(
417419
header: Optional[Mapping[str, Any]] = None,
418420
start_signal: Optional[str] = None,
419421
start_signal_args: Iterable[Any] = [],
420-
) -> WorkflowReturnType:
422+
) -> ReturnType:
421423
...
422424

423425
# Overload for multi-param workflow
424426
@overload
425427
async def execute_workflow(
426428
self,
427429
workflow: Callable[
428-
Concatenate[WorkflowClass, MultiParamSpec], Awaitable[WorkflowReturnType]
430+
Concatenate[SelfType, MultiParamSpec], Awaitable[ReturnType]
429431
],
430432
*,
431433
args: Iterable[Any],
@@ -442,7 +444,7 @@ async def execute_workflow(
442444
header: Optional[Mapping[str, Any]] = None,
443445
start_signal: Optional[str] = None,
444446
start_signal_args: Iterable[Any] = [],
445-
) -> WorkflowReturnType:
447+
) -> ReturnType:
446448
...
447449

448450
# Overload for string-name workflow
@@ -546,14 +548,14 @@ def get_workflow_handle(
546548
def get_workflow_handle_for(
547549
self,
548550
workflow: Union[
549-
Callable[[WorkflowClass, LocalParamType], Awaitable[WorkflowReturnType]],
550-
Callable[[WorkflowClass], Awaitable[WorkflowReturnType]],
551+
MethodAsyncNoParam[SelfType, ReturnType],
552+
MethodAsyncSingleParam[SelfType, Any, ReturnType],
551553
],
552554
workflow_id: str,
553555
*,
554556
run_id: Optional[str] = None,
555557
first_execution_run_id: Optional[str] = None,
556-
) -> WorkflowHandle[WorkflowClass, WorkflowReturnType]:
558+
) -> WorkflowHandle[SelfType, ReturnType]:
557559
"""Get a typed workflow handle to an existing workflow by its ID.
558560
559561
This is the same as :py:meth:`get_workflow_handle` but typed. Note, the
@@ -641,7 +643,7 @@ class ClientConfig(TypedDict, total=False):
641643
type_hint_eval_str: bool
642644

643645

644-
class WorkflowHandle(Generic[WorkflowClass, WorkflowReturnType]):
646+
class WorkflowHandle(Generic[SelfType, ReturnType]):
645647
"""Handle for interacting with a workflow.
646648
647649
This is usually created via :py:meth:`Client.get_workflow_handle` or
@@ -714,7 +716,7 @@ def first_execution_run_id(self) -> Optional[str]:
714716
"""
715717
return self._first_execution_run_id
716718

717-
async def result(self, *, follow_runs: bool = True) -> WorkflowReturnType:
719+
async def result(self, *, follow_runs: bool = True) -> ReturnType:
718720
"""Wait for result of the workflow.
719721
720722
This will use :py:attr:`result_run_id` if present to base the result on.
@@ -772,10 +774,10 @@ async def result(self, *, follow_runs: bool = True) -> WorkflowReturnType:
772774
type_hints,
773775
)
774776
if not results:
775-
return cast(WorkflowReturnType, None)
777+
return cast(ReturnType, None)
776778
elif len(results) > 1:
777779
warnings.warn(f"Expected single result, got {len(results)}")
778-
return cast(WorkflowReturnType, results[0])
780+
return cast(ReturnType, results[0])
779781
elif event.HasField("workflow_execution_failed_event_attributes"):
780782
fail_attr = event.workflow_execution_failed_event_attributes
781783
# Follow execution
@@ -891,9 +893,7 @@ async def describe(
891893
@overload
892894
async def query(
893895
self,
894-
query: Callable[
895-
[WorkflowClass], Union[Awaitable[LocalReturnType], LocalReturnType]
896-
],
896+
query: MethodSyncOrAsyncNoParam[SelfType, LocalReturnType],
897897
*,
898898
reject_condition: Optional[temporalio.common.QueryRejectCondition] = None,
899899
) -> LocalReturnType:
@@ -903,11 +903,8 @@ async def query(
903903
@overload
904904
async def query(
905905
self,
906-
query: Callable[
907-
[WorkflowClass, LocalParamType],
908-
Union[Awaitable[LocalReturnType], LocalReturnType],
909-
],
910-
arg: LocalParamType,
906+
query: MethodSyncOrAsyncSingleParam[SelfType, ParamType, LocalReturnType],
907+
arg: ParamType,
911908
*,
912909
reject_condition: Optional[temporalio.common.QueryRejectCondition] = None,
913910
) -> LocalReturnType:
@@ -918,7 +915,7 @@ async def query(
918915
async def query(
919916
self,
920917
query: Callable[
921-
Concatenate[WorkflowClass, MultiParamSpec],
918+
Concatenate[SelfType, MultiParamSpec],
922919
Union[Awaitable[LocalReturnType], LocalReturnType],
923920
],
924921
*,
@@ -1005,16 +1002,16 @@ async def query(
10051002
@overload
10061003
async def signal(
10071004
self,
1008-
signal: Callable[[WorkflowClass], Union[Awaitable[None], None]],
1005+
signal: MethodSyncOrAsyncNoParam[SelfType, None],
10091006
) -> None:
10101007
...
10111008

10121009
# Overload for single-param signal
10131010
@overload
10141011
async def signal(
10151012
self,
1016-
signal: Callable[[WorkflowClass, LocalParamType], Union[Awaitable[None], None]],
1017-
arg: LocalParamType,
1013+
signal: MethodSyncOrAsyncSingleParam[SelfType, ParamType, None],
1014+
arg: ParamType,
10181015
) -> None:
10191016
...
10201017

@@ -1023,7 +1020,7 @@ async def signal(
10231020
async def signal(
10241021
self,
10251022
signal: Callable[
1026-
Concatenate[WorkflowClass, MultiParamSpec], Union[Awaitable[None], None]
1023+
Concatenate[SelfType, MultiParamSpec], Union[Awaitable[None], None]
10271024
],
10281025
*,
10291026
args: Iterable[Any],

temporalio/converter.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -690,11 +690,18 @@ def get_type_hints(self, fn: Any) -> Tuple[Optional[List[Type]], Optional[Type]]
690690
# Due to MyPy issues, we cannot type "fn" as callable
691691
if not callable(fn):
692692
return (None, None)
693-
ret = self._cache.get(fn.__qualname__)
694-
if not ret:
695-
# TODO(cretz): Do we even need to cache?
696-
ret = _type_hints_from_func(fn, eval_str=self._type_hint_eval_str)
697-
self._cache[fn.__qualname__] = ret
693+
# We base the cache key on the qualified name of the function. However,
694+
# since some callables are not functions, we assume we can never cache
695+
# these just in case the type hints are dynamic for some strange reason.
696+
cache_key = getattr(fn, "__qualname__", None)
697+
if cache_key:
698+
ret = self._cache.get(cache_key)
699+
if ret:
700+
return ret
701+
# TODO(cretz): Do we even need to cache?
702+
ret = _type_hints_from_func(fn, eval_str=self._type_hint_eval_str)
703+
if cache_key:
704+
self._cache[cache_key] = ret
698705
return ret
699706

700707

0 commit comments

Comments
 (0)