2
2
3
3
from .tasks import Task
4
4
5
- DEFAULT_FIELDS = {'callback_url' , 'instruction' , 'urgency' }
5
+ DEFAULT_FIELDS = {'callback_url' , 'instruction' , 'urgency' , 'metadata' }
6
6
ALLOWED_FIELDS = {'categorization' : {'attachment' , 'attachment_type' , 'categories' ,
7
7
'category_ids' , 'allow_multiple' },
8
8
'transcription' : {'attachment' , 'attachment_type' ,
14
14
'annotation' : {'attachment' , 'attachment_type' , 'instruction' ,
15
15
'objects_to_annotate' , 'with_labels' }}
16
16
17
+
17
18
def validate_payload (task_type , kwargs ):
18
19
allowed_fields = DEFAULT_FIELDS + ALLOWED_FIELDS [task_type ]
19
20
for k in kwargs :
20
21
if k not in allowed_fields :
21
- raise ScaleException ('Illegal parameter %s for task_type %s'
22
- % (k , task_type ))
22
+ raise ScaleInvalidRequest ('Illegal parameter %s for task_type %s'
23
+ % (k , task_type ), None )
24
+
23
25
24
26
class ScaleException (Exception ):
25
- pass
27
+ def __init__ (self , message , errcode ):
28
+ super (ScaleException , self ).__init__ (message )
29
+ self .code = errcode
26
30
27
- class ScaleInvalidRequest (ScaleException ):
31
+
32
+ class ScaleInvalidRequest (ScaleException , ValueError ):
28
33
pass
29
34
35
+
30
36
class ScaleClient (object ):
31
- def __init__ (self , api_key , callback_key = None ,
32
- endpoint = 'https://api.scaleapi.com/v1/' ):
37
+ def __init__ (self , api_key , endpoint = 'https://api.scaleapi.com/v1/' ):
33
38
self .api_key = api_key
34
- self .callback_key = callback_key
35
39
self .endpoint = endpoint
36
40
37
41
def _getrequest (self , endpoint ):
@@ -47,7 +51,7 @@ def _getrequest(self, endpoint):
47
51
48
52
if r .status_code == 200 :
49
53
return r .json ()
50
- raise ScaleException (r .json ()['error' ])
54
+ raise ScaleException (r .json ()['error' ], r . status_code )
51
55
52
56
def _postrequest (self , endpoint , payload = None ):
53
57
"""Makes a post request to an endpoint.
@@ -63,51 +67,50 @@ def _postrequest(self, endpoint, payload=None):
63
67
64
68
if r .status_code == 200 :
65
69
return r .json ()
66
- if r .status_code == 401 :
67
- raise ScaleException (r .json ()['error' ])
68
70
if r .status_code == 400 :
69
- raise ScaleInvalidRequest (r .json ()['error' ])
70
- raise ScaleException (r .json ()['error' ])
71
+ raise ScaleInvalidRequest (r .json ()['error' ], r . status_code )
72
+ raise ScaleException (r .json ()['error' ], r . status_code )
71
73
72
74
def fetch_task (self , task_id ):
73
75
"""Fetches a task.
74
76
75
77
Returns the associated task.
76
78
"""
77
- return Task (self ._getrequest ('task/%s' % task_id ))
79
+ return Task (self ._getrequest ('task/%s' % task_id ), self )
78
80
79
81
def cancel_task (self , task_id ):
80
82
"""Cancels a task.
81
83
82
84
Returns the associated task.
85
+ Raises a ScaleException if it has already been canceled.
83
86
"""
84
- return Task (self ._postrequest ('task/%s/cancel' % task_id ))
87
+ return Task (self ._postrequest ('task/%s/cancel' % task_id ), self )
85
88
86
89
def tasks (self ):
87
90
"""Returns a list of all your tasks."""
88
- return [Task (json ) for json in self ._getrequest ('tasks' )]
91
+ return [Task (json , self ) for json in self ._getrequest ('tasks' )]
89
92
90
93
def create_categorization_task (self , ** kwargs ):
91
94
validate_payload ('categorization' , kwargs )
92
95
taskdata = self ._postrequest ('task/categorize' , payload = kwargs )
93
- return Task (taskdata )
96
+ return Task (taskdata , self )
94
97
95
98
def create_transcription_task (self , ** kwargs ):
96
99
validate_payload ('transcription' , kwargs )
97
100
taskdata = self ._postrequest ('task/transcription' , payload = kwargs )
98
- return Task (taskdata )
101
+ return Task (taskdata , self )
99
102
100
103
def create_phonecall_task (self , ** kwargs ):
101
104
validate_payload ('phonecall' , kwargs )
102
105
taskdata = self ._postrequest ('task/phonecall' , payload = kwargs )
103
- return Task (taskdata )
106
+ return Task (taskdata , self )
104
107
105
108
def create_comparison_task (self , ** kwargs ):
106
109
validate_payload ('comparison' , kwargs )
107
110
taskdata = self ._postrequest ('task/comparison' , payload = kwargs )
108
- return Task (taskdata )
111
+ return Task (taskdata , self )
109
112
110
113
def create_annotation_task (self , ** kwargs ):
111
114
validate_payload ('annotation' , kwargs )
112
115
taskdata = self ._postrequest ('task/annotation' , payload = kwargs )
113
- return Task (taskdata )
116
+ return Task (taskdata , self )
0 commit comments