4
4
import nucleus
5
5
import os
6
6
7
+ from itertools import zip_longest
8
+
7
9
import time
8
10
9
11
21
23
"API Key to use. Defaults to NUCLEUS_PYTEST_API_KEY environment variable" ,
22
24
)
23
25
26
+ flags .DEFINE_integer ("job_parallelism" , 8 , "Amount of concurrent jobs to use." )
27
+
24
28
# Dataset upload flags
25
29
flags .DEFINE_enum (
26
30
"create_or_reuse_dataset" ,
35
39
)
36
40
flags .DEFINE_integer (
37
41
"num_dataset_items" ,
38
- 100000 ,
42
+ 10000000 ,
39
43
"Number of dataset items to create if creating a dataset" ,
40
44
lower_bound = 0 ,
41
45
)
42
46
flags .DEFINE_bool (
43
- "cleanup_dataset" , True , "Whether to delete the dataset after the test."
47
+ "cleanup_dataset" , False , "Whether to delete the dataset after the test."
44
48
)
45
49
46
50
# Annotation upload flags
54
58
# Prediction upload flags
55
59
flags .DEFINE_integer (
56
60
"num_predictions_per_dataset_item" ,
57
- 0 ,
61
+ 1 ,
58
62
"Number of annotations per dataset item" ,
59
63
lower_bound = 0 ,
60
64
)
61
65
66
+ TIMINGS = {}
67
+
68
+
69
+ def chunk (iterable , chunk_size , fillvalue = None ):
70
+ "Collect data into fixed-length chunks or blocks"
71
+ args = [iter (iterable )] * chunk_size
72
+
73
+ for chunk_iterable in zip_longest (* args , fillvalue = fillvalue ):
74
+ yield filter (lambda x : x is not None , chunk_iterable )
75
+
62
76
63
77
def client ():
64
78
return nucleus .NucleusClient (api_key = FLAGS .api_key )
@@ -126,15 +140,23 @@ def create_or_get_dataset():
126
140
dataset = client ().create_dataset ("Privacy Mode Load Test Dataset" )
127
141
print ("Starting dataset item upload" )
128
142
tic = time .time ()
129
- job = dataset .append (
130
- dataset_item_generator (), update = True , asynchronous = True
131
- )
132
- try :
133
- job .sleep_until_complete (False )
134
- except JobError :
135
- print (job .errors ())
143
+ chunk_size = FLAGS .num_dataset_items // FLAGS .job_parallelism
144
+ jobs = []
145
+ for dataset_item_chunk in chunk (dataset_item_generator (), chunk_size ):
146
+ jobs .append (
147
+ dataset .append (
148
+ dataset_item_chunk , update = True , asynchronous = True
149
+ )
150
+ )
151
+
152
+ for job in jobs :
153
+ try :
154
+ job .sleep_until_complete (False )
155
+ except JobError :
156
+ print (job .errors ())
136
157
toc = time .time ()
137
158
print ("Finished dataset item upload: %s" % (toc - tic ))
159
+ TIMINGS [f"Dataset Item Upload { FLAGS .num_dataset_items } " ] = toc - tic
138
160
else :
139
161
print (f"Reusing dataset { FLAGS .dataset_id } " )
140
162
dataset = client ().get_dataset (FLAGS .dataset_id )
@@ -144,15 +166,26 @@ def create_or_get_dataset():
144
166
def upload_annotations (dataset : Dataset ):
145
167
print ("Starting annotation upload" )
146
168
tic = time .time ()
147
- job = dataset .annotate (
148
- list (annotation_generator ()), update = False , asynchronous = True
169
+ jobs = []
170
+ num_annotations = (
171
+ FLAGS .num_dataset_items * FLAGS .num_annotations_per_dataset_item
149
172
)
150
- try :
151
- job .sleep_until_complete (False )
152
- except JobError :
153
- print (job .errors ())
173
+ chunk_size = num_annotations // FLAGS .job_parallelism
174
+ for annotation_chunk in chunk (annotation_generator (), chunk_size ):
175
+ jobs .append (
176
+ dataset .annotate (
177
+ list (annotation_chunk ), update = False , asynchronous = True
178
+ )
179
+ )
180
+
181
+ for job in jobs :
182
+ try :
183
+ job .sleep_until_complete (False )
184
+ except JobError :
185
+ print (job .errors ())
154
186
toc = time .time ()
155
187
print ("Finished annotation upload: %s" % (toc - tic ))
188
+ TIMINGS [f"Annotation Upload { num_annotations } " ] = toc - tic
156
189
157
190
158
191
def upload_predictions (dataset : Dataset ):
@@ -167,16 +200,24 @@ def upload_predictions(dataset: Dataset):
167
200
168
201
print ("Starting prediction upload" )
169
202
170
- job = run . predict (
171
- list ( prediction_generator ()), update = True , asynchronous = True
203
+ num_predictions = (
204
+ FLAGS . num_dataset_items * FLAGS . num_predictions_per_dataset_item
172
205
)
206
+ chunk_size = num_predictions // FLAGS .job_parallelism
207
+ jobs = []
208
+ for prediction_chunk in chunk (prediction_generator (), chunk_size ):
209
+ jobs .append (
210
+ run .predict (list (prediction_chunk ), update = True , asynchronous = True )
211
+ )
173
212
174
- try :
175
- job .sleep_until_complete (False )
176
- except JobError :
177
- print (job .errors ())
213
+ for job in jobs :
214
+ try :
215
+ job .sleep_until_complete (False )
216
+ except JobError :
217
+ print (job .errors ())
178
218
toc = time .time ()
179
219
print ("Finished prediction upload: %s" % (toc - tic ))
220
+ TIMINGS [f"Prediction Upload { num_predictions } " ] = toc - tic
180
221
181
222
182
223
def main (unused_argv ):
@@ -194,6 +235,8 @@ def main(unused_argv):
194
235
if FLAGS .cleanup_dataset and FLAGS .create_or_reuse_dataset == "create" :
195
236
client ().delete_dataset (dataset .id )
196
237
238
+ print (TIMINGS )
239
+
197
240
198
241
if __name__ == "__main__" :
199
242
app .run (main )
0 commit comments