Skip to content

Commit a3fe401

Browse files
authored
Merge pull request #19 from scaleapi/felix/batch
Add support for creating and finalizing task batch
2 parents f0ccf03 + 42ef954 commit a3fe401

File tree

2 files changed

+68
-8
lines changed

2 files changed

+68
-8
lines changed

scaleapi/__init__.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import requests
33

44
from .tasks import Task
5+
from .batches import Batch
56

67
TASK_TYPES = [
78
'annotation',
@@ -10,15 +11,15 @@
1011
'comparison',
1112
'cuboidannotation',
1213
'datacollection',
13-
'imageannotation',
14+
'imageannotation',
1415
'lineannotation',
1516
'namedentityrecognition',
1617
'pointannotation',
1718
'polygonannotation',
1819
'segmentannotation',
1920
'transcription',
20-
'videoannotation',
21-
'videoboxannotation',
21+
'videoannotation',
22+
'videoboxannotation',
2223
'videocuboidannotation'
2324
]
2425
SCALE_ENDPOINT = 'https://api.scale.com/v1/'
@@ -35,27 +36,37 @@ class ScaleInvalidRequest(ScaleException, ValueError):
3536
pass
3637

3738

38-
class Tasklist(list):
39+
class Paginator(list):
3940
def __init__(self, docs, total, limit, offset, has_more, next_token=None):
40-
super(Tasklist, self).__init__(docs)
41+
super(Paginator, self).__init__(docs)
4142
self.docs = docs
4243
self.total = total
4344
self.limit = limit
4445
self.offset = offset
4546
self.has_more = has_more
4647
self.next_token = next_token
4748

49+
50+
class Tasklist(Paginator):
51+
pass
52+
53+
54+
class Batchlist(Paginator):
55+
pass
56+
57+
4858
class ScaleClient(object):
4959
def __init__(self, api_key):
5060
self.api_key = api_key
5161

52-
def _getrequest(self, endpoint, params={}):
62+
def _getrequest(self, endpoint, params=None):
5363
"""Makes a get request to an endpoint.
5464
5565
If an error occurs, assumes that endpoint returns JSON as:
5666
{ 'status_code': XXX,
5767
'error': 'I failed' }
5868
"""
69+
params = params or {}
5970
r = requests.get(SCALE_ENDPOINT + endpoint,
6071
headers={"Content-Type": "application/json"},
6172
auth=(self.api_key, ''), params=params)
@@ -114,7 +125,7 @@ def cancel_task(self, task_id):
114125
def tasks(self, **kwargs):
115126
"""Returns a list of your tasks.
116127
Returns up to 100 at a time, to get more, use the next_token param passed back.
117-
128+
118129
Note that offset is deprecated.
119130
120131
start/end_time are ISO8601 dates, the time range of tasks to fetch.
@@ -125,7 +136,7 @@ def tasks(self, **kwargs):
125136
offset (deprecated) is the number of results to skip (for showing more pages).
126137
"""
127138
allowed_kwargs = {'start_time', 'end_time', 'status', 'type', 'project',
128-
'batch', 'limit', 'offset', 'completed_before', 'completed_after',
139+
'batch', 'limit', 'offset', 'completed_before', 'completed_after',
129140
'next_token'}
130141
for key in kwargs:
131142
if key not in allowed_kwargs:
@@ -140,6 +151,29 @@ def create_task(self, task_type, **kwargs):
140151
taskdata = self._postrequest(endpoint, payload=kwargs)
141152
return Task(taskdata, self)
142153

154+
def create_batch(self, project, batch_name, callback):
155+
payload = dict(project=project, name=batch_name, callback=callback)
156+
batchdata = self._postrequest('batches', payload)
157+
return Batch(batchdata, self)
158+
159+
def get_batch(self, batch_name: str):
160+
batchdata = self._getrequest('batches/%s' % batch_name)
161+
return Batch(batchdata, self)
162+
163+
def list_batches(self, **kwargs):
164+
allowed_kwargs = { 'start_time', 'end_time', 'status', 'project',
165+
'batch', 'limit', 'offset', }
166+
for key in kwargs:
167+
if key not in allowed_kwargs:
168+
raise ScaleInvalidRequest('Illegal parameter %s for ScaleClient.tasks()'
169+
% key, None)
170+
response = self._getrequest('tasks', params=kwargs)
171+
docs = [Batch(doc, self) for doc in response['docs']]
172+
return Batchlist(
173+
docs, response['total'], response['limit'], response['offset'],
174+
response['has_more'], response.get('next_token'),
175+
)
176+
143177

144178
def _AddTaskTypeCreator(task_type):
145179
def create_task_wrapper(self, **kwargs):

scaleapi/batches.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
class Batch(object):
2+
def __init__(self, param_dict, client):
3+
self.param_dict = param_dict
4+
self.name = param_dict['name']
5+
self.pending = None
6+
self.completed = None
7+
self.error = None
8+
self.canceled = None
9+
self.client = client
10+
11+
def __hash__(self):
12+
return hash(self.name)
13+
14+
def __str__(self):
15+
return 'Batch(name=%s)' % self.name
16+
17+
def __repr__(self):
18+
return 'Batch(%s)' % self.param_dict
19+
20+
def finalize(self):
21+
return self.client._postrequest("batches/%s/finalize" % self.name)
22+
23+
def get_status(self):
24+
res = self.client._getrequest("batches/%s/status" % self.name)
25+
for stat in ["pending", "completed", "error", "canceled"]:
26+
setattr(self, stat, res.get(stat, 0))

0 commit comments

Comments
 (0)