1
1
import requests
2
2
3
- from . import tasks
3
+ from .tasks import Task
4
4
5
5
DEFAULT_FIELDS = {'callback_url' , 'instruction' , 'urgency' }
6
+ ALLOWED_FIELDS = {'categorization' : {'attachment' , 'attachment_type' , 'categories' ,
7
+ 'category_ids' , 'allow_multiple' },
8
+ 'transcription' : {'attachment' , 'attachment_type' ,
9
+ 'fields' , 'row_fields' },
10
+ 'phonecall' : {'attachment' , 'attachment_type' , 'phone_number' ,
11
+ 'script' , 'entity_name' , 'fields' , 'choices' },
12
+ 'comparison' : {'attachments' , 'attachment_type' ,
13
+ 'fields' , 'choices' },
14
+ 'annotation' : {'attachment' , 'attachment_type' , 'instruction' ,
15
+ 'objects_to_annotate' , 'with_labels' }}
16
+
17
+ def validate_payload (task_type , kwargs ):
18
+ allowed_fields = DEFAULT_FIELDS + ALLOWED_FIELDS [task_type ]
19
+ for k in kwargs :
20
+ if k not in allowed_fields :
21
+ raise ScaleException ('Illegal parameter %s for task_type %s'
22
+ % (k , task_type ))
23
+
24
+ class ScaleException (Exception ):
25
+ pass
26
+
27
+ class ScaleInvalidRequest (ScaleException ):
28
+ pass
6
29
7
30
class ScaleClient (object ):
8
31
def __init__ (self , api_key , callback_key = None ,
@@ -11,26 +34,80 @@ def __init__(self, api_key, callback_key=None,
11
34
self .callback_key = callback_key
12
35
self .endpoint = endpoint
13
36
14
- def create_comparison_task (** kwargs ):
15
- payload = generic_payload (kwargs )
16
- allowed_fields = DEFAULT_FIELDS + \
17
- {'attachment' , 'attachment_type' , 'categories' , 'category_ids' , 'allow_multiple' }
18
- if field in kwargs :
19
- payload [field ] = kwargs [field ]
20
- payload = {
21
- 'callback_url' : callback_url ,
22
- 'instruction' : instruction ,
23
- }
24
- return self ._dotask (tasks .ComparisonTask (* args , ** kwargs ))
25
-
26
- def create_transcription_task (* args , ** kwargs ):
27
- return self ._dotask (tasks .TranscriptionTask (* args , ** kwargs ))
28
-
29
- def create_phonecall_task (* args , ** kwargs ):
30
- return self ._dotask (tasks .PhonecallTask (* args , ** kwargs ))
31
-
32
- def create_comparison_task (* args , ** kwargs ):
33
- return self ._dotask (tasks .TranscriptionTask (* args , ** kwargs ))
34
-
35
- def create_annotation_task (* args , ** kwargs ):
36
- return self ._dotask (tasks .AnnotationTask (* args , ** kwargs ))
37
+ def _getrequest (self , endpoint ):
38
+ """Makes a get request to an endpoint.
39
+
40
+ If an error occurs, assumes that endpoint returns JSON as:
41
+ { 'status_code': XXX,
42
+ 'error': 'I failed' }
43
+ """
44
+ r = requests .get (self .endpoint + endpoint ,
45
+ headers = {"Content-Type" : "application/json" },
46
+ auth = (self .api_key , '' ))
47
+
48
+ if r .status_code == 200 :
49
+ return r .json ()
50
+ raise ScaleException (r .json ()['error' ])
51
+
52
+ def _postrequest (self , endpoint , payload = None ):
53
+ """Makes a post request to an endpoint.
54
+
55
+ If an error occurs, assumes that endpoint returns JSON as:
56
+ { 'status_code': XXX,
57
+ 'error': 'I failed' }
58
+ """
59
+ payload = payload or {}
60
+ r = requests .post (self .endpoint + endpoint , json = payload ,
61
+ headers = {"Content-Type" : "application/json" },
62
+ auth = (self .api_key , '' ))
63
+
64
+ if r .status_code == 200 :
65
+ return r .json ()
66
+ if r .status_code == 401 :
67
+ raise ScaleException (r .json ()['error' ])
68
+ if r .status_code == 400 :
69
+ raise ScaleInvalidRequest (r .json ()['error' ])
70
+ raise ScaleException (r .json ()['error' ])
71
+
72
+ def fetch_task (self , task_id ):
73
+ """Fetches a task.
74
+
75
+ Returns the associated task.
76
+ """
77
+ return Task (self ._getrequest ('task/%s' % task_id ))
78
+
79
+ def cancel_task (self , task_id ):
80
+ """Cancels a task.
81
+
82
+ Returns the associated task.
83
+ """
84
+ return Task (self ._postrequest ('task/%s/cancel' % task_id ))
85
+
86
+ def tasks (self ):
87
+ """Returns a list of all your tasks."""
88
+ return [Task (json ) for json in self ._getrequest ('tasks' )]
89
+
90
+ def create_categorization_task (self , ** kwargs ):
91
+ validate_payload ('categorization' , kwargs )
92
+ taskdata = self ._postrequest ('task/categorize' , payload = kwargs )
93
+ return Task (taskdata )
94
+
95
+ def create_transcription_task (self , ** kwargs ):
96
+ validate_payload ('transcription' , kwargs )
97
+ taskdata = self ._postrequest ('task/transcription' , payload = kwargs )
98
+ return Task (taskdata )
99
+
100
+ def create_phonecall_task (self , ** kwargs ):
101
+ validate_payload ('phonecall' , kwargs )
102
+ taskdata = self ._postrequest ('task/phonecall' , payload = kwargs )
103
+ return Task (taskdata )
104
+
105
+ def create_comparison_task (self , ** kwargs ):
106
+ validate_payload ('comparison' , kwargs )
107
+ taskdata = self ._postrequest ('task/comparison' , payload = kwargs )
108
+ return Task (taskdata )
109
+
110
+ def create_annotation_task (self , ** kwargs ):
111
+ validate_payload ('annotation' , kwargs )
112
+ taskdata = self ._postrequest ('task/annotation' , payload = kwargs )
113
+ return Task (taskdata )
0 commit comments