Skip to content

Commit 7299216

Browse files
committed
Immediately dispatch line/termination/finish (fixes #1049, fixes #1071)
Avoids races between queued up lines and command finish callbacks.
1 parent 71e7c31 commit 7299216

File tree

3 files changed

+47
-13
lines changed

3 files changed

+47
-13
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ New features:
3030

3131
Bugfixes:
3232

33+
* Fix unsolicited engine output may cause assertion errors with regard to
34+
command states.
3335
* Fix handling of whitespace in UCI engine communication.
3436
* For ``chess.Board.epd()`` and ``chess.Board.set_epd()``, require that EPD
3537
opcodes start with a letter.

chess/engine.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,7 @@ def write(self, data: bytes) -> None:
883883
expectation, responses = self.expectations.popleft()
884884
assert expectation == line, f"expected {expectation}, got: {line}"
885885
if responses:
886-
self.protocol.pipe_data_received(1, "\n".join(responses + [""]).encode("utf-8"))
886+
self.protocol.loop.call_soon(self.protocol.pipe_data_received, 1, "\n".join(responses + [""]).encode("utf-8"))
887887

888888
def get_pid(self) -> int:
889889
return id(self)
@@ -934,12 +934,12 @@ def connection_lost(self, exc: Optional[Exception]) -> None:
934934
LOGGER.debug("%s: Connection lost (exit code: %d, error: %s)", self, code, exc)
935935

936936
# Terminate commands.
937-
if self.command is not None:
938-
self.command._engine_terminated(code)
939-
self.command = None
940-
if self.next_command is not None:
941-
self.next_command._engine_terminated(code)
942-
self.next_command = None
937+
command, self.command = self.command, None
938+
next_command, self.next_command = self.next_command, None
939+
if command:
940+
command._engine_terminated(code)
941+
if next_command:
942+
next_command._engine_terminated(code)
943943

944944
self.returncode.set_result(code)
945945

@@ -965,9 +965,9 @@ def pipe_data_received(self, fd: int, data: Union[bytes, str]) -> None:
965965
LOGGER.warning("%s: >> %r (%s)", self, bytes(line_bytes), err)
966966
else:
967967
if fd == 1:
968-
self.loop.call_soon(self._line_received, line)
968+
self._line_received(line)
969969
else:
970-
self.loop.call_soon(self.error_line_received, line)
970+
self.error_line_received(line)
971971

972972
def error_line_received(self, line: str) -> None:
973973
LOGGER.warning("%s: stderr >> %s", self, line)
@@ -998,7 +998,7 @@ async def communicate(self, command_factory: Callable[[Self], BaseCommand[T]]) -
998998

999999
self.next_command = command
10001000

1001-
def previous_command_finished(_: Optional[asyncio.Future[None]]) -> None:
1001+
def previous_command_finished() -> None:
10021002
self.command, self.next_command = self.next_command, None
10031003
if self.command is not None:
10041004
cmd = self.command
@@ -1008,11 +1008,11 @@ def cancel_if_cancelled(result: asyncio.Future[T]) -> None:
10081008
cmd._cancel()
10091009

10101010
cmd.result.add_done_callback(cancel_if_cancelled)
1011-
cmd.finished.add_done_callback(previous_command_finished)
10121011
cmd._start()
1012+
cmd.add_finished_callback(previous_command_finished)
10131013

10141014
if self.command is None:
1015-
previous_command_finished(None)
1015+
previous_command_finished()
10161016
elif not self.command.result.done():
10171017
self.command.result.cancel()
10181018
elif not self.command.result.cancelled():
@@ -1228,13 +1228,25 @@ def __init__(self, engine: Protocol) -> None:
12281228
self.result: asyncio.Future[T] = asyncio.Future()
12291229
self.finished: asyncio.Future[None] = asyncio.Future()
12301230

1231+
self._finished_callbacks: List[Callable[[], None]] = []
1232+
1233+
def add_finished_callback(self, callback: Callable[[], None]) -> None:
1234+
self._finished_callbacks.append(callback)
1235+
self._dispatch_finished()
1236+
1237+
def _dispatch_finished(self) -> None:
1238+
if self.finished.done():
1239+
while self._finished_callbacks:
1240+
self._finished_callbacks.pop()()
1241+
12311242
def _engine_terminated(self, code: int) -> None:
12321243
hint = ", binary not compatible with cpu?" if code in [-4, 0xc000001d] else ""
12331244
exc = EngineTerminatedError(f"engine process died unexpectedly (exit code: {code}{hint})")
12341245
if self.state == CommandState.ACTIVE:
12351246
self.engine_terminated(exc)
12361247
elif self.state == CommandState.CANCELLING:
12371248
self.finished.set_result(None)
1249+
self._dispatch_finished()
12381250
elif self.state == CommandState.NEW:
12391251
self._handle_exception(exc)
12401252

@@ -1251,13 +1263,15 @@ def _handle_exception(self, exc: Exception) -> None:
12511263

12521264
if not self.finished.done():
12531265
self.finished.set_result(None)
1266+
self._dispatch_finished()
12541267

12551268
def set_finished(self) -> None:
12561269
assert self.state in [CommandState.ACTIVE, CommandState.CANCELLING], self.state
12571270
if not self.result.done():
12581271
self.result.set_exception(EngineError(f"engine command finished before returning result: {self!r}"))
1259-
self.finished.set_result(None)
12601272
self.state = CommandState.DONE
1273+
self.finished.set_result(None)
1274+
self._dispatch_finished()
12611275

12621276
def _cancel(self) -> None:
12631277
if self.state != CommandState.CANCELLING and self.state != CommandState.DONE:

test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3527,6 +3527,24 @@ async def main():
35273527

35283528
asyncio.run(main())
35293529

3530+
def test_uci_output_after_command(self):
3531+
async def main():
3532+
protocol = chess.engine.UciProtocol()
3533+
mock = chess.engine.MockTransport(protocol)
3534+
3535+
mock.expect("uci", [
3536+
"Arasan v24.0.0-10-g367aa9f Copyright 1994-2023 by Jon Dart.",
3537+
"All rights reserved.",
3538+
"id name Arasan v24.0.0-10-g367aa9f",
3539+
"uciok",
3540+
"info string out of do_all_pending, list size=0"
3541+
])
3542+
await protocol.initialize()
3543+
3544+
mock.assert_done()
3545+
3546+
asyncio.run(main())
3547+
35303548
def test_hiarcs_bestmove(self):
35313549
async def main():
35323550
protocol = chess.engine.UciProtocol()

0 commit comments

Comments
 (0)