From 0eb74c3a914fe875af5d2b3740ecb0f6fb11cff8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=AE=D0=B4=D1=8B=D1=86=D0=BA=D0=B8=D0=B9=20=D0=98=D0=B3?= =?UTF-8?q?=D0=BE=D1=80=D1=8C?= Date: Tue, 29 Oct 2024 16:09:55 +0300 Subject: [PATCH 1/2] feat: AbortPipeline error propagated to the last step task result of pipeline --- taskiq_pipelines/middleware.py | 7 ++++--- tests/test_steps.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/taskiq_pipelines/middleware.py b/taskiq_pipelines/middleware.py index 49692d5..4ac88f8 100644 --- a/taskiq_pipelines/middleware.py +++ b/taskiq_pipelines/middleware.py @@ -108,9 +108,9 @@ async def on_error( return if current_step_num == len(steps) - 1: return - await self.fail_pipeline(steps[-1].task_id) + await self.fail_pipeline(steps[-1].task_id, result.error) - async def fail_pipeline(self, last_task_id: str) -> None: + async def fail_pipeline(self, last_task_id: str, abort: AbortPipeline | None = None) -> None: """ This function aborts pipeline. @@ -118,13 +118,14 @@ async def fail_pipeline(self, last_task_id: str) -> None: the last task in the pipeline. :param last_task_id: id of the last task. + :param abort: caught earlier exception or default """ await self.broker.result_backend.set_result( last_task_id, TaskiqResult( is_err=True, return_value=None, # type: ignore - error=AbortPipeline("Execution aborted."), + error=abort or AbortPipeline("Execution aborted."), execution_time=0, log="Error found while executing pipeline.", ), diff --git a/tests/test_steps.py b/tests/test_steps.py index 1a02d16..7e72187 100644 --- a/tests/test_steps.py +++ b/tests/test_steps.py @@ -1,9 +1,10 @@ +import sys from typing import List import pytest from taskiq import InMemoryBroker -from taskiq_pipelines import Pipeline, PipelineMiddleware +from taskiq_pipelines import Pipeline, PipelineMiddleware, AbortPipeline @pytest.mark.anyio @@ -42,3 +43,32 @@ def double(i: int) -> int: sent = await pipe.kiq(4) res = await sent.wait_result() assert res.return_value == list(map(double, ranger(4))) + + +@pytest.mark.anyio +async def test_abort_pipeline() -> None: + """Test AbortPipeline.""" + broker = InMemoryBroker().with_middlewares(PipelineMiddleware()) + text = 'task was aborted' + + @broker.task + def normal_task(i: bool) -> bool: + return i + + @broker.task + def aborting_task(i: int) -> bool: + if i: + raise AbortPipeline(text) + return True + + pipe = Pipeline(broker, aborting_task).call_next(normal_task) + sent = await pipe.kiq(0) + res = await sent.wait_result() + assert res.is_err is False + assert res.return_value is True + assert res.error is None + sent = await pipe.kiq(1) + res = await sent.wait_result() + assert res.is_err is True + assert res.return_value is None + assert res.error.args[0] == text From 9e0285ca52c4dd103efc352c0e77d12b74adf8ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=AE=D0=B4=D1=8B=D1=86=D0=BA=D0=B8=D0=B9=20=D0=98=D0=B3?= =?UTF-8?q?=D0=BE=D1=80=D1=8C?= Date: Wed, 30 Oct 2024 16:55:21 +0300 Subject: [PATCH 2/2] Fixed for linters --- taskiq_pipelines/middleware.py | 8 ++++++-- tests/test_steps.py | 5 ++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/taskiq_pipelines/middleware.py b/taskiq_pipelines/middleware.py index 4ac88f8..27f8232 100644 --- a/taskiq_pipelines/middleware.py +++ b/taskiq_pipelines/middleware.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Any, List +from typing import Any, List, Optional import pydantic from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult @@ -110,7 +110,11 @@ async def on_error( return await self.fail_pipeline(steps[-1].task_id, result.error) - async def fail_pipeline(self, last_task_id: str, abort: AbortPipeline | None = None) -> None: + async def fail_pipeline( + self, + last_task_id: str, + abort: Optional[BaseException] = None, + ) -> None: """ This function aborts pipeline. diff --git a/tests/test_steps.py b/tests/test_steps.py index 7e72187..28292c2 100644 --- a/tests/test_steps.py +++ b/tests/test_steps.py @@ -1,10 +1,9 @@ -import sys from typing import List import pytest from taskiq import InMemoryBroker -from taskiq_pipelines import Pipeline, PipelineMiddleware, AbortPipeline +from taskiq_pipelines import AbortPipeline, Pipeline, PipelineMiddleware @pytest.mark.anyio @@ -49,7 +48,7 @@ def double(i: int) -> int: async def test_abort_pipeline() -> None: """Test AbortPipeline.""" broker = InMemoryBroker().with_middlewares(PipelineMiddleware()) - text = 'task was aborted' + text = "task was aborted" @broker.task def normal_task(i: bool) -> bool: