Skip to content

Commit b505539

Browse files
authored
node : add language detection support (#3190)
This commit add support for language detection in the Whisper Node.js addon example. It also updates the node addon to return an object instead of an array as the results. The motivation for this change is to enable the inclusion of the detected language in the result, in addition to the transcription segments. For example, when using the `detect_language` option, the result will now be: ```console { language: 'en' } ``` And if the `language` option is set to "auto", it will also return: ```console { language: 'en', transcription: [ [ '00:00:00.000', '00:00:07.600', ' And so my fellow Americans, ask not what your country can do for you,' ], [ '00:00:07.600', '00:00:10.600', ' ask what you can do for your country.' ] ] } ```
1 parent 7fd6fa8 commit b505539

File tree

3 files changed

+45
-15
lines changed

3 files changed

+45
-15
lines changed

examples/addon.node/__test__/whisper.spec.js

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ const whisperParamsMock = {
1717
comma_in_time: false,
1818
translate: true,
1919
no_timestamps: false,
20+
detect_language: false,
2021
audio_ctx: 0,
2122
max_len: 0,
2223
prompt: "",
@@ -30,8 +31,9 @@ const whisperParamsMock = {
3031
describe("Run whisper.node", () => {
3132
test("it should receive a non-empty value", async () => {
3233
let result = await whisperAsync(whisperParamsMock);
34+
console.log(result);
3335

34-
expect(result.length).toBeGreaterThan(0);
36+
expect(result['transcription'].length).toBeGreaterThan(0);
3537
}, 10000);
3638
});
3739

examples/addon.node/addon.cpp

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ struct whisper_params {
3838
bool print_progress = false;
3939
bool no_timestamps = false;
4040
bool no_prints = false;
41+
bool detect_language= false;
4142
bool use_gpu = true;
4243
bool flash_attn = false;
4344
bool comma_in_time = true;
@@ -130,6 +131,11 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
130131

131132
void cb_log_disable(enum ggml_log_level, const char *, void *) {}
132133

134+
struct whisper_result {
135+
std::vector<std::vector<std::string>> segments;
136+
std::string language;
137+
};
138+
133139
class ProgressWorker : public Napi::AsyncWorker {
134140
public:
135141
ProgressWorker(Napi::Function& callback, whisper_params params, Napi::Function progress_callback, Napi::Env env)
@@ -160,15 +166,27 @@ class ProgressWorker : public Napi::AsyncWorker {
160166

161167
void OnOK() override {
162168
Napi::HandleScope scope(Env());
163-
Napi::Object res = Napi::Array::New(Env(), result.size());
164-
for (uint64_t i = 0; i < result.size(); ++i) {
169+
170+
if (params.detect_language) {
171+
Napi::Object resultObj = Napi::Object::New(Env());
172+
resultObj.Set("language", Napi::String::New(Env(), result.language));
173+
Callback().Call({Env().Null(), resultObj});
174+
}
175+
176+
Napi::Object returnObj = Napi::Object::New(Env());
177+
if (!result.language.empty()) {
178+
returnObj.Set("language", Napi::String::New(Env(), result.language));
179+
}
180+
Napi::Array transcriptionArray = Napi::Array::New(Env(), result.segments.size());
181+
for (uint64_t i = 0; i < result.segments.size(); ++i) {
165182
Napi::Object tmp = Napi::Array::New(Env(), 3);
166183
for (uint64_t j = 0; j < 3; ++j) {
167-
tmp[j] = Napi::String::New(Env(), result[i][j]);
184+
tmp[j] = Napi::String::New(Env(), result.segments[i][j]);
168185
}
169-
res[i] = tmp;
170-
}
171-
Callback().Call({Env().Null(), res});
186+
transcriptionArray[i] = tmp;
187+
}
188+
returnObj.Set("transcription", transcriptionArray);
189+
Callback().Call({Env().Null(), returnObj});
172190
}
173191

174192
// Progress callback function - using thread-safe function
@@ -185,12 +203,12 @@ class ProgressWorker : public Napi::AsyncWorker {
185203

186204
private:
187205
whisper_params params;
188-
std::vector<std::vector<std::string>> result;
206+
whisper_result result;
189207
Napi::Env env;
190208
Napi::ThreadSafeFunction tsfn;
191209

192210
// Custom run function with progress callback support
193-
int run_with_progress(whisper_params &params, std::vector<std::vector<std::string>> &result) {
211+
int run_with_progress(whisper_params &params, whisper_result & result) {
194212
if (params.no_prints) {
195213
whisper_log_set(cb_log_disable, NULL);
196214
}
@@ -279,7 +297,8 @@ class ProgressWorker : public Napi::AsyncWorker {
279297
wparams.print_timestamps = !params.no_timestamps;
280298
wparams.print_special = params.print_special;
281299
wparams.translate = params.translate;
282-
wparams.language = params.language.c_str();
300+
wparams.language = params.detect_language ? "auto" : params.language.c_str();
301+
wparams.detect_language = params.detect_language;
283302
wparams.n_threads = params.n_threads;
284303
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
285304
wparams.offset_ms = params.offset_t_ms;
@@ -330,18 +349,22 @@ class ProgressWorker : public Napi::AsyncWorker {
330349
return 10;
331350
}
332351
}
333-
}
352+
}
334353

354+
if (params.detect_language || params.language == "auto") {
355+
result.language = whisper_lang_str(whisper_full_lang_id(ctx));
356+
}
335357
const int n_segments = whisper_full_n_segments(ctx);
336-
result.resize(n_segments);
358+
result.segments.resize(n_segments);
359+
337360
for (int i = 0; i < n_segments; ++i) {
338361
const char * text = whisper_full_get_segment_text(ctx, i);
339362
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
340363
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
341364

342-
result[i].emplace_back(to_timestamp(t0, params.comma_in_time));
343-
result[i].emplace_back(to_timestamp(t1, params.comma_in_time));
344-
result[i].emplace_back(text);
365+
result.segments[i].emplace_back(to_timestamp(t0, params.comma_in_time));
366+
result.segments[i].emplace_back(to_timestamp(t1, params.comma_in_time));
367+
result.segments[i].emplace_back(text);
345368
}
346369

347370
whisper_print_timings(ctx);
@@ -366,6 +389,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
366389
bool flash_attn = whisper_params.Get("flash_attn").As<Napi::Boolean>();
367390
bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
368391
bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
392+
bool detect_language = whisper_params.Get("detect_language").As<Napi::Boolean>();
369393
int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
370394
bool comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
371395
int32_t max_len = whisper_params.Get("max_len").As<Napi::Number>();
@@ -418,6 +442,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
418442
params.max_context = max_context;
419443
params.print_progress = print_progress;
420444
params.prompt = prompt;
445+
params.detect_language = detect_language;
421446

422447
Napi::Function callback = info[1].As<Napi::Function>();
423448
// Create a new Worker class with progress callback support

examples/addon.node/index.js

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ const whisperParams = {
1717
comma_in_time: false,
1818
translate: true,
1919
no_timestamps: false,
20+
detect_language: false,
2021
audio_ctx: 0,
2122
max_len: 0,
2223
progress_callback: (progress) => {
@@ -31,6 +32,8 @@ const params = Object.fromEntries(
3132
const [key, value] = item.slice(2).split("=");
3233
if (key === "audio_ctx") {
3334
whisperParams[key] = parseInt(value);
35+
} else if (key === "detect_language") {
36+
whisperParams[key] = value === "true";
3437
} else {
3538
whisperParams[key] = value;
3639
}

0 commit comments

Comments
 (0)