8
8
import gc
9
9
import torch
10
10
import os
11
- from fastapi import FastAPI , Request
11
+ from fastapi import FastAPI , Request , HTTPException , Depends , BackgroundTasks
12
12
from fastapi .middleware .cors import CORSMiddleware
13
13
from contextlib import contextmanager
14
14
from colorama import Fore , Style
15
+ from typing import Optional , Dict , Any , List
15
16
16
17
# Try to import FastAPICache, but don't fail if not available
17
18
try :
@@ -29,7 +30,7 @@ def init(backend, **kwargs):
29
30
30
31
from .. import __version__
31
32
from ..logger import get_logger
32
- from ..logger .logger import log_request , log_model_loaded , log_model_unloaded , get_request_count
33
+ from ..logger .logger import log_request , log_model_loaded , log_model_unloaded , get_request_count , set_server_status
33
34
from ..model_manager import ModelManager
34
35
from ..config import (
35
36
ENABLE_CORS ,
@@ -38,8 +39,11 @@ def init(backend, **kwargs):
38
39
ENABLE_COMPRESSION ,
39
40
QUANTIZATION_TYPE ,
40
41
SERVER_PORT ,
42
+ DEFAULT_MAX_LENGTH ,
43
+ get_env_var ,
41
44
)
42
45
from ..cli .config import get_config_value
46
+ from ..utils .system import get_system_resources
43
47
44
48
# Get the logger
45
49
logger = get_logger ("locallab.app" )
@@ -77,10 +81,29 @@ def init(backend, **kwargs):
77
81
app .include_router (generate_router )
78
82
app .include_router (system_router )
79
83
84
+ # Startup event triggered flag
85
+ startup_event_triggered = False
80
86
87
+ # Application startup event to ensure banners are displayed
81
88
@app .on_event ("startup" )
82
89
async def startup_event ():
83
- """Initialization tasks when the server starts"""
90
+ """Event that is triggered when the application starts up"""
91
+ global startup_event_triggered
92
+
93
+ # Only log once
94
+ if startup_event_triggered :
95
+ return
96
+
97
+ logger .info ("FastAPI application startup event triggered" )
98
+ startup_event_triggered = True
99
+
100
+ # Wait a short time to ensure logs are processed
101
+ await asyncio .sleep (0.5 )
102
+
103
+ # Log a special message that our callback handler will detect
104
+ root_logger = logging .getLogger ()
105
+ root_logger .info ("Application startup complete - banner display trigger" )
106
+
84
107
logger .info (f"{ Fore .CYAN } Starting LocalLab server...{ Style .RESET_ALL } " )
85
108
86
109
# Get HuggingFace token and set it in environment if available
@@ -158,7 +181,8 @@ async def shutdown_event():
158
181
model_manager .current_model = None
159
182
160
183
# Clean up memory
161
- torch .cuda .empty_cache ()
184
+ if torch .cuda .is_available ():
185
+ torch .cuda .empty_cache ()
162
186
gc .collect ()
163
187
164
188
# Log model unloading
@@ -169,7 +193,37 @@ async def shutdown_event():
169
193
except Exception as e :
170
194
logger .error (f"Error during shutdown cleanup: { str (e )} " )
171
195
196
+ # Clean up any pending tasks
197
+ try :
198
+ tasks = [t for t in asyncio .all_tasks ()
199
+ if t is not asyncio .current_task () and not t .done ()]
200
+ if tasks :
201
+ logger .debug (f"Cancelling { len (tasks )} remaining tasks" )
202
+ for task in tasks :
203
+ task .cancel ()
204
+ await asyncio .gather (* tasks , return_exceptions = True )
205
+ except Exception as e :
206
+ logger .warning (f"Error cleaning up tasks: { str (e )} " )
207
+
208
+ # Set server status to stopped
209
+ set_server_status ("stopped" )
210
+
172
211
logger .info (f"{ Fore .GREEN } Server shutdown complete{ Style .RESET_ALL } " )
212
+
213
+ # Force exit if needed to clean up any hanging resources
214
+ import threading
215
+ def force_exit ():
216
+ import time
217
+ import os
218
+ import signal
219
+ time .sleep (3 ) # Give a little time for clean shutdown
220
+ logger .info ("Forcing exit after shutdown to ensure clean termination" )
221
+ try :
222
+ os .kill (os .getpid (), signal .SIGTERM )
223
+ except :
224
+ os ._exit (0 )
225
+
226
+ threading .Thread (target = force_exit , daemon = True ).start ()
173
227
174
228
175
229
@app .middleware ("http" )
0 commit comments