1
1
package pool
2
2
3
3
import (
4
-
5
- // Packages
4
+ "encoding/json"
6
5
"fmt"
7
- "path/filepath"
8
6
7
+ // Packages
9
8
model "github.com/mutablelogic/go-whisper/pkg/whisper/model"
10
- "github.com/mutablelogic/go-whisper/sys /whisper"
9
+ task "github.com/mutablelogic/go-whisper/pkg /whisper/task "
11
10
12
11
// Namespace imports
13
12
. "github.com/djthorpe/go-errors"
@@ -23,26 +22,25 @@ type ContextPool struct {
23
22
24
23
// Base path for models
25
24
path string
26
- }
27
25
28
- // Context is used for running the transcription or translation
29
- type Context struct {
30
- Model * model.Model
31
- Context * whisper.Context
32
- Params whisper.FullParams
26
+ // GPU flags
27
+ gpu int
33
28
}
34
29
35
30
//////////////////////////////////////////////////////////////////////////////
36
31
// LIFECYCLE
37
32
38
33
// Create a new context pool of context objects, up to 'max' items
39
34
// Set the path for the model storage
40
- func NewContextPool (path string , max int32 ) * ContextPool {
35
+ // If GPU is -1 then disable, if 0 then use default, if >0 then enable
36
+ // and use the specified device
37
+ func NewContextPool (path string , max int , gpu int ) * ContextPool {
41
38
pool := new (ContextPool )
42
39
pool .Pool = NewPool (max , func () any {
43
- return & Context {}
40
+ return task . New ()
44
41
})
45
42
pool .path = path
43
+ pool .gpu = gpu
46
44
47
45
// Return success
48
46
return pool
@@ -53,84 +51,63 @@ func (m *ContextPool) Close() error {
53
51
return m .Pool .Close ()
54
52
}
55
53
56
- // Init the context
57
- func (m * Context ) Init (path string , model * model.Model ) error {
58
- // Check parameters
59
- if model == nil {
60
- return ErrBadParameter
61
- }
62
-
63
- // Get a context
64
- ctx := whisper .Whisper_init_from_file_with_params (filepath .Join (path , model .Path ), whisper .DefaultContextParams ())
65
- if ctx == nil {
66
- return ErrInternalAppError .With ("whisper_init_from_file_with_params" )
67
- }
68
-
69
- // Set resources
70
- m .Context = ctx
71
- m .Model = model
72
-
73
- // Return success
74
- return nil
54
+ //////////////////////////////////////////////////////////////////////////////
55
+ // STRINGIFY
56
+
57
+ func (m * ContextPool ) MarshalJSON () ([]byte , error ) {
58
+ return json .Marshal (struct {
59
+ Gpu int `json:"gpu"`
60
+ N int `json:"n"`
61
+ Max int `json:"max"`
62
+ }{
63
+ Gpu : m .gpu ,
64
+ N : m .N (),
65
+ Max : m .max ,
66
+ })
75
67
}
76
68
77
- // Close the context and release all resources
78
- func (m * Context ) Close () error {
79
- var result error
80
-
81
- // Do nothing if nil
82
- if m == nil {
83
- return nil
69
+ func (m * ContextPool ) String () string {
70
+ data , err := json .MarshalIndent (m , "" , " " )
71
+ if err != nil {
72
+ return err .Error ()
84
73
}
85
-
86
- // Release resources
87
- if m .Context != nil {
88
- whisper .Whisper_free (m .Context )
89
- }
90
- m .Context = nil
91
- m .Model = nil
92
-
93
- // Return any errors
94
- return result
74
+ return string (data )
95
75
}
96
76
97
77
//////////////////////////////////////////////////////////////////////////////
98
78
// PUBLIC METHODS
99
79
100
80
// Get a context from the pool, for a model
101
- func (m * ContextPool ) Get (model * model.Model ) (* Context , error ) {
81
+ func (m * ContextPool ) Get (model * model.Model ) (* task. Context , error ) {
102
82
// Check parameters
103
83
if model == nil {
104
84
return nil , ErrBadParameter
105
85
}
106
86
107
87
// Get a context from the pool
108
- ctx , ok := m .Pool .Get ().(* Context )
109
- if ! ok || ctx == nil {
88
+ t , ok := m .Pool .Get ().(* task. Context )
89
+ if ! ok || t == nil {
110
90
return nil , ErrChannelBlocked .With ("unable to get a context from the pool, try again later" )
111
91
}
112
92
113
- // If the model matches, return it
114
- if ctx .Model != nil && ctx .Model .Id == model .Id {
115
- return ctx , nil
116
- }
117
-
118
- // Model didn't match: close the context
119
- if err := ctx .Close (); err != nil {
93
+ // If the model matches, return it, or else release the resources
94
+ if t .Is (model ) {
95
+ return t , nil
96
+ } else if err := t .Close (); err != nil {
120
97
return nil , err
121
98
}
122
99
123
100
// Initialise the context
124
- if err := ctx .Init (m .path , model ); err != nil {
101
+ if err := t .Init (m .path , model , m . gpu ); err != nil {
125
102
return nil , err
126
103
}
127
104
128
105
// Return the context
129
- return ctx , nil
106
+ return t , nil
130
107
}
131
108
132
109
// Put a context back into the pool
133
- func (m * ContextPool ) Put (ctx * Context ) {
110
+ func (m * ContextPool ) Put (ctx * task. Context ) {
134
111
m .Pool .Put (ctx )
135
112
}
136
113
0 commit comments