Skip to content

Commit 54d1527

Browse files
authored
Merge pull request #307 from Labelbox/ms/metadata-tests
metadata updates
2 parents f154a53 + 026e4e8 commit 54d1527

File tree

5 files changed

+89
-45
lines changed

5 files changed

+89
-45
lines changed

examples/basics/data_row_metadata.ipynb

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -169,18 +169,6 @@
169169
"train_field = mdo.reserved_by_name[\"split\"][\"train\"]"
170170
]
171171
},
172-
{
173-
"cell_type": "code",
174-
"execution_count": null,
175-
"id": "uOS2QlHmqAIs",
176-
"metadata": {
177-
"id": "uOS2QlHmqAIs"
178-
},
179-
"outputs": [],
180-
"source": [
181-
"split_field.options"
182-
]
183-
},
184172
{
185173
"cell_type": "code",
186174
"execution_count": null,
@@ -567,7 +555,6 @@
567555
"fields = []\n",
568556
"# iterate through the fields you want to delete\n",
569557
"for field in md.fields:\n",
570-
" schema = mdo.field_by_index[field.schema_id]\n",
571558
" fields.append(field.schema_id)\n",
572559
"\n",
573560
"deletes = DeleteDataRowMetadata(\n",
@@ -649,4 +636,4 @@
649636
},
650637
"nbformat": 4,
651638
"nbformat_minor": 5
652-
}
639+
}

labelbox/client.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# type: ignore
22
from datetime import datetime, timezone
33
import json
4+
from typing import List, Dict
5+
from collections import defaultdict
46

57
import logging
68
import mimetypes
@@ -658,3 +660,29 @@ def create_model(self, name, ontology_id):
658660
"ontologyId": ontology_id
659661
})
660662
return Model(self, result['createModel'])
663+
664+
def get_data_row_ids_for_external_ids(
665+
self, external_ids: List[str]) -> Dict[str, List[str]]:
666+
"""
667+
Returns a list of data row ids for a list of external ids.
668+
There is a max of 1500 items returned at a time.
669+
670+
Args:
671+
external_ids: List of external ids to fetch data row ids for
672+
673+
Returns:
674+
A dict of external ids as keys and values as a list of data row ids that correspond to that external id.
675+
"""
676+
query_str = """query externalIdsToDataRowIdsPyApi($externalId_in: [String!]!){
677+
externalIdsToDataRowIds(externalId_in: $externalId_in) { dataRowId externalId }
678+
}
679+
"""
680+
max_n_per_request = 100
681+
result = defaultdict(list)
682+
for i in range(0, len(external_ids), max_n_per_request):
683+
for row in self.execute(
684+
query_str,
685+
{'externalId_in': external_ids[i:i + max_n_per_request]
686+
})['externalIdsToDataRowIds']:
687+
result[row['externalId']].append(row['dataRowId'])
688+
return result

labelbox/schema/data_row_metadata.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# type: ignore
2-
import datetime
2+
from datetime import datetime
33
import warnings
44
from copy import deepcopy
55
from enum import Enum
@@ -42,7 +42,7 @@ def id(self):
4242

4343
# Constraints for metadata values
4444
Embedding: Type[List[float]] = conlist(float, min_items=128, max_items=128)
45-
DateTime: Type[datetime.datetime] = datetime.datetime # must be in UTC
45+
DateTime: Type[datetime] = datetime # must be in UTC
4646
String: Type[str] = constr(max_length=500)
4747
OptionId: Type[SchemaId] = SchemaId # enum option
4848
Number: Type[float] = float
@@ -62,7 +62,7 @@ class Config:
6262
# Metadata base class
6363
class DataRowMetadataField(_CamelCaseMixin):
6464
schema_id: SchemaId
65-
value: Any
65+
value: Union[DataRowMetadataValue, _DataRowMetadataValuePrimitives]
6666

6767

6868
class DataRowMetadata(_CamelCaseMixin):
@@ -489,7 +489,6 @@ def _validate_parse_number(
489489

490490
def _validate_parse_datetime(
491491
field: DataRowMetadataField) -> List[Dict[str, Union[SchemaId, str]]]:
492-
# TODO: better validate tzinfo
493492
return [{
494493
"schemaId": field.schema_id,
495494
"value": field.value.isoformat() + "Z", # needs to be UTC

tests/integration/test_data_row_metadata.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from labelbox.schema.data_row_metadata import DataRowMetadataField, DataRowMetadata, DeleteDataRowMetadata, \
88
DataRowMetadataOntology
99

10+
INVALID_SCHEMA_ID = "1" * 25
1011
FAKE_SCHEMA_ID = "0" * 25
1112
FAKE_DATAROW_ID = "D" * 25
1213
SPLIT_SCHEMA_ID = "cko8sbczn0002h2dkdaxb5kal"
@@ -15,6 +16,7 @@
1516
EMBEDDING_SCHEMA_ID = "ckpyije740000yxdk81pbgjdc"
1617
TEXT_SCHEMA_ID = "cko8s9r5v0001h2dk9elqdidh"
1718
CAPTURE_DT_SCHEMA_ID = "cko8sdzv70006h2dk8jg64zvb"
19+
PRE_COMPUTED_EMBEDDINGS_ID = 'ckrzang79000008l6hb5s6za1'
1820

1921
FAKE_NUMBER_FIELD = {
2022
"id": FAKE_SCHEMA_ID,
@@ -40,24 +42,13 @@ def big_dataset(dataset: Dataset, image_url):
4042
"row_data": image_url,
4143
"external_id": "my-image"
4244
},
43-
] * 250)
45+
] * 5)
4446
task.wait_till_done()
4547

4648
yield dataset
4749
dataset.delete()
4850

4951

50-
def wait_for_embeddings_svc(data_row_ids, mdo):
51-
for idx in range(5):
52-
if all([
53-
len(metadata.fields)
54-
for metadata in mdo.bulk_export(data_row_ids)
55-
]):
56-
return
57-
time.sleep((idx + 1)**2)
58-
raise Exception("Embedding svc failed to update metadata.")
59-
60-
6152
def make_metadata(dr_id) -> DataRowMetadata:
6253
embeddings = [0.0] * 128
6354
msg = "A message"
@@ -97,18 +88,20 @@ def test_get_datarow_metadata_ontology(mdo):
9788

9889

9990
def test_bulk_upsert_datarow_metadata(datarow, mdo: DataRowMetadataOntology):
100-
wait_for_embeddings_svc([datarow.uid], mdo)
10191
metadata = make_metadata(datarow.uid)
10292
mdo.bulk_upsert([metadata])
103-
assert len(mdo.bulk_export([datarow.uid]))
104-
assert len(mdo.bulk_export([datarow.uid])[0].fields) == 5
93+
exported = mdo.bulk_export([datarow.uid])
94+
assert len(exported)
95+
assert len([
96+
field for field in exported[0].fields
97+
if field.schema_id != PRE_COMPUTED_EMBEDDINGS_ID
98+
]) == 4
10599

106100

107101
@pytest.mark.slow
108102
def test_large_bulk_upsert_datarow_metadata(big_dataset, mdo):
109103
metadata = []
110104
data_row_ids = [dr.uid for dr in big_dataset.data_rows()]
111-
wait_for_embeddings_svc(data_row_ids, mdo)
112105
for data_row_id in data_row_ids:
113106
metadata.append(make_metadata(data_row_id))
114107
errors = mdo.bulk_upsert(metadata)
@@ -119,14 +112,16 @@ def test_large_bulk_upsert_datarow_metadata(big_dataset, mdo):
119112
for metadata in mdo.bulk_export(data_row_ids)
120113
}
121114
for data_row_id in data_row_ids:
122-
assert len(metadata_lookup.get(data_row_id).fields)
115+
assert len([
116+
f for f in metadata_lookup.get(data_row_id).fields
117+
if f.schema_id != PRE_COMPUTED_EMBEDDINGS_ID
118+
]), metadata_lookup.get(data_row_id).fields
123119

124120

125121
def test_bulk_delete_datarow_metadata(datarow, mdo):
126122
"""test bulk deletes for all fields"""
127123
metadata = make_metadata(datarow.uid)
128124
mdo.bulk_upsert([metadata])
129-
130125
assert len(mdo.bulk_export([datarow.uid])[0].fields)
131126
upload_ids = [m.schema_id for m in metadata.fields[:-2]]
132127
mdo.bulk_delete(
@@ -155,7 +150,6 @@ def test_bulk_partial_delete_datarow_metadata(datarow, mdo):
155150
def test_large_bulk_delete_datarow_metadata(big_dataset, mdo):
156151
metadata = []
157152
data_row_ids = [dr.uid for dr in big_dataset.data_rows()]
158-
wait_for_embeddings_svc(data_row_ids, mdo)
159153
for data_row_id in data_row_ids:
160154
metadata.append(
161155
DataRowMetadata(data_row_id=data_row_id,
@@ -181,29 +175,33 @@ def test_large_bulk_delete_datarow_metadata(big_dataset, mdo):
181175
errors = mdo.bulk_delete(deletes)
182176
assert len(errors) == 0
183177
for data_row_id in data_row_ids:
184-
# 2 remaining because we delete the user provided embedding but text and labelbox generated embeddings still exist
185-
fields = mdo.bulk_export([data_row_id])[0].fields
186-
assert len(fields) == 2
178+
fields = [
179+
f for f in mdo.bulk_export([data_row_id])[0].fields
180+
if f.schema_id != PRE_COMPUTED_EMBEDDINGS_ID
181+
]
182+
assert len(fields) == 1, fields
187183
assert EMBEDDING_SCHEMA_ID not in [field.schema_id for field in fields]
188184

189185

190186
def test_bulk_delete_datarow_enum_metadata(datarow: DataRow, mdo):
191187
"""test bulk deletes for non non fields"""
192-
wait_for_embeddings_svc([datarow.uid], mdo)
193188
metadata = make_metadata(datarow.uid)
194189
metadata.fields = [
195190
m for m in metadata.fields if m.schema_id == SPLIT_SCHEMA_ID
196191
]
197192
mdo.bulk_upsert([metadata])
198193

199-
assert len(mdo.bulk_export([datarow.uid])[0].fields) == len(
194+
exported = mdo.bulk_export([datarow.uid])[0].fields
195+
assert len(exported) == len(
200196
set([x.schema_id for x in metadata.fields] +
201-
[x.schema_id for x in mdo.bulk_export([datarow.uid])[0].fields]))
197+
[x.schema_id for x in exported]))
202198

203199
mdo.bulk_delete([
204200
DeleteDataRowMetadata(data_row_id=datarow.uid, fields=[SPLIT_SCHEMA_ID])
205201
])
206-
assert len(mdo.bulk_export([datarow.uid])[0].fields) == 1
202+
exported = mdo.bulk_export([datarow.uid])[0].fields
203+
assert len(
204+
[f for f in exported if f.schema_id != PRE_COMPUTED_EMBEDDINGS_ID]) == 0
207205

208206

209207
def test_raise_enum_upsert_schema_error(datarow, mdo):
@@ -223,7 +221,7 @@ def test_upsert_non_existent_schema_id(datarow, mdo):
223221
metadata = DataRowMetadata(data_row_id=datarow.uid,
224222
fields=[
225223
DataRowMetadataField(
226-
schema_id=FAKE_SCHEMA_ID,
224+
schema_id=INVALID_SCHEMA_ID,
227225
value="message"),
228226
])
229227
with pytest.raises(ValueError):

tests/integration/test_data_rows.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from tempfile import NamedTemporaryFile
2+
import uuid
3+
import time
24

35
import pytest
46
import requests
@@ -11,6 +13,36 @@ def test_get_data_row(datarow, client):
1113
assert client.get_data_row(datarow.uid)
1214

1315

16+
def test_lookup_data_rows(client, dataset):
17+
uid = str(uuid.uuid4())
18+
# 1 external id : 1 uid
19+
dr = dataset.create_data_row(row_data="123", external_id=uid)
20+
lookup = client.get_data_row_ids_for_external_ids([uid])
21+
assert len(lookup) == 1
22+
assert lookup[uid][0] == dr.uid
23+
# 2 external ids : 1 uid
24+
uid2 = str(uuid.uuid4())
25+
dr2 = dataset.create_data_row(row_data="123", external_id=uid2)
26+
lookup = client.get_data_row_ids_for_external_ids([uid, uid2])
27+
assert len(lookup) == 2
28+
assert all([len(x) == 1 for x in lookup.values()])
29+
assert lookup[uid][0] == dr.uid
30+
assert lookup[uid2][0] == dr2.uid
31+
#1 external id : 2 uid
32+
dr3 = dataset.create_data_row(row_data="123", external_id=uid2)
33+
lookup = client.get_data_row_ids_for_external_ids([uid2])
34+
assert len(lookup) == 1
35+
assert len(lookup[uid2]) == 2
36+
assert lookup[uid2][0] == dr2.uid
37+
assert lookup[uid2][1] == dr3.uid
38+
# Empty args
39+
lookup = client.get_data_row_ids_for_external_ids([])
40+
assert len(lookup) == 0
41+
# Non matching
42+
lookup = client.get_data_row_ids_for_external_ids([str(uuid.uuid4())])
43+
assert len(lookup) == 0
44+
45+
1446
def test_data_row_bulk_creation(dataset, rand_gen, image_url):
1547
client = dataset.client
1648
assert len(list(dataset.data_rows())) == 0

0 commit comments

Comments
 (0)