@@ -329,7 +329,7 @@ def test_raises_error_for_duplicate():
329
329
)
330
330
331
331
332
- def test_dataset_export_autotag_scores (CLIENT ):
332
+ def test_dataset_export_autotag_tagged_items (CLIENT ):
333
333
# This test can only run for the test user who has an indexed dataset.
334
334
# TODO: if/when we can create autotags via api, create one instead.
335
335
if NUCLEUS_PYTEST_USER_ID in CLIENT .api_key :
@@ -342,11 +342,51 @@ def test_dataset_export_autotag_scores(CLIENT):
342
342
in str (api_error .value )
343
343
)
344
344
345
- scores = dataset .autotag_scores (autotag_name = "PytestTestTag" )
345
+ items = dataset .autotag_items (autotag_name = "PytestTestTag" )
346
346
347
- for column in ["dataset_item_ids" , "ref_ids" , "scores" ]:
348
- assert column in scores
349
- assert len (scores [column ]) > 0
347
+ assert "autotagItems" in items
348
+ assert "autotag" in items
349
+
350
+ autotagItems = items ["autotagItems" ]
351
+ autotag = items ["autotag" ]
352
+
353
+ assert len (autotagItems ) > 0
354
+ for item in autotagItems :
355
+ for column in ["ref_id" , "score" ]:
356
+ assert column in item
357
+
358
+ for column in ["id" , "name" , "status" , "autotag_level" ]:
359
+ assert column in autotag
360
+
361
+
362
+ def test_dataset_export_autotag_training_items (CLIENT ):
363
+ # This test can only run for the test user who has an indexed dataset.
364
+ # TODO: if/when we can create autotags via api, create one instead.
365
+ if NUCLEUS_PYTEST_USER_ID in CLIENT .api_key :
366
+ dataset = CLIENT .get_dataset (DATASET_WITH_AUTOTAG )
367
+
368
+ with pytest .raises (NucleusAPIError ) as api_error :
369
+ dataset .autotag_scores (autotag_name = "NONSENSE_GARBAGE" )
370
+ assert (
371
+ f"The autotag NONSENSE_GARBAGE was not found in dataset { DATASET_WITH_AUTOTAG } "
372
+ in str (api_error .value )
373
+ )
374
+
375
+ items = dataset .autotag_training_items (autotag_name = "PytestTestTag" )
376
+
377
+ assert "autotagItems" in items
378
+ assert "autotag" in items
379
+
380
+ autotagTrainingItems = items ["autotagPositiveTrainingItems" ]
381
+ autotag = items ["autotag" ]
382
+
383
+ assert len (autotagTrainingItems ) > 0
384
+ for item in autotagTrainingItems :
385
+ for column in ["ref_id" ]:
386
+ assert column in item
387
+
388
+ for column in ["id" , "name" , "status" , "autotag_level" ]:
389
+ assert column in autotag
350
390
351
391
352
392
@pytest .mark .integration
0 commit comments