Skip to content

Commit db76e6e

Browse files
authored
Merge pull request #32 from mutablelogic/v1
V1
2 parents 315af70 + f1a5b14 commit db76e6e

File tree

12 files changed

+664
-248
lines changed

12 files changed

+664
-248
lines changed

pkg/whisper/context/context.go

Lines changed: 0 additions & 97 deletions
This file was deleted.

pkg/whisper/model/store.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package model
22

33
import (
44
"context"
5+
"encoding/json"
56
"errors"
67
"io/fs"
78
"net/http"
@@ -11,6 +12,7 @@ import (
1112
"sync"
1213

1314
// Packages
15+
1416
"github.com/mutablelogic/go-whisper/sys/whisper"
1517

1618
// Namespace imports
@@ -65,6 +67,36 @@ func NewStore(path, ext, modelUrl string) (*Store, error) {
6567
return store, nil
6668
}
6769

70+
//////////////////////////////////////////////////////////////////////////////
71+
// STRINGIFY
72+
73+
func (s *Store) MarshalJSON() ([]byte, error) {
74+
modelNames := func() []string {
75+
result := make([]string, len(s.models))
76+
for i, model := range s.models {
77+
result[i] = model.Id
78+
}
79+
return result
80+
}
81+
return json.Marshal(struct {
82+
Path string `json:"path"`
83+
Ext string `json:"ext,omitempty"`
84+
Models []string `json:"models"`
85+
}{
86+
Path: s.path,
87+
Ext: s.ext,
88+
Models: modelNames(),
89+
})
90+
}
91+
92+
func (s *Store) String() string {
93+
data, err := json.MarshalIndent(s, "", " ")
94+
if err != nil {
95+
return err.Error()
96+
}
97+
return string(data)
98+
}
99+
68100
//////////////////////////////////////////////////////////////////////////////
69101
// PUBLIC METHODS
70102

pkg/whisper/opt.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,18 @@ import (
1010

1111
type opts struct {
1212
MaxConcurrent int
13+
logfn LogFn
14+
debug bool
15+
gpu int
1316
}
1417

1518
type Opt func(*opts) error
19+
type LogFn func(string)
1620

1721
///////////////////////////////////////////////////////////////////////////////
1822
// PUBLIC METHODS
1923

24+
// Set maximum number of concurrent tasks
2025
func OptMaxConcurrent(v int) Opt {
2126
return func(o *opts) error {
2227
if v < 1 {
@@ -26,3 +31,27 @@ func OptMaxConcurrent(v int) Opt {
2631
return nil
2732
}
2833
}
34+
35+
// Set logging function
36+
func OptLog(fn LogFn) Opt {
37+
return func(o *opts) error {
38+
o.logfn = fn
39+
return nil
40+
}
41+
}
42+
43+
// Set debugging
44+
func OptDebug() Opt {
45+
return func(o *opts) error {
46+
o.debug = true
47+
return nil
48+
}
49+
}
50+
51+
// Disable GPU acceleration
52+
func OptNoGPU() Opt {
53+
return func(o *opts) error {
54+
o.gpu = -1
55+
return nil
56+
}
57+
}

pkg/whisper/pool/contextpool.go

Lines changed: 38 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
package pool
22

33
import (
4-
5-
// Packages
4+
"encoding/json"
65
"fmt"
7-
"path/filepath"
86

7+
// Packages
98
model "github.com/mutablelogic/go-whisper/pkg/whisper/model"
10-
"github.com/mutablelogic/go-whisper/sys/whisper"
9+
task "github.com/mutablelogic/go-whisper/pkg/whisper/task"
1110

1211
// Namespace imports
1312
. "github.com/djthorpe/go-errors"
@@ -23,26 +22,25 @@ type ContextPool struct {
2322

2423
// Base path for models
2524
path string
26-
}
2725

28-
// Context is used for running the transcription or translation
29-
type Context struct {
30-
Model *model.Model
31-
Context *whisper.Context
32-
Params whisper.FullParams
26+
// GPU flags
27+
gpu int
3328
}
3429

3530
//////////////////////////////////////////////////////////////////////////////
3631
// LIFECYCLE
3732

3833
// Create a new context pool of context objects, up to 'max' items
3934
// Set the path for the model storage
40-
func NewContextPool(path string, max int32) *ContextPool {
35+
// If GPU is -1 then disable, if 0 then use default, if >0 then enable
36+
// and use the specified device
37+
func NewContextPool(path string, max int, gpu int) *ContextPool {
4138
pool := new(ContextPool)
4239
pool.Pool = NewPool(max, func() any {
43-
return &Context{}
40+
return task.New()
4441
})
4542
pool.path = path
43+
pool.gpu = gpu
4644

4745
// Return success
4846
return pool
@@ -53,84 +51,63 @@ func (m *ContextPool) Close() error {
5351
return m.Pool.Close()
5452
}
5553

56-
// Init the context
57-
func (m *Context) Init(path string, model *model.Model) error {
58-
// Check parameters
59-
if model == nil {
60-
return ErrBadParameter
61-
}
62-
63-
// Get a context
64-
ctx := whisper.Whisper_init_from_file_with_params(filepath.Join(path, model.Path), whisper.DefaultContextParams())
65-
if ctx == nil {
66-
return ErrInternalAppError.With("whisper_init_from_file_with_params")
67-
}
68-
69-
// Set resources
70-
m.Context = ctx
71-
m.Model = model
72-
73-
// Return success
74-
return nil
54+
//////////////////////////////////////////////////////////////////////////////
55+
// STRINGIFY
56+
57+
func (m *ContextPool) MarshalJSON() ([]byte, error) {
58+
return json.Marshal(struct {
59+
Gpu int `json:"gpu"`
60+
N int `json:"n"`
61+
Max int `json:"max"`
62+
}{
63+
Gpu: m.gpu,
64+
N: m.N(),
65+
Max: m.max,
66+
})
7567
}
7668

77-
// Close the context and release all resources
78-
func (m *Context) Close() error {
79-
var result error
80-
81-
// Do nothing if nil
82-
if m == nil {
83-
return nil
69+
func (m *ContextPool) String() string {
70+
data, err := json.MarshalIndent(m, "", " ")
71+
if err != nil {
72+
return err.Error()
8473
}
85-
86-
// Release resources
87-
if m.Context != nil {
88-
whisper.Whisper_free(m.Context)
89-
}
90-
m.Context = nil
91-
m.Model = nil
92-
93-
// Return any errors
94-
return result
74+
return string(data)
9575
}
9676

9777
//////////////////////////////////////////////////////////////////////////////
9878
// PUBLIC METHODS
9979

10080
// Get a context from the pool, for a model
101-
func (m *ContextPool) Get(model *model.Model) (*Context, error) {
81+
func (m *ContextPool) Get(model *model.Model) (*task.Context, error) {
10282
// Check parameters
10383
if model == nil {
10484
return nil, ErrBadParameter
10585
}
10686

10787
// Get a context from the pool
108-
ctx, ok := m.Pool.Get().(*Context)
109-
if !ok || ctx == nil {
88+
t, ok := m.Pool.Get().(*task.Context)
89+
if !ok || t == nil {
11090
return nil, ErrChannelBlocked.With("unable to get a context from the pool, try again later")
11191
}
11292

113-
// If the model matches, return it
114-
if ctx.Model != nil && ctx.Model.Id == model.Id {
115-
return ctx, nil
116-
}
117-
118-
// Model didn't match: close the context
119-
if err := ctx.Close(); err != nil {
93+
// If the model matches, return it, or else release the resources
94+
if t.Is(model) {
95+
return t, nil
96+
} else if err := t.Close(); err != nil {
12097
return nil, err
12198
}
12299

123100
// Initialise the context
124-
if err := ctx.Init(m.path, model); err != nil {
101+
if err := t.Init(m.path, model, m.gpu); err != nil {
125102
return nil, err
126103
}
127104

128105
// Return the context
129-
return ctx, nil
106+
return t, nil
130107
}
131108

132109
// Put a context back into the pool
133-
func (m *ContextPool) Put(ctx *Context) {
110+
func (m *ContextPool) Put(ctx *task.Context) {
134111
m.Pool.Put(ctx)
135112
}
136113

pkg/whisper/pool/contextpool_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@ package pool_test
33
import (
44
"testing"
55

6+
// Packages
67
model "github.com/mutablelogic/go-whisper/pkg/whisper/model"
7-
"github.com/mutablelogic/go-whisper/pkg/whisper/pool"
8+
pool "github.com/mutablelogic/go-whisper/pkg/whisper/pool"
89
)
910

1011
func Test_contextpool_001(t *testing.T) {
11-
var pool = pool.NewContextPool(t.TempDir(), 2)
12+
var pool = pool.NewContextPool(t.TempDir(), 2, 0)
1213

1314
model1, err := pool.Get(&model.Model{
1415
Id: "model1",

0 commit comments

Comments
 (0)