Skip to content

Commit 7f9d06b

Browse files
committed
Added test for KTOR-7234 Fix WS session closure
1 parent c09ed4c commit 7f9d06b

File tree

3 files changed

+188
-1
lines changed

3 files changed

+188
-1
lines changed

krpc/krpc-ktor/krpc-ktor-core/build.gradle.kts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,17 @@ kotlin {
2626
implementation(projects.krpc.krpcSerialization.krpcSerializationJson)
2727
implementation(projects.krpc.krpcKtor.krpcKtorServer)
2828
implementation(projects.krpc.krpcKtor.krpcKtorClient)
29+
implementation(projects.krpc.krpcLogging)
2930

3031
implementation(libs.kotlin.test)
3132
implementation(libs.ktor.server.netty)
3233
implementation(libs.ktor.server.test.host)
34+
implementation(libs.ktor.server.websockets)
35+
implementation(libs.ktor.client.core)
36+
implementation(libs.ktor.client.websockets)
37+
implementation(libs.ktor.client.cio)
38+
implementation(libs.logback.classic)
39+
implementation(libs.coroutines.debug)
3340
}
3441
}
3542
}

krpc/krpc-ktor/krpc-ktor-core/src/jvmTest/kotlin/kotlinx/rpc/krpc/ktor/KtorTransportTest.kt

Lines changed: 165 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,23 @@
66

77
package kotlinx.rpc.krpc.ktor
88

9+
import io.ktor.client.*
10+
import io.ktor.client.engine.cio.*
11+
import io.ktor.client.request.*
12+
import io.ktor.client.statement.*
913
import io.ktor.server.application.*
14+
import io.ktor.server.engine.*
15+
import io.ktor.server.netty.*
16+
import io.ktor.server.response.*
17+
import io.ktor.server.routing.*
1018
import io.ktor.server.testing.*
11-
import kotlinx.coroutines.cancel
19+
import kotlinx.coroutines.*
20+
import kotlinx.coroutines.debug.DebugProbes
21+
import kotlinx.coroutines.test.runTest
1222
import kotlinx.rpc.annotations.Rpc
23+
import kotlinx.rpc.krpc.client.KrpcClient
24+
import kotlinx.rpc.krpc.internal.logging.RpcInternalDumpLogger
25+
import kotlinx.rpc.krpc.internal.logging.RpcInternalDumpLoggerContainer
1326
import kotlinx.rpc.krpc.ktor.client.installKrpc
1427
import kotlinx.rpc.krpc.ktor.client.rpc
1528
import kotlinx.rpc.krpc.ktor.client.rpcConfig
@@ -18,7 +31,15 @@ import kotlinx.rpc.krpc.ktor.server.rpc
1831
import kotlinx.rpc.krpc.serialization.json.json
1932
import kotlinx.rpc.withService
2033
import org.junit.Assert.assertEquals
34+
import org.junit.platform.commons.logging.Logger
35+
import org.junit.platform.commons.logging.LoggerFactory
36+
import java.net.ServerSocket
37+
import java.util.concurrent.Executors
38+
import java.util.concurrent.TimeUnit
39+
import kotlin.coroutines.cancellation.CancellationException
40+
import kotlin.test.Ignore
2141
import kotlin.test.Test
42+
import kotlin.time.Duration.Companion.seconds
2243

2344
@Rpc
2445
interface NewService {
@@ -35,6 +56,23 @@ class NewServiceImpl(
3556
}
3657
}
3758

59+
@Rpc
60+
interface SlowService {
61+
suspend fun verySlow(): String
62+
}
63+
64+
class SlowServiceImpl : SlowService {
65+
val received = CompletableDeferred<Unit>()
66+
67+
override suspend fun verySlow(): String {
68+
received.complete(Unit)
69+
70+
delay(Int.MAX_VALUE.toLong())
71+
72+
error("Must not be called")
73+
}
74+
}
75+
3876
class KtorTransportTest {
3977
@Test
4078
fun testEcho() = testApplication {
@@ -96,4 +134,130 @@ class KtorTransportTest {
96134

97135
clientWithNoConfig.cancel()
98136
}
137+
138+
@OptIn(DelicateCoroutinesApi::class, ExperimentalCoroutinesApi::class)
139+
@Test
140+
@Ignore("Wait for Ktor fix (https://github.com/ktorio/ktor/pull/4927) or apply workaround if rejected")
141+
fun testEndpointsTerminateWhenWsDoes() = runTest(timeout = 15.seconds) {
142+
DebugProbes.install()
143+
144+
val logger = setupLogger()
145+
146+
val port: Int = findFreePort()
147+
148+
val newPool = Executors.newCachedThreadPool().asCoroutineDispatcher()
149+
150+
val serverReady = CompletableDeferred<Unit>()
151+
val dropServer = CompletableDeferred<Unit>()
152+
153+
val service = SlowServiceImpl()
154+
155+
val serverJob = GlobalScope.launch(CoroutineName("server")) {
156+
withContext(newPool) {
157+
val server = embeddedServer(
158+
factory = Netty,
159+
port = port,
160+
parentCoroutineContext = newPool,
161+
) {
162+
install(Krpc)
163+
164+
routing {
165+
get {
166+
call.respondText("hello")
167+
}
168+
169+
rpc("/rpc") {
170+
rpcConfig {
171+
serialization {
172+
json()
173+
}
174+
}
175+
176+
registerService<SlowService> { service }
177+
}
178+
}
179+
}.start(wait = false)
180+
181+
serverReady.complete(Unit)
182+
183+
dropServer.await()
184+
185+
server.stop(shutdownGracePeriod = 100L, shutdownTimeout = 100L, timeUnit = TimeUnit.MILLISECONDS)
186+
}
187+
188+
logger.info { "Server stopped" }
189+
}
190+
191+
val ktorClient = HttpClient(CIO) {
192+
installKrpc {
193+
serialization {
194+
json()
195+
}
196+
}
197+
}
198+
199+
serverReady.await()
200+
201+
assertEquals("hello", ktorClient.get("http://0.0.0.0:$port").bodyAsText())
202+
203+
val rpcClient = ktorClient.rpc("ws://0.0.0.0:$port/rpc")
204+
205+
launch {
206+
try {
207+
rpcClient.withService<SlowService>().verySlow()
208+
error("Must not be called")
209+
} catch (_: CancellationException) {
210+
logger.info { "Cancellation exception caught for RPC request" }
211+
ensureActive()
212+
}
213+
}
214+
215+
service.received.await()
216+
217+
logger.info { "Received RPC request" }
218+
219+
dropServer.complete(Unit)
220+
221+
logger.info { "Waiting for RPC client to complete" }
222+
223+
(rpcClient as KrpcClient).awaitCompletion()
224+
225+
logger.info { "RPC client completed" }
226+
227+
ktorClient.close()
228+
newPool.close()
229+
230+
serverJob.cancel()
231+
}
232+
233+
private fun findFreePort(): Int {
234+
val port: Int
235+
while (true) {
236+
val socket = try {
237+
ServerSocket(0)
238+
} catch (_: Throwable) {
239+
continue
240+
}
241+
242+
port = socket.localPort
243+
socket.close()
244+
break
245+
}
246+
return port
247+
}
248+
249+
private fun setupLogger(): Logger {
250+
val logger = LoggerFactory.getLogger(KtorTransportTest::class.java)
251+
252+
RpcInternalDumpLoggerContainer.set(object : RpcInternalDumpLogger {
253+
254+
override val isEnabled: Boolean = true
255+
256+
override fun dump(vararg tags: String, message: () -> String) {
257+
logger.info { "[${tags.joinToString()}] ${message()}" }
258+
}
259+
})
260+
261+
return logger
262+
}
99263
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
<!--
2+
~ Copyright 2023-2024 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
3+
-->
4+
5+
<configuration>
6+
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
7+
<encoder>
8+
<pattern>%d{YYYY-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
9+
</encoder>
10+
</appender>
11+
<root level="trace">
12+
<appender-ref ref="STDOUT"/>
13+
</root>
14+
<logger name="org.eclipse.jetty" level="INFO"/>
15+
<logger name="io.netty" level="TRACE"/>
16+
</configuration>

0 commit comments

Comments
 (0)