Skip to content

Commit 718f9af

Browse files
committed
Enforce at registration time
1 parent 12572b0 commit 718f9af

File tree

5 files changed

+119
-7
lines changed

5 files changed

+119
-7
lines changed

temporalio/worker/_replayer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def on_eviction_hook(
216216
on_eviction_hook=on_eviction_hook,
217217
disable_eager_activity_execution=False,
218218
disable_safe_eviction=self._config["disable_safe_workflow_eviction"],
219+
should_enforce_versioning_behavior=False,
219220
)
220221
# Create bridge worker
221222
bridge_worker, pusher = temporalio.bridge.worker.Worker.for_replay(

temporalio/worker/_worker.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
Optional,
1919
Sequence,
2020
Type,
21-
TypeAlias,
22-
Union,
2321
cast,
2422
)
2523

@@ -331,6 +329,12 @@ def __init__(
331329
)
332330
self._workflow_worker: Optional[_WorkflowWorker] = None
333331
if workflows:
332+
should_enforce_versioning_behavior = (
333+
deployment_options is not None
334+
and deployment_options.use_worker_versioning
335+
and deployment_options.default_versioning_behavior
336+
== temporalio.common.VersioningBehavior.UNSPECIFIED
337+
)
334338
self._workflow_worker = _WorkflowWorker(
335339
bridge_worker=lambda: self._bridge_worker,
336340
namespace=client.namespace,
@@ -348,6 +352,7 @@ def __init__(
348352
metric_meter=self._runtime.metric_meter,
349353
on_eviction_hook=None,
350354
disable_safe_eviction=disable_safe_workflow_eviction,
355+
should_enforce_versioning_behavior=should_enforce_versioning_behavior,
351356
)
352357

353358
if tuner is not None:

temporalio/worker/_workflow.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import os
99
import sys
1010
import threading
11-
from dataclasses import dataclass
1211
from datetime import timezone
1312
from types import TracebackType
1413
from typing import (
@@ -78,6 +77,7 @@ def __init__(
7877
]
7978
],
8079
disable_safe_eviction: bool,
80+
should_enforce_versioning_behavior: bool,
8181
) -> None:
8282
self._bridge_worker = bridge_worker
8383
self._namespace = namespace
@@ -135,6 +135,25 @@ def __init__(
135135
# Confirm name unique
136136
if defn.name in self._workflows:
137137
raise ValueError(f"More than one workflow named {defn.name}")
138+
if should_enforce_versioning_behavior:
139+
not_in_annotation = defn.versioning_behavior in [
140+
None,
141+
temporalio.common.VersioningBehavior.UNSPECIFIED,
142+
]
143+
if defn.name:
144+
if not_in_annotation:
145+
raise ValueError(
146+
f"Workflow {defn.name} must specify a versioning behavior using "
147+
"the `versioning_behavior` argument to `@workflow.run`."
148+
)
149+
else:
150+
if defn.dynamic_versioning_behavior is None:
151+
raise ValueError(
152+
f"Dynamic Workflow {defn.name} must specify a versioning behavior "
153+
"using `@workflow.dynamic_versioning_behavior` or the "
154+
"`versioning_behavior` argument to `@workflow.run`."
155+
)
156+
138157
# Prepare the workflow with the runner (this will error in the
139158
# sandbox if an import fails somehow)
140159
try:

tests/test_workflow.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -430,9 +430,9 @@ def test_parameters_identical_up_to_naming():
430430
for f1, f2 in itertools.combinations(fns, 2):
431431
name1, name2 = f1.__name__, f2.__name__
432432
expect_equal = name1[0] == name2[0]
433-
assert workflow._parameters_identical_up_to_naming(f1, f2) == (expect_equal), (
434-
f"expected {name1} and {name2} parameters{' ' if expect_equal else ' not '}to compare equal"
435-
)
433+
assert (
434+
workflow._parameters_identical_up_to_naming(f1, f2) == (expect_equal)
435+
), f"expected {name1} and {name2} parameters{' ' if expect_equal else ' not '}to compare equal"
436436

437437

438438
@workflow.defn

tests/worker/test_worker.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,94 @@ async def test_worker_deployment_dynamic_workflow_getter(
843843
)
844844

845845

846-
# TODO: Test for fail at registration time if deployment versioning on, no default, no behavior
846+
@workflow.defn
847+
class NoVersioningAnnotationWorkflow:
848+
@workflow.run
849+
async def run(self) -> str:
850+
return "whee"
851+
852+
853+
@workflow.defn(dynamic=True)
854+
class NoVersioningAnnotationDynamicWorkflow:
855+
@workflow.run
856+
async def run(self, args: Sequence[RawValue]) -> str:
857+
return "whee"
858+
859+
860+
async def test_workflows_must_have_versioning_behavior_when_feature_turned_on(
861+
client: Client, env: WorkflowEnvironment
862+
):
863+
with pytest.raises(ValueError) as exc_info:
864+
Worker(
865+
client,
866+
task_queue=f"task-queue-{uuid.uuid4()}",
867+
workflows=[NoVersioningAnnotationWorkflow],
868+
deployment_options=WorkerDeploymentOptions(
869+
version=WorkerDeploymentVersion(
870+
deployment_name="whatever", build_id="1.0"
871+
),
872+
use_worker_versioning=True,
873+
),
874+
)
875+
876+
assert "must specify a versioning behavior" in str(exc_info.value)
877+
878+
with pytest.raises(ValueError) as exc_info:
879+
Worker(
880+
client,
881+
task_queue=f"task-queue-{uuid.uuid4()}",
882+
workflows=[NoVersioningAnnotationDynamicWorkflow],
883+
deployment_options=WorkerDeploymentOptions(
884+
version=WorkerDeploymentVersion(
885+
deployment_name="whatever", build_id="1.0"
886+
),
887+
use_worker_versioning=True,
888+
),
889+
)
890+
891+
assert "must specify a versioning behavior" in str(exc_info.value)
892+
893+
894+
async def test_workflows_can_use_default_versioning_behavior(
895+
client: Client, env: WorkflowEnvironment
896+
):
897+
if env.supports_time_skipping:
898+
pytest.skip("Test Server doesn't support worker versioning")
899+
900+
deployment_name = f"deployment-default-versioning-{uuid.uuid4()}"
901+
worker_v1 = WorkerDeploymentVersion(deployment_name=deployment_name, build_id="1.0")
902+
903+
async with new_worker(
904+
client,
905+
NoVersioningAnnotationWorkflow,
906+
deployment_options=WorkerDeploymentOptions(
907+
version=worker_v1,
908+
use_worker_versioning=True,
909+
default_versioning_behavior=VersioningBehavior.PINNED,
910+
),
911+
) as w:
912+
describe_resp = await wait_until_worker_deployment_visible(
913+
client,
914+
worker_v1,
915+
)
916+
await set_current_deployment_version(
917+
client, describe_resp.conflict_token, worker_v1
918+
)
919+
920+
wf = await client.start_workflow(
921+
NoVersioningAnnotationWorkflow.run,
922+
id=f"default-versioning-behavior-{uuid.uuid4()}",
923+
task_queue=w.task_queue,
924+
)
925+
await wf.result()
926+
927+
history = await wf.fetch_history()
928+
assert any(
929+
event.HasField("workflow_task_completed_event_attributes")
930+
and event.workflow_task_completed_event_attributes.versioning_behavior
931+
== temporalio.api.enums.v1.VersioningBehavior.VERSIONING_BEHAVIOR_PINNED
932+
for event in history.events
933+
)
847934

848935

849936
async def wait_until_worker_deployment_visible(

0 commit comments

Comments
 (0)