1
- from typing import List
1
+ from __future__ import annotations
2
+
3
+ from typing import List , Iterable , Set , Tuple , Optional
4
+ from nucleus .dataset_item import DatasetItem
5
+ from nucleus .annotation import Annotation
6
+
7
+ from .constants import DEFAULT_ANNOTATION_UPDATE_MODE
8
+
9
+
10
+ def check_annotations_are_in_slice (
11
+ annotations : List [Annotation ], slice_to_check : Slice
12
+ ) -> Tuple [bool , Set [str ], Set [str ]]:
13
+ """Check membership of the annotation targets within this slice.
14
+
15
+ annotations: Annnotations with ids referring to targets.
16
+ slice: The slice to check against.
17
+ """
18
+ info = slice_to_check .info ()
19
+ item_ids_not_found_in_slice = {
20
+ annotation .item_id
21
+ for annotation in annotations
22
+ if annotation .item_id is not None
23
+ }.difference ({item_metadata ["id" ] for item_metadata in info })
24
+ reference_ids_not_found_in_slice = {
25
+ annotation .reference_id
26
+ for annotation in annotations
27
+ if annotation .reference_id is not None
28
+ }.difference ({item_metadata ["reference_id" ] for item_metadata in info })
29
+ if item_ids_not_found_in_slice or reference_ids_not_found_in_slice :
30
+ annotations_are_in_slice = False
31
+ else :
32
+ annotations_are_in_slice = True
33
+
34
+ return (
35
+ annotations_are_in_slice ,
36
+ item_ids_not_found_in_slice ,
37
+ reference_ids_not_found_in_slice ,
38
+ )
2
39
3
40
4
41
class Slice :
@@ -9,6 +46,7 @@ class Slice:
9
46
def __init__ (self , slice_id : str , client ):
10
47
self .slice_id = slice_id
11
48
self ._client = client
49
+ self ._dataset_id = None
12
50
13
51
def __repr__ (self ):
14
52
return f"Slice(slice_id='{ self .slice_id } ', client={ self ._client } )"
@@ -19,6 +57,13 @@ def __eq__(self, other):
19
57
return True
20
58
return False
21
59
60
+ @property
61
+ def dataset_id (self ):
62
+ """The id of the dataset this slice belongs to."""
63
+ if self ._dataset_id is None :
64
+ self .info ()
65
+ return self ._dataset_id
66
+
22
67
def info (self ) -> dict :
23
68
"""
24
69
This endpoint provides information about specified slice.
@@ -30,7 +75,9 @@ def info(self) -> dict:
30
75
"dataset_items",
31
76
}
32
77
"""
33
- return self ._client .slice_info (self .slice_id )
78
+ info = self ._client .slice_info (self .slice_id )
79
+ self ._dataset_id = info ["dataset_id" ]
80
+ return info
34
81
35
82
def append (
36
83
self ,
@@ -57,3 +104,52 @@ def append(
57
104
reference_ids = reference_ids ,
58
105
)
59
106
return response
107
+
108
+ def items_generator (self ) -> Iterable [DatasetItem ]:
109
+ """Returns an iterable of DatasetItems in this slice."""
110
+ info = self .info ()
111
+ for item_metadata in info ["dataset_items" ]:
112
+ yield self ._client .dataitem_loc (
113
+ self ,
114
+ dataset_id = info ["dataset_id" ],
115
+ dataset_item_id = item_metadata ["id" ],
116
+ )
117
+
118
+ def items (self ) -> List [DatasetItem ]:
119
+ """Returns a list of all DatasetItems in this slice."""
120
+ return list (self .items_generator ())
121
+
122
+ def annotate (
123
+ self ,
124
+ annotations : List [Annotation ],
125
+ update : Optional [bool ] = DEFAULT_ANNOTATION_UPDATE_MODE ,
126
+ batch_size : int = 5000 ,
127
+ strict = True ,
128
+ ):
129
+ """Update annotations within this slice.
130
+
131
+ Args:
132
+ annotations: List of annotations to upload
133
+ batch_size: How many annotations to send per request.
134
+ strict: Whether to first check that the annotations belong to this slice.
135
+ Set to false to avoid this check and speed up upload.
136
+ """
137
+ if strict :
138
+ (
139
+ annotations_are_in_slice ,
140
+ item_ids_not_found_in_slice ,
141
+ reference_ids_not_found_in_slice ,
142
+ ) = check_annotations_are_in_slice (annotations , self )
143
+ if not annotations_are_in_slice :
144
+ message = "Not all annotations are in this slice.\n "
145
+ if item_ids_not_found_in_slice :
146
+ message += f"Item ids not found in slice: { item_ids_not_found_in_slice } \n "
147
+ if reference_ids_not_found_in_slice :
148
+ message += f"Reference ids not found in slice: { reference_ids_not_found_in_slice } "
149
+ raise ValueError (message )
150
+ self ._client .annotate_dataset (
151
+ dataset_id = self .dataset_id ,
152
+ annotations = annotations ,
153
+ update = update ,
154
+ batch_size = batch_size ,
155
+ )
0 commit comments