Skip to content

Commit 4e8e1dd

Browse files
author
litongmacos
committed
add resample
1 parent 0c21b4d commit 4e8e1dd

File tree

10 files changed

+141
-96
lines changed

10 files changed

+141
-96
lines changed

CMakeLists.txt

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ message(STATUS "SDL2 libraries: ${SDL2_LIBRARIES}")
1919

2020
include_directories(${SDL2_INCLUDE_DIRS})
2121

22+
find_package(SampleRate CONFIG REQUIRED)
23+
2224
# webrtc
2325
include_directories(webrtc)
2426
include_directories(.)
@@ -51,17 +53,17 @@ add_executable(sdl_version examples/sdl_version.cpp)
5153
target_link_libraries(sdl_version ${SDL2_LIBRARIES})
5254

5355
add_executable(simplest examples/simplest.cpp common/common.cpp)
54-
target_link_libraries(simplest whisper)
56+
target_link_libraries(simplest whisper SampleRate::samplerate)
5557

5658
add_executable(stream_local examples/stream_local.cpp common/common.cpp common/common-sdl.cpp
5759
stream/stream_components_service.cpp stream/stream_components_audio.cpp
5860
stream/stream_components_output.cpp
5961
)
60-
target_link_libraries(stream_local whisper ${SDL2_LIBRARIES})
62+
target_link_libraries(stream_local whisper ${SDL2_LIBRARIES} SampleRate::samplerate)
6163

6264
add_executable(whisper_http_server_base_httplib whisper_http_server_base_httplib.cpp
6365
common/common.cpp common/utils.cpp handler/inference_handler.cpp params/whisper_params.cpp)
64-
target_link_libraries(whisper_http_server_base_httplib whisper)
66+
target_link_libraries(whisper_http_server_base_httplib whisper SampleRate::samplerate)
6567

6668
add_executable(whisper_server_base_on_uwebsockets whisper_server_base_on_uwebsockets.cpp common/common.cpp stream/stream_components_service.cpp common/utils.cpp)
6769
#add uwebsockets head files
@@ -70,14 +72,16 @@ target_include_directories(whisper_server_base_on_uwebsockets PRIVATE ${UWEBSOCK
7072
# Detecting Operating Systems
7173
if (WIN32)
7274
# if Windows
73-
target_link_libraries(whisper_server_base_on_uwebsockets PRIVATE whisper ZLIB::ZLIB libuv::uv ${USOCKETS_LIBRARY})
75+
target_link_libraries(whisper_server_base_on_uwebsockets PRIVATE libuv::uv)
7476
elseif (APPLE)
7577
# if MacOS
76-
target_link_libraries(whisper_server_base_on_uwebsockets PRIVATE whisper ZLIB::ZLIB libuv::uv_a ${USOCKETS_LIBRARY})
78+
target_link_libraries(whisper_server_base_on_uwebsockets PRIVATE libuv::uv_a)
7779
else ()
7880
# if others eg. Linux
79-
target_link_libraries(whisper_server_base_on_uwebsockets PRIVATE whisper ZLIB::ZLIB libuv::uv ${USOCKETS_LIBRARY})
81+
target_link_libraries(whisper_server_base_on_uwebsockets PRIVATE libuv::uv)
8082
endif ()
8183

84+
target_link_libraries(whisper_server_base_on_uwebsockets PRIVATE whisper ZLIB::ZLIB ${USOCKETS_LIBRARY} SampleRate::samplerate)
85+
8286

8387

common/common.cpp

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,18 @@
88

99
#include "../dr_libs/dr_wav.h"
1010

11+
#define DR_MP3_IMPLEMENTATION
12+
13+
#include "dr_libs/dr_mp3.h"
14+
#include <samplerate.h>
1115
#include <cmath>
1216
#include <cstring>
1317
#include <fstream>
1418
#include <regex>
1519
#include <locale>
1620
#include <codecvt>
1721
#include <sstream>
18-
#define DR_MP3_IMPLEMENTATION
19-
#include "dr_libs/dr_mp3.h"
22+
2023
#if defined(_MSC_VER)
2124
#pragma warning(disable: 4244 4267) // possible loss of data
2225
#endif
@@ -631,6 +634,43 @@ gpt_vocab::id gpt_sample_top_k_top_p_repeat(
631634

632635
}
633636

637+
bool resample(const float *input, size_t inputSampleRate, size_t inputSize,
638+
std::vector<float> &output, size_t outputSampleRate) {
639+
// Initialize Converter
640+
int error;
641+
SRC_STATE *src_state = src_new(SRC_SINC_FASTEST, 1, &error);
642+
if (src_state == NULL) {
643+
fprintf(stderr,"error %s\n",src_strerror(error));
644+
return false;
645+
}
646+
647+
// set convert param
648+
SRC_DATA src_data;
649+
src_data.data_in = input;
650+
src_data.input_frames = inputSize;
651+
src_data.data_out = new float[inputSize]; // assign size
652+
src_data.output_frames = inputSize;
653+
src_data.src_ratio = double(outputSampleRate) / inputSampleRate;
654+
655+
// convert
656+
error = src_process(src_state, &src_data);
657+
if (error) {
658+
fprintf(stderr,"Error converting sample rate: %d",error);
659+
delete[] src_data.data_out;
660+
src_delete(src_state);
661+
return false;
662+
}
663+
664+
// Copy the transformed data into the output vector
665+
output.assign(src_data.data_out, src_data.data_out + src_data.output_frames_gen);
666+
667+
// clean
668+
delete[] src_data.data_out;
669+
src_delete(src_state);
670+
671+
return true;
672+
}
673+
634674
bool
635675
read_wav(const std::string &fname, std::vector<float> &pcmf32, std::vector<std::vector<float>> &pcmf32s, bool stereo) {
636676
drwav wav;
@@ -721,11 +761,6 @@ bool read_mp3(const std::string &fname, std::vector<float> &pcmf32, bool stereo)
721761
return false;
722762
}
723763

724-
if (mp3.sampleRate != COMMON_SAMPLE_RATE) {
725-
fprintf(stderr, "%s: MP3 file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE / 1000);
726-
return false;
727-
}
728-
729764
if (mp3.channels != 1 && mp3.channels != 2) {
730765
fprintf(stderr, "%s: MP3 file '%s' must be mono or stereo\n", __func__, fname.c_str());
731766
return false;
@@ -739,7 +774,19 @@ bool read_mp3(const std::string &fname, std::vector<float> &pcmf32, bool stereo)
739774
drmp3_uint64 frameCount;
740775
float *pSampleData = drmp3__full_read_and_close_f32(&mp3, nullptr, &frameCount);
741776

742-
pcmf32.assign(pSampleData, pSampleData + frameCount * mp3.channels);
777+
if (mp3.sampleRate != COMMON_SAMPLE_RATE) {
778+
std::vector<float> resampledData;
779+
if (!resample(pSampleData, mp3.sampleRate, frameCount * mp3.channels, resampledData, COMMON_SAMPLE_RATE)) {
780+
fprintf(stderr, "error: failed to resample MP3 data\n");
781+
drmp3_free(pSampleData, nullptr);
782+
return false;
783+
}
784+
785+
pcmf32.swap(resampledData); // 使用转换后的数据
786+
787+
} else {
788+
pcmf32.assign(pSampleData, pSampleData + frameCount * mp3.channels);
789+
}
743790
drmp3_free(pSampleData, nullptr);
744791

745792
return true;
@@ -827,6 +874,7 @@ read_m4a(const std::string &fname, std::vector<float> &pcmf32, std::vector<std::
827874

828875
return true;
829876
}
877+
830878
void high_pass_filter(std::vector<float> &data, float cutoff, float sample_rate) {
831879
const float rc = 1.0f / (2.0f * M_PI * cutoff);
832880
const float dt = 1.0f / sample_rate;

common/common.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ struct gpt_params {
4242
int32_t interactive_port = -1;
4343
};
4444

45+
4546
bool gpt_params_parse(int argc, char **argv, gpt_params &params);
4647

4748
void gpt_print_usage(int argc, char **argv, const gpt_params &params);
@@ -134,7 +135,8 @@ gpt_vocab::id gpt_sample_top_k_top_p_repeat(
134135
//
135136
// Audio utils
136137
//
137-
138+
bool resample(const float *input, size_t inputSampleRate, size_t inputSize,
139+
std::vector<float> &output, size_t outputSampleRate);
138140
// Read WAV audio file and store the PCM data into pcmf32
139141
// The sample rate of the audio must be equal to COMMON_SAMPLE_RATE
140142
// If stereo flag is set and the audio has 2 channels, the pcmf32s will contain 2 channel PCM

common/utils.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,23 @@ long get_current_time_millis(){
2828
return std::chrono::duration_cast<std::chrono::milliseconds>(start.time_since_epoch()).count();
2929
}
3030

31+
// 500 -> 00:05.000
32+
// 6000 -> 01:00.000
33+
std::string to_timestamp(int64_t t, bool comma) {
34+
int64_t msec = t * 10;
35+
int64_t hr = msec / (1000 * 60 * 60);
36+
msec = msec - hr * (1000 * 60 * 60);
37+
int64_t min = msec / (1000 * 60);
38+
msec = msec - min * (1000 * 60);
39+
int64_t sec = msec / 1000;
40+
msec = msec - sec * 1000;
41+
42+
char buf[32];
43+
snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
44+
45+
return std::string(buf);
46+
}
47+
3148
nlohmann::json get_result(whisper_context *ctx) {
3249
nlohmann::json results = nlohmann::json(nlohmann::json::array());
3350
const int n_segments = whisper_full_n_segments(ctx);
@@ -37,12 +54,11 @@ nlohmann::json get_result(whisper_context *ctx) {
3754
int64_t t1 = whisper_full_get_segment_t1(ctx, i);
3855
const char *sentence = whisper_full_get_segment_text(ctx, i);
3956
auto result = std::to_string(t0) + "-->" + std::to_string(t1) + ":" + sentence + "\n";
40-
printf("%s: result:%s\n", get_current_time().c_str(), result.c_str());
41-
segment["t0"] = t0;
42-
segment["t1"] = t1;
57+
//printf("%s: result:%s\n", get_current_time().c_str(), result.c_str());
58+
segment["t0"] = to_timestamp(t0);
59+
segment["t1"] = to_timestamp(t1);
4360
segment["sentence"] = sentence;
4461
results.push_back(segment);
4562
}
4663
return results;
4764
}
48-

common/utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@
66

77
std::string get_current_time();
88
long get_current_time_millis();
9+
std::string to_timestamp(int64_t t, bool comma = false);
910
nlohmann::json get_result(whisper_context *ctx);

examples/simplest.cpp

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "../common/common.h"
22

33
#include "whisper.h"
4+
#include "common/utils.h"
45

56
#include <cmath>
67
#include <cstdio>
@@ -20,23 +21,6 @@ const std::vector<std::string> k_colors = {
2021
"\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m",
2122
};
2223

23-
// 500 -> 00:05.000
24-
// 6000 -> 01:00.000
25-
std::string to_timestamp(int64_t t, bool comma = false) {
26-
int64_t msec = t * 10;
27-
int64_t hr = msec / (1000 * 60 * 60);
28-
msec = msec - hr * (1000 * 60 * 60);
29-
int64_t min = msec / (1000 * 60);
30-
msec = msec - min * (1000 * 60);
31-
int64_t sec = msec / 1000;
32-
msec = msec - sec * 1000;
33-
34-
char buf[32];
35-
snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
36-
37-
return std::string(buf);
38-
}
39-
4024
int timestamp_to_sample(int64_t t, int n_samples) {
4125
return std::max(0, std::min((int) n_samples - 1, (int) ((t * WHISPER_SAMPLE_RATE) / 100)));
4226
}

handler/inference_handler.cpp

Lines changed: 26 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -23,23 +23,6 @@ const std::vector<std::string> k_colors = {
2323
};
2424

2525

26-
// 500 -> 00:05.000
27-
// 6000 -> 01:00.000
28-
std::string to_timestamp(int64_t t, bool comma = false) {
29-
int64_t msec = t * 10;
30-
int64_t hr = msec / (1000 * 60 * 60);
31-
msec = msec - hr * (1000 * 60 * 60);
32-
int64_t min = msec / (1000 * 60);
33-
msec = msec - min * (1000 * 60);
34-
int64_t sec = msec / 1000;
35-
msec = msec - sec * 1000;
36-
37-
char buf[32];
38-
snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec);
39-
40-
return std::string(buf);
41-
}
42-
4326
int timestamp_to_sample(int64_t t, int n_samples) {
4427
return std::max(0, std::min((int) n_samples - 1, (int) ((t * WHISPER_SAMPLE_RATE) / 100)));
4528
}
@@ -202,20 +185,20 @@ void getReqParameters(const Request &req, whisper_params &params) {
202185
if (req.has_file("temerature")) {
203186
params.userdef_temp = std::stof(req.get_file_value("temperature").content);
204187
}
205-
if(req.has_file("audio_format")){
206-
params.audio_format=req.get_file_value("audio_format").content;
188+
if (req.has_file("audio_format")) {
189+
params.audio_format = req.get_file_value("audio_format").content;
207190
}
208191
}
209192

210193

211194
void getReqParameters(const Request &request, whisper_params &params);
212195

213-
bool read_audio_file(std::string audio_format, std::string filename, std::vector<float> & pcmf32,
214-
std::vector<std::vector<float>> & pcmf32s, bool diarize) {
196+
bool read_audio_file(std::string audio_format, std::string filename, std::vector<float> &pcmf32,
197+
std::vector<std::vector<float>> &pcmf32s, bool diarize) {
215198

216199
// read audio content into pcmf32
217200
if (audio_format == "mp3") {
218-
if (!::read_mp3(filename, pcmf32,diarize)) {
201+
if (!::read_mp3(filename, pcmf32, diarize)) {
219202
fprintf(stderr, "error: failed to read mp3 file '%s'\n", filename.c_str());
220203
return false;
221204
}
@@ -234,7 +217,7 @@ bool read_audio_file(std::string audio_format, std::string filename, std::vector
234217
}
235218

236219
bool run(std::mutex &whisper_mutex, whisper_params &params, whisper_context *ctx, std::string filename,
237-
const std::vector<std::vector<float>>& pcmf32s, std::vector<float> pcmf32) {
220+
const std::vector<std::vector<float>> &pcmf32s, std::vector<float> pcmf32) {
238221
// print system information
239222
{
240223
fprintf(stderr, "\n");
@@ -363,19 +346,19 @@ void handleInference(const Request &request, Response &response, std::mutex &whi
363346
if (!request.has_file("file")) {
364347
fprintf(stderr, "error: no 'file' field in the request\n");
365348
json jres = json{
366-
{"code",-1},
367-
{"msg", "no 'file' field in the request"}
349+
{"code", -1},
350+
{"msg", "no 'file' field in the request"}
368351
};
369-
auto json_string = jres.dump(-1, ' ', false,json::error_handler_t::replace);
370-
response.set_content(json_string,"application/json");
352+
auto json_string = jres.dump(-1, ' ', false, json::error_handler_t::replace);
353+
response.set_content(json_string, "application/json");
371354
return;
372355
}
373356
auto audio_file = request.get_file_value("file");
374357
std::string filename{audio_file.filename};
375-
printf("%s: Received filename: %s \n",get_current_time().c_str(),filename.c_str());
358+
printf("%s: Received filename: %s \n", get_current_time().c_str(), filename.c_str());
376359
// check non-required fields
377360
getReqParameters(request, params);
378-
printf("%s: audio_format:%s \n",get_current_time().c_str(),params.audio_format.c_str());
361+
printf("%s: audio_format:%s \n", get_current_time().c_str(), params.audio_format.c_str());
379362

380363
// audio arrays
381364
std::vector<float> pcmf32; // mono-channel F32 PCM
@@ -385,13 +368,13 @@ void handleInference(const Request &request, Response &response, std::mutex &whi
385368
std::ofstream temp_file{filename, std::ios::binary};
386369
temp_file << audio_file.content;
387370

388-
bool isOK=read_audio_file(params.audio_format,filename,pcmf32,pcmf32s,params.diarize);
389-
if(!isOK){
390-
json json_obj={
391-
{"code",-1},
392-
{"msg","error: failed to read WAV file "}
371+
bool isOK = read_audio_file(params.audio_format, filename, pcmf32, pcmf32s, params.diarize);
372+
if (!isOK) {
373+
json json_obj = {
374+
{"code", -1},
375+
{"msg", "error: failed to read WAV file "}
393376
};
394-
auto json_string = json_obj.dump(-1, ' ', false,json::error_handler_t::replace);
377+
auto json_string = json_obj.dump(-1, ' ', false, json::error_handler_t::replace);
395378
response.set_content(json_string, "application/json");
396379
return;
397380
}
@@ -401,8 +384,8 @@ void handleInference(const Request &request, Response &response, std::mutex &whi
401384

402385
printf("Successfully loaded %s\n", filename.c_str());
403386

404-
bool isOk= run(whisper_mutex, params, ctx, filename, pcmf32s, pcmf32);
405-
if(isOk){
387+
bool isOk = run(whisper_mutex, params, ctx, filename, pcmf32s, pcmf32);
388+
if (isOk) {
406389
// return results to user
407390
if (params.response_format == text_format) {
408391
std::string results = output_str(ctx, params, pcmf32s);
@@ -412,18 +395,18 @@ void handleInference(const Request &request, Response &response, std::mutex &whi
412395
else {
413396
auto results = get_result(ctx);
414397
json jres = json{
415-
{"code",0},
398+
{"code", 0},
416399
{"data", results}
417400
};
418401
response.set_content(jres.dump(-1, ' ', false, json::error_handler_t::replace),
419402
"application/json");
420403
}
421-
}else{
404+
} else {
422405
json jres = json{
423-
{"code",-1},
424-
{"msg", "run error"}
406+
{"code", -1},
407+
{"msg", "run error"}
425408
};
426-
auto json_string = jres.dump(-1, ' ', false,json::error_handler_t::replace);
427-
response.set_content(json_string,"application/json");
409+
auto json_string = jres.dump(-1, ' ', false, json::error_handler_t::replace);
410+
response.set_content(json_string, "application/json");
428411
}
429412
}

0 commit comments

Comments
 (0)