diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index af3b5b3fc..a216bb0d3 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -1809,17 +1809,14 @@ def _format_failed_rows( def get_catalog(self) -> Catalog: return Catalog(client=self) - def get_catalog_slice(self, slice_id) -> CatalogSlice: + def get_catalog_slices(self) -> List[CatalogSlice]: """ - Fetches a Catalog Slice by ID. - - Args: - slice_id (str): The ID of the Slice + Fetches all slices of the given entity type. Returns: - CatalogSlice + List[CatalogSlice]: A list of CatalogSlice objects. """ - query_str = """query getSavedQueryPyApi($id: ID!) { - getSavedQuery(id: $id) { + query_str = """query GetCatalogSavedQueriesPyApi { + catalogSavedQueries { id name description @@ -1828,9 +1825,62 @@ def get_catalog_slice(self, slice_id) -> CatalogSlice: updatedAt } } + """ + res = self.execute(query_str) + return [CatalogSlice(self, sl) for sl in res["catalogSavedQueries"]] + + def get_catalog_slice( + self, slice_id: Optional[str] = None, slice_name: Optional[str] = None + ) -> Union[CatalogSlice, List[CatalogSlice]]: """ - res = self.execute(query_str, {"id": slice_id}) - return Entity.CatalogSlice(self, res["getSavedQuery"]) + Fetches a Slice using either the slice ID or the slice name. + + Args: + slice_id (Optional[str]): The ID of the Slice. + slice_name (Optional[str]): The name of the Slice. + + Returns: + Union[CatalogSlice, List[CatalogSlice], ModelSlice, List[ModelSlice]]: + The corresponding Slice object or list of Slice objects. + + Raises: + ValueError: If neither or both id and name are provided. + ResourceNotFoundError: If the slice is not found. + """ + if (slice_id is None and slice_name is None) or ( + slice_id is not None and slice_name is not None + ): + raise ValueError("Provide exactly one of id or name") + + if slice_id is not None: + query_str = """query getSavedQueryPyApi($id: ID!) { + getSavedQuery(id: $id) { + id + name + description + filter + createdAt + updatedAt + } + } + """ + + res = self.execute(query_str, {"id": slice_id}) + if res is None: + raise ResourceNotFoundError(CatalogSlice, {"id": slice_id}) + + return CatalogSlice(self, res["getSavedQuery"]) + + else: + slices = self.get_catalog_slices() + matches = [s for s in slices if s.name == slice_name] + + if not matches: + raise ResourceNotFoundError(CatalogSlice, {"name": slice_name}) + elif len(matches) > 1: + return matches + else: + return matches[0] def is_feature_schema_archived( self, ontology_id: str, feature_schema_id: str diff --git a/libs/labelbox/tests/integration/test_slice.py b/libs/labelbox/tests/integration/test_slice.py new file mode 100644 index 000000000..0521d24dd --- /dev/null +++ b/libs/labelbox/tests/integration/test_slice.py @@ -0,0 +1,118 @@ +from typing import Optional +from labelbox import Client, CatalogSlice + + +def _create_catalog_slice( + client: Client, name: str, description: Optional[str] = None +) -> str: + """Creates a catalog slice for testing purposes. + + Args: + client (Client): Labelbox client instance + name (str): Name of the catalog slice + description (str): Description of the catalog slice + + Returns: + str: ID of the created catalog slice + """ + + mutation = """mutation CreateCatalogSlicePyApi($name: String!, $description: String, $query: SearchServiceQuery!, $sorting: [SearchServiceSorting!]) { + createCatalogSavedQuery( + args: {name: $name, description: $description, filter: $query, sorting: $sorting} + ) { + id + name + description + filter + sorting + catalogCount { + count + } + } + } + """ + + params = { + "description": description, + "name": name, + "query": [ + { + "type": "media_attribute_asset_type", + "assetType": {"type": "asset_type", "assetTypes": ["image"]}, + } + ], + "sorting": [ + { + "field": { + "field": "dataRowCreatedAt", + "verboseName": "Created At", + }, + "direction": "DESC", + "metadataSchemaId": None, + } + ], + } + + result = client.execute(mutation, params, experimental=True) + + return result["createCatalogSavedQuery"].get("id") + + +def _delete_catalog_slice(client, slice_id: str) -> bool: + mutation = """mutation DeleteCatalogSlicePyApi($id: ID!) { + deleteSavedQuery(args: { id: $id }) { + success + } + } + """ + + params = {"id": slice_id} + + operation_done = True + try: + client.execute(mutation, params, experimental=True) + except Exception as ex: + operation_done = False + + return operation_done + + +def test_get_slice(client): + # Pre-cleaning + slices = ( + s + for s in client.get_catalog_slices() + if s.name in ["Test Slice 1", "Test Slice 2"] + ) + for slice in slices: + _delete_catalog_slice(client, slice.id) + + # Create slices + slice_id_1 = _create_catalog_slice( + client, "Test Slice 1", "Slice created for SDK test." + ) + slice_id_2 = _create_catalog_slice( + client, "Test Slice 2", "Slice created for SDK test." + ) + # Create slice 2b - with the same name as slice 2 + slice_id_2b = _create_catalog_slice( + client, "Test Slice 2", "Slice created for SDK test." + ) + + # Assert get slice 1 by ID + slice_1 = client.get_catalog_slice(slice_id_1) + assert isinstance(slice_1, CatalogSlice) + + slice_1 = client.get_catalog_slice(slice_name="Test Slice 1") + assert isinstance(slice_1, CatalogSlice) + + slices_2 = client.get_catalog_slice(slice_name="Test Slice 2") + assert len(slices_2) == 2 + assert isinstance(slices_2, list) and all( + [isinstance(item, CatalogSlice) for item in slices_2] + ) + + # Cleaning - Delete slices + _delete_catalog_slice(client, slice_id_1) + _delete_catalog_slice(client, slice_id_2) + _delete_catalog_slice(client, slice_id_2b)