Skip to content

Commit 00b9976

Browse files
authored
Merge pull request #262 from Labelbox/ms/attachments
attachments
2 parents e5c6a58 + 16841e8 commit 00b9976

File tree

5 files changed

+149
-44
lines changed

5 files changed

+149
-44
lines changed

labelbox/schema/asset_attachment.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from enum import Enum
2+
from typing import Dict
23

34
from labelbox.orm.db_object import DbObject
45
from labelbox.orm.model import Field
@@ -24,3 +25,20 @@ class AttachmentType(Enum):
2425

2526
attachment_type = Field.String("attachment_type", "type")
2627
attachment_value = Field.String("attachment_value", "value")
28+
29+
@classmethod
30+
def validate_attachment_json(cls, attachment_json: Dict[str, str]) -> None:
31+
for required_key in ['type', 'value']:
32+
if required_key not in attachment_json:
33+
raise ValueError(
34+
f"Must provide a `{required_key}` key for each attachment. Found {attachment_json}."
35+
)
36+
cls.validate_attachment_type(attachment_json['type'])
37+
38+
@classmethod
39+
def validate_attachment_type(cls, attachment_type: str) -> None:
40+
valid_types = set(cls.AttachmentType.__members__)
41+
if attachment_type not in valid_types:
42+
raise ValueError(
43+
f"meta_type must be one of {valid_types}. Found {attachment_type}"
44+
)

labelbox/schema/data_row.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,8 @@ class DataRow(DbObject, Updateable, BulkDeletable):
4242
labels = Relationship.ToMany("Label", True)
4343
attachments = Relationship.ToMany("AssetAttachment", False, "attachments")
4444

45-
supported_meta_types = supported_attachment_types = {
46-
attachment_type.value
47-
for attachment_type in AssetAttachment.AttachmentType
48-
}
45+
supported_meta_types = supported_attachment_types = set(
46+
AssetAttachment.AttachmentType.__members__)
4947

5048
def __init__(self, *args, **kwargs):
5149
super().__init__(*args, **kwargs)
@@ -103,11 +101,7 @@ def create_attachment(self, attachment_type, attachment_value):
103101
Raises:
104102
ValueError: asset_type must be one of the supported types.
105103
"""
106-
107-
if attachment_type not in self.supported_attachment_types:
108-
raise ValueError(
109-
f"meta_type must be one of {self.supported_attachment_types}. Found {attachment_type}"
110-
)
104+
AssetAttachment.validate_attachment_type(attachment_type)
111105

112106
attachment_type_param = "type"
113107
attachment_value_param = "value"

labelbox/schema/dataset.py

Lines changed: 79 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from labelbox import utils
12
import os
23
import json
34
import logging
@@ -81,13 +82,17 @@ def create_data_rows(self, items):
8182
is uploaded to Labelbox and a DataRow referencing it is created.
8283
8384
If an item is a `dict`, then it could support one of the two following structures
84-
1. For static imagery, video, and text it should map `DataRow` fields (or their names) to values.
85-
At the minimum an `item` passed as a `dict` must contain a `DataRow.row_data` key and value.
85+
1. For static imagery, video, and text it should map `DataRow` field names to values.
86+
At the minimum an `item` passed as a `dict` must contain a `row_data` key and value.
87+
If the value for row_data is a local file path and the path exists,
88+
then the local file will be uploaded to labelbox.
89+
8690
2. For tiled imagery the dict must match the import structure specified in the link below
8791
https://docs.labelbox.com/data-model/en/index-en#tiled-imagery-import
8892
8993
>>> dataset.create_data_rows([
9094
>>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"},
95+
>>> {DataRow.row_data:"/path/to/file1.jpg"},
9196
>>> "path/to/file2.jpg",
9297
>>> {"tileLayerUrl" : "http://", ...}
9398
>>> ])
@@ -115,64 +120,105 @@ def create_data_rows(self, items):
115120
DataRow = Entity.DataRow
116121

117122
def upload_if_necessary(item):
118-
if isinstance(item, str):
119-
item_url = self.client.upload_file(item)
120-
# Convert item from str into a dict so it gets processed
121-
# like all other dicts.
122-
item = {DataRow.row_data: item_url, DataRow.external_id: item}
123+
row_data = item['row_data']
124+
if os.path.exists(row_data):
125+
item_url = self.client.upload_file(item['row_data'])
126+
item = {
127+
"row_data": item_url,
128+
"external_id": item.get('external_id', item['row_data']),
129+
"attachments": item.get('attachments', [])
130+
}
123131
return item
124132

125-
with ThreadPoolExecutor(file_upload_thread_count) as executor:
126-
futures = [
127-
executor.submit(upload_if_necessary, item) for item in items
128-
]
129-
items = [future.result() for future in as_completed(futures)]
130-
131-
def convert_item(item):
132-
# Don't make any changes to tms data
133-
if "tileLayerUrl" in item:
134-
return item
135-
# Convert string names to fields.
136-
item = {
137-
key if isinstance(key, Field) else DataRow.field(key): value
138-
for key, value in item.items()
139-
}
133+
def validate_attachments(item):
134+
attachments = item.get('attachments')
135+
if attachments:
136+
if isinstance(attachments, list):
137+
for attachment in attachments:
138+
Entity.AssetAttachment.validate_attachment_json(
139+
attachment)
140+
else:
141+
raise ValueError(
142+
f"Attachments must be a list. Found {type(attachments)}"
143+
)
144+
return attachments
145+
146+
def format_row(item):
147+
# Formats user input into a consistent dict structure
148+
if isinstance(item, dict):
149+
# Convert fields to strings
150+
item = {
151+
key.name if isinstance(key, Field) else key: value
152+
for key, value in item.items()
153+
}
154+
elif isinstance(item, str):
155+
# The main advantage of using a string over a dict is that the user is specifying
156+
# that the file should exist locally.
157+
# That info is lost after this section so we should check for it here.
158+
if not os.path.exists(item):
159+
raise ValueError(f"Filepath {item} does not exist.")
160+
item = {"row_data": item, "external_id": item}
161+
return item
140162

141-
if DataRow.row_data not in item:
163+
def validate_keys(item):
164+
if 'row_data' not in item:
142165
raise InvalidQueryError(
143-
"DataRow.row_data missing when creating DataRow.")
166+
"`row_data` missing when creating DataRow.")
144167

145-
invalid_keys = set(item) - set(DataRow.fields())
168+
invalid_keys = set(item) - {
169+
*{f.name for f in DataRow.fields()}, 'attachments'
170+
}
146171
if invalid_keys:
147172
raise InvalidAttributeError(DataRow, invalid_keys)
173+
return item
174+
175+
def convert_item(item):
176+
# Don't make any changes to tms data
177+
if "tileLayerUrl" in item:
178+
validate_attachments(item)
179+
return item
180+
# Convert all payload variations into the same dict format
181+
item = format_row(item)
182+
# Make sure required keys exist (and there are no extra keys)
183+
validate_keys(item)
184+
# Make sure attachments are valid
185+
validate_attachments(item)
186+
# Upload any local file paths
187+
item = upload_if_necessary(item)
148188

149-
# Item is valid, convert it to a dict {graphql_field_name: value}
150-
# Need to change the name of DataRow.row_data to "data"
151189
return {
152-
"data" if key == DataRow.row_data else key.graphql_name: value
190+
"data" if key == "row_data" else utils.camel_case(key): value
153191
for key, value in item.items()
154192
}
155193

194+
if not isinstance(items, list):
195+
raise ValueError(
196+
f"Must pass a list to create_data_rows. Found {type(items)}")
197+
198+
with ThreadPoolExecutor(file_upload_thread_count) as executor:
199+
futures = [executor.submit(convert_item, item) for item in items]
200+
items = [future.result() for future in as_completed(futures)]
201+
156202
# Prepare and upload the desciptor file
157-
items = [convert_item(item) for item in items]
158203
data = json.dumps(items)
159204
descriptor_url = self.client.upload_data(data)
160-
161205
# Create data source
162206
dataset_param = "datasetId"
163207
url_param = "jsonUrl"
164208
query_str = """mutation AppendRowsToDatasetPyApi($%s: ID!, $%s: String!){
165209
appendRowsToDataset(data:{datasetId: $%s, jsonFileUrl: $%s}
166-
){ taskId accepted } } """ % (dataset_param, url_param,
167-
dataset_param, url_param)
210+
){ taskId accepted errorMessage } } """ % (dataset_param, url_param,
211+
dataset_param, url_param)
212+
168213
res = self.client.execute(query_str, {
169214
dataset_param: self.uid,
170215
url_param: descriptor_url
171216
})
172217
res = res["appendRowsToDataset"]
173218
if not res["accepted"]:
219+
msg = res['errorMessage']
174220
raise InvalidQueryError(
175-
"Server did not accept DataRow creation request")
221+
f"Server did not accept DataRow creation request. {msg}")
176222

177223
# Fetch and return the task.
178224
task_id = res["taskId"]

tests/integration/test_data_rows.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,19 @@ def test_data_row_bulk_creation(dataset, rand_gen, image_url):
4444
task.wait_till_done()
4545
assert task.status == "COMPLETE"
4646

47+
task = dataset.create_data_rows([{
48+
"row_data": fp.name,
49+
'external_id': 'some_name'
50+
}])
51+
task.wait_till_done()
52+
assert task.status == "COMPLETE"
53+
54+
task = dataset.create_data_rows([{"row_data": fp.name}])
55+
task.wait_till_done()
56+
assert task.status == "COMPLETE"
57+
4758
data_rows = list(dataset.data_rows())
48-
assert len(data_rows) == 3
59+
assert len(data_rows) == 5
4960
url = ({data_row.row_data for data_row in data_rows} - {image_url}).pop()
5061
assert requests.get(url).content == data
5162

@@ -64,7 +75,7 @@ def test_data_row_large_bulk_creation(dataset, image_url):
6475
assert task.status == "IN_PROGRESS"
6576
task.wait_till_done(timeout_seconds=120)
6677
assert task.status == "COMPLETE"
67-
data_rows = len(list(dataset.data_rows())) == 5003
78+
assert len(list(dataset.data_rows())) == 1000
6879

6980

7081
@pytest.mark.xfail(reason="DataRow.dataset() relationship not set")
@@ -168,3 +179,34 @@ def test_data_row_iteration(dataset, image_url) -> None:
168179
])
169180
task.wait_till_done()
170181
assert next(dataset.data_rows())
182+
183+
184+
def test_data_row_attachments(dataset, image_url):
185+
attachments = [("IMAGE", image_url), ("TEXT", "test-text"),
186+
("IMAGE_OVERLAY", image_url), ("HTML", image_url)]
187+
task = dataset.create_data_rows([{
188+
"row_data": image_url,
189+
"external_id": "test-id",
190+
"attachments": [{
191+
"type": attachment_type,
192+
"value": attachment_value
193+
}]
194+
} for attachment_type, attachment_value in attachments])
195+
196+
task.wait_till_done()
197+
assert task.status == "COMPLETE"
198+
data_rows = list(dataset.data_rows())
199+
assert len(data_rows) == len(attachments)
200+
for data_row in data_rows:
201+
assert len(list(data_row.attachments())) == 1
202+
assert data_row.external_id == "test-id"
203+
204+
with pytest.raises(ValueError) as exc:
205+
task = dataset.create_data_rows([{
206+
"row_data": image_url,
207+
"external_id": "test-id",
208+
"attachments": [{
209+
"type": "INVALID",
210+
"value": "123"
211+
}]
212+
}])

tests/integration/test_dataset.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ def test_get_data_row_for_external_id(dataset, rand_gen, image_url):
7575
dataset.create_data_row(row_data=image_url, external_id=external_id)
7676
assert len(dataset.data_rows_for_external_id(external_id)) == 2
7777

78+
task = dataset.create_data_rows(
79+
[dict(row_data=image_url, external_id=external_id)])
80+
task.wait_till_done()
81+
assert len(dataset.data_rows_for_external_id(external_id)) == 3
82+
7883

7984
def test_upload_video_file(dataset, sample_video: str) -> None:
8085
"""

0 commit comments

Comments
 (0)