Skip to content

Commit e500630

Browse files
committed
fill in task stuff
1 parent 3b9f501 commit e500630

File tree

3 files changed

+40
-27
lines changed

3 files changed

+40
-27
lines changed

README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ client = scale.ScaleClient('YOUR_API_KEY_HERE', callback_key='OPTIONAL_CALLBACK_
1313

1414
### Tasks
1515

16-
Most of these methods will return a `Scale::Resources::Task` object, which will contain information
16+
Most of these methods will return a `scale.Task` object, which will contain information
1717
about the json response (task_id, status...).
1818

1919
Any parameter available in the [documentation](https://docs.scaleapi.com) can be passed as an argument
@@ -147,6 +147,5 @@ The api initialization accepts the following options:
147147
| Name | Description | Default |
148148
| ---- | ----------- | ------- |
149149
| `endpoint` | Endpoint used in the http requests. | `'https://api.scaleapi.com/v1/'` |
150-
| `api_key` | API key used in the http requests. | `nil` |
151-
| `callback_key` | API key used to validate callback POST requests. | `nil` |
150+
| `api_key` | API key used in the http requests. | required |
152151

scale/__init__.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from .tasks import Task
44

5-
DEFAULT_FIELDS = {'callback_url', 'instruction', 'urgency'}
5+
DEFAULT_FIELDS = {'callback_url', 'instruction', 'urgency', 'metadata'}
66
ALLOWED_FIELDS = {'categorization': {'attachment', 'attachment_type', 'categories',
77
'category_ids', 'allow_multiple'},
88
'transcription': {'attachment', 'attachment_type',
@@ -14,24 +14,28 @@
1414
'annotation': {'attachment', 'attachment_type', 'instruction',
1515
'objects_to_annotate', 'with_labels'}}
1616

17+
1718
def validate_payload(task_type, kwargs):
1819
allowed_fields = DEFAULT_FIELDS + ALLOWED_FIELDS[task_type]
1920
for k in kwargs:
2021
if k not in allowed_fields:
21-
raise ScaleException('Illegal parameter %s for task_type %s'
22-
% (k, task_type))
22+
raise ScaleInvalidRequest('Illegal parameter %s for task_type %s'
23+
% (k, task_type), None)
24+
2325

2426
class ScaleException(Exception):
25-
pass
27+
def __init__(self, message, errcode):
28+
super(ScaleException, self).__init__(message)
29+
self.code = errcode
2630

27-
class ScaleInvalidRequest(ScaleException):
31+
32+
class ScaleInvalidRequest(ScaleException, ValueError):
2833
pass
2934

35+
3036
class ScaleClient(object):
31-
def __init__(self, api_key, callback_key=None,
32-
endpoint='https://api.scaleapi.com/v1/'):
37+
def __init__(self, api_key, endpoint='https://api.scaleapi.com/v1/'):
3338
self.api_key = api_key
34-
self.callback_key = callback_key
3539
self.endpoint = endpoint
3640

3741
def _getrequest(self, endpoint):
@@ -47,7 +51,7 @@ def _getrequest(self, endpoint):
4751

4852
if r.status_code == 200:
4953
return r.json()
50-
raise ScaleException(r.json()['error'])
54+
raise ScaleException(r.json()['error'], r.status_code)
5155

5256
def _postrequest(self, endpoint, payload=None):
5357
"""Makes a post request to an endpoint.
@@ -63,51 +67,50 @@ def _postrequest(self, endpoint, payload=None):
6367

6468
if r.status_code == 200:
6569
return r.json()
66-
if r.status_code == 401:
67-
raise ScaleException(r.json()['error'])
6870
if r.status_code == 400:
69-
raise ScaleInvalidRequest(r.json()['error'])
70-
raise ScaleException(r.json()['error'])
71+
raise ScaleInvalidRequest(r.json()['error'], r.status_code)
72+
raise ScaleException(r.json()['error'], r.status_code)
7173

7274
def fetch_task(self, task_id):
7375
"""Fetches a task.
7476
7577
Returns the associated task.
7678
"""
77-
return Task(self._getrequest('task/%s' % task_id))
79+
return Task(self._getrequest('task/%s' % task_id), self)
7880

7981
def cancel_task(self, task_id):
8082
"""Cancels a task.
8183
8284
Returns the associated task.
85+
Raises a ScaleException if it has already been canceled.
8386
"""
84-
return Task(self._postrequest('task/%s/cancel' % task_id))
87+
return Task(self._postrequest('task/%s/cancel' % task_id), self)
8588

8689
def tasks(self):
8790
"""Returns a list of all your tasks."""
88-
return [Task(json) for json in self._getrequest('tasks')]
91+
return [Task(json, self) for json in self._getrequest('tasks')]
8992

9093
def create_categorization_task(self, **kwargs):
9194
validate_payload('categorization', kwargs)
9295
taskdata = self._postrequest('task/categorize', payload=kwargs)
93-
return Task(taskdata)
96+
return Task(taskdata, self)
9497

9598
def create_transcription_task(self, **kwargs):
9699
validate_payload('transcription', kwargs)
97100
taskdata = self._postrequest('task/transcription', payload=kwargs)
98-
return Task(taskdata)
101+
return Task(taskdata, self)
99102

100103
def create_phonecall_task(self, **kwargs):
101104
validate_payload('phonecall', kwargs)
102105
taskdata = self._postrequest('task/phonecall', payload=kwargs)
103-
return Task(taskdata)
106+
return Task(taskdata, self)
104107

105108
def create_comparison_task(self, **kwargs):
106109
validate_payload('comparison', kwargs)
107110
taskdata = self._postrequest('task/comparison', payload=kwargs)
108-
return Task(taskdata)
111+
return Task(taskdata, self)
109112

110113
def create_annotation_task(self, **kwargs):
111114
validate_payload('annotation', kwargs)
112115
taskdata = self._postrequest('task/annotation', payload=kwargs)
113-
return Task(taskdata)
116+
return Task(taskdata, self)

scale/tasks.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,15 @@
11
class Task(object):
2-
def __init__(self, param_dict):
3-
# TODO(calvin)
4-
pass
2+
"""Task class, containing task information."""
3+
def __init__(self, param_dict, client):
4+
self.client = client
5+
self.param_dict = param_dict
6+
self.id = param_dict['task_id']
7+
8+
def __getattr__(self, name):
9+
if name in self.param_dict:
10+
return self.param_dict[name]
11+
raise AttributeError("'%s' object has no attribute %s"
12+
% (type(self).__name__, name))
13+
14+
def cancel(self):
15+
self.client.cancel_task(self.id)

0 commit comments

Comments
 (0)