Skip to content

Commit c164dd0

Browse files
committed
potential fix for local-stt-server connection issues
1 parent 55e4ee7 commit c164dd0

File tree

6 files changed

+83
-80
lines changed

6 files changed

+83
-80
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ reqwest = "0.12"
9898
tokio = "1"
9999
tokio-stream = "0.1.17"
100100
tokio-tungstenite = "0.26.0"
101+
tokio-util = "0.7.15"
101102

102103
anyhow = "1"
103104
approx = "0.5.1"

crates/ws/src/client.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,6 @@ impl WebSocketClient {
3535
.with_delay(std::time::Duration::from_millis(500)),
3636
)
3737
.when(|e| {
38-
if let crate::Error::Connection(te) = e {
39-
if let tokio_tungstenite::tungstenite::Error::Http(res) = te {
40-
if res.status() == 429 {
41-
return true;
42-
}
43-
}
44-
}
45-
4638
tracing::error!("ws_connect_failed: {:?}", e);
4739
true
4840
})
@@ -105,7 +97,7 @@ impl WebSocketClient {
10597
tracing::info!("connect_async: {:?}", req.uri());
10698

10799
let (ws_stream, _) =
108-
tokio::time::timeout(std::time::Duration::from_secs(4), connect_async(req)).await??;
100+
tokio::time::timeout(std::time::Duration::from_secs(8), connect_async(req)).await??;
109101

110102
Ok(ws_stream)
111103
}

plugins/local-stt/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,11 @@ thiserror = { workspace = true }
4444
tracing = { workspace = true }
4545

4646
axum = { workspace = true, features = ["ws"] }
47+
tower-http = { workspace = true, features = ["cors", "trace"] }
48+
4749
futures-util = { workspace = true }
4850
tokio = { workspace = true, features = ["rt", "macros"] }
49-
tower-http = { workspace = true, features = ["cors", "trace"] }
51+
tokio-util = { workspace = true }
5052

5153
[target.'cfg(not(target_os = "macos"))'.dependencies]
5254
kalosm-sound = { workspace = true, default-features = false }

plugins/local-stt/src/manager.rs

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,41 @@
1-
use std::{
2-
sync::atomic::{AtomicUsize, Ordering},
3-
sync::Arc,
4-
};
1+
use std::sync::{Arc, Mutex};
2+
use tokio_util::sync::CancellationToken;
53

64
#[derive(Clone)]
75
pub struct ConnectionManager {
8-
num_connections: Arc<AtomicUsize>,
6+
inner: Arc<Mutex<Option<CancellationToken>>>,
97
}
108

119
impl Default for ConnectionManager {
1210
fn default() -> Self {
1311
Self {
14-
num_connections: Arc::new(AtomicUsize::new(0)),
12+
inner: Arc::new(Mutex::new(None)),
1513
}
1614
}
1715
}
1816

1917
impl ConnectionManager {
20-
pub fn try_acquire_connection(&self) -> Option<ConnectionGuard> {
21-
let current = self.num_connections.load(Ordering::SeqCst);
22-
if current >= 1 {
23-
return None;
24-
}
18+
pub fn acquire_connection(&self) -> ConnectionGuard {
19+
let mut slot = self.inner.lock().unwrap();
2520

26-
match self
27-
.num_connections
28-
.compare_exchange(0, 1, Ordering::SeqCst, Ordering::SeqCst)
29-
{
30-
Ok(_) => Some(ConnectionGuard(self.num_connections.clone())),
31-
Err(_) => None,
21+
if let Some(old) = slot.take() {
22+
old.cancel();
3223
}
24+
25+
let token = CancellationToken::new();
26+
*slot = Some(token.clone());
27+
28+
ConnectionGuard { token }
3329
}
3430
}
3531

36-
pub struct ConnectionGuard(Arc<AtomicUsize>);
32+
pub struct ConnectionGuard {
33+
token: CancellationToken,
34+
}
3735

38-
impl Drop for ConnectionGuard {
39-
fn drop(&mut self) {
40-
self.0.fetch_sub(1, Ordering::SeqCst);
36+
impl ConnectionGuard {
37+
pub async fn cancelled(&self) {
38+
self.token.cancelled().await
4139
}
4240
}
4341

@@ -70,9 +68,7 @@ mod tests {
7068
ws: WebSocketUpgrade,
7169
AxumState(manager): AxumState<ConnectionManager>,
7270
) -> Result<impl IntoResponse, StatusCode> {
73-
let guard = manager
74-
.try_acquire_connection()
75-
.ok_or(StatusCode::TOO_MANY_REQUESTS)?;
71+
let guard = manager.acquire_connection();
7672

7773
Ok(ws.on_upgrade(move |socket| handle_socket(socket, guard)))
7874
}

plugins/local-stt/src/server.rs

Lines changed: 57 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,6 @@ pub struct ServerState {
5656
connection_manager: ConnectionManager,
5757
}
5858

59-
impl ServerState {
60-
pub fn try_acquire_connection(&self) -> Option<ConnectionGuard> {
61-
self.connection_manager.try_acquire_connection()
62-
}
63-
}
64-
6559
#[derive(Clone)]
6660
pub struct ServerHandle {
6761
pub addr: SocketAddr,
@@ -114,13 +108,25 @@ async fn listen(
114108
ws: WebSocketUpgrade,
115109
AxumState(state): AxumState<ServerState>,
116110
) -> Result<impl IntoResponse, StatusCode> {
117-
let guard = state
118-
.try_acquire_connection()
119-
.ok_or(StatusCode::TOO_MANY_REQUESTS)?;
111+
let guard = state.connection_manager.acquire_connection();
120112

121-
let model_path = state.model_type.model_path(&state.model_cache_dir);
113+
Ok(ws.on_upgrade(move |socket| async move {
114+
websocket_with_model(socket, params, state, guard).await
115+
}))
116+
}
117+
118+
async fn websocket_with_model(
119+
socket: WebSocket,
120+
params: ListenParams,
121+
state: ServerState,
122+
guard: ConnectionGuard,
123+
) {
124+
let model_type = state.model_type;
125+
let model_cache_dir = state.model_cache_dir.clone();
126+
127+
let model_path = model_type.model_path(&model_cache_dir);
122128
let language = params.language.try_into().unwrap_or_else(|e| {
123-
tracing::error!("convert_to_whisper_language: {:?}", e);
129+
tracing::error!("convert_to_whisper_language: {e:?}");
124130
hypr_whisper::Language::En
125131
});
126132

@@ -131,15 +137,11 @@ async fn listen(
131137
.dynamic_prompt(&params.dynamic_prompt)
132138
.build();
133139

134-
Ok(ws.on_upgrade(move |socket| websocket(socket, model, guard)))
140+
websocket(socket, model, guard).await;
135141
}
136142

137143
#[tracing::instrument(skip_all)]
138-
async fn websocket(
139-
socket: WebSocket,
140-
model: hypr_whisper::local::Whisper,
141-
_guard: ConnectionGuard,
142-
) {
144+
async fn websocket(socket: WebSocket, model: hypr_whisper::local::Whisper, guard: ConnectionGuard) {
143145
let (mut ws_sender, ws_receiver) = socket.split();
144146
let mut stream = {
145147
let audio_source = WebSocketAudioSource::new(ws_receiver, 16 * 1000);
@@ -148,35 +150,44 @@ async fn websocket(
148150
hypr_whisper::local::TranscribeChunkedAudioStreamExt::transcribe(chunked, model)
149151
};
150152

151-
while let Some(chunk) = stream.next().await {
152-
let text = chunk.text().to_string();
153-
let start = chunk.start() as u64;
154-
let duration = chunk.duration() as u64;
155-
let confidence = chunk.confidence();
156-
157-
if confidence < 0.45 {
158-
tracing::warn!(confidence, "skipping_transcript: {}", text);
159-
continue;
160-
}
161-
162-
let data = ListenOutputChunk {
163-
words: text
164-
.split_whitespace()
165-
.filter(|w| !w.is_empty())
166-
.map(|w| Word {
167-
text: w.trim().to_string(),
168-
speaker: None,
169-
start_ms: Some(start),
170-
end_ms: Some(start + duration),
171-
confidence: Some(confidence),
172-
})
173-
.collect(),
174-
};
175-
176-
let msg = Message::Text(serde_json::to_string(&data).unwrap().into());
177-
if let Err(e) = ws_sender.send(msg).await {
178-
tracing::warn!("websocket_send_error: {}", e);
179-
break;
153+
loop {
154+
tokio::select! {
155+
_ = guard.cancelled() => {
156+
tracing::info!("websocket_cancelled_by_new_connection");
157+
break;
158+
}
159+
chunk_opt = stream.next() => {
160+
let Some(chunk) = chunk_opt else { break };
161+
let text = chunk.text().to_string();
162+
let start = chunk.start() as u64;
163+
let duration = chunk.duration() as u64;
164+
let confidence = chunk.confidence();
165+
166+
if confidence < 0.45 {
167+
tracing::warn!(confidence, "skipping_transcript: {}", text);
168+
continue;
169+
}
170+
171+
let data = ListenOutputChunk {
172+
words: text
173+
.split_whitespace()
174+
.filter(|w| !w.is_empty())
175+
.map(|w| Word {
176+
text: w.trim().to_string(),
177+
speaker: None,
178+
start_ms: Some(start),
179+
end_ms: Some(start + duration),
180+
confidence: Some(confidence),
181+
})
182+
.collect(),
183+
};
184+
185+
let msg = Message::Text(serde_json::to_string(&data).unwrap().into());
186+
if let Err(e) = ws_sender.send(msg).await {
187+
tracing::warn!("websocket_send_error: {}", e);
188+
break;
189+
}
190+
}
180191
}
181192
}
182193

0 commit comments

Comments
 (0)