Skip to content

Commit 7a2daa1

Browse files
author
litongmacos
committed
add some function
1 parent 76ef40f commit 7a2daa1

File tree

6 files changed

+190
-67
lines changed

6 files changed

+190
-67
lines changed

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ add_executable(stream_local examples/stream_local.cpp common/common.cpp common/c
5959
)
6060
target_link_libraries(stream_local whisper ${SDL2_LIBRARIES})
6161

62-
add_executable(whisper_http_server_base_httplib whisper_http_server_base_httplib.cpp common/common.cpp httplib/httplib.h nlohmann/json.hpp handler/inference_handler.cpp params/whisper_params.cpp)
62+
add_executable(whisper_http_server_base_httplib whisper_http_server_base_httplib.cpp
63+
common/common.cpp common/utils.cpp handler/inference_handler.cpp params/whisper_params.cpp)
6364
target_link_libraries(whisper_http_server_base_httplib whisper)
6465

6566
add_executable(whisper_server_base_on_uwebsockets whisper_server_base_on_uwebsockets.cpp common/common.cpp stream/stream_components_service.cpp common/utils.cpp)

common/common.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,93 @@ read_wav(const std::string &fname, std::vector<float> &pcmf32, std::vector<std::
713713
return true;
714714
}
715715

716+
bool
717+
read_mp3(const std::string &fname, std::vector<float> &pcmf32, std::vector<std::vector<float>> &pcmf32s, bool stereo) {
718+
719+
}
720+
721+
bool
722+
read_m4a(const std::string &fname, std::vector<float> &pcmf32, std::vector<std::vector<float>> &pcmf32s, bool stereo) {
723+
drwav wav;
724+
std::vector<uint8_t> wav_data; // used for pipe input from stdin
725+
726+
if (fname == "-") {
727+
{
728+
uint8_t buf[1024];
729+
while (true) {
730+
const size_t n = fread(buf, 1, sizeof(buf), stdin);
731+
if (n == 0) {
732+
break;
733+
}
734+
wav_data.insert(wav_data.end(), buf, buf + n);
735+
}
736+
}
737+
738+
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
739+
fprintf(stderr, "error: failed to open WAV file from stdin\n");
740+
return false;
741+
}
742+
743+
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
744+
} else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) {
745+
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str());
746+
return false;
747+
}
748+
749+
if (wav.channels != 1 && wav.channels != 2) {
750+
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str());
751+
return false;
752+
}
753+
754+
if (stereo && wav.channels != 2) {
755+
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str());
756+
return false;
757+
}
758+
759+
if (wav.sampleRate != COMMON_SAMPLE_RATE) {
760+
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE / 1000);
761+
return false;
762+
}
763+
764+
if (wav.bitsPerSample != 16) {
765+
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str());
766+
return false;
767+
}
768+
769+
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size() /
770+
(wav.channels * wav.bitsPerSample / 8);
771+
772+
std::vector<int16_t> pcm16;
773+
pcm16.resize(n * wav.channels);
774+
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
775+
drwav_uninit(&wav);
776+
777+
// convert to mono, float
778+
pcmf32.resize(n);
779+
if (wav.channels == 1) {
780+
for (uint64_t i = 0; i < n; i++) {
781+
pcmf32[i] = float(pcm16[i]) / 32768.0f;
782+
}
783+
} else {
784+
for (uint64_t i = 0; i < n; i++) {
785+
pcmf32[i] = float(pcm16[2 * i] + pcm16[2 * i + 1]) / 65536.0f;
786+
}
787+
}
788+
789+
if (stereo) {
790+
// convert to stereo, float
791+
pcmf32s.resize(2);
792+
793+
pcmf32s[0].resize(n);
794+
pcmf32s[1].resize(n);
795+
for (uint64_t i = 0; i < n; i++) {
796+
pcmf32s[0][i] = float(pcm16[2 * i]) / 32768.0f;
797+
pcmf32s[1][i] = float(pcm16[2 * i + 1]) / 32768.0f;
798+
}
799+
}
800+
801+
return true;
802+
}
716803
void high_pass_filter(std::vector<float> &data, float cutoff, float sample_rate) {
717804
const float rc = 1.0f / (2.0f * M_PI * cutoff);
718805
const float dt = 1.0f / sample_rate;

common/common.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,10 @@ bool read_wav(
143143
std::vector<float> &pcmf32,
144144
std::vector<std::vector<float>> &pcmf32s,
145145
bool stereo);
146-
146+
bool
147+
read_mp3(const std::string &fname, std::vector<float> &pcmf32, std::vector<std::vector<float>> &pcmf32s, bool stereo);
148+
bool
149+
read_m4a(const std::string &fname, std::vector<float> &pcmf32, std::vector<std::vector<float>> &pcmf32s, bool stereo);
147150
// Write PCM data into WAV audio file
148151
class wav_writer {
149152
private:

handler/inference_handler.cpp

Lines changed: 95 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "../common/common.h"
44
#include "../params/whisper_params.h"
55
#include "../nlohmann/json.hpp"
6-
#include "common/utils.h"
6+
#include "../common/utils.h"
77

88
using json = nlohmann::json;
99

@@ -210,55 +210,31 @@ void getReqParameters(const Request &req, whisper_params &params) {
210210

211211
void getReqParameters(const Request &request, whisper_params &params);
212212

213-
void handleInference(const Request &req, Response &res, std::mutex &whisper_mutex, whisper_params &params,
214-
whisper_context *ctx, char *arg_audio_file) {
215-
// aquire whisper model mutex lock
216-
whisper_mutex.lock();
217-
218-
// first check user requested fields of the request
219-
if (!req.has_file("file")) {
220-
fprintf(stderr, "error: no 'file' field in the request\n");
221-
const std::string error_resp = "{\"error\":\"no 'file' field in the request\"}";
222-
res.set_content(error_resp, "application/json");
223-
whisper_mutex.unlock();
224-
return;
225-
}
226-
auto audio_file = req.get_file_value("file");
227-
228-
// check non-required fields
229-
getReqParameters(req, params);
213+
bool read_audio_file(std::string audio_format, std::string filename, std::vector<float> & pcmf32,
214+
std::vector<std::vector<float>> & pcmf32s, bool diarize) {
230215

231-
std::string filename{audio_file.filename};
232-
printf("%s: Received filename: %s,audio_format\n",get_current_time().c_str(),filename.c_str(),params.audio_format.c_str());
233-
234-
// audio arrays
235-
std::vector<float> pcmf32; // mono-channel F32 PCM
236-
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
237-
238-
// write file to temporary file
239-
std::ofstream temp_file{filename, std::ios::binary};
240-
temp_file << audio_file.content;
241-
242-
// read wav content into pcmf32
243-
if(params.audio_format=="mp3"){
244-
245-
}else if(params.audio_format=="m4a"){
246-
247-
}else{
248-
if (!::read_wav(filename, pcmf32, pcmf32s, params.diarize)) {
216+
// read audio content into pcmf32
217+
if (audio_format == "mp3") {
218+
if (!::read_mp3(filename, pcmf32, pcmf32s, diarize)) {
219+
fprintf(stderr, "error: failed to read mp3 file '%s'\n", filename.c_str());
220+
return false;
221+
}
222+
} else if (audio_format == "m4a") {
223+
if (!::read_m4a(filename, pcmf32, pcmf32s, diarize)) {
224+
fprintf(stderr, "error: failed to read m4a file '%s'\n", filename.c_str());
225+
return false;
226+
}
227+
} else {
228+
if (!::read_wav(filename, pcmf32, pcmf32s, diarize)) {
249229
fprintf(stderr, "error: failed to read WAV file '%s'\n", filename.c_str());
250-
const std::string error_resp = "{\"error\":\"failed to read WAV file\"}";
251-
res.set_content(error_resp, "application/json");
252-
whisper_mutex.unlock();
253-
return;
230+
return false;
254231
}
255232
}
233+
return true;
234+
}
256235

257-
// remove temp file
258-
std::remove(filename.c_str());
259-
260-
printf("Successfully loaded %s\n", filename.c_str());
261-
236+
bool run(std::mutex &whisper_mutex, whisper_params &params, whisper_context *ctx, std::string filename,
237+
const std::vector<std::vector<float>>& pcmf32s, std::vector<float> pcmf32) {
262238
// print system information
263239
{
264240
fprintf(stderr, "\n");
@@ -368,31 +344,87 @@ void handleInference(const Request &req, Response &res, std::mutex &whisper_mute
368344
wparams.abort_callback_user_data = &is_aborted;
369345
}
370346

347+
// aquire whisper model mutex lock
348+
whisper_mutex.lock();
371349
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
372-
fprintf(stderr, "%s: failed to process audio\n", arg_audio_file);
373-
const std::string error_resp = "{\"error\":\"failed to process audio\"}";
374-
res.set_content(error_resp, "application/json");
350+
fprintf(stderr, "%s: failed to process audio\n", filename.c_str());
375351
whisper_mutex.unlock();
376-
return;
352+
return false;
377353
}
354+
whisper_mutex.unlock();
355+
return true;
378356
}
357+
}
379358

380-
// return results to user
381-
if (params.response_format == text_format) {
382-
std::string results = output_str(ctx, params, pcmf32s);
383-
res.set_content(results.c_str(), "text/html");
384-
}
385-
// TODO add more output formats
386-
else {
387-
std::string results = output_str(ctx, params, pcmf32s);
359+
360+
void handleInference(const Request &request, Response &response, std::mutex &whisper_mutex, whisper_params &params,
361+
whisper_context *ctx, char *arg_audio_file) {
362+
// first check user requested fields of the request
363+
if (!request.has_file("file")) {
364+
fprintf(stderr, "error: no 'file' field in the request\n");
388365
json jres = json{
389-
{"text", results}
366+
{"code",-1},
367+
{"msg", "no 'file' field in the request"}
390368
};
391-
res.set_content(jres.dump(-1, ' ', false, json::error_handler_t::replace),
392-
"application/json");
369+
auto json_string = jres.dump(-1, ' ', false,json::error_handler_t::replace);
370+
response.set_content(json_string,"application/json");
371+
return;
393372
}
373+
auto audio_file = request.get_file_value("file");
394374

395-
// return whisper model mutex lock
396-
whisper_mutex.unlock();
397-
}
375+
// check non-required fields
376+
getReqParameters(request, params);
377+
378+
std::string filename{audio_file.filename};
379+
printf("%s: Received filename: %s,audio_format:%s \n",get_current_time().c_str(),filename.c_str(),params.audio_format.c_str());
380+
381+
// audio arrays
382+
std::vector<float> pcmf32; // mono-channel F32 PCM
383+
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
384+
385+
// write file to temporary file
386+
std::ofstream temp_file{filename, std::ios::binary};
387+
temp_file << audio_file.content;
398388

389+
bool isOK=read_audio_file(params.audio_format,filename,pcmf32,pcmf32s,params.diarize);
390+
if(!isOK){
391+
json json_obj={
392+
{"code",-1},
393+
{"msg","error: failed to read WAV file "}
394+
};
395+
auto json_string = json_obj.dump(-1, ' ', false,json::error_handler_t::replace);
396+
response.set_content(json_string, "application/json");
397+
return;
398+
}
399+
400+
// remove temp file
401+
std::remove(filename.c_str());
402+
403+
printf("Successfully loaded %s\n", filename.c_str());
404+
405+
bool isOk= run(whisper_mutex, params, ctx, filename, pcmf32s, pcmf32);
406+
if(isOk){
407+
// return results to user
408+
if (params.response_format == text_format) {
409+
std::string results = output_str(ctx, params, pcmf32s);
410+
response.set_content(results.c_str(), "text/html");
411+
}
412+
// TODO add more output formats
413+
else {
414+
std::string results = output_str(ctx, params, pcmf32s);
415+
json jres = json{
416+
{"code",0},
417+
{"text", results}
418+
};
419+
response.set_content(jres.dump(-1, ' ', false, json::error_handler_t::replace),
420+
"application/json");
421+
}
422+
}else{
423+
json jres = json{
424+
{"code",-1},
425+
{"msg", "run error"}
426+
};
427+
auto json_string = jres.dump(-1, ' ', false,json::error_handler_t::replace);
428+
response.set_content(json_string,"application/json");
429+
}
430+
}

handler/inference_handler.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55

66
using namespace httplib;
77

8-
void handleInference(const Request &req, Response &res, std::mutex &whisper_mutex, whisper_params &params,
8+
void handleInference(const Request &request, Response &response, std::mutex &whisper_mutex, whisper_params &params,
99
whisper_context *ctx, char *arg_audio_file);

params/whisper_params.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ struct whisper_params {
5050
std::string language = "en";
5151
std::string prompt = "";
5252
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
53-
std::string model = "models/ggml-base.en.bin";
53+
std::string model = "../models/ggml-base.en.bin";
5454

5555
std::string response_format = json_format;
5656

0 commit comments

Comments
 (0)