6
6
import numpy as np
7
7
import ray
8
8
import torch
9
- import yaml
10
9
from fastapi import FastAPI , File , UploadFile
11
10
from fastapi .responses import JSONResponse
12
11
from PIL import Image
13
12
from ray import serve
13
+ import argparse
14
14
15
15
from . import ImagesInput , to_base64_nparray
16
16
from .core import ImageMatchingAPI
17
17
from ..hloc import DEVICE
18
+ from ..hloc .utils .io import read_yaml
18
19
from ..ui import get_version
19
20
20
21
warnings .simplefilter ("ignore" )
21
22
app = FastAPI ()
22
23
if ray .is_initialized ():
23
24
ray .shutdown ()
25
+
26
+
27
+ # read some configs
28
+ parser = argparse .ArgumentParser ()
29
+ parser .add_argument (
30
+ "--config" ,
31
+ type = Path ,
32
+ required = False ,
33
+ default = Path (__file__ ).parent / "config/api.yaml" ,
34
+ )
35
+ args = parser .parse_args ()
36
+ config_path = args .config
37
+ config = read_yaml (config_path )
38
+ num_gpus = 1 if torch .cuda .is_available () else 0
39
+ ray_actor_options = config ["service" ].get ("ray_actor_options" , {})
40
+ ray_actor_options .update ({"num_gpus" : num_gpus })
41
+ dashboard_port = config ["service" ].get ("dashboard_port" , 8265 )
42
+ http_options = config ["service" ].get (
43
+ "http_options" ,
44
+ {
45
+ "host" : "0.0.0.0" ,
46
+ "port" : 8001 ,
47
+ },
48
+ )
49
+ num_replicas = config ["service" ].get ("num_replicas" , 4 )
24
50
ray .init (
25
- dashboard_port = 8265 ,
51
+ dashboard_port = dashboard_port ,
26
52
ignore_reinit_error = True ,
27
53
)
28
- serve .start (
29
- http_options = {"host" : "0.0.0.0" , "port" : 8001 },
30
- )
31
-
32
- num_gpus = 1 if torch .cuda .is_available () else 0
54
+ serve .start (http_options = http_options )
33
55
34
56
35
57
@serve .deployment (
36
- num_replicas = 4 , ray_actor_options = {"num_cpus" : 2 , "num_gpus" : num_gpus }
58
+ num_replicas = num_replicas ,
59
+ ray_actor_options = ray_actor_options ,
37
60
)
38
61
@serve .ingress (app )
39
62
class ImageMatchingService :
40
- def __init__ (self , conf : dict , device : str ):
63
+ def __init__ (self , conf : dict , device : str , ** kwargs ):
41
64
self .conf = conf
42
65
self .api = ImageMatchingAPI (conf = conf , device = device )
43
66
@@ -137,7 +160,7 @@ def load_image(self, file_path: Union[str, UploadFile]) -> np.ndarray:
137
160
image_array = np .array (img )
138
161
return image_array
139
162
140
- def postprocess (self , output : dict , skip_keys : list , binarize : bool = True ) -> dict :
163
+ def postprocess (self , output : dict , skip_keys : list , ** kwargs ) -> dict :
141
164
pred = {}
142
165
for key , value in output .items ():
143
166
if key in skip_keys :
@@ -152,19 +175,12 @@ def run(self, host: str = "0.0.0.0", port: int = 8001):
152
175
uvicorn .run (app , host = host , port = port )
153
176
154
177
155
- def read_config (config_path : Path ) -> dict :
156
- with open (config_path , "r" ) as f :
157
- conf = yaml .safe_load (f )
158
- return conf
159
-
160
-
161
- # api server
162
- conf = read_config (Path (__file__ ).parent / "config/api.yaml" )
163
- service = ImageMatchingService .bind (conf = conf ["api" ], device = DEVICE )
164
- handle = serve .run (service , route_prefix = "/" )
178
+ if __name__ == "__main__" :
179
+ # api server
180
+ service = ImageMatchingService .bind (conf = config ["api" ], device = DEVICE )
181
+ handle = serve .run (service , route_prefix = "/" , blocking = False )
165
182
166
183
# serve run api.server_ray:service
167
-
168
184
# build to generate config file
169
185
# serve build api.server_ray:service -o api/config/ray.yaml
170
186
# serve run api/config/ray.yaml
0 commit comments