Skip to content

Commit abc0074

Browse files
author
Ben Elam
committed
Implement RBAC for celery
1 parent e0b4373 commit abc0074

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

orchestrator/services/celery.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import structlog
1919
from celery.result import AsyncResult
2020
from kombu.exceptions import ConnectionError, OperationalError
21+
from oauth2_lib.fastapi import OIDCUserModel
2122

2223
from orchestrator import app_settings
2324
from orchestrator.api.error_handling import raise_status
@@ -42,7 +43,11 @@ def _block_when_testing(task_result: AsyncResult) -> None:
4243

4344

4445
def _celery_start_process(
45-
workflow_key: str, user_inputs: list[State] | None, user: str = SYSTEM_USER, **kwargs: Any
46+
workflow_key: str,
47+
user_inputs: list[State] | None,
48+
user: str = SYSTEM_USER,
49+
user_model: OIDCUserModel | None = None,
50+
**kwargs: Any
4651
) -> UUID:
4752
"""Client side call of Celery."""
4853
from orchestrator.services.tasks import NEW_TASK, NEW_WORKFLOW, get_celery_task
@@ -53,7 +58,7 @@ def _celery_start_process(
5358

5459
task_name = NEW_TASK if workflow.target == Target.SYSTEM else NEW_WORKFLOW
5560
trigger_task = get_celery_task(task_name)
56-
pstat = create_process(workflow_key, user_inputs, user)
61+
pstat = create_process(workflow_key, user_inputs=user_inputs, user=user, user_model=user_model)
5762
try:
5863
result = trigger_task.delay(pstat.process_id, workflow_key, user)
5964
_block_when_testing(result)
@@ -70,6 +75,7 @@ def _celery_resume_process(
7075
*,
7176
user_inputs: list[State] | None,
7277
user: str | None,
78+
user_model: OIDCUserModel | None = None,
7379
**kwargs: Any,
7480
) -> UUID:
7581
"""Client side call of Celery."""
@@ -87,7 +93,7 @@ def _celery_resume_process(
8793
store_input_state(pstat.process_id, user_inputs, "user_input")
8894
try:
8995
_celery_set_process_status_resumed(process)
90-
result = trigger_task.delay(pstat.process_id, user)
96+
result = trigger_task.delay(pstat.process_id, user, user_model=user_model)
9197
_block_when_testing(result)
9298

9399
return pstat.process_id

orchestrator/services/processes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ def resume_process(
628628
raise
629629

630630
resume_func = get_execution_context()["resume"]
631-
return resume_func(process, user_inputs=user_inputs, user=user, broadcast_func=broadcast_func)
631+
return resume_func(process, user_inputs=user_inputs, user=user, user_model=user_model, broadcast_func=broadcast_func)
632632

633633

634634
def ensure_correct_callback_token(pstat: ProcessStat, *, token: str) -> None:

orchestrator/services/tasks.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
thread_resume_process,
3333
)
3434
from orchestrator.types import BroadcastFunc
35+
from orchestrator.utils.auth import Authorizer
3536
from orchestrator.utils.json import json_dumps, json_loads
3637
from orchestrator.workflow import ProcessStat, ProcessStatus, Success, runwf
3738
from orchestrator.workflows import get_workflow
@@ -103,12 +104,12 @@ def start_process(process_id: UUID, workflow_key: str, state: dict[str, Any], us
103104
else:
104105
return process_id
105106

106-
def resume_process(process_id: UUID, user_inputs: list[State] | None, user: str) -> UUID | None:
107+
def resume_process(process_id: UUID, user_inputs: list[State] | None, user: str, user_model: Authorizer | None = None) -> UUID | None:
107108
try:
108109
process = _get_process(process_id)
109110
ensure_correct_process_status(process_id, ProcessStatus.RESUMED)
110111
process_id = thread_resume_process(
111-
process, user_inputs=user_inputs, user=user, broadcast_func=process_broadcast_fn
112+
process, user_inputs=user_inputs, user=user, user_model=user_model, broadcast_func=process_broadcast_fn
112113
)
113114
except Exception as exc:
114115
local_logger.error("Worker failed to resume workflow", process_id=process_id, details=str(exc))
@@ -131,16 +132,16 @@ def new_workflow(process_id, workflow_key: str, user: str) -> UUID | None:
131132
return start_process(process_id, workflow_key, state=state, user=user)
132133

133134
@celery_task(name=RESUME_TASK) # type: ignore
134-
def resume_task(process_id: UUID, user: str) -> UUID | None:
135+
def resume_task(process_id: UUID, user: str, user_model: Authorizer | None = None) -> UUID | None:
135136
local_logger.info("Resume task", process_id=process_id)
136137
state = retrieve_input_state(process_id, "user_input").input_state
137-
return resume_process(process_id, user_inputs=state, user=user)
138+
return resume_process(process_id, user_inputs=state, user=user, user_model=user_model)
138139

139140
@celery_task(name=RESUME_WORKFLOW) # type: ignore
140-
def resume_workflow(process_id: UUID, user: str) -> UUID | None:
141+
def resume_workflow(process_id: UUID, user: str, user_model: Authorizer | None = None) -> UUID | None:
141142
local_logger.info("Resume workflow", process_id=process_id)
142143
state = retrieve_input_state(process_id, "user_input").input_state
143-
return resume_process(process_id, user_inputs=state, user=user)
144+
return resume_process(process_id, user_inputs=state, user=user, user_model=user_model)
144145

145146

146147
class CeleryJobWorkerStatus(WorkerStatus):

0 commit comments

Comments
 (0)