Skip to content

Commit 77cd2c5

Browse files
author
Gareth
authored
Merge pull request #505 from Labelbox/gj/batch
Batch SDK Beta
2 parents 76a5665 + e4738c2 commit 77cd2c5

File tree

8 files changed

+132
-25
lines changed

8 files changed

+132
-25
lines changed

labelbox/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
name = "labelbox"
22
__version__ = "3.16.0"
33

4-
from labelbox.schema.project import Project
54
from labelbox.client import Client
5+
from labelbox.schema.project import Project
66
from labelbox.schema.model import Model
77
from labelbox.schema.bulk_import_request import BulkImportRequest
88
from labelbox.schema.annotation_import import MALPredictionImport, MEAPredictionImport, LabelImport
99
from labelbox.schema.dataset import Dataset
1010
from labelbox.schema.data_row import DataRow
1111
from labelbox.schema.label import Label
12+
from labelbox.schema.batch import Batch
1213
from labelbox.schema.review import Review
1314
from labelbox.schema.user import User
1415
from labelbox.schema.organization import Organization

labelbox/orm/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ class Entity(metaclass=EntityMeta):
347347
Invite: Type[labelbox.Invite]
348348
InviteLimit: Type[labelbox.InviteLimit]
349349
ProjectRole: Type[labelbox.ProjectRole]
350+
Batch: Type[labelbox.Batch]
350351

351352
@classmethod
352353
def _attributes_of_type(cls, attr_type):

labelbox/schema/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@
1818
import labelbox.schema.user
1919
import labelbox.schema.webhook
2020
import labelbox.schema.data_row_metadata
21+
import labelbox.schema.batch
2122
import labelbox.schema.iam_integration

labelbox/schema/batch.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from labelbox.orm.db_object import DbObject
2+
from labelbox.orm.model import Field, Relationship
3+
4+
5+
class Batch(DbObject):
6+
""" A Batch is a group of data rows submitted to a project for labeling
7+
8+
Attributes:
9+
name (str)
10+
created_at (datetime)
11+
updated_at (datetime)
12+
deleted (bool)
13+
14+
project (Relationship): `ToOne` relationship to Project
15+
created_by (Relationship): `ToOne` relationship to User
16+
17+
"""
18+
name = Field.String("name")
19+
created_at = Field.DateTime("created_at")
20+
updated_at = Field.DateTime("updated_at")
21+
size = Field.Int("size")
22+
23+
# Relationships
24+
project = Relationship.ToOne("Project")
25+
created_by = Relationship.ToOne("User")

labelbox/schema/project.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import logging
44
import time
5+
import warnings
56
from collections import namedtuple
67
from datetime import datetime, timezone
78
from pathlib import Path
@@ -37,11 +38,6 @@
3738
logger = logging.getLogger(__name__)
3839

3940

40-
class QueueMode(enum.Enum):
41-
Batch = "Batch"
42-
Dataset = "Dataset"
43-
44-
4541
class Project(DbObject, Updateable, Deletable):
4642
""" A Project is a container that includes a labeling frontend, an ontology,
4743
datasets and labels.
@@ -89,9 +85,12 @@ class Project(DbObject, Updateable, Deletable):
8985
benchmarks = Relationship.ToMany("Benchmark", False)
9086
ontology = Relationship.ToOne("Ontology", True)
9187

92-
def update(self, **kwargs):
88+
class QueueMode(enum.Enum):
89+
Batch = "Batch"
90+
Dataset = "Dataset"
9391

94-
mode: Optional[QueueMode] = kwargs.pop("queue_mode", None)
92+
def update(self, **kwargs):
93+
mode: Optional[Project.QueueMode] = kwargs.pop("queue_mode", None)
9594
if mode:
9695
self._update_queue_mode(mode)
9796

@@ -569,14 +568,69 @@ def setup(self, labeling_frontend, labeling_frontend_options) -> None:
569568
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
570569
self.update(setup_complete=timestamp)
571570

572-
def _update_queue_mode(self, mode: QueueMode) -> QueueMode:
571+
def create_batch(self, name: str, data_rows: List[str], priority: int = 5):
572+
"""Create a new batch for a project. Batches is in Beta and subject to change
573+
574+
Args:
575+
name: a name for the batch, must be unique within a project
576+
data_rows: Either a list of `DataRows` or Data Row ids
577+
priority: An optional priority for the Data Rows in the Batch. 1 highest -> 5 lowest
578+
579+
"""
580+
581+
# @TODO: make this automatic?
582+
if self.queue_mode() != Project.QueueMode.Batch:
583+
raise ValueError("Project must be in batch mode")
584+
585+
dr_ids = []
586+
for dr in data_rows:
587+
if isinstance(dr, Entity.DataRow):
588+
dr_ids.append(dr.uid)
589+
elif isinstance(dr, str):
590+
dr_ids.append(dr)
591+
else:
592+
raise ValueError("You can DataRow ids or DataRow objects")
593+
594+
if len(dr_ids) > 25_000:
595+
raise ValueError(
596+
f"Batch exceeds max size, break into smaller batches")
597+
if not len(dr_ids):
598+
raise ValueError("You need at least one data row in a batch")
599+
600+
method = 'createBatch'
601+
query_str = """mutation %sPyApi($projectId: ID!, $batchInput: CreateBatchInput!) {
602+
project(where: {id: $projectId}) {
603+
%s(input: $batchInput) {
604+
%s
605+
}
606+
}
607+
}
608+
""" % (method, method, query.results_query_part(Entity.Batch))
609+
610+
params = {
611+
"projectId": self.uid,
612+
"batchInput": {
613+
"name": name,
614+
"dataRowIds": dr_ids,
615+
"priority": priority
616+
}
617+
}
618+
619+
res = self.client.execute(query_str, params,
620+
experimental=True)["project"][method]
621+
622+
res['size'] = len(dr_ids)
623+
return Entity.Batch(self.client, res)
624+
625+
def _update_queue_mode(self,
626+
mode: "Project.QueueMode") -> "Project.QueueMode":
573627

574628
if self.queue_mode() == mode:
575629
return mode
576630

577-
if mode == QueueMode.Batch:
631+
if mode == Project.QueueMode.Batch:
578632
status = "ENABLED"
579-
elif mode == QueueMode.Dataset:
633+
elif mode == Project.QueueMode.Dataset:
580634
status = "DISABLED"
581635
else:
582636
raise ValueError(
@@ -598,7 +652,7 @@ def _update_queue_mode(self, mode: QueueMode) -> QueueMode:
598652

599653
return mode
600654

601-
def queue_mode(self) -> QueueMode:
655+
def queue_mode(self) -> "Project.QueueMode":
602656
"""Provides the status of if queue mode is enabled in the project."""
603657

604658
query_str = """query %s($projectId: ID!) {
@@ -612,9 +666,9 @@ def queue_mode(self) -> QueueMode:
612666
query_str, {'projectId': self.uid})["project"]["tagSetStatus"]
613667

614668
if status == "ENABLED":
615-
return QueueMode.Batch
669+
return Project.QueueMode.Batch
616670
elif status == "DISABLED":
617-
return QueueMode.Dataset
671+
return Project.QueueMode.Dataset
618672
else:
619673
raise ValueError("Status not known")
620674

tests/integration/conftest.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,21 @@
11
import os
22
import re
3-
import uuid
43
import time
5-
from datetime import datetime
4+
import uuid
65
from enum import Enum
7-
from random import randint
8-
from string import ascii_letters
96
from types import SimpleNamespace
107

118
import pytest
129
import requests
1310

1411
from labelbox import Client
1512
from labelbox import LabelingFrontend
13+
from labelbox import OntologyBuilder, Tool, Option, Classification
1614
from labelbox.orm import query
1715
from labelbox.pagination import PaginatedCollection
16+
from labelbox.schema.annotation_import import LabelImport
1817
from labelbox.schema.invite import Invite
1918
from labelbox.schema.user import User
20-
from labelbox import OntologyBuilder, Tool, Option, Classification
21-
from labelbox.schema.annotation_import import LabelImport
2219

2320
IMG_URL = "https://picsum.photos/200/300"
2421

@@ -256,7 +253,7 @@ def configured_project_with_label(client, rand_gen, image_url, project, dataset,
256253
Tool(tool=Tool.Type.BBOX, name="test-bbox-class"),
257254
])
258255
project.setup(editor, ontology_builder.asdict())
259-
#TODO: ontology may not be synchronous after setup. remove sleep when api is more consistent
256+
# TODO: ontology may not be synchronous after setup. remove sleep when api is more consistent
260257
time.sleep(2)
261258

262259
ontology = ontology_builder.from_project(project)

tests/integration/test_batch.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import pytest
2+
3+
from labelbox import Dataset, Project
4+
5+
IMAGE_URL = "https://storage.googleapis.com/diagnostics-demo-data/coco/COCO_train2014_000000000034.jpg"
6+
7+
8+
@pytest.fixture
9+
def big_dataset(dataset: Dataset):
10+
task = dataset.create_data_rows([
11+
{
12+
"row_data": IMAGE_URL,
13+
"external_id": "my-image"
14+
},
15+
] * 250)
16+
task.wait_till_done()
17+
18+
yield dataset
19+
dataset.delete()
20+
21+
22+
def test_create_batch(configured_project: Project, big_dataset: Dataset):
23+
configured_project.update(queue_mode=Project.QueueMode.Batch)
24+
25+
data_rows = [dr.uid for dr in list(big_dataset.export_data_rows())]
26+
batch = configured_project.create_batch("test-batch", data_rows, 3)
27+
assert batch.name == 'test-batch'
28+
assert batch.size == len(data_rows)

tests/integration/test_project.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from labelbox import Project, LabelingFrontend
77
from labelbox.exceptions import InvalidQueryError
8-
from labelbox.schema.project import QueueMode
98

109

1110
def test_project(client, rand_gen):
@@ -181,6 +180,7 @@ def test_queued_data_row_export(configured_project):
181180

182181

183182
def test_queue_mode(configured_project: Project):
184-
assert configured_project.queue_mode() == QueueMode.Dataset
185-
configured_project.update(queue_mode=QueueMode.Batch)
186-
assert configured_project.queue_mode() == QueueMode.Batch
183+
assert configured_project.queue_mode(
184+
) == configured_project.QueueMode.Dataset
185+
configured_project.update(queue_mode=configured_project.QueueMode.Batch)
186+
assert configured_project.queue_mode() == configured_project.QueueMode.Batch

0 commit comments

Comments
 (0)