|
1 |
| -# type: ignore |
2 |
| -from datetime import datetime, timezone |
3 | 1 | import json
|
4 |
| -from typing import List, Dict |
5 |
| -from collections import defaultdict |
6 |
| - |
7 | 2 | import logging
|
8 | 3 | import mimetypes
|
9 | 4 | import os
|
| 5 | +from collections import defaultdict |
| 6 | +from datetime import datetime, timezone |
| 7 | +from typing import List, Dict |
10 | 8 |
|
11 |
| -from google.api_core import retry |
12 | 9 | import requests
|
13 | 10 | import requests.exceptions
|
| 11 | +from google.api_core import retry |
14 | 12 |
|
15 | 13 | import labelbox.exceptions
|
16 |
| -from labelbox import utils |
17 | 14 | from labelbox import __version__ as SDK_VERSION
|
| 15 | +from labelbox import utils |
18 | 16 | from labelbox.orm import query
|
19 | 17 | from labelbox.orm.db_object import DbObject
|
20 | 18 | from labelbox.orm.model import Entity
|
21 | 19 | from labelbox.pagination import PaginatedCollection
|
| 20 | +from labelbox.schema import role |
22 | 21 | from labelbox.schema.data_row_metadata import DataRowMetadataOntology
|
23 | 22 | from labelbox.schema.iam_integration import IAMIntegration
|
24 |
| -from labelbox.schema import role |
25 | 23 | from labelbox.schema.ontology import Tool, Classification
|
26 | 24 |
|
27 | 25 | logger = logging.getLogger(__name__)
|
@@ -354,7 +352,7 @@ def upload_data(self,
|
354 | 352 | data=request_data,
|
355 | 353 | files={
|
356 | 354 | "1": (filename, content, content_type) if
|
357 |
| - (filename and content_type) else content |
| 355 | + (filename and content_type) else content |
358 | 356 | })
|
359 | 357 |
|
360 | 358 | if response.status_code == 502:
|
@@ -518,7 +516,7 @@ def _create(self, db_object_type, data):
|
518 | 516 | # Also convert Labelbox object values to their UIDs.
|
519 | 517 | data = {
|
520 | 518 | db_object_type.attribute(attr) if isinstance(attr, str) else attr:
|
521 |
| - value.uid if isinstance(value, DbObject) else value |
| 519 | + value.uid if isinstance(value, DbObject) else value |
522 | 520 | for attr, value in data.items()
|
523 | 521 | }
|
524 | 522 |
|
@@ -702,8 +700,8 @@ def get_data_row_ids_for_external_ids(
|
702 | 700 | for i in range(0, len(external_ids), max_ids_per_request):
|
703 | 701 | for row in self.execute(
|
704 | 702 | query_str,
|
705 |
| - {'externalId_in': external_ids[i:i + max_ids_per_request] |
706 |
| - })['externalIdsToDataRowIds']: |
| 703 | + {'externalId_in': external_ids[i:i + max_ids_per_request] |
| 704 | + })['externalIdsToDataRowIds']: |
707 | 705 | result[row['externalId']].append(row['dataRowId'])
|
708 | 706 | return result
|
709 | 707 |
|
@@ -896,3 +894,57 @@ def create_feature_schema(self, normalized):
|
896 | 894 | # But the features are the same so we just grab the feature schema id
|
897 | 895 | res['id'] = res['normalized']['featureSchemaId']
|
898 | 896 | return Entity.FeatureSchema(self, res)
|
| 897 | + |
| 898 | + def get_batch(self, batch_id: str): |
| 899 | + """Gets a single Batch using its ID |
| 900 | +
|
| 901 | + Args: |
| 902 | + batch_id: Id of the batch |
| 903 | +
|
| 904 | + Returns: |
| 905 | + The sought Batch |
| 906 | + """ |
| 907 | + |
| 908 | + return self._get_single(Entity.Batch, batch_id) |
| 909 | + |
| 910 | + def create_batch(self, name: str, project, data_rows: List[str], priority: int): |
| 911 | + """Create a batch of data rows to send to a project |
| 912 | +
|
| 913 | + >>> data_rows = ['<data-row-id>', ...] |
| 914 | + >>> project = client.get("<project-id>") |
| 915 | + >>> client.create_batch(name="low-confidence-images", project=project, data_rows=data_rows) |
| 916 | +
|
| 917 | + Args: |
| 918 | + name: Descriptive name for the batch, must be unique per project |
| 919 | + project: The project to send the batch to |
| 920 | + data_rows: A list of data rows ids |
| 921 | + priority: the default priority for the datarows, lowest priority by default |
| 922 | +
|
| 923 | + Returns: |
| 924 | + The created batch |
| 925 | + """ |
| 926 | + |
| 927 | + if isinstance(project, Entity.Project): |
| 928 | + project_id = project.uid |
| 929 | + elif isinstance(project, str): |
| 930 | + project_id = project |
| 931 | + else: |
| 932 | + raise ValueError("You must pass a project id or a Project") |
| 933 | + |
| 934 | + data_row_ids = [] |
| 935 | + for dr in data_rows: |
| 936 | + pass |
| 937 | + |
| 938 | + query_str = """mutation createBatchPyApi($name: String!, $dataRowIds: [ID!]!, $priority: Int!){ |
| 939 | + createBatch(){ |
| 940 | + %s |
| 941 | + } |
| 942 | + }""" |
| 943 | + |
| 944 | + result = self.execute(query_str, { |
| 945 | + "name": name, |
| 946 | + "projectId": project_id, |
| 947 | + "dataRowIds": data_row_ids, |
| 948 | + "priority": priority |
| 949 | + }) |
| 950 | + return Entity.Batch(self, result['createModel']) |
0 commit comments