Skip to content

Commit cb87de9

Browse files
committed
Use an executor to prevent GSSAPI calls from blocking the event loop
Some operations such as GSSAPI calls can sometimes block the event loop if not run in an executor. However, doing that requires packet handlers to be asynchronous. This commit adds support for async packet handlers for key exchange and auth, and changes the GSSAPI handlers to run the step() call in an executor.
1 parent 358c175 commit cb87de9

File tree

10 files changed

+210
-155
lines changed

10 files changed

+210
-155
lines changed

asyncssh/auth.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2013-2022 by Ron Frederick <ronf@timeheart.net> and others.
1+
# Copyright (c) 2013-2024 by Ron Frederick <ronf@timeheart.net> and others.
22
#
33
# This program and the accompanying materials are made available under
44
# the terms of the Eclipse Public License v2.0 which accompanies this
@@ -27,6 +27,7 @@
2727
from .gss import GSSBase, GSSError
2828
from .logging import SSHLogger
2929
from .misc import ProtocolError, PasswordChangeRequired, get_symbol_names
30+
from .misc import run_in_executor
3031
from .packet import Boolean, String, UInt32, SSHPacket, SSHPacketHandler
3132
from .public_key import SigningKey
3233
from .saslprep import saslprep, SASLPrepError
@@ -199,8 +200,8 @@ def _finish(self) -> None:
199200
else:
200201
self.send_packet(MSG_USERAUTH_GSSAPI_EXCHANGE_COMPLETE)
201202

202-
def _process_response(self, _pkttype: int, _pktid: int,
203-
packet: SSHPacket) -> None:
203+
async def _process_response(self, _pkttype: int, _pktid: int,
204+
packet: SSHPacket) -> None:
204205
"""Process a GSS response from the server"""
205206

206207
mech = packet.get_string()
@@ -212,7 +213,7 @@ def _process_response(self, _pkttype: int, _pktid: int,
212213
raise ProtocolError('Mechanism mismatch')
213214

214215
try:
215-
token = self._gss.step()
216+
token = await run_in_executor(self._gss.step)
216217
assert token is not None
217218

218219
self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token))
@@ -225,8 +226,8 @@ def _process_response(self, _pkttype: int, _pktid: int,
225226

226227
self._conn.try_next_auth()
227228

228-
def _process_token(self, _pkttype: int, _pktid: int,
229-
packet: SSHPacket) -> None:
229+
async def _process_token(self, _pkttype: int, _pktid: int,
230+
packet: SSHPacket) -> None:
230231
"""Process a GSS token from the server"""
231232

232233
token: Optional[bytes] = packet.get_string()
@@ -235,7 +236,7 @@ def _process_token(self, _pkttype: int, _pktid: int,
235236
assert self._gss is not None
236237

237238
try:
238-
token = self._gss.step(token)
239+
token = await run_in_executor(self._gss.step, token)
239240

240241
if token:
241242
self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token))
@@ -261,8 +262,8 @@ def _process_error(self, _pkttype: int, _pktid: int,
261262
self.logger.debug1('GSS error from server: %s', msg)
262263
self._got_error = True
263264

264-
def _process_error_token(self, _pkttype: int, _pktid: int,
265-
packet: SSHPacket) -> None:
265+
async def _process_error_token(self, _pkttype: int, _pktid: int,
266+
packet: SSHPacket) -> None:
266267
"""Process a GSS error token from the server"""
267268

268269
token = packet.get_string()
@@ -271,7 +272,7 @@ def _process_error_token(self, _pkttype: int, _pktid: int,
271272
assert self._gss is not None
272273

273274
try:
274-
self._gss.step(token)
275+
await run_in_executor(self._gss.step, token)
275276
except GSSError as exc:
276277
if not self._got_error: # pragma: no cover
277278
self.logger.debug1('GSS error from server: %s', str(exc))
@@ -649,15 +650,15 @@ async def _finish(self) -> None:
649650
else:
650651
self.send_failure()
651652

652-
def _process_token(self, _pkttype: int, _pktid: int,
653-
packet: SSHPacket) -> None:
653+
async def _process_token(self, _pkttype: int, _pktid: int,
654+
packet: SSHPacket) -> None:
654655
"""Process a GSS token from the client"""
655656

656657
token: Optional[bytes] = packet.get_string()
657658
packet.check_end()
658659

659660
try:
660-
token = self._gss.step(token)
661+
token = await run_in_executor(self._gss.step, token)
661662

662663
if token:
663664
self.send_packet(MSG_USERAUTH_GSSAPI_TOKEN, String(token))
@@ -682,15 +683,15 @@ def _process_exchange_complete(self, _pkttype: int, _pktid: int,
682683
else:
683684
self.send_failure()
684685

685-
def _process_error_token(self, _pkttype: int, _pktid: int,
686-
packet: SSHPacket) -> None:
686+
async def _process_error_token(self, _pkttype: int, _pktid: int,
687+
packet: SSHPacket) -> None:
687688
"""Process a GSS error token from the client"""
688689

689690
token = packet.get_string()
690691
packet.check_end()
691692

692693
try:
693-
self._gss.step(token)
694+
await run_in_executor(self._gss.step, token)
694695
except GSSError as exc:
695696
self.logger.debug1('GSS error from client: %s', str(exc))
696697

asyncssh/connection.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,17 +1326,7 @@ def data_received(self, data: bytes, datatype: DataType = None) -> None:
13261326

13271327
self._inpbuf += data
13281328

1329-
self._reset_keepalive_timer()
1330-
1331-
# pylint: disable=broad-except
1332-
try:
1333-
while self._inpbuf and self._recv_handler():
1334-
pass
1335-
except DisconnectError as exc:
1336-
self._send_disconnect(exc.code, exc.reason, exc.lang)
1337-
self._force_close(exc)
1338-
except Exception:
1339-
self.internal_error()
1329+
self._recv_data()
13401330
# pylint: enable=arguments-differ
13411331

13421332
def eof_received(self) -> None:
@@ -1442,6 +1432,21 @@ def _send_version(self) -> None:
14421432

14431433
self._send(version + b'\r\n')
14441434

1435+
def _recv_data(self) -> None:
1436+
"""Parse received data"""
1437+
1438+
self._reset_keepalive_timer()
1439+
1440+
# pylint: disable=broad-except
1441+
try:
1442+
while self._inpbuf and self._recv_handler():
1443+
pass
1444+
except DisconnectError as exc:
1445+
self._send_disconnect(exc.code, exc.reason, exc.lang)
1446+
self._force_close(exc)
1447+
except Exception:
1448+
self.internal_error()
1449+
14451450
def _recv_version(self) -> bool:
14461451
"""Receive and parse the remote SSH version"""
14471452

@@ -1595,11 +1600,20 @@ def _recv_packet(self) -> bool:
15951600

15961601
if not skip_reason:
15971602
try:
1598-
processed = handler.process_packet(pkttype, seq, packet)
1603+
result = handler.process_packet(pkttype, seq, packet)
15991604
except PacketDecodeError as exc:
16001605
raise ProtocolError(str(exc)) from None
16011606

1602-
if not processed:
1607+
if inspect.isawaitable(result):
1608+
# Buffer received data until current packet is processed
1609+
self._recv_handler = lambda: False
1610+
1611+
task = self.create_task(result)
1612+
task.add_done_callback(functools.partial(
1613+
self._finish_recv_packet, pkttype, seq, is_async=True))
1614+
1615+
return False
1616+
elif not result:
16031617
if self._strict_kex and not self._recv_encryption:
16041618
exc_reason = 'Strict key exchange violation: ' \
16051619
'unexpected packet type %d received' % pkttype
@@ -1611,6 +1625,14 @@ def _recv_packet(self) -> bool:
16111625
if exc_reason:
16121626
raise ProtocolError(exc_reason)
16131627

1628+
self._finish_recv_packet(pkttype, seq)
1629+
return True
1630+
1631+
def _finish_recv_packet(self, pkttype: int, seq: int,
1632+
_task: Optional[asyncio.Task] = None,
1633+
is_async: bool = False) -> None:
1634+
"""Finish processing a packet"""
1635+
16141636
if pkttype > MSG_USERAUTH_LAST:
16151637
self._auth_final = True
16161638

@@ -1625,7 +1647,8 @@ def _recv_packet(self) -> bool:
16251647
else:
16261648
self._recv_seq = (seq + 1) & 0xffffffff
16271649

1628-
return True
1650+
if is_async and self._inpbuf:
1651+
self._recv_data()
16291652

16301653
def send_packet(self, pkttype: int, *args: bytes,
16311654
handler: Optional[SSHPacketLogger] = None) -> None:
@@ -2218,8 +2241,8 @@ def _process_ext_info(self, _pkttype: int, _pktid: int,
22182241
self._server_sig_algs = \
22192242
set(extensions.get(b'server-sig-algs', b'').split(b','))
22202243

2221-
def _process_kexinit(self, _pkttype: int, _pktid: int,
2222-
packet: SSHPacket) -> None:
2244+
async def _process_kexinit(self, _pkttype: int, _pktid: int,
2245+
packet: SSHPacket) -> None:
22232246
"""Process a key exchange request"""
22242247

22252248
if self._kex:
@@ -2323,7 +2346,7 @@ def _process_kexinit(self, _pkttype: int, _pktid: int,
23232346
self.logger.debug1('Beginning key exchange')
23242347
self.logger.debug2(' Key exchange alg: %s', self._kex.algorithm)
23252348

2326-
self._kex.start()
2349+
await self._kex.start()
23272350

23282351
def _process_newkeys(self, _pkttype: int, _pktid: int,
23292352
packet: SSHPacket) -> None:

asyncssh/kex.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2013-2022 by Ron Frederick <ronf@timeheart.net> and others.
1+
# Copyright (c) 2013-2024 by Ron Frederick <ronf@timeheart.net> and others.
22
#
33
# This program and the accompanying materials are made available under
44
# the terms of the Eclipse Public License v2.0 which accompanies this
@@ -58,7 +58,7 @@ def __init__(self, alg: bytes, conn: 'SSHConnection', hash_alg: HashType):
5858
self._hash_alg = hash_alg
5959

6060

61-
def start(self) -> None:
61+
async def start(self) -> None:
6262
"""Start key exchange"""
6363

6464
raise NotImplementedError

asyncssh/kex_dh.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2013-2022 by Ron Frederick <ronf@timeheart.net> and others.
1+
# Copyright (c) 2013-2024 by Ron Frederick <ronf@timeheart.net> and others.
22
#
33
# This program and the accompanying materials are made available under
44
# the terms of the Eclipse Public License v2.0 which accompanies this
@@ -33,7 +33,7 @@
3333
from .gss import GSSError
3434
from .kex import Kex, register_kex_alg, register_gss_kex_alg
3535
from .misc import HashType, KeyExchangeFailed, ProtocolError
36-
from .misc import get_symbol_names
36+
from .misc import get_symbol_names, run_in_executor
3737
from .packet import Boolean, MPInt, String, UInt32, SSHPacket
3838
from .public_key import SigningKey, VerifyingKey
3939

@@ -274,7 +274,7 @@ def _process_reply(self, _pkttype: int, _pktid: int,
274274
host_key = client_conn.validate_server_host_key(host_key_data)
275275
self._verify_reply(host_key, host_key_data, sig)
276276

277-
def start(self) -> None:
277+
async def start(self) -> None:
278278
"""Start DH key exchange"""
279279

280280
if self._conn.is_client():
@@ -384,7 +384,7 @@ def _process_group(self, _pkttype: int, _pktid: int,
384384
self._gex_data += MPInt(p) + MPInt(g)
385385
self._perform_init()
386386

387-
def start(self) -> None:
387+
async def start(self) -> None:
388388
"""Start DH group exchange"""
389389

390390
if self._conn.is_client():
@@ -455,7 +455,7 @@ def _compute_server_shared(self) -> bytes:
455455
except ValueError:
456456
raise ProtocolError('Invalid ECDH client public key') from None
457457

458-
def start(self) -> None:
458+
async def start(self) -> None:
459459
"""Start ECDH key exchange"""
460460

461461
if self._conn.is_client():
@@ -567,11 +567,11 @@ def _send_continue(self) -> None:
567567

568568
self.send_packet(MSG_KEXGSS_CONTINUE, String(self._token))
569569

570-
def _process_token(self, token: Optional[bytes] = None) -> None:
570+
async def _process_token(self, token: Optional[bytes] = None) -> None:
571571
"""Process a GSS token"""
572572

573573
try:
574-
self._token = self._gss.step(token)
574+
self._token = await run_in_executor(self._gss.step, token)
575575
except GSSError as exc:
576576
if self._conn.is_server():
577577
self.send_packet(MSG_KEXGSS_ERROR, UInt32(exc.maj_code),
@@ -583,8 +583,8 @@ def _process_token(self, token: Optional[bytes] = None) -> None:
583583

584584
raise KeyExchangeFailed(str(exc)) from None
585585

586-
def _process_init(self, _pkttype: int, _pktid: int,
587-
packet: SSHPacket) -> None:
586+
async def _process_gss_init(self, _pkttype: int, _pktid: int,
587+
packet: SSHPacket) -> None:
588588
"""Process a GSS init message"""
589589

590590
if self._conn.is_client():
@@ -603,7 +603,7 @@ def _process_init(self, _pkttype: int, _pktid: int,
603603
else:
604604
self._host_key_data = b''
605605

606-
self._process_token(token)
606+
await self._process_token(token)
607607

608608
if self._gss.complete:
609609
self._check_secure()
@@ -612,8 +612,8 @@ def _process_init(self, _pkttype: int, _pktid: int,
612612
else:
613613
self._send_continue()
614614

615-
def _process_continue(self, _pkttype: int, _pktid: int,
616-
packet: SSHPacket) -> None:
615+
async def _process_continue(self, _pkttype: int, _pktid: int,
616+
packet: SSHPacket) -> None:
617617
"""Process a GSS continue message"""
618618

619619
token = packet.get_string()
@@ -622,16 +622,16 @@ def _process_continue(self, _pkttype: int, _pktid: int,
622622
if self._conn.is_client() and self._gss.complete:
623623
raise ProtocolError('Unexpected kexgss continue msg')
624624

625-
self._process_token(token)
625+
await self._process_token(token)
626626

627627
if self._conn.is_server() and self._gss.complete:
628628
self._check_secure()
629629
self._perform_reply(self._gss, self._host_key_data)
630630
else:
631631
self._send_continue()
632632

633-
def _process_complete(self, _pkttype: int, _pktid: int,
634-
packet: SSHPacket) -> None:
633+
async def _process_complete(self, _pkttype: int, _pktid: int,
634+
packet: SSHPacket) -> None:
635635
"""Process a GSS complete message"""
636636

637637
if self._conn.is_server():
@@ -647,7 +647,7 @@ def _process_complete(self, _pkttype: int, _pktid: int,
647647
if self._gss.complete:
648648
raise ProtocolError('Non-empty token after complete')
649649

650-
self._process_token(token)
650+
await self._process_token(token)
651651

652652
if self._token:
653653
raise ProtocolError('Non-empty token after complete')
@@ -682,12 +682,12 @@ def _process_error(self, _pkttype: int, _pktid: int,
682682
self._conn.logger.debug1('GSS error: %s',
683683
msg.decode('utf-8', errors='ignore'))
684684

685-
def start(self) -> None:
685+
async def start(self) -> None:
686686
"""Start GSS key exchange"""
687687

688688
if self._conn.is_client():
689-
self._process_token()
690-
super().start()
689+
await self._process_token()
690+
await super().start()
691691

692692

693693
class _KexGSS(_KexGSSBase, _KexDH):
@@ -696,7 +696,7 @@ class _KexGSS(_KexGSSBase, _KexDH):
696696
_handler_names = get_symbol_names(globals(), 'MSG_KEXGSS_')
697697

698698
_packet_handlers = {
699-
MSG_KEXGSS_INIT: _KexGSSBase._process_init,
699+
MSG_KEXGSS_INIT: _KexGSSBase._process_gss_init,
700700
MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue,
701701
MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete,
702702
MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey,
@@ -713,7 +713,7 @@ class _KexGSSGex(_KexGSSBase, _KexDHGex):
713713
_group_type = MSG_KEXGSS_GROUP
714714

715715
_packet_handlers = {
716-
MSG_KEXGSS_INIT: _KexGSSBase._process_init,
716+
MSG_KEXGSS_INIT: _KexGSSBase._process_gss_init,
717717
MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue,
718718
MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete,
719719
MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey,
@@ -729,7 +729,7 @@ class _KexGSSECDH(_KexGSSBase, _KexECDH):
729729
_handler_names = get_symbol_names(globals(), 'MSG_KEXGSS_')
730730

731731
_packet_handlers = {
732-
MSG_KEXGSS_INIT: _KexGSSBase._process_init,
732+
MSG_KEXGSS_INIT: _KexGSSBase._process_gss_init,
733733
MSG_KEXGSS_CONTINUE: _KexGSSBase._process_continue,
734734
MSG_KEXGSS_COMPLETE: _KexGSSBase._process_complete,
735735
MSG_KEXGSS_HOSTKEY: _KexGSSBase._process_hostkey,

0 commit comments

Comments
 (0)