|
1 | 1 | from datetime import datetime
|
2 | 2 | from string import Template
|
3 |
| -from typing import Any, Optional |
| 3 | +from typing import Any, Dict, List, Optional |
4 | 4 |
|
5 | 5 | from labelbox.exceptions import ResourceNotFoundError
|
6 | 6 | from labelbox.orm.comparison import Comparison
|
| 7 | +from labelbox.orm import query |
| 8 | +from ..orm.model import Field |
7 | 9 | from labelbox.pagination import PaginatedCollection
|
8 | 10 | from labelbox.pydantic_compat import BaseModel, root_validator
|
| 11 | +from .organization import Organization |
9 | 12 | from labelbox.utils import _CamelCaseMixin
|
10 | 13 | from labelbox.schema.labeling_service_status import LabelingServiceStatus
|
11 | 14 |
|
| 15 | +GRAPHQL_QUERY_SELECTIONS = """ |
| 16 | + id |
| 17 | + name |
| 18 | + # serviceType |
| 19 | + # createdAt |
| 20 | + # updatedAt |
| 21 | + # createdById |
| 22 | + boostStatus |
| 23 | + dataRowsCount |
| 24 | + dataRowsInReviewCount |
| 25 | + dataRowsInReworkCount |
| 26 | + dataRowsDoneCount |
| 27 | + """ |
| 28 | + |
12 | 29 |
|
13 | 30 | class LabelingServiceDashboard(BaseModel):
|
14 | 31 | id: str
|
@@ -69,40 +86,41 @@ def get(cls, client, project_id: str) -> 'LabelingServiceDashboard':
|
69 | 86 | def get_all(
|
70 | 87 | cls,
|
71 | 88 | client,
|
72 |
| - from_cursor: Optional[str] = None, |
73 |
| - where: Optional[Comparison] = None, |
| 89 | + after: Optional[str] = None, |
| 90 | + # where: Optional[Comparison] = None, |
| 91 | + search_query: Optional[List[Dict]] = None, |
74 | 92 | ) -> PaginatedCollection:
|
75 |
| - page_size = 500 # hardcode to avoid overloading the server |
76 |
| - # where_param = query.where_as_dict(Entity.DataRow, |
77 |
| - # where) if where is not None else None |
78 |
| - |
79 | 93 | template = Template(
|
80 |
| - """query SearchProjectsPyApi($$id: ID!, $$after: ID, $$first: Int, $$where: SearchProjectsInput) { |
81 |
| - searchProjects(id: $$id, after: $$after, first: $$first, where: $$where) |
| 94 | + """query SearchProjectsPyApi($$first: Int, $$from: String) { |
| 95 | + searchProjects(input: {after: $$from, searchQuery: $search_query, size: $$first}) |
82 | 96 | {
|
83 |
| - nodes { $datarow_selections } |
84 |
| - pageInfo { hasNextPage startCursor } |
| 97 | + nodes { $labeling_dashboard_selections } |
| 98 | + pageInfo { endCursor } |
85 | 99 | }
|
86 | 100 | }
|
87 | 101 | """)
|
| 102 | + organization_id = client.get_organization().uid |
88 | 103 | query_str = template.substitute(
|
89 |
| - datarow_selections=LabelingServiceDashboard.schema() |
90 |
| - ['properties'].keys()) |
| 104 | + labeling_dashboard_selections=GRAPHQL_QUERY_SELECTIONS, |
| 105 | + search_query= |
| 106 | + f"[{{type: \"organization\", operator: \"is\", values: [\"{organization_id}\"]}}]" |
| 107 | + ) |
91 | 108 |
|
92 | 109 | params = {
|
93 |
| - 'id': self.uid, |
94 |
| - 'from': from_cursor, |
95 |
| - 'first': page_size, |
96 |
| - 'where': where_param, |
| 110 | + 'from': after, |
97 | 111 | }
|
98 | 112 |
|
| 113 | + def convert_to_labeling_service_dashboard(client, data): |
| 114 | + data['client'] = client |
| 115 | + return LabelingServiceDashboard(**data) |
| 116 | + |
99 | 117 | return PaginatedCollection(
|
100 | 118 | client=client,
|
101 | 119 | query=query_str,
|
102 | 120 | params=params,
|
103 | 121 | dereferencing=['searchProjects', 'nodes'],
|
104 |
| - obj_class=LabelingServiceDashboard, |
105 |
| - cursor_path=['datasetDataRows', 'pageInfo', 'endCursor'], |
| 122 | + obj_class=convert_to_labeling_service_dashboard, |
| 123 | + cursor_path=['searchProjects', 'pageInfo', 'endCursor'], |
106 | 124 | )
|
107 | 125 |
|
108 | 126 | @root_validator(pre=True)
|
|
0 commit comments