Skip to content

Commit 7828066

Browse files
author
Diego Ardila
committed
Added some tests and they pass with updated backend
1 parent dd61dbe commit 7828066

File tree

2 files changed

+75
-18
lines changed

2 files changed

+75
-18
lines changed

nucleus/__init__.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -387,28 +387,33 @@ def populate_dataset(
387387

388388
agg_response = UploadResponse(json={DATASET_ID_KEY: dataset_id})
389389

390-
tqdm_local_batches = self.tqdm_bar(local_batches)
391-
392-
tqdm_remote_batches = self.tqdm_bar(remote_batches)
393-
394390
async_responses: List[Any] = []
395391

396-
for batch in tqdm_local_batches:
397-
payload = construct_append_payload(batch, update)
398-
responses = self._process_append_requests_local(
399-
dataset_id, payload, update
392+
if local_batches:
393+
tqdm_local_batches = self.tqdm_bar(
394+
local_batches, desc="Local file batches"
400395
)
401-
async_responses.extend(responses)
402-
403-
for batch in tqdm_remote_batches:
404-
payload = construct_append_payload(batch, update)
405-
responses = self._process_append_requests(
406-
dataset_id=dataset_id,
407-
payload=payload,
408-
update=update,
409-
batch_size=batch_size,
396+
397+
for batch in tqdm_local_batches:
398+
payload = construct_append_payload(batch, update)
399+
responses = self._process_append_requests_local(
400+
dataset_id, payload, update
401+
)
402+
async_responses.extend(responses)
403+
404+
if remote_batches:
405+
tqdm_remote_batches = self.tqdm_bar(
406+
remote_batches, desc="Remote file batches"
410407
)
411-
async_responses.extend(responses)
408+
for batch in tqdm_remote_batches:
409+
payload = construct_append_payload(batch, update)
410+
responses = self._process_append_requests(
411+
dataset_id=dataset_id,
412+
payload=payload,
413+
update=update,
414+
batch_size=batch_size,
415+
)
416+
async_responses.extend(responses)
412417

413418
for response in async_responses:
414419
agg_response.update_response(response)
@@ -423,6 +428,8 @@ def _process_append_requests_local(
423428
local_batch_size: int = 10,
424429
):
425430
def get_files(batch):
431+
for item in batch:
432+
item[UPDATE_KEY] = update
426433
request_payload = [
427434
(
428435
ITEMS_KEY,

tests/test_dataset.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,56 @@ def test_dataset_create_and_delete(CLIENT):
113113
assert response == {"message": "Beginning dataset deletion..."}
114114

115115

116+
def test_dataset_update_metadata_local(dataset):
117+
dataset.append(
118+
[
119+
DatasetItem(
120+
image_location=LOCAL_FILENAME,
121+
metadata={"snake_field": 0},
122+
reference_id="test_image",
123+
)
124+
]
125+
)
126+
dataset.append(
127+
[
128+
DatasetItem(
129+
image_location=LOCAL_FILENAME,
130+
metadata={"snake_field": 1},
131+
reference_id="test_image",
132+
)
133+
],
134+
update=True,
135+
)
136+
resulting_item = dataset.iloc(0)["item"]
137+
print(resulting_item)
138+
assert resulting_item.metadata["snake_field"] == 1
139+
140+
141+
def test_dataset_update_metadata(dataset):
142+
dataset.append(
143+
[
144+
DatasetItem(
145+
image_location=TEST_IMG_URLS[0],
146+
metadata={"snake_field": 0},
147+
reference_id="test_image",
148+
)
149+
]
150+
)
151+
dataset.append(
152+
[
153+
DatasetItem(
154+
image_location=TEST_IMG_URLS[0],
155+
metadata={"snake_field": 1},
156+
reference_id="test_image",
157+
)
158+
],
159+
update=True,
160+
)
161+
resulting_item = dataset.iloc(0)["item"]
162+
print(resulting_item)
163+
assert resulting_item.metadata["snake_field"] == 1
164+
165+
116166
def test_dataset_append(dataset):
117167
def check_is_expected_response(response):
118168
assert isinstance(response, UploadResponse)

0 commit comments

Comments
 (0)