Skip to content

Commit 50fda73

Browse files
ruby : add encoder begin callback related methods (ggml-org#3076)
* Lazy run TestBase.whisper * Fix indentation * Remove disused GGML_HIP_UMA from Ruby * Add encoder_begin_callback * Comment out existing abort mechanism * Add test for encoder_begin_callback * Add signatures for encoder_begin_callback related methods * Update gem date
1 parent 1c20f46 commit 50fda73

File tree

9 files changed

+192
-17
lines changed

9 files changed

+192
-17
lines changed

bindings/ruby/ext/options.rb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ def configure
114114
bool "GGML_HIP_GRAPHS"
115115
bool "GGML_HIP_NO_VMM"
116116
bool "GGML_HIP_ROCWMMA_FATTN"
117-
bool "GGML_HIP_UMA"
118117
ignored "GGML_INCLUDE_INSTALL_DIR"
119118
bool "GGML_KOMPUTE"
120119
bool "GGML_LASX"

bindings/ruby/ext/ruby_whisper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ typedef struct {
1919
bool diarize;
2020
ruby_whisper_callback_container *new_segment_callback_container;
2121
ruby_whisper_callback_container *progress_callback_container;
22+
ruby_whisper_callback_container *encoder_begin_callback_container;
2223
ruby_whisper_callback_container *abort_callback_container;
2324
} ruby_whisper_params;
2425

bindings/ruby/ext/ruby_whisper_params.c

Lines changed: 116 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \
2727
rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1);
2828

29-
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 30
29+
#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 32
3030

3131
extern VALUE cParams;
3232

@@ -63,6 +63,8 @@ static ID id_new_segment_callback;
6363
static ID id_new_segment_callback_user_data;
6464
static ID id_progress_callback;
6565
static ID id_progress_callback_user_data;
66+
static ID id_encoder_begin_callback;
67+
static ID id_encoder_begin_callback_user_data;
6668
static ID id_abort_callback;
6769
static ID id_abort_callback_user_data;
6870

@@ -126,6 +128,33 @@ static void progress_callback(struct whisper_context *ctx, struct whisper_state
126128
}
127129
}
128130

131+
static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_state *state, void *user_data) {
132+
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
133+
bool is_aborted = false;
134+
VALUE result;
135+
136+
// Currently, doesn't support state because
137+
// those require to resolve GC-related problems.
138+
if (!NIL_P(container->callback)) {
139+
result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data);
140+
if (result == Qfalse) {
141+
is_aborted = true;
142+
}
143+
}
144+
const long callbacks_len = RARRAY_LEN(container->callbacks);
145+
if (0 == callbacks_len) {
146+
return !is_aborted;
147+
}
148+
for (int j = 0; j < callbacks_len; j++) {
149+
VALUE cb = rb_ary_entry(container->callbacks, j);
150+
result = rb_funcall(cb, id_call, 0);
151+
if (result == Qfalse) {
152+
is_aborted = true;
153+
}
154+
}
155+
return !is_aborted;
156+
}
157+
129158
static bool abort_callback(void * user_data) {
130159
const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data;
131160
if (!NIL_P(container->callback)) {
@@ -161,6 +190,12 @@ void register_callbacks(ruby_whisper_params * rwp, VALUE * context) {
161190
rwp->params.progress_callback_user_data = rwp->progress_callback_container;
162191
}
163192

193+
if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) {
194+
rwp->encoder_begin_callback_container->context = context;
195+
rwp->params.encoder_begin_callback = encoder_begin_callback;
196+
rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container;
197+
}
198+
164199
if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) {
165200
rwp->abort_callback_container->context = context;
166201
rwp->params.abort_callback = abort_callback;
@@ -173,6 +208,7 @@ rb_whisper_params_mark(ruby_whisper_params *rwp)
173208
{
174209
rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container);
175210
rb_whisper_callbcack_container_mark(rwp->progress_callback_container);
211+
rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container);
176212
rb_whisper_callbcack_container_mark(rwp->abort_callback_container);
177213
}
178214

@@ -198,6 +234,7 @@ ruby_whisper_params_allocate(VALUE klass)
198234
rwp->diarize = false;
199235
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
200236
rwp->progress_callback_container = rb_whisper_callback_container_allocate();
237+
rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate();
201238
rwp->abort_callback_container = rb_whisper_callback_container_allocate();
202239
return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp);
203240
}
@@ -849,6 +886,57 @@ ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value)
849886
rwp->progress_callback_container->user_data = value;
850887
return value;
851888
}
889+
890+
static VALUE
891+
ruby_whisper_params_get_encoder_begin_callback(VALUE self)
892+
{
893+
ruby_whisper_params *rwp;
894+
Data_Get_Struct(self, ruby_whisper_params, rwp);
895+
return rwp->encoder_begin_callback_container->callback;
896+
}
897+
898+
/*
899+
* Sets encoder begin callback, called when the encoder starts.
900+
*
901+
* params.encoder_begin_callback = ->(context, _, user_data) {
902+
* # ...
903+
* }
904+
*
905+
* call-seq:
906+
* encoder_begin_callback = callback -> callback
907+
*/
908+
static VALUE
909+
ruby_whisper_params_set_encoder_begin_callback(VALUE self, VALUE value)
910+
{
911+
ruby_whisper_params *rwp;
912+
Data_Get_Struct(self, ruby_whisper_params, rwp);
913+
rwp->encoder_begin_callback_container->callback = value;
914+
return value;
915+
}
916+
917+
static VALUE
918+
ruby_whisper_params_get_encoder_begin_callback_user_data(VALUE self)
919+
{
920+
ruby_whisper_params *rwp;
921+
Data_Get_Struct(self, ruby_whisper_params, rwp);
922+
return rwp->encoder_begin_callback_container->user_data;
923+
}
924+
925+
/*
926+
* Sets user data passed to the last argument of encoder begin callback.
927+
*
928+
* call-seq:
929+
* encoder_begin_callback_user_data = user_data -> use_data
930+
*/
931+
static VALUE
932+
ruby_whisper_params_set_encoder_begin_callback_user_data(VALUE self, VALUE value)
933+
{
934+
ruby_whisper_params *rwp;
935+
Data_Get_Struct(self, ruby_whisper_params, rwp);
936+
rwp->encoder_begin_callback_container->user_data = value;
937+
return value;
938+
}
939+
852940
static VALUE
853941
ruby_whisper_params_get_abort_callback(VALUE self)
854942
{
@@ -958,6 +1046,8 @@ ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self)
9581046
SET_PARAM_IF_SAME(new_segment_callback_user_data)
9591047
SET_PARAM_IF_SAME(progress_callback)
9601048
SET_PARAM_IF_SAME(progress_callback_user_data)
1049+
SET_PARAM_IF_SAME(encoder_begin_callback)
1050+
SET_PARAM_IF_SAME(encoder_begin_callback_user_data)
9611051
SET_PARAM_IF_SAME(abort_callback)
9621052
SET_PARAM_IF_SAME(abort_callback_user_data)
9631053
}
@@ -1008,6 +1098,26 @@ ruby_whisper_params_on_progress(VALUE self)
10081098
return Qnil;
10091099
}
10101100

1101+
/*
1102+
* Hook called when the encoder starts.
1103+
*
1104+
* whisper.on_encoder_begin do
1105+
* # ...
1106+
* end
1107+
*
1108+
* call-seq:
1109+
* on_encoder_begin { ... }
1110+
*/
1111+
static VALUE
1112+
ruby_whisper_params_on_encoder_begin(VALUE self)
1113+
{
1114+
ruby_whisper_params *rws;
1115+
Data_Get_Struct(self, ruby_whisper_params, rws);
1116+
const VALUE blk = rb_block_proc();
1117+
rb_ary_push(rws->encoder_begin_callback_container->callbacks, blk);
1118+
return Qnil;
1119+
}
1120+
10111121
/*
10121122
* Call block to determine whether abort or not. Return +true+ when you want to abort.
10131123
*
@@ -1068,10 +1178,13 @@ init_ruby_whisper_params(VALUE *mWhisper)
10681178
DEFINE_PARAM(new_segment_callback_user_data, 25)
10691179
DEFINE_PARAM(progress_callback, 26)
10701180
DEFINE_PARAM(progress_callback_user_data, 27)
1071-
DEFINE_PARAM(abort_callback, 28)
1072-
DEFINE_PARAM(abort_callback_user_data, 29)
1181+
DEFINE_PARAM(encoder_begin_callback, 28)
1182+
DEFINE_PARAM(encoder_begin_callback_user_data, 29)
1183+
DEFINE_PARAM(abort_callback, 30)
1184+
DEFINE_PARAM(abort_callback_user_data, 31)
10731185

10741186
rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0);
10751187
rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0);
1188+
rb_define_method(cParams, "on_encoder_begin", ruby_whisper_params_on_encoder_begin, 0);
10761189
rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0);
10771190
}

bindings/ruby/ext/ruby_whisper_transcribe.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,16 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
5050
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
5151
return self;
5252
}
53-
{
54-
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
53+
// Commented out because it is work in progress
54+
// {
55+
// static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
5556

56-
rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
57-
bool is_aborted = *(bool*)user_data;
58-
return !is_aborted;
59-
};
60-
rwp->params.encoder_begin_callback_user_data = &is_aborted;
61-
}
57+
// rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
58+
// bool is_aborted = *(bool*)user_data;
59+
// return !is_aborted;
60+
// };
61+
// rwp->params.encoder_begin_callback_user_data = &is_aborted;
62+
// }
6263

6364
register_callbacks(rwp, &self);
6465

bindings/ruby/lib/whisper/model/uri.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def request(uri, headers)
5353
http.request request do |response|
5454
case response
5555
when Net::HTTPNotModified
56-
# noop
56+
# noop
5757
when Net::HTTPOK
5858
download response
5959
when Net::HTTPRedirection
@@ -68,7 +68,7 @@ def request(uri, headers)
6868
rescue => err
6969
if cache_path.exist?
7070
warn err
71-
# Use cache file
71+
# Use cache file
7272
else
7373
raise
7474
end

bindings/ruby/sig/whisper.rbs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ module Whisper
77
type log_callback = ^(Integer level, String message, Object user_data) -> void
88
type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void
99
type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void
10+
type encoder_begin_callback = ^(Whisper::Context, void, Object user_data) -> void
1011
type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish
1112

1213
LOG_LEVEL_NONE: Integer
@@ -146,6 +147,8 @@ module Whisper
146147
?new_segment_callback_user_data: Object,
147148
?progress_callback: progress_callback,
148149
?progress_callback_user_data: Object,
150+
?encoder_begin_callback: encoder_begin_callback,
151+
?encoder_begin_callback_user_data: Object,
149152
?abort_callback: abort_callback,
150153
?abort_callback_user_data: Object
151154
) -> instance
@@ -306,6 +309,18 @@ module Whisper
306309

307310
def progress_callback_user_data: () -> Object
308311

312+
# Sets encoder begin callback, called when the encoder starts.
313+
#
314+
def encoder_begin_callback=: (encoder_begin_callback) -> encoder_begin_callback
315+
316+
def encoder_begin_callback: () -> (encoder_begin_callback | nil)
317+
318+
# Sets user data passed to the last argument of encoder begin callback.
319+
#
320+
def encoder_begin_callback_user_data=: (Object) -> Object
321+
322+
def encoder_begin_callback_user_data: () -> Object
323+
309324
# Sets abort callback, called to check if the process should be aborted.
310325
#
311326
# params.abort_callback = ->(user_data) {
@@ -335,6 +350,10 @@ module Whisper
335350
#
336351
def on_progress: { (Integer progress) -> void } -> void
337352

353+
# Hook called on encoder starts.
354+
#
355+
def on_encoder_begin: { () -> void } -> void
356+
338357
# Call block to determine whether abort or not. Return +true+ when you want to abort.
339358
#
340359
# params.abort_on do

bindings/ruby/tests/helper.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@ class TestBase < Test::Unit::TestCase
66
AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
77

88
class << self
9-
attr_reader :whisper
9+
def whisper
10+
return @whisper if @whisper
1011

11-
def startup
1212
@whisper = Whisper::Context.new("base.en")
1313
params = Whisper::Params.new
1414
params.print_timestamps = false

bindings/ruby/tests/test_callback.rb

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,48 @@ def test_on_progress
111111
assert_equal 100, last
112112
end
113113

114+
def test_encoder_begin_callback
115+
i = 0
116+
@params.encoder_begin_callback = ->(context, state, user_data) {
117+
i += 1
118+
}
119+
@whisper.transcribe(@audio, @params)
120+
assert i > 0
121+
end
122+
123+
def test_encoder_begin_callback_abort
124+
logs = []
125+
Whisper.log_set -> (level, buffer, user_data) {
126+
logs << buffer if level == Whisper::LOG_LEVEL_ERROR
127+
}, logs
128+
@params.encoder_begin_callback = ->(context, state, user_data) {
129+
return false
130+
}
131+
@whisper.transcribe(@audio, @params)
132+
assert_match(/encoder_begin_callback returned false - aborting/, logs.join)
133+
Whisper.log_set ->(level, buffer, user_data) {}, nil
134+
end
135+
136+
def test_encoder_begin_callback_user_data
137+
udata = Object.new
138+
@params.encoder_begin_callback_user_data = udata
139+
yielded = nil
140+
@params.encoder_begin_callback = ->(context, state, user_data) {
141+
yielded = user_data
142+
}
143+
@whisper.transcribe(@audio, @params)
144+
assert_same udata, yielded
145+
end
146+
147+
def test_on_encoder_begin
148+
i = 0
149+
@params.on_encoder_begin do
150+
i += 1
151+
end
152+
@whisper.transcribe(@audio, @params)
153+
assert i > 0
154+
end
155+
114156
def test_abort_callback
115157
i = 0
116158
@params.abort_callback = ->(user_data) {

bindings/ruby/whispercpp.gemspec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Gem::Specification.new do |s|
44
s.name = "whispercpp"
55
s.authors = ["Georgi Gerganov", "Todd A. Fisher"]
66
s.version = '1.3.2'
7-
s.date = '2025-04-17'
7+
s.date = '2025-04-25'
88
s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby}
99
s.email = 'todd.fisher@gmail.com'
1010
s.extra_rdoc_files = ['LICENSE', 'README.md']

0 commit comments

Comments
 (0)