Skip to content

Commit 1edc45a

Browse files
author
gdj0nes
committed
ADD: batch creation on project
1 parent a351926 commit 1edc45a

File tree

6 files changed

+93
-9
lines changed

6 files changed

+93
-9
lines changed

labelbox/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
name = "labelbox"
22
__version__ = "3.16.0"
33

4-
from labelbox.schema.project import Project
54
from labelbox.client import Client
5+
from labelbox.schema.project import Project
66
from labelbox.schema.model import Model
77
from labelbox.schema.bulk_import_request import BulkImportRequest
88
from labelbox.schema.annotation_import import MALPredictionImport, MEAPredictionImport, LabelImport
99
from labelbox.schema.dataset import Dataset
1010
from labelbox.schema.data_row import DataRow
1111
from labelbox.schema.label import Label
12+
from labelbox.schema.batch import Batch, BatchPriority
1213
from labelbox.schema.review import Review
1314
from labelbox.schema.user import User
1415
from labelbox.schema.organization import Organization

labelbox/schema/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@
1818
import labelbox.schema.user
1919
import labelbox.schema.webhook
2020
import labelbox.schema.data_row_metadata
21+
import labelbox.schema.batch
2122
import labelbox.schema.iam_integration

labelbox/schema/batch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
from typing import NewType
2+
13
from labelbox.orm.db_object import DbObject
24
from labelbox.orm.model import Field, Relationship
35

6+
BatchPriority = NewType('BatchPriority', int)
7+
48

59
class Batch(DbObject):
610
""" A Batch is a group of data rows submitted to a project for labeling

labelbox/schema/project.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,60 @@ def setup(self, labeling_frontend, labeling_frontend_options) -> None:
569569
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
570570
self.update(setup_complete=timestamp)
571571

572+
def create_batch(self, name: str, data_rows: List[Union[str, "DataRow"]], priority: "BatchPriority" = 5):
573+
"""Create a new batch for a project
574+
575+
Args:
576+
name: a name for the batch, must be unique within a project
577+
data_rows: Either a list of `DataRows` or Data Row ids
578+
priority: An optional priority for the Data Rows in the Batch. 1 highest -> 5 lowest
579+
580+
"""
581+
582+
# @TODO: make this automatic?
583+
if self.queue_mode() != QueueMode.Batch:
584+
raise ValueError("Project must be in batch mode")
585+
586+
dr_ids = []
587+
for dr in data_rows:
588+
from labelbox import DataRow
589+
if isinstance(dr, DataRow):
590+
dr_ids.append(dr.uid)
591+
elif isinstance(dr, str):
592+
dr_ids.append(dr)
593+
else:
594+
raise ValueError("Must pass")
595+
596+
if len(dr_ids) > 25_000:
597+
raise ValueError(
598+
f"Batch exceeds max size, break into smaller batches"
599+
)
600+
if not len(dr_ids):
601+
raise ValueError("You need at least one data row in a batch")
602+
603+
method = 'createBatch'
604+
query_str = """mutation %sPyApi($projectId: ID!, $batchInput: CreateBatchInput!) {
605+
project(where: {id: $projectId}) {
606+
%s(input: $batchInput) {
607+
%s
608+
}
609+
}
610+
}
611+
""" % (method, method, query.results_query_part(Entity.Batch))
612+
613+
params = {
614+
"projectId": self.uid,
615+
"batchInput": {
616+
"name": name,
617+
"dataRowIds": dr_ids,
618+
"priority": priority
619+
}
620+
}
621+
622+
res = self.client.execute(query_str, params, experimental=True)["project"][method]
623+
res['size'] = len(dr_ids)
624+
return Entity.Batch(self.client, res)
625+
572626
def _update_queue_mode(self, mode: QueueMode) -> QueueMode:
573627

574628
if self.queue_mode() == mode:
@@ -885,7 +939,7 @@ class LabelingParameterOverride(DbObject):
885939

886940
LabelerPerformance = namedtuple(
887941
"LabelerPerformance", "user count seconds_per_label, total_time_labeling "
888-
"consensus average_benchmark_agreement last_activity_time")
942+
"consensus average_benchmark_agreement last_activity_time")
889943
LabelerPerformance.__doc__ = (
890944
"Named tuple containing info about a labeler's performance.")
891945

tests/integration/conftest.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,21 @@
11
import os
22
import re
3-
import uuid
43
import time
5-
from datetime import datetime
4+
import uuid
65
from enum import Enum
7-
from random import randint
8-
from string import ascii_letters
96
from types import SimpleNamespace
107

118
import pytest
129
import requests
1310

1411
from labelbox import Client
1512
from labelbox import LabelingFrontend
13+
from labelbox import OntologyBuilder, Tool, Option, Classification
1614
from labelbox.orm import query
1715
from labelbox.pagination import PaginatedCollection
16+
from labelbox.schema.annotation_import import LabelImport
1817
from labelbox.schema.invite import Invite
1918
from labelbox.schema.user import User
20-
from labelbox import OntologyBuilder, Tool, Option, Classification
21-
from labelbox.schema.annotation_import import LabelImport
2219

2320
IMG_URL = "https://picsum.photos/200/300"
2421

@@ -256,7 +253,7 @@ def configured_project_with_label(client, rand_gen, image_url, project, dataset,
256253
Tool(tool=Tool.Type.BBOX, name="test-bbox-class"),
257254
])
258255
project.setup(editor, ontology_builder.asdict())
259-
#TODO: ontology may not be synchronous after setup. remove sleep when api is more consistent
256+
# TODO: ontology may not be synchronous after setup. remove sleep when api is more consistent
260257
time.sleep(2)
261258

262259
ontology = ontology_builder.from_project(project)

tests/integration/test_batch.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
3+
from labelbox import Dataset, Project
4+
from labelbox.schema.project import QueueMode
5+
6+
IMAGE_URL = "https://storage.googleapis.com/diagnostics-demo-data/coco/COCO_train2014_000000000034.jpg"
7+
8+
9+
@pytest.fixture
10+
def big_dataset(dataset: Dataset):
11+
task = dataset.create_data_rows([
12+
{
13+
"row_data": IMAGE_URL,
14+
"external_id": "my-image"
15+
},
16+
] * 250)
17+
task.wait_till_done()
18+
19+
yield dataset
20+
dataset.delete()
21+
22+
23+
def test_create_batch(configured_project: Project, big_dataset: Dataset):
24+
configured_project.update(queue_mode=QueueMode.Batch)
25+
26+
data_rows = [dr.uid for dr in list(big_dataset.export_data_rows())]
27+
queue_res = configured_project.create_batch("test-batch", data_rows, 3)

0 commit comments

Comments
 (0)