Skip to content

Commit bf3ff3a

Browse files
authored
Add tasks to organization and TaskStatus to filter tasks (#1958)
2 parents af7e100 + 520479a commit bf3ff3a

File tree

7 files changed

+109
-3
lines changed

7 files changed

+109
-3
lines changed

libs/labelbox/src/labelbox/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,4 @@
9898
PromptResponseClassification,
9999
)
100100
from lbox.exceptions import *
101+
from labelbox.schema.taskstatus import TaskStatus

libs/labelbox/src/labelbox/client.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979
from labelbox.schema.slice import CatalogSlice, ModelSlice
8080
from labelbox.schema.task import DataUpsertTask, Task
8181
from labelbox.schema.user import User
82+
from labelbox.schema.taskstatus import TaskStatus
8283

8384
logger = logging.getLogger(__name__)
8485

@@ -90,6 +91,9 @@ class Client:
9091
top-level data objects (Projects, Datasets).
9192
"""
9293

94+
# Class variable to cache task types
95+
_cancelable_task_types = None
96+
9397
def __init__(
9498
self,
9599
api_key=None,
@@ -2390,9 +2394,31 @@ def get_task_by_id(self, task_id: str) -> Union[Task, DataUpsertTask]:
23902394
task._user = user
23912395
return task
23922396

2397+
def _get_cancelable_task_types(self):
2398+
"""Internal method that returns a list of task types that can be canceled.
2399+
2400+
The result is cached after the first call to avoid unnecessary API requests.
2401+
2402+
Returns:
2403+
List[str]: List of cancelable task types in snake_case format
2404+
"""
2405+
if self._cancelable_task_types is None:
2406+
query = """query GetCancelableTaskTypesPyApi {
2407+
cancelableTaskTypes
2408+
}"""
2409+
2410+
result = self.execute(query).get("cancelableTaskTypes", [])
2411+
# Reformat to kebab case
2412+
self._cancelable_task_types = [
2413+
utils.snake_case(task_type).replace("_", "-")
2414+
for task_type in result
2415+
]
2416+
2417+
return self._cancelable_task_types
2418+
23932419
def cancel_task(self, task_id: str) -> bool:
23942420
"""
2395-
Cancels a task with the given ID.
2421+
Cancels a task with the given ID if the task type is cancelable and the task is in progress.
23962422
23972423
Args:
23982424
task_id (str): The ID of the task to cancel.
@@ -2401,8 +2427,26 @@ def cancel_task(self, task_id: str) -> bool:
24012427
bool: True if the task was successfully cancelled.
24022428
24032429
Raises:
2404-
LabelboxError: If the task could not be cancelled.
2430+
LabelboxError: If the task could not be cancelled, if the task type is not cancelable,
2431+
or if the task is not in progress.
2432+
ResourceNotFoundError: If the task does not exist (raised by get_task_by_id).
24052433
"""
2434+
# Get the task object to check its type and status
2435+
task = self.get_task_by_id(task_id)
2436+
2437+
# Check if task type is cancelable
2438+
cancelable_types = self._get_cancelable_task_types()
2439+
if task.type not in cancelable_types:
2440+
raise LabelboxError(
2441+
f"Task type '{task.type}' cannot be cancelled. Cancelable types are: {cancelable_types}"
2442+
)
2443+
2444+
# Check if task is in progress
2445+
if task.status_as_enum != TaskStatus.In_Progress:
2446+
raise LabelboxError(
2447+
f"Task cannot be cancelled because it is not in progress. Current status: {task.status}"
2448+
)
2449+
24062450
mutation_str = """
24072451
mutation CancelTaskPyApi($id: ID!) {
24082452
cancelBulkOperationJob(id: $id) {

libs/labelbox/src/labelbox/schema/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@
2626
import labelbox.schema.catalog
2727
import labelbox.schema.ontology_kind
2828
import labelbox.schema.project_overview
29+
import labelbox.schema.taskstatus

libs/labelbox/src/labelbox/schema/organization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self, *args, **kwargs):
5252
projects = Relationship.ToMany("Project", True)
5353
webhooks = Relationship.ToMany("Webhook", False)
5454
resource_tags = Relationship.ToMany("ResourceTags", False)
55+
tasks = Relationship.ToMany("Task", False, "tasks")
5556

5657
def invite_user(
5758
self,

libs/labelbox/src/labelbox/schema/task.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from labelbox.schema.internal.datarow_upload_constants import (
1515
DOWNLOAD_RESULT_PAGE_SIZE,
1616
)
17+
from labelbox.schema.taskstatus import TaskStatus
1718

1819
if TYPE_CHECKING:
1920
from labelbox import User
@@ -45,6 +46,9 @@ class Task(DbObject):
4546
created_at = Field.DateTime("created_at")
4647
name = Field.String("name")
4748
status = Field.String("status")
49+
status_as_enum = Field.Enum(
50+
TaskStatus, "status_as_enum", "status"
51+
) # additional status for filtering
4852
completion_percentage = Field.Float("completion_percentage")
4953
result_url = Field.String("result_url", "result")
5054
errors_url = Field.String("errors_url", "errors")
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from enum import Enum
2+
3+
4+
class TaskStatus(str, Enum):
5+
In_Progress = "IN_PROGRESS"
6+
Complete = "COMPLETE"
7+
Canceling = "CANCELLING"
8+
Canceled = "CANCELED"
9+
Failed = "FAILED"
10+
Unknown = "UNKNOWN"
11+
12+
@classmethod
13+
def _missing_(cls, value):
14+
"""Handle missing or unknown task status values.
15+
16+
If a task status value is not found in the enum, this method returns
17+
the Unknown status instead of raising an error.
18+
19+
Args:
20+
value: The status value that doesn't match any enum member
21+
22+
Returns:
23+
TaskStatus.Unknown: The default status for unrecognized values
24+
"""
25+
return cls.Unknown

libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import time
22

3-
from labelbox import DataRow, ExportTask, StreamType
3+
from labelbox import DataRow, ExportTask, StreamType, Task, TaskStatus
44

55

66
class TestExportDataRow:
@@ -135,3 +135,33 @@ def test_cancel_export_task(
135135
# Verify the task was cancelled
136136
cancelled_task = client.get_task_by_id(export_task.uid)
137137
assert cancelled_task.status in ["CANCELING", "CANCELED"]
138+
139+
def test_task_filter(self, client, data_row, wait_for_data_row_processing):
140+
organization = client.get_organization()
141+
user = client.get_user()
142+
143+
export_task = DataRow.export(
144+
client=client,
145+
data_rows=[data_row],
146+
task_name="TestExportDataRow:test_task_filter",
147+
)
148+
149+
# Check if task is listed "in progress" in organization's tasks
150+
org_tasks_in_progress = organization.tasks(
151+
where=Task.status_as_enum == TaskStatus.In_Progress
152+
)
153+
retrieved_task_in_progress = next(
154+
(t for t in org_tasks_in_progress if t.uid == export_task.uid), ""
155+
)
156+
assert getattr(retrieved_task_in_progress, "uid", "") == export_task.uid
157+
158+
export_task.wait_till_done()
159+
160+
# Check if task is listed "complete" in user's created tasks
161+
user_tasks_complete = user.created_tasks(
162+
where=Task.status_as_enum == TaskStatus.Complete
163+
)
164+
retrieved_task_complete = next(
165+
(t for t in user_tasks_complete if t.uid == export_task.uid), ""
166+
)
167+
assert getattr(retrieved_task_complete, "uid", "") == export_task.uid

0 commit comments

Comments
 (0)