diff --git a/taskiq_pipelines/middleware.py b/taskiq_pipelines/middleware.py index 49692d5..27f8232 100644 --- a/taskiq_pipelines/middleware.py +++ b/taskiq_pipelines/middleware.py @@ -1,5 +1,5 @@ from logging import getLogger -from typing import Any, List +from typing import Any, List, Optional import pydantic from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult @@ -108,9 +108,13 @@ async def on_error( return if current_step_num == len(steps) - 1: return - await self.fail_pipeline(steps[-1].task_id) + await self.fail_pipeline(steps[-1].task_id, result.error) - async def fail_pipeline(self, last_task_id: str) -> None: + async def fail_pipeline( + self, + last_task_id: str, + abort: Optional[BaseException] = None, + ) -> None: """ This function aborts pipeline. @@ -118,13 +122,14 @@ async def fail_pipeline(self, last_task_id: str) -> None: the last task in the pipeline. :param last_task_id: id of the last task. + :param abort: caught earlier exception or default """ await self.broker.result_backend.set_result( last_task_id, TaskiqResult( is_err=True, return_value=None, # type: ignore - error=AbortPipeline("Execution aborted."), + error=abort or AbortPipeline("Execution aborted."), execution_time=0, log="Error found while executing pipeline.", ), diff --git a/tests/test_steps.py b/tests/test_steps.py index 1a02d16..28292c2 100644 --- a/tests/test_steps.py +++ b/tests/test_steps.py @@ -3,7 +3,7 @@ import pytest from taskiq import InMemoryBroker -from taskiq_pipelines import Pipeline, PipelineMiddleware +from taskiq_pipelines import AbortPipeline, Pipeline, PipelineMiddleware @pytest.mark.anyio @@ -42,3 +42,32 @@ def double(i: int) -> int: sent = await pipe.kiq(4) res = await sent.wait_result() assert res.return_value == list(map(double, ranger(4))) + + +@pytest.mark.anyio +async def test_abort_pipeline() -> None: + """Test AbortPipeline.""" + broker = InMemoryBroker().with_middlewares(PipelineMiddleware()) + text = "task was aborted" + + @broker.task + def normal_task(i: bool) -> bool: + return i + + @broker.task + def aborting_task(i: int) -> bool: + if i: + raise AbortPipeline(text) + return True + + pipe = Pipeline(broker, aborting_task).call_next(normal_task) + sent = await pipe.kiq(0) + res = await sent.wait_result() + assert res.is_err is False + assert res.return_value is True + assert res.error is None + sent = await pipe.kiq(1) + res = await sent.wait_result() + assert res.is_err is True + assert res.return_value is None + assert res.error.args[0] == text