diff --git a/libs/labelbox/src/labelbox/schema/labeling_service.py b/libs/labelbox/src/labelbox/schema/labeling_service.py index 3b0ef9445..48d2b6241 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_service.py +++ b/libs/labelbox/src/labelbox/schema/labeling_service.py @@ -87,6 +87,19 @@ def request(self) -> 'LabelingService': raise Exception("Failed to start labeling service") return LabelingService.get(self.client, self.project_id) + @classmethod + def getOrCreate(cls, client, project_id: Cuid) -> 'LabelingService': + """ + Returns the labeling service associated with the project. If the project does not have a labeling service, it will create one. + + Returns: + LabelingService: The labeling service for the project. + """ + try: + return cls.get(client, project_id) + except ResourceNotFoundError: + return cls.start(client, project_id) + @classmethod def get(cls, client, project_id: Cuid) -> 'LabelingService': """ diff --git a/libs/labelbox/src/labelbox/schema/project.py b/libs/labelbox/src/labelbox/schema/project.py index 761d0e391..baf75ee65 100644 --- a/libs/labelbox/src/labelbox/schema/project.py +++ b/libs/labelbox/src/labelbox/schema/project.py @@ -1920,13 +1920,12 @@ def clone(self) -> "Project": def get_labeling_service(self) -> LabelingService: """Get the labeling service for this project. - Raises: - ResourceNotFoundError if the project does not have a labeling service. + Will automatically create a labeling service if one does not exist. Returns: LabelingService: The labeling service for this project. """ - return LabelingService.get(self.client, self.uid) + return LabelingService.getOrCreate(self.client, self.uid) @experimental def get_labeling_service_status(self) -> LabelingServiceStatus: @@ -1940,15 +1939,6 @@ def get_labeling_service_status(self) -> LabelingServiceStatus: """ return self.get_labeling_service().status - @experimental - def request_labeling_service(self) -> LabelingService: - """Get the labeling service for this project. - - Returns: - LabelingService: The labeling service for this project. - """ - return LabelingService.start(self.client, self.uid) # type: ignore - class ProjectMember(DbObject): user = Relationship.ToOne("User", cache=True) diff --git a/libs/labelbox/tests/integration/test_labeling_service.py b/libs/labelbox/tests/integration/test_labeling_service.py index a35777fe9..670f20ac6 100644 --- a/libs/labelbox/tests/integration/test_labeling_service.py +++ b/libs/labelbox/tests/integration/test_labeling_service.py @@ -4,20 +4,9 @@ from labelbox.schema.labeling_service import LabelingServiceStatus -def test_get_labeling_service_throws_exception(project): - with pytest.raises(ResourceNotFoundError): # No labeling service by default - project.get_labeling_service() - with pytest.raises(ResourceNotFoundError): # No labeling service by default - project.get_labeling_service_status() - - def test_start_labeling_service(project): - labeling_service = project.request_labeling_service() - assert labeling_service.status == LabelingServiceStatus.SetUp - assert labeling_service.project_id == project.uid - # Check that the labeling service is now available - labeling_service = project.get_labeling_service() + labeling_service = project.get_labeling_service() # creates and gets it assert labeling_service.status == LabelingServiceStatus.SetUp assert labeling_service.project_id == project.uid @@ -25,19 +14,6 @@ def test_start_labeling_service(project): assert labeling_service_status == LabelingServiceStatus.SetUp -def test_request_labeling_service( - configured_batch_project_for_labeling_service): - project = configured_batch_project_for_labeling_service - - project.upsert_instructions('tests/integration/media/sample_pdf.pdf') - - labeling_service = project.request_labeling_service( - ) # project fixture is an Image type project - labeling_service.request() - assert project.get_labeling_service_status( - ) == LabelingServiceStatus.Requested - - def test_request_labeling_service_moe_offline_project( rand_gen, offline_chat_evaluation_project, chat_evaluation_ontology, offline_conversational_data_row, model_config): @@ -51,7 +27,7 @@ def test_request_labeling_service_moe_offline_project( project.upsert_instructions('tests/integration/media/sample_pdf.pdf') - labeling_service = project.request_labeling_service() + labeling_service = project.get_labeling_service() labeling_service.request() assert project.get_labeling_service_status( ) == LabelingServiceStatus.Requested @@ -65,7 +41,7 @@ def test_request_labeling_service_moe_project( project.upsert_instructions('tests/integration/media/sample_pdf.pdf') - labeling_service = project.request_labeling_service() + labeling_service = project.get_labeling_service() with pytest.raises( LabelboxError, match= @@ -81,7 +57,7 @@ def test_request_labeling_service_moe_project( def test_request_labeling_service_incomplete_requirements(project, ontology): - labeling_service = project.request_labeling_service( + labeling_service = project.get_labeling_service( ) # project fixture is an Image type project with pytest.raises(ResourceNotFoundError, match="Associated ontology id could not be found"