Skip to content

Commit e827d13

Browse files
committed
Allow the SSHServer auth_completed method to be a coroutine
This commit allows to auth_completed callback in SSHServer to be either a callable or a coroutine, allowing async operations to be performed from it when auth completes successfully. This callback will run to completion before any sessions are started up, and before an acceptor task is started if one is specified.
1 parent 7dd29f5 commit e827d13

File tree

5 files changed

+30
-22
lines changed

5 files changed

+30
-22
lines changed

asyncssh/auth.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2013-2024 by Ron Frederick <ronf@timeheart.net> and others.
1+
# Copyright (c) 2013-2025 by Ron Frederick <ronf@timeheart.net> and others.
22
#
33
# This program and the accompanying materials are made available under
44
# the terms of the Eclipse Public License v2.0 which accompanies this
@@ -548,10 +548,10 @@ def send_failure(self, partial_success: bool = False) -> None:
548548

549549
self._conn.send_userauth_failure(partial_success)
550550

551-
def send_success(self) -> None:
551+
async def send_success(self) -> None:
552552
"""Send a user authentication success response"""
553553

554-
self._conn.send_userauth_success()
554+
await self._conn.send_userauth_success()
555555

556556

557557
class _ServerNullAuth(ServerAuth):
@@ -596,7 +596,7 @@ async def _start(self, packet: SSHPacket) -> None:
596596
(await self._conn.validate_gss_principal(self._username,
597597
self._gss.user,
598598
self._gss.host))):
599-
self.send_success()
599+
await self.send_success()
600600
else:
601601
self.send_failure()
602602

@@ -650,7 +650,7 @@ async def _finish(self) -> None:
650650
if (await self._conn.validate_gss_principal(self._username,
651651
self._gss.user,
652652
self._gss.host)):
653-
self.send_success()
653+
await self.send_success()
654654
else:
655655
self.send_failure()
656656

@@ -757,7 +757,7 @@ async def _start(self, packet: SSHPacket) -> None:
757757
key_data, client_host,
758758
client_username,
759759
msg, signature)):
760-
self.send_success()
760+
await self.send_success()
761761
else:
762762
self.send_failure()
763763

@@ -795,7 +795,7 @@ async def _start(self, packet: SSHPacket) -> None:
795795
if (await self._conn.validate_public_key(self._username, key_data,
796796
msg, signature)):
797797
if sig_present:
798-
self.send_success()
798+
await self.send_success()
799799
else:
800800
self.send_packet(MSG_USERAUTH_PK_OK, String(algorithm),
801801
String(key_data))
@@ -832,9 +832,9 @@ async def _start(self, packet: SSHPacket) -> None:
832832

833833
challenge = await self._conn.get_kbdint_challenge(self._username,
834834
lang, submethods)
835-
self._send_challenge(challenge)
835+
await self._send_challenge(challenge)
836836

837-
def _send_challenge(self, challenge: KbdIntChallenge) -> None:
837+
async def _send_challenge(self, challenge: KbdIntChallenge) -> None:
838838
"""Send a keyboard interactive authentication request"""
839839

840840
if isinstance(challenge, (tuple, list)):
@@ -848,7 +848,7 @@ def _send_challenge(self, challenge: KbdIntChallenge) -> None:
848848
String(instruction), String(lang),
849849
UInt32(num_prompts), *prompts_bytes)
850850
elif challenge:
851-
self.send_success()
851+
await self.send_success()
852852
else:
853853
self.send_failure()
854854

@@ -857,7 +857,7 @@ async def _validate_response(self, responses: KbdIntResponse) -> None:
857857

858858
next_challenge = \
859859
await self._conn.validate_kbdint_response(self._username, responses)
860-
self._send_challenge(next_challenge)
860+
await self._send_challenge(next_challenge)
861861

862862
def _process_info_response(self, _pkttype: int, _pktid: int,
863863
packet: SSHPacket) -> None:
@@ -922,7 +922,7 @@ async def _start(self, packet: SSHPacket) -> None:
922922
await self._conn.validate_password(self._username, password)
923923

924924
if result:
925-
self.send_success()
925+
await self.send_success()
926926
else:
927927
self.send_failure()
928928
except PasswordChangeRequired as exc:

asyncssh/connection.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2013-2024 by Ron Frederick <ronf@timeheart.net> and others.
1+
# Copyright (c) 2013-2025 by Ron Frederick <ronf@timeheart.net> and others.
22
#
33
# This program and the accompanying materials are made available under
44
# the terms of the Eclipse Public License v2.0 which accompanies this
@@ -2069,7 +2069,7 @@ def send_userauth_failure(self, partial_success: bool) -> None:
20692069
self.send_packet(MSG_USERAUTH_FAILURE, NameList(methods),
20702070
Boolean(partial_success))
20712071

2072-
def send_userauth_success(self) -> None:
2072+
async def send_userauth_success(self) -> None:
20732073
"""Send a user authentication success response"""
20742074

20752075
self.logger.info('Auth for user %s succeeded', self._username)
@@ -2086,7 +2086,10 @@ def send_userauth_success(self) -> None:
20862086
self._set_keepalive_timer()
20872087

20882088
if self._owner: # pragma: no branch
2089-
self._owner.auth_completed()
2089+
result = self._owner.auth_completed()
2090+
2091+
if inspect.isawaitable(result):
2092+
await result
20902093

20912094
if self._acceptor:
20922095
result = self._acceptor(self)
@@ -2506,7 +2509,7 @@ async def _finish_userauth(self, begin_auth: bool, method: bytes,
25062509
result = await cast(Awaitable[bool], result)
25072510

25082511
if not result:
2509-
self.send_userauth_success()
2512+
await self.send_userauth_success()
25102513
return
25112514

25122515
if not self._owner: # pragma: no cover
@@ -4130,7 +4133,6 @@ async def _finish_hostkeys(self, packet: SSHPacket) -> None:
41304133
retained, revoked)
41314134

41324135
if inspect.isawaitable(result):
4133-
assert result is not None
41344136
await result
41354137

41364138
self._report_global_response(True)

asyncssh/server.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2013-2024 by Ron Frederick <ronf@timeheart.net> and others.
1+
# Copyright (c) 2013-2025 by Ron Frederick <ronf@timeheart.net> and others.
22
#
33
# This program and the accompanying materials are made available under
44
# the terms of the Eclipse Public License v2.0 which accompanies this
@@ -157,7 +157,7 @@ def begin_auth(self, username: str) -> MaybeAwait[bool]:
157157

158158
return True # pragma: no cover
159159

160-
def auth_completed(self) -> None:
160+
def auth_completed(self) -> MaybeAwait[None]:
161161
"""Authentication was completed successfully
162162
163163
This method is called when authentication has completed
@@ -167,6 +167,9 @@ def auth_completed(self) -> None:
167167
user before any sessions are opened or forwarding requests
168168
are handled.
169169
170+
If blocking operations need to be performed when authentication
171+
completes, this method may be defined as a coroutine.
172+
170173
"""
171174

172175
def validate_gss_principal(self, username: str, user_principal: str,

tests/test_auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2015-2022 by Ron Frederick <ronf@timeheart.net> and others.
1+
# Copyright (c) 2015-2025 by Ron Frederick <ronf@timeheart.net> and others.
22
#
33
# This program and the accompanying materials are made available under
44
# the terms of the Eclipse Public License v2.0 which accompanies this
@@ -326,7 +326,7 @@ def send_userauth_failure(self, partial_success):
326326
self.send_userauth_packet(MSG_USERAUTH_FAILURE, NameList([]),
327327
Boolean(partial_success))
328328

329-
def send_userauth_success(self):
329+
async def send_userauth_success(self):
330330
"""Send a user authentication success response"""
331331

332332
self._auth = None

tests/test_connection_auth.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2016-2022 by Ron Frederick <ronf@timeheart.net> and others.
1+
# Copyright (c) 2016-2025 by Ron Frederick <ronf@timeheart.net> and others.
22
#
33
# This program and the accompanying materials are made available under
44
# the terms of the Eclipse Public License v2.0 which accompanies this
@@ -75,6 +75,9 @@ async def begin_auth(self, username):
7575

7676
return False
7777

78+
async def auth_completed(self):
79+
"""Handle client authentication request"""
80+
7881

7982
class _HostBasedServer(Server):
8083
"""Server for testing host-based authentication"""

0 commit comments

Comments
 (0)