4
4
import os
5
5
import time
6
6
from typing import Any , Dict , List , BinaryIO
7
+ from tqdm import tqdm # type: ignore
7
8
8
9
import backoff
9
10
import ndjson
@@ -25,6 +26,7 @@ class AnnotationImport(DbObject):
25
26
input_file_url = Field .String ("input_file_url" )
26
27
error_file_url = Field .String ("error_file_url" )
27
28
status_file_url = Field .String ("status_file_url" )
29
+ progress = Field .String ("progress" )
28
30
29
31
created_by = Relationship .ToOne ("User" , False , "created_by" )
30
32
@@ -76,18 +78,28 @@ def statuses(self) -> List[Dict[str, Any]]:
76
78
self .wait_until_done ()
77
79
return self ._fetch_remote_ndjson (self .status_file_url )
78
80
79
- def wait_until_done (self , sleep_time_seconds : int = 10 ) -> None :
81
+ def wait_until_done (self ,
82
+ sleep_time_seconds : int = 10 ,
83
+ show_progress : bool = False ) -> None :
80
84
"""Blocks import job until certain conditions are met.
81
85
Blocks until the AnnotationImport.state changes either to
82
86
`AnnotationImportState.FINISHED` or `AnnotationImportState.FAILED`,
83
87
periodically refreshing object's state.
84
88
Args:
85
- sleep_time_seconds (str): a time to block between subsequent API calls
89
+ sleep_time_seconds (int): a time to block between subsequent API calls
90
+ show_progress (bool): should show progress bar
86
91
"""
92
+ pbar = tqdm (total = 100 ) if show_progress else None
87
93
while self .state .value == AnnotationImportState .RUNNING .value :
88
94
logger .info (f"Sleeping for { sleep_time_seconds } seconds..." )
89
95
time .sleep (sleep_time_seconds )
90
96
self .__backoff_refresh ()
97
+ if self .progress and pbar :
98
+ pbar .update (self .progress )
99
+
100
+ if pbar :
101
+ pbar .update (100 )
102
+ pbar .close ()
91
103
92
104
@backoff .on_exception (
93
105
backoff .expo ,
0 commit comments