Skip to content

Commit 6a79d46

Browse files
committed
Add cancelable types and additional check to cancel_task()
1 parent f4ca457 commit 6a79d46

File tree

1 file changed

+46
-2
lines changed

1 file changed

+46
-2
lines changed

libs/labelbox/src/labelbox/client.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
from labelbox.schema.slice import CatalogSlice, ModelSlice
8080
from labelbox.schema.task import DataUpsertTask, Task
8181
from labelbox.schema.user import User
82+
from labelbox.schema.taskstatus import TaskStatus
8283

8384
logger = logging.getLogger(__name__)
8485

@@ -90,6 +91,9 @@ class Client:
9091
top-level data objects (Projects, Datasets).
9192
"""
9293

94+
# Class variable to cache task types
95+
_cancelable_task_types = None
96+
9397
def __init__(
9498
self,
9599
api_key=None,
@@ -2390,9 +2394,31 @@ def get_task_by_id(self, task_id: str) -> Union[Task, DataUpsertTask]:
23902394
task._user = user
23912395
return task
23922396

2397+
def _get_cancelable_task_types(self):
2398+
"""Internal method that returns a list of task types that can be canceled.
2399+
2400+
The result is cached after the first call to avoid unnecessary API requests.
2401+
2402+
Returns:
2403+
List[str]: List of cancelable task types in snake_case format
2404+
"""
2405+
if self._cancelable_task_types is None:
2406+
query = """query GetCancelableTaskTypes {
2407+
cancelableTaskTypes
2408+
}"""
2409+
2410+
result = self.execute(query).get("cancelableTaskTypes", [])
2411+
# Reformat to kebab case
2412+
self._cancelable_task_types = [
2413+
utils.snake_case(task_type).replace("_", "-")
2414+
for task_type in result
2415+
]
2416+
2417+
return self._cancelable_task_types
2418+
23932419
def cancel_task(self, task_id: str) -> bool:
23942420
"""
2395-
Cancels a task with the given ID.
2421+
Cancels a task with the given ID if the task type is cancelable and the task is in progress.
23962422
23972423
Args:
23982424
task_id (str): The ID of the task to cancel.
@@ -2401,8 +2427,26 @@ def cancel_task(self, task_id: str) -> bool:
24012427
bool: True if the task was successfully cancelled.
24022428
24032429
Raises:
2404-
LabelboxError: If the task could not be cancelled.
2430+
LabelboxError: If the task could not be cancelled, if the task type is not cancelable,
2431+
or if the task is not in progress.
2432+
ResourceNotFoundError: If the task does not exist (raised by get_task_by_id).
24052433
"""
2434+
# Get the task object to check its type and status
2435+
task = self.get_task_by_id(task_id)
2436+
2437+
# Check if task type is cancelable
2438+
cancelable_types = self._get_cancelable_task_types()
2439+
if task.type not in cancelable_types:
2440+
raise LabelboxError(
2441+
f"Task type '{task.type}' cannot be cancelled. Cancelable types are: {cancelable_types}"
2442+
)
2443+
2444+
# Check if task is in progress
2445+
if task.status_type != TaskStatus.In_Progress:
2446+
raise LabelboxError(
2447+
f"Task cannot be cancelled because it is not in progress. Current status: {task.status}"
2448+
)
2449+
24062450
mutation_str = """
24072451
mutation CancelTaskPyApi($id: ID!) {
24082452
cancelBulkOperationJob(id: $id) {

0 commit comments

Comments
 (0)