From 54c699dbd2f35d91656077309847d0e58f6946e4 Mon Sep 17 00:00:00 2001 From: Gabefire <33893811+Gabefire@users.noreply.github.com> Date: Wed, 5 Jun 2024 16:13:02 -0500 Subject: [PATCH] added get_model_configs methods --- libs/labelbox/src/labelbox/client.py | 23 +++++++++++++++++++ .../tests/integration/test_model_config.py | 5 ++++ 2 files changed, 28 insertions(+) diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index a2fb09186..8988e2551 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -654,6 +654,29 @@ def delete_model_config(self, id: str) -> bool: if not result: raise labelbox.exceptions.ResourceNotFoundError(Entity.ModelConfig, params) return result['deleteModelConfig']['success'] + + def get_model_configs(self, model_id: str) -> List[ModelConfig]: + """ Gets all model configs attached to a given model id + + Args: + model_id (str): ID of the model associated with the model configs + + Returns: + List[ModelConfig], list of ModelConfigs if the operation was a success. + """ + + query = """query SearchModelConfigsPyApi($modelId: ID!) { + modelConfigs( + where: {modelId: $modelId} + ) { + id + inferenceParams + name + } + }""" + params = {"modelId": model_id} + result = self.execute(query, params) + return [ModelConfig(self, {**model_config, "modelId":model_id}) for model_config in result["modelConfigs"]] def create_dataset(self, iam_integration=IAMIntegration._DEFAULT, diff --git a/libs/labelbox/tests/integration/test_model_config.py b/libs/labelbox/tests/integration/test_model_config.py index 960b096c6..412a2afae 100644 --- a/libs/labelbox/tests/integration/test_model_config.py +++ b/libs/labelbox/tests/integration/test_model_config.py @@ -15,3 +15,8 @@ def test_delete_model_config(client, valid_model_id): def test_delete_nonexistant_model_config(client): with pytest.raises(ResourceNotFoundError): client.delete_model_config("invalid_model_id") + +def test_get_model_configs(client, valid_model_id): + model_config = client.create_model_config("model_config_1", valid_model_id, {"param": "value"}) + model_config = client.create_model_config("model_config_2", valid_model_id, {"param": "value"}) + assert len(client.get_model_configs(valid_model_id)) == 2