|
3 | 3 | from datetime import datetime
|
4 | 4 | import json
|
5 | 5 | import requests
|
| 6 | +import os |
6 | 7 |
|
7 | 8 | from unittest.mock import patch
|
8 | 9 | import pytest
|
@@ -171,66 +172,110 @@ def test_lookup_data_rows(client, dataset):
|
171 | 172 |
|
172 | 173 | def test_data_row_bulk_creation(dataset, rand_gen, image_url):
|
173 | 174 | client = dataset.client
|
| 175 | + data_rows = [] |
174 | 176 | assert len(list(dataset.data_rows())) == 0
|
175 | 177 |
|
176 |
| - with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', |
177 |
| - new=1): # Force chunking |
178 |
| - # Test creation using URL |
179 |
| - task = dataset.create_data_rows([ |
180 |
| - { |
181 |
| - DataRow.row_data: image_url |
182 |
| - }, |
183 |
| - { |
184 |
| - "row_data": image_url |
185 |
| - }, |
186 |
| - ]) |
187 |
| - task.wait_till_done() |
188 |
| - assert task.has_errors() is False |
189 |
| - assert task.status == "COMPLETE" |
| 178 | + try: |
| 179 | + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', |
| 180 | + new=1): # Force chunking |
| 181 | + # Test creation using URL |
| 182 | + task = dataset.create_data_rows([ |
| 183 | + { |
| 184 | + DataRow.row_data: image_url |
| 185 | + }, |
| 186 | + { |
| 187 | + "row_data": image_url |
| 188 | + }, |
| 189 | + ]) |
| 190 | + task.wait_till_done() |
| 191 | + assert task.has_errors() is False |
| 192 | + assert task.status == "COMPLETE" |
190 | 193 |
|
191 |
| - results = task.result |
192 |
| - assert len(results) == 2 |
193 |
| - row_data = [result["row_data"] for result in results] |
194 |
| - assert row_data == [image_url, image_url] |
195 |
| - results_all = task.result_all |
196 |
| - row_data = [result["row_data"] for result in results_all] |
197 |
| - assert row_data == [image_url, image_url] |
| 194 | + results = task.result |
| 195 | + assert len(results) == 2 |
| 196 | + row_data = [result["row_data"] for result in results] |
| 197 | + assert row_data == [image_url, image_url] |
| 198 | + results_all = task.result_all |
| 199 | + row_data = [result["row_data"] for result in results_all] |
| 200 | + assert row_data == [image_url, image_url] |
198 | 201 |
|
199 |
| - data_rows = list(dataset.data_rows()) |
200 |
| - assert len(data_rows) == 2 |
201 |
| - assert {data_row.row_data for data_row in data_rows} == {image_url} |
202 |
| - assert {data_row.global_key for data_row in data_rows} == {None} |
| 202 | + data_rows = list(dataset.data_rows()) |
| 203 | + assert len(data_rows) == 2 |
| 204 | + assert {data_row.row_data for data_row in data_rows} == {image_url} |
| 205 | + assert {data_row.global_key for data_row in data_rows} == {None} |
203 | 206 |
|
204 |
| - data_rows = list(dataset.data_rows(from_cursor=data_rows[0].uid)) |
205 |
| - assert len(data_rows) == 1 |
| 207 | + data_rows = list(dataset.data_rows(from_cursor=data_rows[0].uid)) |
| 208 | + assert len(data_rows) == 1 |
206 | 209 |
|
207 |
| - # Test creation using file name |
208 |
| - with NamedTemporaryFile() as fp: |
209 |
| - data = rand_gen(str).encode() |
210 |
| - fp.write(data) |
211 |
| - fp.flush() |
212 |
| - task = dataset.create_data_rows([fp.name]) |
| 210 | + finally: |
| 211 | + for dr in data_rows: |
| 212 | + dr.delete() |
| 213 | + |
| 214 | + |
| 215 | +@pytest.fixture |
| 216 | +def local_image_file(image_url) -> NamedTemporaryFile: |
| 217 | + response = requests.get(image_url, stream=True) |
| 218 | + response.raise_for_status() |
| 219 | + |
| 220 | + with NamedTemporaryFile(delete=False) as f: |
| 221 | + for chunk in response.iter_content(chunk_size=8192): |
| 222 | + if chunk: |
| 223 | + f.write(chunk) |
| 224 | + |
| 225 | + yield f # Return the path to the temp file |
| 226 | + |
| 227 | + os.remove(f.name) |
| 228 | + |
| 229 | + |
| 230 | +def test_data_row_bulk_creation_from_file(dataset, local_image_file, image_url): |
| 231 | + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', |
| 232 | + new=1): # Force chunking |
| 233 | + task = dataset.create_data_rows( |
| 234 | + [local_image_file.name, local_image_file.name]) |
213 | 235 | task.wait_till_done()
|
214 | 236 | assert task.status == "COMPLETE"
|
| 237 | + assert len(task.result) == 2 |
| 238 | + assert task.has_errors() is False |
| 239 | + results = [r for r in task.result_all] |
| 240 | + row_data = [result["row_data"] for result in results] |
| 241 | + assert row_data == [image_url, image_url] |
| 242 | + |
215 | 243 |
|
| 244 | +def test_data_row_bulk_creation_from_row_data_file_external_id( |
| 245 | + dataset, local_image_file, image_url): |
| 246 | + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', |
| 247 | + new=1): # Force chunking |
216 | 248 | task = dataset.create_data_rows([{
|
217 |
| - "row_data": fp.name, |
| 249 | + "row_data": local_image_file.name, |
218 | 250 | 'external_id': 'some_name'
|
| 251 | + }, { |
| 252 | + "row_data": image_url, |
| 253 | + 'external_id': 'some_name2' |
219 | 254 | }])
|
220 |
| - task.wait_till_done() |
221 | 255 | assert task.status == "COMPLETE"
|
| 256 | + assert len(task.result) == 2 |
| 257 | + assert task.has_errors() is False |
| 258 | + results = [r for r in task.result_all] |
| 259 | + row_data = [result["row_data"] for result in results] |
| 260 | + assert row_data == [image_url, image_url] |
| 261 | + |
222 | 262 |
|
223 |
| - task = dataset.create_data_rows([{"row_data": fp.name}]) |
| 263 | +def test_data_row_bulk_creation_from_row_data_file(dataset, rand_gen, |
| 264 | + local_image_file, image_url): |
| 265 | + with patch('labelbox.schema.dataset.UPSERT_CHUNK_SIZE', |
| 266 | + new=1): # Force chunking |
| 267 | + task = dataset.create_data_rows([{ |
| 268 | + "row_data": local_image_file.name |
| 269 | + }, { |
| 270 | + "row_data": local_image_file.name |
| 271 | + }]) |
224 | 272 | task.wait_till_done()
|
225 | 273 | assert task.status == "COMPLETE"
|
226 |
| - |
227 |
| - data_rows = list(dataset.data_rows()) |
228 |
| - assert len(data_rows) == 5 |
229 |
| - url = ({data_row.row_data for data_row in data_rows} - {image_url}).pop() |
230 |
| - assert requests.get(url).content == data |
231 |
| - |
232 |
| - for dr in data_rows: |
233 |
| - dr.delete() |
| 274 | + assert len(task.result) == 2 |
| 275 | + assert task.has_errors() is False |
| 276 | + results = [r for r in task.result_all] |
| 277 | + row_data = [result["row_data"] for result in results] |
| 278 | + assert row_data == [image_url, image_url] |
234 | 279 |
|
235 | 280 |
|
236 | 281 | @pytest.mark.slow
|
|
0 commit comments