Skip to content

Commit 8518cf0

Browse files
authored
asyncnet ssl overhaul (#24896)
Fixes #24895 - Remove all bio handling - Remove all `sendPendingSslData` which only seems to make things work by chance - Wrap the client socket on `acceptAddr` (std/net does this) - Do the SSL handshake on accept (std/net does this) The only concern is if addWrite/addRead works well on Windows.
1 parent d7b1f0a commit 8518cf0

File tree

2 files changed

+165
-96
lines changed

2 files changed

+165
-96
lines changed

lib/pure/asyncnet.nim

Lines changed: 86 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,6 @@ type
126126
when defineSsl:
127127
sslHandle: SslPtr
128128
sslContext: SslContext
129-
bioIn: BIO
130-
bioOut: BIO
131129
sslNoShutdown: bool
132130
domain: Domain
133131
sockType: SockType
@@ -210,7 +208,7 @@ when defineSsl:
210208
proc raiseSslHandleError =
211209
raiseSSLError("The SSL Handle is closed/unset")
212210

213-
proc getSslError(socket: AsyncSocket, err: cint): cint =
211+
proc getSslError(socket: AsyncSocket, flags: set[SocketFlag], err: cint): cint =
214212
assert socket.isSsl
215213
assert err < 0
216214
var ret = SSL_get_error(socket.sslHandle, err.cint)
@@ -223,47 +221,49 @@ when defineSsl:
223221
return ret
224222
of SSL_ERROR_WANT_X509_LOOKUP:
225223
raiseSSLError("Function for x509 lookup has been called.")
226-
of SSL_ERROR_SYSCALL, SSL_ERROR_SSL:
224+
of SSL_ERROR_SYSCALL:
225+
socket.sslNoShutdown = true
226+
let osErr = osLastError()
227+
if not flags.isDisconnectionError(osErr):
228+
var errStr = "IO error has occurred"
229+
let sslErr = ERR_peek_last_error()
230+
if sslErr == 0 and err == 0:
231+
errStr.add ' '
232+
errStr.add "because an EOF was observed that violates the protocol"
233+
elif sslErr == 0 and err == -1:
234+
errStr.add ' '
235+
errStr.add "in the BIO layer"
236+
else:
237+
let errStr = $ERR_error_string(sslErr, nil)
238+
raiseSSLError(errStr & ": " & errStr)
239+
raiseOSError(osErr, errStr)
240+
else:
241+
return ret
242+
of SSL_ERROR_SSL:
227243
socket.sslNoShutdown = true
228244
raiseSSLError()
229245
else: raiseSSLError("Unknown Error")
230246

231-
proc sendPendingSslData(socket: AsyncSocket,
232-
flags: set[SocketFlag]) {.async.} =
233-
if socket.sslHandle == nil:
234-
raiseSslHandleError()
235-
let len = bioCtrlPending(socket.bioOut)
236-
if len > 0:
237-
var data = newString(len)
238-
let read = bioRead(socket.bioOut, cast[cstring](addr data[0]), len)
239-
assert read != 0
240-
if read < 0:
241-
raiseSSLError()
242-
data.setLen(read)
243-
await socket.fd.AsyncFD.send(data, flags)
244-
245-
proc appeaseSsl(socket: AsyncSocket, flags: set[SocketFlag],
246-
sslError: cint): owned(Future[bool]) {.async.} =
247+
proc handleSslFailure(socket: AsyncSocket, flags: set[SocketFlag], sslError: cint): Future[bool] =
247248
## Returns `true` if `socket` is still connected, otherwise `false`.
248-
result = true
249+
let retFut = newFuture[bool]("asyncnet.handleSslFailure")
249250
case sslError
250-
of SSL_ERROR_WANT_WRITE:
251-
await sendPendingSslData(socket, flags)
251+
of SSL_ERROR_WANT_WRITE, SSL_ERROR_WANT_CONNECT, SSL_ERROR_WANT_ACCEPT:
252+
addWrite(socket.fd.AsyncFD, proc (sock: AsyncFD): bool =
253+
retFut.complete(true)
254+
return true
255+
)
252256
of SSL_ERROR_WANT_READ:
253-
var data = await recv(socket.fd.AsyncFD, BufferSize, flags)
254-
if socket.sslHandle == nil:
255-
raiseSslHandleError()
256-
let length = len(data)
257-
if length > 0:
258-
let ret = bioWrite(socket.bioIn, cast[cstring](addr data[0]), length.cint)
259-
if ret < 0:
260-
raiseSSLError()
261-
elif length == 0:
262-
# connection not properly closed by remote side or connection dropped
263-
SSL_set_shutdown(socket.sslHandle, SSL_RECEIVED_SHUTDOWN)
264-
result = false
257+
addRead(socket.fd.AsyncFD, proc (sock: AsyncFD): bool =
258+
retFut.complete(true)
259+
return true
260+
)
261+
of SSL_ERROR_SYSCALL:
262+
assert flags.isDisconnectionError(osLastError())
263+
retFut.complete(false)
265264
else:
266-
raiseSSLError("Cannot appease SSL.")
265+
raiseSSLError("Cannot handle SSL failure.")
266+
return retFut
267267

268268
template sslLoop(socket: AsyncSocket, flags: set[SocketFlag],
269269
op: untyped) =
@@ -274,20 +274,12 @@ when defineSsl:
274274
ErrClearError()
275275
# Call the desired operation.
276276
opResult = op
277-
let err =
278-
if opResult < 0:
279-
getSslError(socket, opResult.cint)
280-
else:
281-
SSL_ERROR_NONE
282-
# Send any remaining pending SSL data.
283-
await sendPendingSslData(socket, flags)
284-
285277
# If the operation failed, try to see if SSL has some data to read
286278
# or write.
287279
if opResult < 0:
288-
let fut = appeaseSsl(socket, flags, err.cint)
289-
yield fut
290-
if not fut.read():
280+
let err = getSslError(socket, flags, opResult.cint)
281+
let connected = await handleSslFailure(socket, flags, err.cint)
282+
if not connected:
291283
# Socket disconnected.
292284
if SocketFlag.SafeDisconn in flags:
293285
opResult = 0.cint
@@ -323,8 +315,7 @@ proc connect*(socket: AsyncSocket, address: string, port: Port) {.async.} =
323315
discard SSL_set_tlsext_host_name(socket.sslHandle, address)
324316

325317
let flags = {SocketFlag.SafeDisconn}
326-
sslSetConnectState(socket.sslHandle)
327-
sslLoop(socket, flags, sslDoHandshake(socket.sslHandle))
318+
sslLoop(socket, flags, SSL_connect(socket.sslHandle))
328319

329320
template readInto(buf: pointer, size: int, socket: AsyncSocket,
330321
flags: set[SocketFlag]): int =
@@ -461,7 +452,6 @@ proc send*(socket: AsyncSocket, buf: pointer, size: int,
461452
when defineSsl:
462453
sslLoop(socket, flags,
463454
sslWrite(socket.sslHandle, cast[cstring](buf), size.cint))
464-
await sendPendingSslData(socket, flags)
465455
else:
466456
await send(socket.fd.AsyncFD, buf, size, flags)
467457

@@ -475,52 +465,9 @@ proc send*(socket: AsyncSocket, data: string,
475465
var copy = data
476466
sslLoop(socket, flags,
477467
sslWrite(socket.sslHandle, cast[cstring](addr copy[0]), copy.len.cint))
478-
await sendPendingSslData(socket, flags)
479468
else:
480469
await send(socket.fd.AsyncFD, data, flags)
481470

482-
proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn},
483-
inheritable = defined(nimInheritHandles)):
484-
owned(Future[tuple[address: string, client: AsyncSocket]]) =
485-
## Accepts a new connection. Returns a future containing the client socket
486-
## corresponding to that connection and the remote address of the client.
487-
##
488-
## If `inheritable` is false (the default), the resulting client socket will
489-
## not be inheritable by child processes.
490-
##
491-
## The future will complete when the connection is successfully accepted.
492-
var retFuture = newFuture[tuple[address: string, client: AsyncSocket]]("asyncnet.acceptAddr")
493-
var fut = acceptAddr(socket.fd.AsyncFD, flags, inheritable)
494-
fut.callback =
495-
proc (future: Future[tuple[address: string, client: AsyncFD]]) =
496-
assert future.finished
497-
if future.failed:
498-
retFuture.fail(future.readError)
499-
else:
500-
let resultTup = (future.read.address,
501-
newAsyncSocket(future.read.client, socket.domain,
502-
socket.sockType, socket.protocol, socket.isBuffered, inheritable))
503-
retFuture.complete(resultTup)
504-
return retFuture
505-
506-
proc accept*(socket: AsyncSocket,
507-
flags = {SocketFlag.SafeDisconn}): owned(Future[AsyncSocket]) =
508-
## Accepts a new connection. Returns a future containing the client socket
509-
## corresponding to that connection.
510-
## If `inheritable` is false (the default), the resulting client socket will
511-
## not be inheritable by child processes.
512-
## The future will complete when the connection is successfully accepted.
513-
var retFut = newFuture[AsyncSocket]("asyncnet.accept")
514-
var fut = acceptAddr(socket, flags)
515-
fut.callback =
516-
proc (future: Future[tuple[address: string, client: AsyncSocket]]) =
517-
assert future.finished
518-
if future.failed:
519-
retFut.fail(future.readError)
520-
else:
521-
retFut.complete(future.read.client)
522-
return retFut
523-
524471
proc recvLineInto*(socket: AsyncSocket, resString: FutureVar[string],
525472
flags = {SocketFlag.SafeDisconn}, maxLength = MaxLineLength) {.async.} =
526473
## Reads a line of data from `socket` into `resString`.
@@ -776,9 +723,8 @@ when defineSsl:
776723
if socket.sslHandle == nil:
777724
raiseSSLError()
778725
779-
socket.bioIn = bioNew(bioSMem())
780-
socket.bioOut = bioNew(bioSMem())
781-
sslSetBio(socket.sslHandle, socket.bioIn, socket.bioOut)
726+
if SSL_set_fd(socket.sslHandle, socket.fd) != 1:
727+
raiseSSLError()
782728
783729
socket.sslNoShutdown = true
784730
@@ -795,6 +741,8 @@ when defineSsl:
795741
##
796742
## **Disclaimer**: This code is not well tested, may be very unsafe and
797743
## prone to security vulnerabilities.
744+
if socket.isSsl:
745+
return
798746
wrapSocket(ctx, socket)
799747
800748
case handshake
@@ -818,6 +766,48 @@ when defineSsl:
818766
else:
819767
result = getPeerCertificates(socket.sslHandle)
820768
769+
proc acceptAddr*(socket: AsyncSocket, flags = {SocketFlag.SafeDisconn},
770+
inheritable = defined(nimInheritHandles)):
771+
owned(Future[tuple[address: string, client: AsyncSocket]]) {.async.} =
772+
## Accepts a new connection. Returns a future containing the client socket
773+
## corresponding to that connection and the remote address of the client.
774+
##
775+
## If `inheritable` is false (the default), the resulting client socket will
776+
## not be inheritable by child processes.
777+
##
778+
## The future will complete when the connection is successfully accepted.
779+
let (address, fd) = await acceptAddr(socket.fd.AsyncFD, flags, inheritable)
780+
let client = newAsyncSocket(fd, socket.domain, socket.sockType,
781+
socket.protocol, socket.isBuffered, inheritable)
782+
result = (address, client)
783+
if socket.isSsl:
784+
when defineSsl:
785+
if socket.sslContext == nil:
786+
raiseSSLError("The SSL Context is closed/unset")
787+
wrapSocket(socket.sslContext, result.client)
788+
if result.client.sslHandle == nil:
789+
raiseSslHandleError()
790+
let flags = {SocketFlag.SafeDisconn}
791+
sslLoop(result.client, flags, SSL_accept(result.client.sslHandle))
792+
793+
proc accept*(socket: AsyncSocket,
794+
flags = {SocketFlag.SafeDisconn}): owned(Future[AsyncSocket]) =
795+
## Accepts a new connection. Returns a future containing the client socket
796+
## corresponding to that connection.
797+
## If `inheritable` is false (the default), the resulting client socket will
798+
## not be inheritable by child processes.
799+
## The future will complete when the connection is successfully accepted.
800+
var retFut = newFuture[AsyncSocket]("asyncnet.accept")
801+
var fut = acceptAddr(socket, flags)
802+
fut.callback =
803+
proc (future: Future[tuple[address: string, client: AsyncSocket]]) =
804+
assert future.finished
805+
if future.failed:
806+
retFut.fail(future.readError)
807+
else:
808+
retFut.complete(future.read.client)
809+
return retFut
810+
821811
proc getSockOpt*(socket: AsyncSocket, opt: SOBool, level = SOL_SOCKET): bool {.
822812
tags: [ReadIOEffect].} =
823813
## Retrieves option `opt` as a boolean value.

tests/async/t24895.nim

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
discard """
2+
cmd: "nim $target --hints:on --define:ssl $options $file"
3+
"""
4+
5+
{.define: ssl.}
6+
7+
import std/[asyncdispatch, asyncnet, net, openssl]
8+
9+
var port0: Port
10+
var checked = 0
11+
12+
proc server {.async.} =
13+
let sock = newAsyncSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, buffered = true)
14+
doAssert sock != nil
15+
defer: sock.close()
16+
let sslCtx = newContext(
17+
protSSLv23,
18+
verifyMode = CVerifyNone,
19+
certFile = "tests/testdata/mycert.pem",
20+
keyFile = "tests/testdata/mycert.pem"
21+
)
22+
doAssert sslCtx != nil
23+
defer: sslCtx.destroyContext()
24+
wrapSocket(sslCtx, sock)
25+
#sock.bindAddr(Port 8181)
26+
sock.bindAddr()
27+
port0 = getLocalAddr(sock)[1]
28+
sock.listen()
29+
echo "accept"
30+
let clientSocket = await sock.accept()
31+
defer: clientSocket.close()
32+
wrapConnectedSocket(
33+
sslCtx, clientSocket, handshakeAsServer, "localhost"
34+
)
35+
let sdata = "x" & newString(41)
36+
let sfut = clientSocket.send(sdata)
37+
let rdata = newString(42)
38+
let rfut = clientSocket.recvInto(addr rdata[0], rdata.len)
39+
echo "send"
40+
await sfut
41+
echo "recv"
42+
let rLen = await rfut # it hang here until the client closes the connection or sends more data
43+
doAssert rLen == 42, $rLen
44+
doAssert rdata[0] == 'x', $rdata[0]
45+
echo "ok"
46+
inc checked
47+
48+
proc client {.async.} =
49+
let sock = newAsyncSocket(AF_INET, SOCK_STREAM, IPPROTO_TCP, buffered = true)
50+
doAssert sock != nil
51+
defer: sock.close()
52+
let sslCtx = newContext(
53+
protSSLv23,
54+
verifyMode = CVerifyNone
55+
)
56+
doAssert sslCtx != nil
57+
defer: sslCtx.destroyContext()
58+
wrapSocket(sslCtx, sock)
59+
#await sock.connect("127.0.0.1", Port 8181)
60+
await sock.connect("localhost", port0)
61+
let sdata = "x" & newString(41)
62+
echo "send"
63+
await sock.send(sdata)
64+
let rdata = newString(42)
65+
echo "recv"
66+
let rLen = await sock.recvInto(addr rdata[0], rdata.len)
67+
doAssert rLen == 42, $rLen
68+
doAssert rdata[0] == 'x', $rdata[0]
69+
#await sleepAsync(10_000)
70+
#await sock.send("x")
71+
echo "ok"
72+
inc checked
73+
74+
discard getGlobalDispatcher()
75+
let serverFut = server()
76+
waitFor client()
77+
waitFor serverFut
78+
doAssert checked == 2
79+
doAssert not hasPendingOperations()

0 commit comments

Comments
 (0)