Skip to content

Commit f4ca457

Browse files
committed
Add tasks to organization and TaskStatus to filter tasks
1 parent 59fbc85 commit f4ca457

File tree

6 files changed

+63
-1
lines changed

6 files changed

+63
-1
lines changed

libs/labelbox/src/labelbox/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,4 @@
9898
PromptResponseClassification,
9999
)
100100
from lbox.exceptions import *
101+
from labelbox.schema.taskstatus import TaskStatus

libs/labelbox/src/labelbox/schema/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,4 @@
2626
import labelbox.schema.catalog
2727
import labelbox.schema.ontology_kind
2828
import labelbox.schema.project_overview
29+
import labelbox.schema.taskstatus

libs/labelbox/src/labelbox/schema/organization.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def __init__(self, *args, **kwargs):
5252
projects = Relationship.ToMany("Project", True)
5353
webhooks = Relationship.ToMany("Webhook", False)
5454
resource_tags = Relationship.ToMany("ResourceTags", False)
55+
tasks = Relationship.ToMany("Task", False, "tasks")
5556

5657
def invite_user(
5758
self,

libs/labelbox/src/labelbox/schema/task.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from labelbox.schema.internal.datarow_upload_constants import (
1515
DOWNLOAD_RESULT_PAGE_SIZE,
1616
)
17+
from labelbox.schema.taskstatus import TaskStatus
1718

1819
if TYPE_CHECKING:
1920
from labelbox import User
@@ -45,6 +46,9 @@ class Task(DbObject):
4546
created_at = Field.DateTime("created_at")
4647
name = Field.String("name")
4748
status = Field.String("status")
49+
status_type = Field.Enum(
50+
TaskStatus, "status_type", "status"
51+
) # additional status for filtering
4852
completion_percentage = Field.Float("completion_percentage")
4953
result_url = Field.String("result_url", "result")
5054
errors_url = Field.String("errors_url", "errors")
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from enum import Enum
2+
3+
4+
class TaskStatus(str, Enum):
5+
In_Progress = "IN_PROGRESS"
6+
Complete = "COMPLETE"
7+
Canceling = "CANCELLING"
8+
Canceled = "CANCELED"
9+
Failed = "FAILED"
10+
Unknown = "UNKNOWN"
11+
12+
@classmethod
13+
def _missing_(cls, value):
14+
"""Handle missing or unknown task status values.
15+
16+
If a task status value is not found in the enum, this method returns
17+
the Unknown status instead of raising an error.
18+
19+
Args:
20+
value: The status value that doesn't match any enum member
21+
22+
Returns:
23+
TaskStatus.Unknown: The default status for unrecognized values
24+
"""
25+
return cls.Unknown

libs/labelbox/tests/data/export/streamable/test_export_data_rows_streamable.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import time
22

3-
from labelbox import DataRow, ExportTask, StreamType
3+
from labelbox import DataRow, ExportTask, StreamType, Task, TaskStatus
44

55

66
class TestExportDataRow:
@@ -135,3 +135,33 @@ def test_cancel_export_task(
135135
# Verify the task was cancelled
136136
cancelled_task = client.get_task_by_id(export_task.uid)
137137
assert cancelled_task.status in ["CANCELING", "CANCELED"]
138+
139+
def test_task_filter(self, client, data_row, wait_for_data_row_processing):
140+
organization = client.get_organization()
141+
user = client.get_user()
142+
143+
export_task = DataRow.export(
144+
client=client,
145+
data_rows=[data_row],
146+
task_name="TestExportDataRow:test_task_filter",
147+
)
148+
149+
# Check if task is listed "in progress" in organization's tasks
150+
org_tasks_in_progress = organization.tasks(
151+
where=Task.status_type == TaskStatus.In_Progress
152+
)
153+
retrieved_task_in_progress = next(
154+
(t for t in org_tasks_in_progress if t.uid == export_task.uid), ""
155+
)
156+
assert getattr(retrieved_task_in_progress, "uid", "") == export_task.uid
157+
158+
export_task.wait_till_done()
159+
160+
# Check if task is listed "complete" in user's created tasks
161+
user_tasks_complete = user.created_tasks(
162+
where=Task.status_type == TaskStatus.Complete
163+
)
164+
retrieved_task_complete = next(
165+
(t for t in user_tasks_complete if t.uid == export_task.uid), ""
166+
)
167+
assert getattr(retrieved_task_complete, "uid", "") == export_task.uid

0 commit comments

Comments
 (0)