diff --git a/libs/labelbox/src/labelbox/__init__.py b/libs/labelbox/src/labelbox/__init__.py index 889e40158..726054046 100644 --- a/libs/labelbox/src/labelbox/__init__.py +++ b/libs/labelbox/src/labelbox/__init__.py @@ -43,4 +43,6 @@ from labelbox.schema.identifiable import UniqueId, GlobalKey from labelbox.schema.ontology_kind import OntologyKind from labelbox.schema.project_overview import ProjectOverview, ProjectOverviewDetailed -from labelbox.schema.labeling_service import LabelingService, LabelingServiceStatus +from labelbox.schema.labeling_service import LabelingService +from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard +from labelbox.schema.labeling_service_status import LabelingServiceStatus diff --git a/libs/labelbox/src/labelbox/client.py b/libs/labelbox/src/labelbox/client.py index 1b0ea866f..a5c398df1 100644 --- a/libs/labelbox/src/labelbox/client.py +++ b/libs/labelbox/src/labelbox/client.py @@ -56,6 +56,7 @@ from labelbox.schema.label_score import LabelScore from labelbox.schema.ontology_kind import (OntologyKind, EditorTaskTypeMapper, EditorTaskType) +from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard logger = logging.getLogger(__name__) @@ -2405,3 +2406,21 @@ def upsert_label_feedback(self, label_id: str, feedback: str, labelbox.LabelScore(name=x['name'], score=x['score']) for x in scores_raw ] + + def get_labeling_service_dashboards( + self, + after: Optional[str] = None, + search_query: Optional[List[Dict]] = None, + ) -> PaginatedCollection: + """ + Get all labeling service dashboards for a given org. + + Optional parameters: + after: The cursor to use for pagination. + where: A filter to apply to the query. + + NOTE: support for after and search_query are not yet implemented. + """ + return LabelingServiceDashboard.get_all(self, + after, + search_query=search_query) diff --git a/libs/labelbox/src/labelbox/schema/labeling_service.py b/libs/labelbox/src/labelbox/schema/labeling_service.py index cbc65232c..41e3e559b 100644 --- a/libs/labelbox/src/labelbox/schema/labeling_service.py +++ b/libs/labelbox/src/labelbox/schema/labeling_service.py @@ -1,5 +1,4 @@ from datetime import datetime -from enum import Enum import json from typing import Any from typing_extensions import Annotated @@ -8,22 +7,12 @@ from labelbox.pydantic_compat import BaseModel, Field from labelbox.utils import _CamelCaseMixin +from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard +from labelbox.schema.labeling_service_status import LabelingServiceStatus Cuid = Annotated[str, Field(min_length=25, max_length=25)] -class LabelingServiceStatus(Enum): - """" - The status of the labeling service. - """ - Accepted = 'ACCEPTED' - Calibration = 'CALIBRATION' - Complete = 'COMPLETE' - Production = 'PRODUCTION' - Requested = 'REQUESTED' - SetUp = 'SET_UP' - - class LabelingService(BaseModel): """ Labeling service for a project. This is a service that can be requested to label data for a project. @@ -66,6 +55,34 @@ def start(cls, client, project_id: Cuid) -> 'LabelingService': raise Exception("Failed to start labeling service") return cls.get(client, project_id) + @classmethod + def get(cls, client, project_id: Cuid) -> 'LabelingService': + """ + Returns the labeling service associated with the project. + + Raises: + ResourceNotFoundError: If the project does not have a labeling service. + """ + query = """ + query GetProjectBoostWorkforcePyApi($projectId: ID!) { + projectBoostWorkforce(data: { projectId: $projectId }) { + id + projectId + createdAt + updatedAt + createdById + status + } + } + """ + result = client.execute(query, {"projectId": project_id}) + if result["projectBoostWorkforce"] is None: + raise ResourceNotFoundError( + message="The project does not have a labeling service.") + data = result["projectBoostWorkforce"] + data["client"] = client + return LabelingService(**data) + def request(self) -> 'LabelingService': """ Creates a request to labeling service to start labeling for the project. @@ -124,30 +141,11 @@ def getOrCreate(cls, client, project_id: Cuid) -> 'LabelingService': except ResourceNotFoundError: return cls.start(client, project_id) - @classmethod - def get(cls, client, project_id: Cuid) -> 'LabelingService': + def dashboard(self) -> LabelingServiceDashboard: """ - Returns the labeling service associated with the project. + Returns the dashboard for the labeling service associated with the project. Raises: ResourceNotFoundError: If the project does not have a labeling service. """ - query = """ - query GetProjectBoostWorkforcePyApi($projectId: ID!) { - projectBoostWorkforce(data: { projectId: $projectId }) { - id - projectId - createdAt - updatedAt - createdById - status - } - } - """ - result = client.execute(query, {"projectId": project_id}) - if result["projectBoostWorkforce"] is None: - raise ResourceNotFoundError( - message="The project does not have a labeling service.") - data = result["projectBoostWorkforce"] - data["client"] = client - return LabelingService(**data) + return LabelingServiceDashboard.get(self.client, self.project_id) \ No newline at end of file diff --git a/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py new file mode 100644 index 000000000..56f28c865 --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/labeling_service_dashboard.py @@ -0,0 +1,139 @@ +from string import Template +from datetime import datetime +from typing import Any, Dict, List, Optional, Union + +from labelbox.exceptions import ResourceNotFoundError +from labelbox.pagination import PaginatedCollection +from labelbox.pydantic_compat import BaseModel, root_validator, Field +from labelbox.utils import _CamelCaseMixin +from labelbox.schema.labeling_service_status import LabelingServiceStatus + +GRAPHQL_QUERY_SELECTIONS = """ + id + name + # serviceType + # createdAt + # updatedAt + # createdById + boostStatus + dataRowsCount + dataRowsInReviewCount + dataRowsInReworkCount + dataRowsDoneCount + """ + + +class LabelingServiceDashboard(BaseModel): + """ + Represent labeling service data for a project + + Attributes: + id (str): project id + name (str): project name + status (LabelingServiceStatus): status of the labeling service + tasks_completed (int): number of data rows completed + tasks_remaining (int): number of data rows that have not started + client (Any): labelbox client + """ + id: str = Field(frozen=True) + name: str = Field(frozen=True) + service_type: Optional[str] = Field(frozen=True, default=None) + created_at: Optional[datetime] = Field(frozen=True, default=None) + updated_at: Optional[datetime] = Field(frozen=True, default=None) + created_by_id: Optional[str] = Field(frozen=True, default=None) + status: LabelingServiceStatus = Field(frozen=True, + default=LabelingServiceStatus.Missing) + data_rows_count: int = Field(frozen=True) + data_rows_in_review_count: int = Field(frozen=True) + data_rows_in_rework_count: int = Field(frozen=True) + data_rows_done_count: int = Field(frozen=True) + + client: Any # type Any to avoid circular import from client + + def __init__(self, **kwargs): + super().__init__(**kwargs) + if not self.client.enable_experimental: + raise RuntimeError( + "Please enable experimental in client to use LabelingService") + + @property + def tasks_completed(self): + return self.data_rows_done_count + + @property + def tasks_remaining(self): + return self.data_rows_count - self.data_rows_done_count + + class Config(_CamelCaseMixin.Config): + ... + + @classmethod + def get(cls, client, project_id: str) -> 'LabelingServiceDashboard': + """ + Returns the labeling service associated with the project. + + Raises: + ResourceNotFoundError: If the project does not have a labeling service. + """ + query = f""" + query GetProjectByIdPyApi($id: ID!) {{ + getProjectById(input: {{id: $id}}) {{ + {GRAPHQL_QUERY_SELECTIONS} + }} + }} + """ + result = client.execute(query, {"id": project_id}, experimental=True) + if result["getProjectById"] is None: + raise ResourceNotFoundError( + message="The project does not have a labeling service.") + data = result["getProjectById"] + data["client"] = client + return cls(**data) + + @classmethod + def get_all( + cls, + client, + after: Optional[str] = None, + search_query: Optional[List[Dict]] = None, + ) -> PaginatedCollection: + template = Template( + """query SearchProjectsPyApi($$first: Int, $$from: String) { + searchProjects(input: {after: $$from, searchQuery: $search_query, size: $$first}) + { + nodes { $labeling_dashboard_selections } + pageInfo { endCursor } + } + } + """) + organization_id = client.get_organization().uid + query_str = template.substitute( + labeling_dashboard_selections=GRAPHQL_QUERY_SELECTIONS, + search_query= + f"[{{type: \"organization\", operator: \"is\", values: [\"{organization_id}\"]}}]" + ) + + params: Dict[str, Union[str, int]] = {} + if after: + params = {"from": after} + + def convert_to_labeling_service_dashboard(client, data): + data['client'] = client + return LabelingServiceDashboard(**data) + + return PaginatedCollection( + client=client, + query=query_str, + params=params, + dereferencing=['searchProjects', 'nodes'], + obj_class=convert_to_labeling_service_dashboard, + cursor_path=['searchProjects', 'pageInfo', 'endCursor'], + experimental=True, + ) + + @root_validator(pre=True) + def convert_boost_status_to_enum(cls, data): + if 'boostStatus' in data: + data['status'] = LabelingServiceStatus(data.pop('boostStatus')) + + return data diff --git a/libs/labelbox/src/labelbox/schema/labeling_service_status.py b/libs/labelbox/src/labelbox/schema/labeling_service_status.py new file mode 100644 index 000000000..62cfd938e --- /dev/null +++ b/libs/labelbox/src/labelbox/schema/labeling_service_status.py @@ -0,0 +1,30 @@ +from enum import Enum + + +class LabelingServiceStatus(Enum): + Accepted = 'ACCEPTED' + Calibration = 'CALIBRATION' + Complete = 'COMPLETE' + Production = 'PRODUCTION' + Requested = 'REQUESTED' + SetUp = 'SET_UP' + Missing = None + + @classmethod + def is_supported(cls, value): + return isinstance(value, cls) + + @classmethod + def _missing_(cls, value) -> 'LabelingServiceStatus': + """Handle missing null new task types + Handle upper case names for compatibility with + the GraphQL""" + + if value is None: + return cls.Missing + + for name, member in cls.__members__.items(): + if value == name.upper(): + return member + + return cls.Missing diff --git a/libs/labelbox/src/labelbox/schema/project.py b/libs/labelbox/src/labelbox/schema/project.py index 4605c3aa5..ff093567c 100644 --- a/libs/labelbox/src/labelbox/schema/project.py +++ b/libs/labelbox/src/labelbox/schema/project.py @@ -10,6 +10,7 @@ from urllib.parse import urlparse from labelbox.schema.labeling_service import LabelingService, LabelingServiceStatus +from labelbox.schema.labeling_service_dashboard import LabelingServiceDashboard import requests from labelbox import parser @@ -1941,6 +1942,15 @@ def get_labeling_service_status(self) -> LabelingServiceStatus: """ return self.get_labeling_service().status + @experimental + def labeling_service_dashboard(self) -> LabelingServiceDashboard: + """Get the labeling service for this project. + + Returns: + LabelingService: The labeling service for this project. + """ + return LabelingServiceDashboard.get(self.client, self.uid) + class ProjectMember(DbObject): user = Relationship.ToOne("User", cache=True) diff --git a/libs/labelbox/tests/integration/test_labeling_dashboard.py b/libs/labelbox/tests/integration/test_labeling_dashboard.py new file mode 100644 index 000000000..a45334bcc --- /dev/null +++ b/libs/labelbox/tests/integration/test_labeling_dashboard.py @@ -0,0 +1,24 @@ +from labelbox.schema.labeling_service import LabelingServiceStatus + + +def test_request_labeling_service_moe_offline_project( + rand_gen, offline_chat_evaluation_project, chat_evaluation_ontology, + offline_conversational_data_row): + project = offline_chat_evaluation_project + project.connect_ontology(chat_evaluation_ontology) + + project.create_batch( + rand_gen(str), + [offline_conversational_data_row.uid], # sample of data row objects + ) + labeling_service_dashboard = project.labeling_service_dashboard() + assert labeling_service_dashboard.status == LabelingServiceStatus.Missing + assert labeling_service_dashboard.tasks_completed == 0 + assert labeling_service_dashboard.tasks_remaining == 0 + + labeling_service_dashboard = [ + ld for ld in project.client.get_labeling_service_dashboards() + ][0] + assert labeling_service_dashboard.status == LabelingServiceStatus.Missing + assert labeling_service_dashboard.tasks_completed == 0 + assert labeling_service_dashboard.tasks_remaining == 0