Skip to content

Commit 161a644

Browse files
authored
Merge pull request #912 from rust-lang/retry-one-at-a-time
Wait for a WebSocket connection to close before reconnecting
2 parents ec9d5af + e15e15a commit 161a644

File tree

3 files changed

+97
-29
lines changed

3 files changed

+97
-29
lines changed

ui/frontend/reducers/websocket.ts

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import { PayloadAction, createSlice } from '@reduxjs/toolkit';
22
import z from 'zod';
33

4-
import { createWebsocketResponseSchema } from '../websocketActions';
4+
import { createWebsocketResponseSchema, makeWebSocketMeta } from '../websocketActions';
55

66
export type State = {
77
connected: boolean;
@@ -14,6 +14,11 @@ const initialState: State = {
1414
featureFlagEnabled: false,
1515
};
1616

17+
const websocketConnectedPayloadSchema = z.object({
18+
iAcceptThisIsAnUnsupportedApi: z.boolean(),
19+
});
20+
type websocketConnectedPayload = z.infer<typeof websocketConnectedPayloadSchema>;
21+
1722
const websocketErrorPayloadSchema = z.object({
1823
error: z.string(),
1924
});
@@ -23,9 +28,18 @@ const slice = createSlice({
2328
name: 'websocket',
2429
initialState,
2530
reducers: {
26-
connected: (state) => {
27-
state.connected = true;
28-
delete state.error;
31+
connected: {
32+
reducer: (state, _action: PayloadAction<websocketConnectedPayload>) => {
33+
state.connected = true;
34+
delete state.error;
35+
},
36+
37+
prepare: () => ({
38+
payload: {
39+
iAcceptThisIsAnUnsupportedApi: true,
40+
},
41+
meta: makeWebSocketMeta(),
42+
}),
2943
},
3044

3145
disconnected: (state) => {
@@ -49,6 +63,11 @@ export const {
4963
featureFlagEnabled: websocketFeatureFlagEnabled,
5064
} = slice.actions;
5165

66+
export const websocketConnectedSchema = createWebsocketResponseSchema(
67+
websocketConnected,
68+
websocketConnectedPayloadSchema,
69+
);
70+
5271
export const websocketErrorSchema = createWebsocketResponseSchema(
5372
websocketError,
5473
websocketErrorPayloadSchema,

ui/frontend/websocketMiddleware.ts

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,49 @@ import { z } from 'zod';
44
import { wsExecuteResponseSchema } from './reducers/output/execute';
55
import {
66
websocketConnected,
7+
websocketConnectedSchema,
78
websocketDisconnected,
89
websocketError,
910
websocketErrorSchema,
1011
} from './reducers/websocket';
1112

1213
const WSMessageResponse = z.discriminatedUnion('type', [
14+
websocketConnectedSchema,
1315
websocketErrorSchema,
1416
wsExecuteResponseSchema,
1517
]);
1618

17-
const reportWebSocketError = async (error: string) => {
18-
try {
19-
await fetch('/nowebsocket', {
20-
method: 'post',
21-
headers: {
22-
'Content-Type': 'application/json',
23-
},
24-
body: JSON.stringify({ error }),
25-
});
26-
} catch (reportError) {
27-
console.log('Unable to report WebSocket error', error, reportError);
28-
}
29-
};
19+
const reportWebSocketError = (() => {
20+
let lastReport: string | undefined;
21+
let lastReportTime = 0;
22+
23+
return async (error: string) => {
24+
// Don't worry about reporting the same thing again.
25+
if (lastReport === error) {
26+
return;
27+
}
28+
lastReport = error;
29+
30+
// Don't worry about spamming the server with reports.
31+
const now = Date.now();
32+
if (now - lastReportTime < 1000) {
33+
return;
34+
}
35+
lastReportTime = now;
36+
37+
try {
38+
await fetch('/nowebsocket', {
39+
method: 'post',
40+
headers: {
41+
'Content-Type': 'application/json',
42+
},
43+
body: JSON.stringify({ error }),
44+
});
45+
} catch (reportError) {
46+
console.log('Unable to report WebSocket error', error, reportError);
47+
}
48+
};
49+
})();
3050

3151
const openWebSocket = (currentLocation: Location) => {
3252
try {
@@ -76,18 +96,16 @@ export const websocketMiddleware =
7696
resetTimeout();
7797

7898
socket.addEventListener('open', () => {
79-
store.dispatch(websocketConnected());
80-
81-
wasConnected = true;
99+
if (socket) {
100+
socket.send(JSON.stringify(websocketConnected()));
101+
}
82102
});
83103

84104
socket.addEventListener('close', (event) => {
85105
store.dispatch(websocketDisconnected());
86106

87107
// Reconnect if we've previously connected
88108
if (wasConnected && !event.wasClean) {
89-
wasConnected = false;
90-
reconnectAttempt = 0;
91109
reconnect();
92110
}
93111
});
@@ -104,6 +122,12 @@ export const websocketMiddleware =
104122
try {
105123
const rawMessage = JSON.parse(event.data);
106124
const message = WSMessageResponse.parse(rawMessage);
125+
126+
if (websocketConnected.match(message)) {
127+
wasConnected = true;
128+
reconnectAttempt = 0;
129+
}
130+
107131
store.dispatch(message);
108132
resetTimeout();
109133
} catch (e) {
@@ -114,15 +138,10 @@ export const websocketMiddleware =
114138
};
115139

116140
const reconnect = () => {
117-
if (socket && socket.readyState == socket.OPEN) {
118-
return;
119-
}
120-
121-
connect();
122-
123141
const delay = backoffMs(reconnectAttempt);
124142
reconnectAttempt += 1;
125-
setTimeout(reconnect, delay);
143+
144+
window.setTimeout(connect, delay);
126145
};
127146

128147
connect();

ui/src/server_axum/websocket.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,23 @@ use tokio::{sync::mpsc, task::JoinSet};
1515

1616
type Meta = serde_json::Value;
1717

18+
#[derive(serde::Deserialize)]
19+
#[serde(tag = "type")]
20+
enum HandshakeMessage {
21+
#[serde(rename = "websocket/connected")]
22+
Connected {
23+
payload: Connected,
24+
#[allow(unused)]
25+
meta: Meta,
26+
},
27+
}
28+
29+
#[derive(serde::Deserialize)]
30+
#[serde(rename_all = "camelCase")]
31+
struct Connected {
32+
i_accept_this_is_an_unsupported_api: bool,
33+
}
34+
1835
#[derive(serde::Deserialize)]
1936
#[serde(tag = "type")]
2037
enum WSMessageRequest {
@@ -107,6 +124,10 @@ pub async fn handle(mut socket: WebSocket) {
107124
LIVE_WS.inc();
108125
let start = Instant::now();
109126

127+
if !connect_handshake(&mut socket).await {
128+
return;
129+
}
130+
110131
let (tx, mut rx) = mpsc::channel(3);
111132
let mut tasks = JoinSet::new();
112133

@@ -172,6 +193,15 @@ pub async fn handle(mut socket: WebSocket) {
172193
DURATION_WS.observe(elapsed.as_secs_f64());
173194
}
174195

196+
async fn connect_handshake(socket: &mut WebSocket) -> bool {
197+
let Some(Ok(Message::Text(txt))) = socket.recv().await else { return false };
198+
let Ok(HandshakeMessage::Connected { payload, .. }) = serde_json::from_str::<HandshakeMessage>(&txt) else { return false };
199+
if !payload.i_accept_this_is_an_unsupported_api {
200+
return false;
201+
}
202+
socket.send(Message::Text(txt)).await.is_ok()
203+
}
204+
175205
fn error_to_response(error: Error) -> MessageResponse {
176206
let error = error.to_string();
177207
// TODO: thread through the Meta from the originating request

0 commit comments

Comments
 (0)