Skip to content

Commit 5096fc1

Browse files
author
Matt Sokoloff
committed
clean up
1 parent 4416109 commit 5096fc1

File tree

5 files changed

+121
-65
lines changed

5 files changed

+121
-65
lines changed

labelbox/schema/asset_attachment.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from enum import Enum
2+
from typing import Dict
23

34
from labelbox.orm.db_object import DbObject
45
from labelbox.orm.model import Field
@@ -24,3 +25,20 @@ class AttachmentType(Enum):
2425

2526
attachment_type = Field.String("attachment_type", "type")
2627
attachment_value = Field.String("attachment_value", "value")
28+
29+
@classmethod
30+
def validate_attachment_json(cls, attachment_json: Dict[str, str]) -> None:
31+
for required_key in ['type', 'value']:
32+
if required_key not in attachment_json:
33+
raise ValueError(
34+
f"Must provide a `{required_key}` key for each attachment. Found {attachment_json}."
35+
)
36+
cls.validate_attachment_type(attachment_json['type'])
37+
38+
@classmethod
39+
def validate_attachment_type(cls, attachment_type: str) -> None:
40+
valid_types = set(cls.AttachmentType.__members__)
41+
if attachment_type not in valid_types:
42+
raise ValueError(
43+
f"meta_type must be one of {valid_types}. Found {attachment_type}"
44+
)

labelbox/schema/data_row.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,8 @@ class DataRow(DbObject, Updateable, BulkDeletable):
4242
labels = Relationship.ToMany("Label", True)
4343
attachments = Relationship.ToMany("AssetAttachment", False, "attachments")
4444

45-
supported_meta_types = supported_attachment_types = {
46-
attachment_type.value
47-
for attachment_type in AssetAttachment.AttachmentType
48-
}
45+
supported_meta_types = supported_attachment_types = set(
46+
AssetAttachment.AttachmentType.__members__)
4947

5048
def __init__(self, *args, **kwargs):
5149
super().__init__(*args, **kwargs)
@@ -103,11 +101,7 @@ def create_attachment(self, attachment_type, attachment_value):
103101
Raises:
104102
ValueError: asset_type must be one of the supported types.
105103
"""
106-
107-
if attachment_type not in self.supported_attachment_types:
108-
raise ValueError(
109-
f"meta_type must be one of {self.supported_attachment_types}. Found {attachment_type}"
110-
)
104+
AssetAttachment.validate_attachment_type(attachment_type)
111105

112106
attachment_type_param = "type"
113107
attachment_value_param = "value"

labelbox/schema/dataset.py

Lines changed: 51 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -75,27 +75,24 @@ def create_data_row(self, **kwargs):
7575
return self.client._create(DataRow, kwargs)
7676

7777
def create_data_rows(self, items):
78-
79-
## NOTE TODOS
80-
"""
81-
Add attachments (works with all types)
82-
Add external ids to bulk imports
83-
improved error handling (why job was accepted or not)
84-
"""
8578
""" Creates multiple DataRow objects based on the given `items`.
8679
8780
Each element in `items` can be either a `str` or a `dict`. If
8881
it is a `str`, then it is interpreted as a local file path. The file
8982
is uploaded to Labelbox and a DataRow referencing it is created.
9083
9184
If an item is a `dict`, then it could support one of the two following structures
92-
1. For static imagery, video, and text it should map `DataRow` fields (or their names) to values.
93-
At the minimum an `item` passed as a `dict` must contain a `DataRow.row_data` key and value.
85+
1. For static imagery, video, and text it should map `DataRow` field names to values.
86+
At the minimum an `item` passed as a `dict` must contain a `row_data` key and value.
87+
If the value for row_data is a local file path and the path exists,
88+
then the local file will be uploaded to labelbox.
89+
9490
2. For tiled imagery the dict must match the import structure specified in the link below
9591
https://docs.labelbox.com/data-model/en/index-en#tiled-imagery-import
9692
9793
>>> dataset.create_data_rows([
9894
>>> {DataRow.row_data:"http://my_site.com/photos/img_01.jpg"},
95+
>>> {DataRow.row_data:"/path/to/file1.jpg"},
9996
>>> "path/to/file2.jpg",
10097
>>> {"tileLayerUrl" : "http://", ...}
10198
>>> ])
@@ -123,72 +120,72 @@ def create_data_rows(self, items):
123120
DataRow = Entity.DataRow
124121

125122
def upload_if_necessary(item):
126-
if isinstance(item, str):
127-
item_url = self.client.upload_file(item)
128-
item = {DataRow.row_data: item_url, DataRow.external_id: item}
129-
elif isinstance(item, dict):
130-
if os.path.exists(item['row_data']):
131-
item_url = self.client.upload_file(item['row_data'])
132-
parts = {
133-
DataRow.row_data:
134-
item_url,
135-
DataRow.external_id:
136-
item.get('external_id', item['row_data'])
137-
}
138-
attachments = item.get('attachments')
139-
if attachments:
140-
item = {**parts, **{'attachments': attachments}}
141-
else:
142-
item = parts
123+
row_data = item['row_data']
124+
if os.path.exists(row_data):
125+
item_url = self.client.upload_file(item['row_data'])
126+
item = {
127+
"row_data": item_url,
128+
"external_id": item.get('external_id', item['row_data']),
129+
"attachments": item.get('attachments', [])
130+
}
143131
return item
144132

145133
def validate_attachments(item):
146134
attachments = item.get('attachments')
147135
if attachments:
148136
if isinstance(attachments, list):
149137
for attachment in attachments:
150-
for required_key in ['type', 'value']:
151-
if required_key not in attachment:
152-
raise ValueError(
153-
f"Must provide a `{required_key}` key for each attachment. Found {attachment}."
154-
)
155-
attachment_type = attachment.get('type')
156-
if attachment_type not in DataRow.supported_attachment_types:
157-
raise ValueError(
158-
f"meta_type must be one of {DataRow.supported_attachment_types}. Found {attachment_type}"
159-
)
138+
Entity.AssetAttachment.validate_attachment_json(
139+
attachment)
160140
else:
161141
raise ValueError(
162142
f"Attachments must be a list. Found {type(attachments)}"
163143
)
164144
return attachments
165145

166-
def convert_item(item):
167-
# Don't make any changes to tms data
168-
validate_attachments(item)
169-
if "tileLayerUrl" in item:
170-
return item
171-
172-
item = upload_if_necessary(item)
173-
# Convert fields to string names.
174-
item = {
175-
key.name if isinstance(key, Field) else key: value
176-
for key, value in item.items()
177-
}
146+
def format_row(item):
147+
# Formats user input into a consistent dict structure
148+
if isinstance(item, dict):
149+
# Convert fields to strings
150+
item = {
151+
key.name if isinstance(key, Field) else key: value
152+
for key, value in item.items()
153+
}
154+
elif isinstance(item, str):
155+
# The main advantage of using a string over a dict is that the user is specifying
156+
# that the file should exist locally.
157+
# That info is lost after this section so we should check for it here.
158+
if not os.path.exists(item):
159+
raise ValueError(f"Filepath {item} does not exist.")
160+
item = {"row_data": item, "external_id": item}
161+
return item
178162

163+
def validate_keys(item):
179164
if 'row_data' not in item:
180165
raise InvalidQueryError(
181166
"`row_data` missing when creating DataRow.")
182167

183-
# TODO: This is technically breaking. but also idt anyone is using the other fields.
184168
invalid_keys = set(item) - {
185-
'row_data', 'external_id', 'attachments'
169+
*{f.name for f in DataRow.fields()}, 'attachments'
186170
}
187171
if invalid_keys:
188172
raise InvalidAttributeError(DataRow, invalid_keys)
173+
return item
174+
175+
def convert_item(item):
176+
# Don't make any changes to tms data
177+
if "tileLayerUrl" in item:
178+
validate_attachments(item)
179+
return item
180+
# Convert all payload variations into the same dict format
181+
item = format_row(item)
182+
# Make sure required keys exist (and there are no extra keys)
183+
validate_keys(item)
184+
# Make sure attachments are valid
185+
validate_attachments(item)
186+
# Upload any local file paths
187+
item = upload_if_necessary(item)
189188

190-
# Item is valid, convert it to a dict {graphql_field_name: value}
191-
# Need to change the name of DataRow.row_data to "data"
192189
return {
193190
"data" if key == "row_data" else utils.camel_case(key): value
194191
for key, value in item.items()
@@ -207,7 +204,8 @@ def convert_item(item):
207204
query_str = """mutation AppendRowsToDatasetPyApi($%s: ID!, $%s: String!){
208205
appendRowsToDataset(data:{datasetId: $%s, jsonFileUrl: $%s}
209206
){ taskId accepted errorMessage } } """ % (dataset_param, url_param,
210-
dataset_param, url_param)
207+
dataset_param, url_param)
208+
211209
res = self.client.execute(query_str, {
212210
dataset_param: self.uid,
213211
url_param: descriptor_url

tests/integration/test_data_rows.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,19 @@ def test_data_row_bulk_creation(dataset, rand_gen, image_url):
4444
task.wait_till_done()
4545
assert task.status == "COMPLETE"
4646

47+
task = dataset.create_data_rows([{
48+
"row_data": fp.name,
49+
'external_id': 'some_name'
50+
}])
51+
task.wait_till_done()
52+
assert task.status == "COMPLETE"
53+
54+
task = dataset.create_data_rows([{"row_data": fp.name}])
55+
task.wait_till_done()
56+
assert task.status == "COMPLETE"
57+
4758
data_rows = list(dataset.data_rows())
48-
assert len(data_rows) == 3
59+
assert len(data_rows) == 5
4960
url = ({data_row.row_data for data_row in data_rows} - {image_url}).pop()
5061
assert requests.get(url).content == data
5162

@@ -64,7 +75,7 @@ def test_data_row_large_bulk_creation(dataset, image_url):
6475
assert task.status == "IN_PROGRESS"
6576
task.wait_till_done(timeout_seconds=120)
6677
assert task.status == "COMPLETE"
67-
data_rows = len(list(dataset.data_rows())) == 5003
78+
assert len(list(dataset.data_rows())) == 1000
6879

6980

7081
@pytest.mark.xfail(reason="DataRow.dataset() relationship not set")
@@ -168,3 +179,34 @@ def test_data_row_iteration(dataset, image_url) -> None:
168179
])
169180
task.wait_till_done()
170181
assert next(dataset.data_rows())
182+
183+
184+
def test_data_row_attachments(dataset, image_url):
185+
attachments = [("IMAGE", image_url), ("TEXT", "test-text"),
186+
("IMAGE_OVERLAY", image_url), ("HTML", image_url)]
187+
task = dataset.create_data_rows([{
188+
"row_data": image_url,
189+
"external_id": "test-id",
190+
"attachments": [{
191+
"type": attachment_type,
192+
"value": attachment_value
193+
}]
194+
} for attachment_type, attachment_value in attachments])
195+
196+
task.wait_till_done()
197+
assert task.status == "COMPLETE"
198+
data_rows = list(dataset.data_rows())
199+
assert len(data_rows) == len(attachments)
200+
for data_row in data_rows:
201+
assert len(list(data_row.attachments())) == 1
202+
assert data_row.external_id == "test-id"
203+
204+
with pytest.raises(ValueError) as exc:
205+
task = dataset.create_data_rows([{
206+
"row_data": image_url,
207+
"external_id": "test-id",
208+
"attachments": [{
209+
"type": "INVALID",
210+
"value": "123"
211+
}]
212+
}])

tests/integration/test_dataset.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ def test_get_data_row_for_external_id(dataset, rand_gen, image_url):
7575
dataset.create_data_row(row_data=image_url, external_id=external_id)
7676
assert len(dataset.data_rows_for_external_id(external_id)) == 2
7777

78+
task = dataset.create_data_rows(
79+
[dict(row_data=image_url, external_id=external_id)])
80+
task.wait_until_done()
81+
assert len(dataset.data_rows_for_external_id(external_id)) == 3
82+
7883

7984
def test_upload_video_file(dataset, sample_video: str) -> None:
8085
"""
@@ -104,4 +109,3 @@ def test_data_row_export(dataset, image_url):
104109
result = list(dataset.export_data_rows())
105110
assert len(result) == n_data_rows
106111
assert set(result) == ids
107-

0 commit comments

Comments
 (0)