Skip to content

Commit 3c1870c

Browse files
committed
WIP: make server async
1 parent 5884292 commit 3c1870c

File tree

2 files changed

+106
-55
lines changed

2 files changed

+106
-55
lines changed

src/jsi/cli.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,10 +260,22 @@ def main(args: list[str] | None = None) -> int:
260260
logger.enable(console=stderr, level=LogLevel.DEBUG)
261261

262262
if config.daemon:
263-
from jsi.server import Server
263+
import asyncio
264+
265+
import daemon
266+
267+
from jsi.server import STDERR_PATH, STDOUT_PATH, Server
268+
269+
async def run_server():
270+
server = Server(config)
271+
await server.start()
272+
273+
stdout_file = open(STDOUT_PATH, "w+") # noqa: SIM115
274+
stderr_file = open(STDERR_PATH, "w+") # noqa: SIM115
275+
276+
with daemon.DaemonContext(stdout=stdout_file, stderr=stderr_file):
277+
asyncio.run(run_server())
264278

265-
server = Server(config)
266-
server.start(detach_process=True)
267279
return 0
268280

269281
with timer("load_config"):

src/jsi/server.py

Lines changed: 91 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import asyncio
2+
import contextlib
13
import os
24
import signal
35
import socket
@@ -19,7 +21,6 @@
1921
set_input_output,
2022
)
2123
from jsi.utils import pid_exists
22-
import contextlib
2324

2425
SERVER_HOME = os.path.expanduser("~/.jsi/daemon")
2526
SOCKET_PATH = os.path.join(SERVER_HOME, "server.sock")
@@ -58,27 +59,26 @@ def result(self) -> str:
5859
class PIDFile:
5960
def __init__(self, path: str):
6061
self.path = path
62+
self.pid = os.getpid()
6163

6264
def __enter__(self):
63-
if os.path.exists(self.path):
64-
print(f"pid file already exists: {self.path}")
65-
65+
try:
6666
with open(self.path) as fd:
67+
print(f"pid file already exists: {self.path}")
6768
other_pid = fd.read()
6869

6970
if pid_exists(int(other_pid)):
7071
print(f"killing existing daemon ({other_pid=})")
7172
os.kill(int(other_pid), signal.SIGKILL)
73+
except FileNotFoundError:
74+
# pid file doesn't exist, we're good to go
75+
pass
7276

73-
# the file may have been removed on termination by another instance
74-
with contextlib.suppress(FileNotFoundError):
75-
os.remove(self.path)
76-
77-
pid = os.getpid()
78-
print(f"creating pid file: {self.path} ({pid=})")
77+
# overwrite the file if it already exists
7978
with open(self.path, "w") as fd:
80-
fd.write(str(pid))
79+
fd.write(str(self.pid))
8180

81+
print(f"created pid file: {self.path} ({self.pid=})")
8282
return self.path
8383

8484
def __exit__(self, exc_type, exc_value, traceback):
@@ -99,7 +99,37 @@ def __init__(self, config: Config):
9999
self.solver_definitions = load_definitions(config)
100100
self.available_solvers = find_available_solvers(self.solver_definitions, config)
101101

102-
def solve(self, file: str) -> str:
102+
async def start(self):
103+
server = await asyncio.start_unix_server(
104+
self.handle_client, path=SOCKET_PATH
105+
)
106+
107+
async with server:
108+
await server.serve_forever()
109+
110+
async def handle_client(
111+
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
112+
):
113+
try:
114+
data: bytes = await reader.read(1024)
115+
if data:
116+
message: str = data.decode()
117+
result = await self.solve(message)
118+
writer.write(result.encode())
119+
await writer.drain()
120+
except Exception as e:
121+
print(f"Error handling client: {e}")
122+
finally:
123+
writer.close()
124+
await writer.wait_closed()
125+
126+
async def solve(self, file: str) -> str:
127+
# Assuming solve is CPU-bound, we use run_in_executor
128+
loop = asyncio.get_running_loop()
129+
result = await loop.run_in_executor(None, self.sync_solve, file)
130+
return result
131+
132+
def sync_solve(self, file: str) -> str:
103133
# initialize the controller
104134
task = Task(name=str(file))
105135

@@ -123,46 +153,55 @@ def solve(self, file: str) -> str:
123153

124154
return listener.result
125155

126-
127-
def start(self, detach_process: bool | None = None):
128-
if not os.path.exists(SERVER_HOME):
129-
print(f"creating server home: {SERVER_HOME}")
130-
os.makedirs(SERVER_HOME)
131-
132-
stdout_file = open(STDOUT_PATH, "w+") # noqa: SIM115
133-
stderr_file = open(STDERR_PATH, "w+") # noqa: SIM115
134-
135-
print(f"daemonizing... (`tail -f {STDOUT_PATH[:-4]}.{{err,out}}` to view logs)")
136-
with daemon.DaemonContext(
137-
stdout=stdout_file,
138-
stderr=stderr_file,
139-
detach_process=detach_process,
140-
pidfile=PIDFile(PID_PATH),
141-
):
142-
if os.path.exists(SOCKET_PATH):
143-
print(f"removing existing socket: {SOCKET_PATH}")
144-
os.remove(SOCKET_PATH)
145-
146-
print(f"binding socket: {SOCKET_PATH}")
147-
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as server:
148-
server.bind(SOCKET_PATH)
149-
server.listen(1)
150-
151-
while True:
152-
try:
153-
conn, _ = server.accept()
154-
with conn:
155-
try:
156-
data = conn.recv(CONN_BUFFER_SIZE).decode()
157-
if not data:
158-
continue
159-
conn.sendall(self.solve(data).encode())
160-
except ConnectionError as e:
161-
print(f"connection error: {e}")
162-
except SystemExit as e:
163-
print(f"system exit: {e}")
164-
return e.code
156+
# def start(self, detach_process: bool | None = None):
157+
# if not os.path.exists(SERVER_HOME):
158+
# print(f"creating server home: {SERVER_HOME}")
159+
# os.makedirs(SERVER_HOME)
160+
161+
# stdout_file = open(STDOUT_PATH, "w+") # noqa: SIM115
162+
# stderr_file = open(STDERR_PATH, "w+") # noqa: SIM115
163+
164+
# print(f"daemonizing... (`tail -f {STDOUT_PATH[:-4]}.{{err,out}}` to view logs)")
165+
# with daemon.DaemonContext(
166+
# stdout=stdout_file,
167+
# stderr=stderr_file,
168+
# detach_process=detach_process,
169+
# pidfile=PIDFile(PID_PATH),
170+
# ):
171+
# if os.path.exists(SOCKET_PATH):
172+
# print(f"removing existing socket: {SOCKET_PATH}")
173+
# os.remove(SOCKET_PATH)
174+
175+
# print(f"binding socket: {SOCKET_PATH}")
176+
# with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as server:
177+
# server.bind(SOCKET_PATH)
178+
# server.listen(1)
179+
180+
# while True:
181+
# try:
182+
# conn, _ = server.accept()
183+
# with conn:
184+
# try:
185+
# data = conn.recv(CONN_BUFFER_SIZE).decode()
186+
# if not data:
187+
# continue
188+
# print(f"solving: {data}")
189+
# conn.sendall(self.solve(data).encode())
190+
# except ConnectionError as e:
191+
# print(f"connection error: {e}")
192+
# except SystemExit as e:
193+
# print(f"system exit: {e}")
194+
# return e.code
165195

166196

167197
if __name__ == "__main__":
168-
Server(Config()).start()
198+
199+
async def run_server():
200+
server = Server(Config())
201+
await server.start()
202+
203+
stdout_file = open(STDOUT_PATH, "w+") # noqa: SIM115
204+
stderr_file = open(STDERR_PATH, "w+") # noqa: SIM115
205+
206+
with daemon.DaemonContext(stdout=stdout_file, stderr=stderr_file):
207+
asyncio.run(run_server())

0 commit comments

Comments
 (0)