Skip to content

Commit 8b7713d

Browse files
author
Val Brodsky
committed
Add DisconncetedTask
1 parent ca3f880 commit 8b7713d

File tree

4 files changed

+116
-1
lines changed

4 files changed

+116
-1
lines changed

docs/labelbox/disconnected-task.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Disconnected Task
2+
===============================================================================================
3+
4+
.. automodule:: labelbox.schema.disconnected_task
5+
:members:
6+
:exclude-members: SupportsTaskQueries, TaskFactory
7+
:show-inheritance:
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Any, Dict, List, Optional, Protocol, Union
2+
3+
from labelbox.exceptions import OperationNotSupportedException, ResourceNotFoundError
4+
from labelbox.orm.db_object import experimental
5+
from labelbox.orm.model import Field
6+
from labelbox.schema.task import Task, DataUpsertTask
7+
8+
TaskDataType = Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]
9+
10+
11+
class SupportsTaskQueries(Protocol):
12+
13+
@property
14+
def status(self) -> Field:
15+
...
16+
17+
@property
18+
def errors(self) -> TaskDataType:
19+
...
20+
21+
@property
22+
def result(self) -> TaskDataType:
23+
...
24+
25+
26+
class TaskFactory:
27+
28+
@classmethod
29+
def get(cls, client, task_id) -> SupportsTaskQueries:
30+
return cls.get_task_by_id(client, task_id)
31+
32+
@classmethod
33+
def get_task_by_id(cls, client, task_id) -> SupportsTaskQueries:
34+
user = client.get_user()
35+
query = """
36+
query GetUserCreatedTasksPyApi($userId: ID!, $taskId: ID!) {
37+
user(where: {id: $userId}) {
38+
createdTasks(where: {id: $taskId} skip: 0 first: 1) {
39+
completionPercentage
40+
createdAt
41+
errors
42+
metadata
43+
name
44+
result
45+
status
46+
type
47+
id
48+
updatedAt
49+
}
50+
}
51+
}
52+
"""
53+
result = client.execute(query, {"userId": user.uid, "taskId": task_id})
54+
data = result.get("user", {}).get("createdTasks", [])
55+
if not data:
56+
raise ResourceNotFoundError(
57+
message=f"The task {task_id} does not exist.")
58+
task_data = data[0]
59+
if task_data["type"].lower() == 'adv-upsert-data-rows':
60+
task = DataUpsertTask(client, task_data)
61+
else:
62+
task = Task(client, task_data)
63+
64+
task._user = user
65+
return task
66+
67+
68+
class DisconnectedTask:
69+
"""
70+
A class to interact with a task by task id.
71+
72+
This class can be used to avoid waiting to tasks that take too long
73+
"""
74+
75+
@experimental
76+
def __init__(self, client, task_id: str):
77+
self.task = TaskFactory.get(client, task_id)
78+
79+
def is_completed(self) -> bool:
80+
return str(self.task.status) != "IN_PROGRESS"
81+
82+
def status(self) -> str:
83+
return str(self.task.status)
84+
85+
def errors(self) -> TaskDataType:
86+
if not self.is_completed():
87+
raise OperationNotSupportedException("Task is not completed yet")
88+
return self.task.errors
89+
90+
def result(self) -> TaskDataType:
91+
if not self.is_completed():
92+
raise OperationNotSupportedException("Task is not completed yet")
93+
return self.task.result

libs/labelbox/src/labelbox/schema/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import requests
44
import time
5-
from typing import TYPE_CHECKING, Callable, Optional, Dict, Any, List, Union, Final
5+
from typing import TYPE_CHECKING, Callable, Optional, Dict, Any, List, Union
66
from labelbox import parser
77

88
from labelbox.exceptions import ResourceNotFoundError

libs/labelbox/tests/integration/test_task.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from labelbox.schema.disconnected_task import DisconnectedTask
23
import pytest
34
import collections.abc
45
from labelbox import DataRow
@@ -31,6 +32,14 @@ def test_task_errors(dataset, image_url, snapshot):
3132
0]['message']
3233
assert len(task.failed_data_rows[0]['failedDataRows'][0]['metadata']) == 2
3334

35+
dt = DisconnectedTask(client, task.uid)
36+
assert dt.is_completed()
37+
assert dt.status() == "COMPLETE"
38+
assert len(dt.errors()) == 1
39+
assert dt.errors()[0]['message'].startswith(
40+
"A schemaId can only be specified once per DataRow")
41+
assert dt.result() is None
42+
3443

3544
def test_task_success_json(dataset, image_url, snapshot):
3645
client = dataset.client
@@ -57,3 +66,9 @@ def test_task_success_json(dataset, image_url, snapshot):
5766
snapshot.assert_match(json.dumps(task_result),
5867
'test_task.test_task_success_json.json')
5968
assert len(task.result)
69+
70+
dt = DisconnectedTask(client, task.uid)
71+
assert dt.is_completed()
72+
assert dt.status() == "COMPLETE"
73+
assert len(dt.result()) == 1
74+
assert dt.errors() is None

0 commit comments

Comments
 (0)