6
6
from app .plugins .worker import run_function_async
7
7
from django .utils .translation import gettext_lazy as _
8
8
9
- def detect (orthophoto , model , progress_callback = None ):
9
+ def detect (orthophoto , model , classes = None , progress_callback = None ):
10
10
import os
11
11
from webodm import settings
12
12
@@ -17,7 +17,7 @@ def detect(orthophoto, model, progress_callback=None):
17
17
return {'error' : "GeoDeep library is missing" }
18
18
19
19
try :
20
- return {'output' : gdetect (orthophoto , model , output_type = 'geojson' , max_threads = settings .WORKERS_MAX_THREADS , progress_callback = progress_callback )}
20
+ return {'output' : gdetect (orthophoto , model , output_type = 'geojson' , classes = classes , max_threads = settings .WORKERS_MAX_THREADS , progress_callback = progress_callback )}
21
21
except Exception as e :
22
22
return {'error' : str (e )}
23
23
@@ -31,10 +31,20 @@ def post(self, request, pk=None):
31
31
orthophoto = os .path .abspath (task .get_asset_download_path ("orthophoto.tif" ))
32
32
model = request .data .get ('model' , 'cars' )
33
33
34
- if not model in ['cars' , 'trees' ]:
34
+ # model --> (modelID, classes)
35
+ model_map = {
36
+ 'cars' : ('cars' , None ),
37
+ 'trees' : ('trees' , None ),
38
+ 'athletic' : ('aerovision' , ['tennis-court' , 'track-field' , 'soccer-field' , 'baseball-field' , 'swimming-pool' , 'basketball-court' ]),
39
+ 'boats' : ('aerovision' , ['boat' ]),
40
+ 'planes' : ('aerovision' , ['plane' ]),
41
+ }
42
+
43
+ if not model in model_map :
35
44
return Response ({'error' : 'Invalid model' }, status = status .HTTP_200_OK )
36
45
37
- celery_task_id = run_function_async (detect , orthophoto , model , with_progress = True ).task_id
46
+ model_id , classes = model_map [model ]
47
+ celery_task_id = run_function_async (detect , orthophoto , model_id , classes , with_progress = True ).task_id
38
48
39
49
return Response ({'celery_task_id' : celery_task_id }, status = status .HTTP_200_OK )
40
50
0 commit comments