Skip to content

Commit 001871d

Browse files
committed
Re-organized schema
1 parent 314da78 commit 001871d

File tree

16 files changed

+176
-152
lines changed

16 files changed

+176
-152
lines changed

Makefile

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,18 @@ OS ?= $(shell uname | tr A-Z a-z)
99
VERSION ?= $(shell git describe --tags --always | sed 's/^v//')
1010
DOCKER_REGISTRY ?= ghcr.io/mutablelogic
1111

12-
# Set docker tag
12+
# Set docker tag, etc
1313
BUILD_TAG := ${DOCKER_REGISTRY}/go-whisper-${OS}-${ARCH}:${VERSION}
1414
ROOT_PATH := $(CURDIR)
1515
BUILD_DIR := build
16+
17+
# Build flags
18+
BUILD_MODULE := $(shell cat go.mod | head -1 | cut -d ' ' -f 2)
19+
BUILD_LD_FLAGS += -X $(BUILD_MODULE)/pkg/whisper/version.GitSource=${BUILD_MODULE}
20+
BUILD_LD_FLAGS += -X $(BUILD_MODULE)/pkg/whisper/version.GitTag=$(shell git describe --tags --always)
21+
BUILD_LD_FLAGS += -X $(BUILD_MODULE)/pkg/whisper/version.GitBranch=$(shell git name-rev HEAD --name-only --always)
22+
BUILD_LD_FLAGS += -X $(BUILD_MODULE)/pkg/whisper/version.GitHash=$(shell git rev-parse HEAD)
23+
BUILD_LD_FLAGS += -X $(BUILD_MODULE)/pkg/whisper/version.GoBuildTime=$(shell date -u '+%Y-%m-%dT%H:%M:%SZ')
1624
BUILD_FLAGS = -ldflags "-s -w $(BUILD_LD_FLAGS)"
1725

1826
# If GGML_CUDA is set, then add a cuda tag for the go ${BUILD FLAGS}
@@ -51,15 +59,11 @@ docker: docker-dep submodule
5159
-f etc/Dockerfile.${ARCH} .
5260

5361
# Test whisper bindings
54-
test: go-tidy libwhisper libggml
62+
test: generate libwhisper libggml
5563
@echo "Running tests (sys)"
56-
@CGO_CFLAGS="-I${ROOT_PATH}/third_party/whisper.cpp/include -I${ROOT_PATH}/third_party/whisper.cpp/ggml/include" \
57-
CGO_LDFLAGS="-L${ROOT_PATH}/third_party/whisper.cpp" \
58-
${GO} test -v ./sys/whisper/...
64+
@PKG_CONFIG_PATH=${ROOT_PATH}/${BUILD_DIR} ${GO} test -v ./sys/whisper/...
5965
@echo "Running tests (pkg)"
60-
@CGO_CFLAGS="-I${ROOT_PATH}/third_party/whisper.cpp/include -I${ROOT_PATH}/third_party/whisper.cpp/ggml/include" \
61-
CGO_LDFLAGS="-L${ROOT_PATH}/third_party/whisper.cpp" \
62-
${GO} test -v ./pkg/whisper/...
66+
@PKG_CONFIG_PATH=${ROOT_PATH}/${BUILD_DIR} ${GO} test -v ./pkg/whisper/...
6367

6468
# Build whisper-static-library
6569
libwhisper: submodule

pkg/whisper/api/models.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@ import (
1313
"github.com/mutablelogic/go-server/pkg/httprequest"
1414
"github.com/mutablelogic/go-server/pkg/httpresponse"
1515
"github.com/mutablelogic/go-whisper/pkg/whisper"
16-
"github.com/mutablelogic/go-whisper/pkg/whisper/model"
16+
"github.com/mutablelogic/go-whisper/pkg/whisper/schema"
1717
)
1818

1919
///////////////////////////////////////////////////////////////////////////////
2020
// TYPES
2121

2222
type respModels struct {
23-
Object string `json:"object,omitempty"`
24-
Models []*model.Model `json:"models"`
23+
Object string `json:"object,omitempty"`
24+
Models []*schema.Model `json:"models"`
2525
}
2626

2727
type reqDownloadModel struct {

pkg/whisper/api/transcribe.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ import (
1111
"github.com/mutablelogic/go-server/pkg/httprequest"
1212
"github.com/mutablelogic/go-server/pkg/httpresponse"
1313
"github.com/mutablelogic/go-whisper/pkg/whisper"
14+
"github.com/mutablelogic/go-whisper/pkg/whisper/schema"
1415
"github.com/mutablelogic/go-whisper/pkg/whisper/segmenter"
1516
"github.com/mutablelogic/go-whisper/pkg/whisper/task"
16-
"github.com/mutablelogic/go-whisper/pkg/whisper/transcription"
1717

1818
// Namespace imports
1919
. "github.com/djthorpe/go-errors"
@@ -76,7 +76,7 @@ func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.Respon
7676
}
7777

7878
// Get context for the model, perform transcription
79-
var result *transcription.Transcription
79+
var result *schema.Transcription
8080
if err := service.WithModel(model, func(task *task.Context) error {
8181
// Check model
8282
if translate && !task.CanTranslate() {
@@ -100,7 +100,7 @@ func TranscribeFile(ctx context.Context, service *whisper.Whisper, w http.Respon
100100
// Read samples and transcribe them
101101
if err := segmenter.Decode(ctx, func(ts time.Duration, buf []float32) error {
102102
// Perform the transcription, return any errors
103-
return task.Transcribe(ctx, ts, buf, req.OutputSegments(), func(segment *transcription.Segment) {
103+
return task.Transcribe(ctx, ts, buf, req.OutputSegments(), func(segment *schema.Segment) {
104104
fmt.Println("TODO: ", segment)
105105
})
106106
}); err != nil {

pkg/whisper/client/client.go

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ import (
1111
"github.com/mutablelogic/go-client"
1212
"github.com/mutablelogic/go-client/pkg/multipart"
1313
"github.com/mutablelogic/go-server/pkg/httprequest"
14-
"github.com/mutablelogic/go-whisper/pkg/whisper/model"
15-
"github.com/mutablelogic/go-whisper/pkg/whisper/transcription"
14+
"github.com/mutablelogic/go-whisper/pkg/whisper/schema"
1615
)
1716

1817
///////////////////////////////////////////////////////////////////////////////
@@ -38,9 +37,9 @@ func New(endpoint string, opts ...client.ClientOpt) (*Client, error) {
3837
///////////////////////////////////////////////////////////////////////////////
3938
// MODELS
4039

41-
func (c *Client) ListModels(ctx context.Context) ([]model.Model, error) {
40+
func (c *Client) ListModels(ctx context.Context) ([]schema.Model, error) {
4241
var models struct {
43-
Models []model.Model `json:"models"`
42+
Models []schema.Model `json:"models"`
4443
}
4544
if err := c.DoWithContext(ctx, client.MethodGet, &models, client.OptPath("models")); err != nil {
4645
return nil, err
@@ -53,12 +52,12 @@ func (c *Client) DeleteModel(ctx context.Context, model string) error {
5352
return c.DoWithContext(ctx, client.MethodDelete, nil, client.OptPath("models", model))
5453
}
5554

56-
func (c *Client) DownloadModel(ctx context.Context, path string, fn func(status string, cur, total int64)) (model.Model, error) {
55+
func (c *Client) DownloadModel(ctx context.Context, path string, fn func(status string, cur, total int64)) (schema.Model, error) {
5756
var req struct {
5857
Path string `json:"path"`
5958
}
6059
type resp struct {
61-
model.Model
60+
schema.Model
6261
Status string `json:"status"`
6362
Total int64 `json:"total,omitempty"`
6463
Completed int64 `json:"completed,omitempty"`
@@ -75,7 +74,7 @@ func (c *Client) DownloadModel(ctx context.Context, path string, fn func(status
7574

7675
var r resp
7776
if payload, err := client.NewJSONRequest(req); err != nil {
78-
return model.Model{}, err
77+
return schema.Model{}, err
7978
} else if err := c.DoWithContext(ctx, payload, &r,
8079
client.OptPath("models"),
8180
client.OptQuery(query),
@@ -87,20 +86,20 @@ func (c *Client) DownloadModel(ctx context.Context, path string, fn func(status
8786
return nil
8887
}),
8988
); err != nil {
90-
return model.Model{}, err
89+
return schema.Model{}, err
9190
}
9291

9392
// Return success
9493
return r.Model, nil
9594
}
9695

97-
func (c *Client) Transcribe(ctx context.Context, model string, r io.Reader, opt ...Opt) (*transcription.Transcription, error) {
96+
func (c *Client) Transcribe(ctx context.Context, model string, r io.Reader, opt ...Opt) (*schema.Transcription, error) {
9897
var request struct {
9998
File multipart.File `json:"file"`
10099
Model string `json:"model"`
101100
opts
102101
}
103-
var response transcription.Transcription
102+
var response schema.Transcription
104103

105104
// Get the name from the io.Reader
106105
name := ""
@@ -131,13 +130,13 @@ func (c *Client) Transcribe(ctx context.Context, model string, r io.Reader, opt
131130
return &response, nil
132131
}
133132

134-
func (c *Client) Translate(ctx context.Context, model string, r io.Reader, opt ...Opt) (*transcription.Transcription, error) {
133+
func (c *Client) Translate(ctx context.Context, model string, r io.Reader, opt ...Opt) (*schema.Transcription, error) {
135134
var request struct {
136135
File multipart.File `json:"file"`
137136
Model string `json:"model"`
138137
opts
139138
}
140-
var response transcription.Transcription
139+
var response schema.Transcription
141140

142141
// Get the name from the io.Reader
143142
name := ""

pkg/whisper/model/store.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ import (
1212
"sync"
1313

1414
// Packages
15-
16-
"github.com/mutablelogic/go-whisper/sys/whisper"
15+
schema "github.com/mutablelogic/go-whisper/pkg/whisper/schema"
16+
whisper "github.com/mutablelogic/go-whisper/sys/whisper"
1717

1818
// Namespace imports
1919
. "github.com/djthorpe/go-errors"
@@ -29,7 +29,7 @@ type Store struct {
2929
path, ext string
3030

3131
// list of all models
32-
models []*Model
32+
models []*schema.Model
3333

3434
// download models
3535
client whisper.Client
@@ -101,7 +101,7 @@ func (s *Store) String() string {
101101
// PUBLIC METHODS
102102

103103
// Return the models
104-
func (s *Store) List() []*Model {
104+
func (s *Store) List() []*schema.Model {
105105
s.RLock()
106106
defer s.RUnlock()
107107
return s.models
@@ -120,7 +120,7 @@ func (s *Store) Rescan() error {
120120
}
121121

122122
// Return a model by its Id
123-
func (s *Store) ById(name string) *Model {
123+
func (s *Store) ById(name string) *schema.Model {
124124
s.RLock()
125125
defer s.RUnlock()
126126
name = modelNameToId(name)
@@ -133,7 +133,7 @@ func (s *Store) ById(name string) *Model {
133133
}
134134

135135
// Return a model by path
136-
func (s *Store) ByPath(path string) *Model {
136+
func (s *Store) ByPath(path string) *schema.Model {
137137
s.RLock()
138138
defer s.RUnlock()
139139
for _, model := range s.models {
@@ -177,7 +177,7 @@ func (s *Store) Delete(id string) error {
177177
//
178178
// A function can be provided to track the progress of the download. If no Content-Length is
179179
// provided by the server, the total bytes will be unknown and is set to zero.
180-
func (s *Store) Download(ctx context.Context, path string, fn func(curBytes, totalBytes uint64)) (*Model, error) {
180+
func (s *Store) Download(ctx context.Context, path string, fn func(curBytes, totalBytes uint64)) (*schema.Model, error) {
181181
// abspath should be contained within the models directory
182182
abspath := filepath.Clean(filepath.Join(s.path, path))
183183
if !strings.HasPrefix(abspath, s.path) {
@@ -255,8 +255,8 @@ func toError(err error) error {
255255
return err
256256
}
257257

258-
func listModels(path, ext string) ([]*Model, error) {
259-
result := make([]*Model, 0, 100)
258+
func listModels(path, ext string) ([]*schema.Model, error) {
259+
result := make([]*schema.Model, 0, 100)
260260

261261
// Walk filesystem
262262
return result, fs.WalkDir(os.DirFS(path), ".", func(path string, d fs.DirEntry, err error) error {
@@ -292,7 +292,7 @@ func listModels(path, ext string) ([]*Model, error) {
292292
}
293293

294294
// Get model information
295-
model := new(Model)
295+
model := new(schema.Model)
296296
model.Object = "model"
297297
model.Path = path
298298
model.Created = info.ModTime().Unix()

pkg/whisper/pool/contextpool.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ import (
55
"fmt"
66

77
// Packages
8-
model "github.com/mutablelogic/go-whisper/pkg/whisper/model"
8+
9+
"github.com/mutablelogic/go-whisper/pkg/whisper/schema"
910
task "github.com/mutablelogic/go-whisper/pkg/whisper/task"
1011

1112
// Namespace imports
@@ -78,7 +79,7 @@ func (m *ContextPool) String() string {
7879
// PUBLIC METHODS
7980

8081
// Get a context from the pool, for a model
81-
func (m *ContextPool) Get(model *model.Model) (*task.Context, error) {
82+
func (m *ContextPool) Get(model *schema.Model) (*task.Context, error) {
8283
// Check parameters
8384
if model == nil {
8485
return nil, ErrBadParameter
@@ -112,7 +113,7 @@ func (m *ContextPool) Put(ctx *task.Context) {
112113
}
113114

114115
// Drain the pool of all contexts for a model, freeing resources
115-
func (m *ContextPool) Drain(model *model.Model) error {
116+
func (m *ContextPool) Drain(model *schema.Model) error {
116117
fmt.Println("TODO: DRAIN", model.Id)
117118
return nil
118119
}

pkg/whisper/pool/contextpool_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@ import (
44
"testing"
55

66
// Packages
7-
model "github.com/mutablelogic/go-whisper/pkg/whisper/model"
87
pool "github.com/mutablelogic/go-whisper/pkg/whisper/pool"
8+
schema "github.com/mutablelogic/go-whisper/pkg/whisper/schema"
99
)
1010

1111
func Test_contextpool_001(t *testing.T) {
1212
var pool = pool.NewContextPool(t.TempDir(), 2, 0)
1313

14-
model1, err := pool.Get(&model.Model{
14+
model1, err := pool.Get(&schema.Model{
1515
Id: "model1",
1616
})
1717
if err != nil {
@@ -22,7 +22,7 @@ func Test_contextpool_001(t *testing.T) {
2222
}
2323
t.Log("Got model1", model1)
2424

25-
model2, err := pool.Get(&model.Model{
25+
model2, err := pool.Get(&schema.Model{
2626
Id: "model2",
2727
})
2828
if err != nil {
@@ -35,7 +35,7 @@ func Test_contextpool_001(t *testing.T) {
3535

3636
pool.Put(model1)
3737

38-
model3, err := pool.Get(&model.Model{
38+
model3, err := pool.Get(&schema.Model{
3939
Id: "model1",
4040
})
4141
if err != nil {

pkg/whisper/model/model.go renamed to pkg/whisper/schema/model.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package model
1+
package schema
22

33
import (
44
"encoding/json"

pkg/whisper/schema/segment.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package schema
2+
3+
import (
4+
"encoding/json"
5+
"time"
6+
)
7+
8+
//////////////////////////////////////////////////////////////////////////////
9+
// TYPES
10+
11+
type Segment struct {
12+
Id int32 `json:"id"`
13+
Start time.Duration `json:"start"`
14+
End time.Duration `json:"end"`
15+
Text string `json:"text"`
16+
SpeakerTurn bool `json:"speaker_turn,omitempty"` // TODO
17+
}
18+
19+
//////////////////////////////////////////////////////////////////////////////
20+
// STRINGIFY
21+
22+
func (s *Segment) String() string {
23+
data, err := json.MarshalIndent(s, "", " ")
24+
if err != nil {
25+
return err.Error()
26+
}
27+
return string(data)
28+
}

pkg/whisper/schema/transcription.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package schema
2+
3+
import (
4+
"encoding/json"
5+
"time"
6+
)
7+
8+
//////////////////////////////////////////////////////////////////////////////
9+
// TYPES
10+
11+
type Transcription struct {
12+
Task string `json:"task,omitempty"`
13+
Language string `json:"language,omitempty" writer:",width:8"`
14+
Duration time.Duration `json:"duration,omitempty" writer:",width:8,right"`
15+
Text string `json:"text" writer:",width:60,wrap"`
16+
Segments []*Segment `json:"segments,omitempty" writer:",width:40,wrap"`
17+
}
18+
19+
//////////////////////////////////////////////////////////////////////////////
20+
// STRINGIFY
21+
22+
func (t *Transcription) String() string {
23+
data, err := json.MarshalIndent(t, "", " ")
24+
if err != nil {
25+
return err.Error()
26+
}
27+
return string(data)
28+
}

0 commit comments

Comments
 (0)