Skip to content

Commit a19988d

Browse files
committed
add stt model getter and setter
1 parent e3e0c4b commit a19988d

File tree

16 files changed

+190
-1
lines changed

16 files changed

+190
-1
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

apps/desktop/src/components/settings/views/ai.tsx

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ export default function LocalAI() {
1515
const sttRunning = useQuery({
1616
queryKey: ["local-stt", "running"],
1717
queryFn: async () => localSttCommands.isServerRunning(),
18+
refetchInterval: 3000,
1819
});
1920

2021
const llmRunning = useQuery({
2122
queryKey: ["local-llm", "running"],
2223
queryFn: async () => localLlmCommands.isServerRunning(),
24+
refetchInterval: 3000,
2325
});
2426

2527
return (
@@ -44,6 +46,20 @@ function SpeechToTextDetails(
4446
},
4547
});
4648

49+
// const currentModel = useQuery({
50+
// queryKey: ["local-stt", "current-model"],
51+
// queryFn: () => localSttCommands.getCurrentModel(),
52+
// });
53+
54+
// const setCurrentModel = useMutation({
55+
// mutationFn: async (model: SupportedModel) => {
56+
// await localSttCommands.setCurrentModel(model);
57+
// },
58+
// onSuccess: () => {
59+
// queryClient.invalidateQueries({ queryKey: ["local-stt", "current-model"] });
60+
// },
61+
// });
62+
4763
return (
4864
<div className="space-y-4">
4965
<div className="flex items-center justify-between rounded-lg border p-4">

plugins/local-stt/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,13 @@ hypr-file = { workspace = true }
2929
hypr-listener-interface = { workspace = true }
3030

3131
tauri = { workspace = true, features = ["test"] }
32+
tauri-plugin-store2 = { workspace = true }
3233
tauri-specta = { workspace = true, features = ["derive", "typescript"] }
3334

3435
serde = { workspace = true }
3536
serde_json = { workspace = true }
3637
specta = { workspace = true }
38+
strum = { workspace = true, features = ["derive"] }
3739
thiserror = { workspace = true }
3840
tracing = { workspace = true }
3941

plugins/local-stt/build.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ const COMMANDS: &[&str] = &[
44
"download_model",
55
"start_server",
66
"stop_server",
7+
"get_current_model",
8+
"set_current_model",
79
];
810

911
fn main() {

plugins/local-stt/js/bindings.gen.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,12 @@ async startServer() : Promise<null> {
2121
},
2222
async stopServer() : Promise<null> {
2323
return await TAURI_INVOKE("plugin:local-stt|stop_server");
24+
},
25+
async getCurrentModel() : Promise<SupportedModel> {
26+
return await TAURI_INVOKE("plugin:local-stt|get_current_model");
27+
},
28+
async setCurrentModel(model: SupportedModel) : Promise<null> {
29+
return await TAURI_INVOKE("plugin:local-stt|set_current_model", { model });
2430
}
2531
}
2632

@@ -34,6 +40,7 @@ async stopServer() : Promise<null> {
3440

3541
/** user-defined types **/
3642

43+
export type SupportedModel = "QuantizedTiny" | "QuantizedTinyEn" | "QuantizedLargeV3Turbo"
3744
export type TAURI_CHANNEL<TSend> = null
3845

3946
/** tauri-specta globals **/
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Automatically generated - DO NOT EDIT!
2+
3+
"$schema" = "../../schemas/schema.json"
4+
5+
[[permission]]
6+
identifier = "allow-get-current-model"
7+
description = "Enables the get_current_model command without any pre-configured scope."
8+
commands.allow = ["get_current_model"]
9+
10+
[[permission]]
11+
identifier = "deny-get-current-model"
12+
description = "Denies the get_current_model command without any pre-configured scope."
13+
commands.deny = ["get_current_model"]
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Automatically generated - DO NOT EDIT!
2+
3+
"$schema" = "../../schemas/schema.json"
4+
5+
[[permission]]
6+
identifier = "allow-set-current-model"
7+
description = "Enables the set_current_model command without any pre-configured scope."
8+
commands.allow = ["set_current_model"]
9+
10+
[[permission]]
11+
identifier = "deny-set-current-model"
12+
description = "Denies the set_current_model command without any pre-configured scope."
13+
commands.deny = ["set_current_model"]

plugins/local-stt/permissions/autogenerated/reference.md

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ Default permissions for the plugin
77
- `allow-download-model`
88
- `allow-start-server`
99
- `allow-stop-server`
10+
- `allow-get-current-model`
11+
- `allow-set-current-model`
1012

1113
## Permission Table
1214

@@ -46,6 +48,32 @@ Denies the download_model command without any pre-configured scope.
4648
<tr>
4749
<td>
4850

51+
`local-stt:allow-get-current-model`
52+
53+
</td>
54+
<td>
55+
56+
Enables the get_current_model command without any pre-configured scope.
57+
58+
</td>
59+
</tr>
60+
61+
<tr>
62+
<td>
63+
64+
`local-stt:deny-get-current-model`
65+
66+
</td>
67+
<td>
68+
69+
Denies the get_current_model command without any pre-configured scope.
70+
71+
</td>
72+
</tr>
73+
74+
<tr>
75+
<td>
76+
4977
`local-stt:allow-get-status`
5078

5179
</td>
@@ -124,6 +152,32 @@ Denies the is_server_running command without any pre-configured scope.
124152
<tr>
125153
<td>
126154

155+
`local-stt:allow-set-current-model`
156+
157+
</td>
158+
<td>
159+
160+
Enables the set_current_model command without any pre-configured scope.
161+
162+
</td>
163+
</tr>
164+
165+
<tr>
166+
<td>
167+
168+
`local-stt:deny-set-current-model`
169+
170+
</td>
171+
<td>
172+
173+
Denies the set_current_model command without any pre-configured scope.
174+
175+
</td>
176+
</tr>
177+
178+
<tr>
179+
<td>
180+
127181
`local-stt:allow-start-server`
128182

129183
</td>

plugins/local-stt/permissions/default.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ permissions = [
66
"allow-download-model",
77
"allow-start-server",
88
"allow-stop-server",
9+
"allow-get-current-model",
10+
"allow-set-current-model",
911
]

plugins/local-stt/permissions/schemas/schema.json

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,16 @@
304304
"type": "string",
305305
"const": "deny-download-model"
306306
},
307+
{
308+
"description": "Enables the get_current_model command without any pre-configured scope.",
309+
"type": "string",
310+
"const": "allow-get-current-model"
311+
},
312+
{
313+
"description": "Denies the get_current_model command without any pre-configured scope.",
314+
"type": "string",
315+
"const": "deny-get-current-model"
316+
},
307317
{
308318
"description": "Enables the get_status command without any pre-configured scope.",
309319
"type": "string",
@@ -334,6 +344,16 @@
334344
"type": "string",
335345
"const": "deny-is-server-running"
336346
},
347+
{
348+
"description": "Enables the set_current_model command without any pre-configured scope.",
349+
"type": "string",
350+
"const": "allow-set-current-model"
351+
},
352+
{
353+
"description": "Denies the set_current_model command without any pre-configured scope.",
354+
"type": "string",
355+
"const": "deny-set-current-model"
356+
},
337357
{
338358
"description": "Enables the start_server command without any pre-configured scope.",
339359
"type": "string",

plugins/local-stt/src/commands.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,20 @@ pub async fn start_server<R: tauri::Runtime>(app: tauri::AppHandle<R>) -> Result
3838
pub async fn stop_server<R: tauri::Runtime>(app: tauri::AppHandle<R>) -> Result<(), String> {
3939
app.stop_server().await.map_err(|e| e.to_string())
4040
}
41+
42+
#[tauri::command]
43+
#[specta::specta]
44+
pub fn get_current_model<R: tauri::Runtime>(
45+
app: tauri::AppHandle<R>,
46+
) -> Result<crate::SupportedModel, String> {
47+
app.get_current_model().map_err(|e| e.to_string())
48+
}
49+
50+
#[tauri::command]
51+
#[specta::specta]
52+
pub fn set_current_model<R: tauri::Runtime>(
53+
app: tauri::AppHandle<R>,
54+
model: crate::SupportedModel,
55+
) -> Result<(), String> {
56+
app.set_current_model(model).map_err(|e| e.to_string())
57+
}

plugins/local-stt/src/error.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ pub enum Error {
1010
TauriError(#[from] tauri::Error),
1111
#[error(transparent)]
1212
IoError(#[from] std::io::Error),
13+
#[error(transparent)]
14+
StoreError(#[from] tauri_plugin_store2::Error),
1315
}
1416

1517
impl Serialize for Error {

plugins/local-stt/src/ext.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use std::future::Future;
22
use std::path::PathBuf;
33

44
use tauri::{ipc::Channel, Manager, Runtime};
5+
use tauri_plugin_store2::StorePluginExt;
56

67
#[derive(Debug, Clone)]
78
pub struct ModelConfig {
@@ -10,6 +11,7 @@ pub struct ModelConfig {
1011
}
1112

1213
pub trait LocalSttPluginExt<R: Runtime> {
14+
fn local_stt_store(&self) -> tauri_plugin_store2::ScopedStore<R, crate::StoreKey>;
1315
fn api_base(&self) -> impl Future<Output = Option<String>>;
1416
fn is_model_downloaded(&self) -> impl Future<Output = Result<bool, crate::Error>>;
1517
fn is_server_running(&self) -> impl Future<Output = bool>;
@@ -18,9 +20,15 @@ pub trait LocalSttPluginExt<R: Runtime> {
1820
fn download_config(&self) -> impl Future<Output = Result<(), crate::Error>>;
1921
fn download_tokenizer(&self) -> impl Future<Output = Result<(), crate::Error>>;
2022
fn download_model(&self, c: Channel<u8>) -> impl Future<Output = Result<(), crate::Error>>;
23+
fn get_current_model(&self) -> Result<crate::SupportedModel, crate::Error>;
24+
fn set_current_model(&self, model: crate::SupportedModel) -> Result<(), crate::Error>;
2125
}
2226

2327
impl<R: Runtime, T: Manager<R>> LocalSttPluginExt<R> for T {
28+
fn local_stt_store(&self) -> tauri_plugin_store2::ScopedStore<R, crate::StoreKey> {
29+
self.scoped_store(crate::PLUGIN_NAME).unwrap()
30+
}
31+
2432
#[tracing::instrument(skip_all)]
2533
async fn api_base(&self) -> Option<String> {
2634
let state = self.state::<crate::SharedState>();
@@ -186,4 +194,18 @@ impl<R: Runtime, T: Manager<R>> LocalSttPluginExt<R> for T {
186194

187195
Ok(())
188196
}
197+
198+
#[tracing::instrument(skip_all)]
199+
fn get_current_model(&self) -> Result<crate::SupportedModel, crate::Error> {
200+
let store = self.local_stt_store();
201+
let model = store.get(crate::StoreKey::DefaultModel)?;
202+
Ok(model.unwrap_or(crate::SupportedModel::QuantizedLargeV3Turbo))
203+
}
204+
205+
#[tracing::instrument(skip_all)]
206+
fn set_current_model(&self, model: crate::SupportedModel) -> Result<(), crate::Error> {
207+
let store = self.local_stt_store();
208+
store.set(crate::StoreKey::DefaultModel, model)?;
209+
Ok(())
210+
}
189211
}

plugins/local-stt/src/lib.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@ mod error;
66
mod ext;
77
mod model;
88
mod server;
9+
mod store;
910

1011
pub use error::*;
1112
pub use ext::*;
1213
use model::*;
1314
use server::*;
15+
use store::*;
1416

1517
pub type SharedState = std::sync::Arc<tokio::sync::Mutex<State>>;
1618

@@ -41,6 +43,8 @@ fn make_specta_builder<R: tauri::Runtime>() -> tauri_specta::Builder<R> {
4143
commands::download_model::<Wry>,
4244
commands::start_server::<Wry>,
4345
commands::stop_server::<Wry>,
46+
commands::get_current_model::<Wry>,
47+
commands::set_current_model::<Wry>,
4448
])
4549
.error_handling(tauri_specta::ErrorHandlingMode::Throw)
4650
}
@@ -51,6 +55,11 @@ pub fn init<R: tauri::Runtime>() -> tauri::plugin::TauriPlugin<R> {
5155
tauri::plugin::Builder::new(PLUGIN_NAME)
5256
.invoke_handler(specta_builder.invoke_handler())
5357
.setup(|app, _api| {
58+
let mut state = State::default();
59+
state.model = app
60+
.get_current_model()
61+
.unwrap_or(SupportedModel::QuantizedLargeV3Turbo);
62+
5463
app.manage(SharedState::default());
5564
Ok(())
5665
})

plugins/local-stt/src/model.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#[derive(Debug, Clone)]
1+
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, specta::Type)]
22
pub enum SupportedModel {
33
QuantizedTiny,
44
QuantizedTinyEn,

plugins/local-stt/src/store.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
use tauri_plugin_store2::ScopedStoreKey;
2+
3+
#[derive(serde::Deserialize, specta::Type, PartialEq, Eq, Hash, strum::Display)]
4+
pub enum StoreKey {
5+
DefaultModel,
6+
}
7+
8+
impl ScopedStoreKey for StoreKey {}

0 commit comments

Comments
 (0)