Skip to content

Commit 1486cfe

Browse files
author
gdj0nes
committed
WIP batches
1 parent dba0ce8 commit 1486cfe

File tree

3 files changed

+109
-66
lines changed

3 files changed

+109
-66
lines changed

labelbox/client.py

Lines changed: 64 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,25 @@
1-
# type: ignore
2-
from datetime import datetime, timezone
31
import json
4-
from typing import List, Dict
5-
from collections import defaultdict
6-
72
import logging
83
import mimetypes
94
import os
5+
from collections import defaultdict
6+
from datetime import datetime, timezone
7+
from typing import List, Dict
108

11-
from google.api_core import retry
129
import requests
1310
import requests.exceptions
11+
from google.api_core import retry
1412

1513
import labelbox.exceptions
16-
from labelbox import utils
1714
from labelbox import __version__ as SDK_VERSION
15+
from labelbox import utils
1816
from labelbox.orm import query
1917
from labelbox.orm.db_object import DbObject
2018
from labelbox.orm.model import Entity
2119
from labelbox.pagination import PaginatedCollection
20+
from labelbox.schema import role
2221
from labelbox.schema.data_row_metadata import DataRowMetadataOntology
2322
from labelbox.schema.iam_integration import IAMIntegration
24-
from labelbox.schema import role
2523
from labelbox.schema.ontology import Tool, Classification
2624

2725
logger = logging.getLogger(__name__)
@@ -354,7 +352,7 @@ def upload_data(self,
354352
data=request_data,
355353
files={
356354
"1": (filename, content, content_type) if
357-
(filename and content_type) else content
355+
(filename and content_type) else content
358356
})
359357

360358
if response.status_code == 502:
@@ -518,7 +516,7 @@ def _create(self, db_object_type, data):
518516
# Also convert Labelbox object values to their UIDs.
519517
data = {
520518
db_object_type.attribute(attr) if isinstance(attr, str) else attr:
521-
value.uid if isinstance(value, DbObject) else value
519+
value.uid if isinstance(value, DbObject) else value
522520
for attr, value in data.items()
523521
}
524522

@@ -702,8 +700,8 @@ def get_data_row_ids_for_external_ids(
702700
for i in range(0, len(external_ids), max_ids_per_request):
703701
for row in self.execute(
704702
query_str,
705-
{'externalId_in': external_ids[i:i + max_ids_per_request]
706-
})['externalIdsToDataRowIds']:
703+
{'externalId_in': external_ids[i:i + max_ids_per_request]
704+
})['externalIdsToDataRowIds']:
707705
result[row['externalId']].append(row['dataRowId'])
708706
return result
709707

@@ -896,3 +894,57 @@ def create_feature_schema(self, normalized):
896894
# But the features are the same so we just grab the feature schema id
897895
res['id'] = res['normalized']['featureSchemaId']
898896
return Entity.FeatureSchema(self, res)
897+
898+
def get_batch(self, batch_id: str):
899+
"""Gets a single Batch using its ID
900+
901+
Args:
902+
batch_id: Id of the batch
903+
904+
Returns:
905+
The sought Batch
906+
"""
907+
908+
return self._get_single(Entity.Batch, batch_id)
909+
910+
def create_batch(self, name: str, project, data_rows: List[str], priority: int):
911+
"""Create a batch of data rows to send to a project
912+
913+
>>> data_rows = ['<data-row-id>', ...]
914+
>>> project = client.get("<project-id>")
915+
>>> client.create_batch(name="low-confidence-images", project=project, data_rows=data_rows)
916+
917+
Args:
918+
name: Descriptive name for the batch, must be unique per project
919+
project: The project to send the batch to
920+
data_rows: A list of data rows ids
921+
priority: the default priority for the datarows, lowest priority by default
922+
923+
Returns:
924+
The created batch
925+
"""
926+
927+
if isinstance(project, Entity.Project):
928+
project_id = project.uid
929+
elif isinstance(project, str):
930+
project_id = project
931+
else:
932+
raise ValueError("You must pass a project id or a Project")
933+
934+
data_row_ids = []
935+
for dr in data_rows:
936+
pass
937+
938+
query_str = """mutation createBatchPyApi($name: String!, $dataRowIds: [ID!]!, $priority: Int!){
939+
createBatch(){
940+
%s
941+
}
942+
}"""
943+
944+
result = self.execute(query_str, {
945+
"name": name,
946+
"projectId": project_id,
947+
"dataRowIds": data_row_ids,
948+
"priority": priority
949+
})
950+
return Entity.Batch(self, result['createModel'])

labelbox/schema/batch.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from typing import Dict, List
2+
3+
from labelbox.orm.db_object import DbObject
4+
from labelbox.orm.model import Field, Relationship
5+
6+
7+
class Batch(DbObject):
8+
""" A Batch is a group of data rows submitted to a project for labeling
9+
10+
Attributes:
11+
name (str)
12+
created_at (datetime)
13+
updated_at (datetime)
14+
deleted (bool)
15+
16+
project (Relationship): `ToOne` relationship to Project
17+
created_by (Relationship): `ToOne` relationship to User
18+
19+
"""
20+
name = Field.String("name")
21+
created_at = Field.DateTime("created_at")
22+
updated_at = Field.DateTime("updated_at")
23+
deleted = Field.Boolean()
24+
size = Field.Int("size")
25+
26+
# Relationships
27+
project = Relationship.ToOne("Project")
28+
created_by = Relationship.ToOne("User")
29+
30+
def export_data_rows(self) -> List[Dict]:
31+
"""Get the data rows associated with a batch"""
32+
33+
gql = """query
34+
35+
"""
36+
37+
self.client.execute()
38+
39+
return

labelbox/schema/project.py

Lines changed: 6 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,10 @@
22
import json
33
import logging
44
import time
5-
import warnings
65
from collections import namedtuple
76
from datetime import datetime, timezone
87
from pathlib import Path
9-
from typing import Dict, Union, Iterable, List, Optional
8+
from typing import Dict, Union, Iterable, Optional
109
from urllib.parse import urlparse
1110

1211
import ndjson
@@ -35,8 +34,6 @@
3534

3635
logger = logging.getLogger(__name__)
3736

38-
MAX_QUEUE_BATCH_SIZE = 1000
39-
4037

4138
class QueueMode(enum.Enum):
4239
Batch = "Batch"
@@ -154,6 +151,10 @@ def labels(self, datasets=None, order_by=None):
154151
return PaginatedCollection(self.client, query_str, {id_param: self.uid},
155152
["project", "labels"], Label)
156153

154+
def batches(self):
155+
"""Returns a generator of batches that are queued for the project"""
156+
# TODO
157+
157158
def export_queued_data_rows(self, timeout_seconds=120):
158159
""" Returns all data rows that are currently enqueued for this project.
159160
@@ -484,55 +485,6 @@ def setup(self, labeling_frontend, labeling_frontend_options):
484485
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
485486
self.update(setup_complete=timestamp)
486487

487-
def queue(self, data_row_ids: List[str]):
488-
"""Add Data Rows to the Project queue"""
489-
490-
method = "submitBatchOfDataRows"
491-
return self._post_batch(method, data_row_ids)
492-
493-
def dequeue(self, data_row_ids: List[str]):
494-
"""Remove Data Rows from the Project queue"""
495-
496-
method = "removeBatchOfDataRows"
497-
return self._post_batch(method, data_row_ids)
498-
499-
def _post_batch(self, method, data_row_ids: List[str]):
500-
"""Post batch methods"""
501-
502-
if self.queue_mode() != QueueMode.Batch:
503-
raise ValueError("Project must be in batch mode")
504-
505-
if len(data_row_ids) > MAX_QUEUE_BATCH_SIZE:
506-
raise ValueError(
507-
f"Batch exceeds max size of {MAX_QUEUE_BATCH_SIZE}, consider breaking it into parts"
508-
)
509-
510-
query = """mutation %sPyApi($projectId: ID!, $dataRowIds: [ID!]!) {
511-
project(where: {id: $projectId}) {
512-
%s(data: {dataRowIds: $dataRowIds}) {
513-
dataRows {
514-
dataRowId
515-
error
516-
}
517-
}
518-
}
519-
}
520-
""" % (method, method)
521-
522-
res = self.client.execute(query, {
523-
"projectId": self.uid,
524-
"dataRowIds": data_row_ids
525-
})["project"][method]["dataRows"]
526-
527-
# TODO: figure out error messaging
528-
if len(data_row_ids) == len(res):
529-
raise ValueError("No dataRows were submitted successfully")
530-
531-
if len(data_row_ids) > 0:
532-
warnings.warn("Some Data Rows were not submitted successfully")
533-
534-
return res
535-
536488
def _update_queue_mode(self, mode: QueueMode) -> QueueMode:
537489

538490
if self.queue_mode() == mode:
@@ -846,7 +798,7 @@ class LabelingParameterOverride(DbObject):
846798

847799
LabelerPerformance = namedtuple(
848800
"LabelerPerformance", "user count seconds_per_label, total_time_labeling "
849-
"consensus average_benchmark_agreement last_activity_time")
801+
"consensus average_benchmark_agreement last_activity_time")
850802
LabelerPerformance.__doc__ = (
851803
"Named tuple containing info about a labeler's performance.")
852804

0 commit comments

Comments
 (0)