Skip to content

Commit ed66517

Browse files
committed
Added Gemini
1 parent 4b771a6 commit ed66517

File tree

7 files changed

+291
-2
lines changed

7 files changed

+291
-2
lines changed

cmd/llm/main.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ type Globals struct {
3131
Anthropic `embed:"" help:"Anthropic configuration"`
3232
Mistral `embed:"" help:"Mistral configuration"`
3333
OpenAI `embed:"" help:"OpenAI configuration"`
34+
Gemini `embed:"" help:"Gemini configuration"`
3435

3536
// Tools
3637
NewsAPI `embed:"" help:"NewsAPI configuration"`
@@ -58,6 +59,10 @@ type OpenAI struct {
5859
OpenAIKey string `env:"OPENAI_API_KEY" help:"OpenAI API Key"`
5960
}
6061

62+
type Gemini struct {
63+
GeminiKey string `env:"GEMINI_API_KEY" help:"Gemini API Key"`
64+
}
65+
6166
type NewsAPI struct {
6267
NewsKey string `env:"NEWSAPI_KEY" help:"News API Key"`
6368
}
@@ -129,6 +134,9 @@ func main() {
129134
if cli.OpenAIKey != "" {
130135
opts = append(opts, agent.WithOpenAI(cli.OpenAIKey, clientopts...))
131136
}
137+
if cli.GeminiKey != "" {
138+
opts = append(opts, agent.WithGemini(cli.GeminiKey, clientopts...))
139+
}
132140

133141
// Make a toolkit
134142
toolkit := tool.NewToolKit()

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ require (
77
github.com/alecthomas/kong v1.7.0
88
github.com/djthorpe/go-errors v1.0.3
99
github.com/fatih/color v1.9.0
10+
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1
1011
github.com/mutablelogic/go-client v1.0.10
1112
github.com/stretchr/testify v1.10.0
1213
golang.org/x/term v0.28.0
1314
)
1415

1516
require (
1617
github.com/davecgh/go-spew v1.1.1 // indirect
17-
github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 // indirect
1818
github.com/mattn/go-colorable v0.1.4 // indirect
1919
github.com/mattn/go-isatty v0.0.11 // indirect
2020
github.com/mattn/go-runewidth v0.0.15 // indirect

pkg/agent/opt.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ import (
44
// Packages
55
client "github.com/mutablelogic/go-client"
66
llm "github.com/mutablelogic/go-llm"
7-
"github.com/mutablelogic/go-llm/pkg/ollama"
7+
gemini "github.com/mutablelogic/go-llm/pkg/gemini"
8+
ollama "github.com/mutablelogic/go-llm/pkg/ollama"
89
openai "github.com/mutablelogic/go-llm/pkg/openai"
910
)
1011

@@ -56,3 +57,14 @@ func WithOpenAI(key string, opts ...client.ClientOpt) llm.Opt {
5657
}
5758
}
5859
}
60+
61+
func WithGemini(key string, opts ...client.ClientOpt) llm.Opt {
62+
return func(o *llm.Opts) error {
63+
client, err := gemini.New(key, opts...)
64+
if err != nil {
65+
return err
66+
} else {
67+
return llm.WithAgent(client)(o)
68+
}
69+
}
70+
}

pkg/gemini/client.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
gemini implements an API client for Google's Gemini LLM (https://ai.google.dev/gemini-api/docs)
3+
*/
4+
package gemini
5+
6+
import (
7+
8+
// Packages
9+
client "github.com/mutablelogic/go-client"
10+
llm "github.com/mutablelogic/go-llm"
11+
)
12+
13+
///////////////////////////////////////////////////////////////////////////////
14+
// TYPES
15+
16+
type Client struct {
17+
*client.Client
18+
cache map[string]llm.Model
19+
}
20+
21+
var _ llm.Agent = (*Client)(nil)
22+
23+
///////////////////////////////////////////////////////////////////////////////
24+
// GLOBALS
25+
26+
const (
27+
endPoint = "https://generativelanguage.googleapis.com/v1beta"
28+
defaultName = "gemini"
29+
)
30+
31+
///////////////////////////////////////////////////////////////////////////////
32+
// LIFECYCLE
33+
34+
// Create a new client
35+
func New(ApiKey string, opts ...client.ClientOpt) (*Client, error) {
36+
// Create client
37+
opts = append(opts, client.OptEndpoint(endPointWithKey(endPoint, ApiKey)))
38+
client, err := client.New(opts...)
39+
if err != nil {
40+
return nil, err
41+
}
42+
43+
// Return the client
44+
return &Client{client, nil}, nil
45+
}
46+
47+
///////////////////////////////////////////////////////////////////////////////
48+
// PUBLIC METHODS
49+
50+
// Return the name of the agent
51+
func (Client) Name() string {
52+
return defaultName
53+
}
54+
55+
///////////////////////////////////////////////////////////////////////////////
56+
// PRIVATE METHODS
57+
58+
func endPointWithKey(endpoint, key string) string {
59+
return endpoint + "?key=" + key
60+
}

pkg/gemini/client_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
package gemini_test
2+
3+
import (
4+
"flag"
5+
"log"
6+
"os"
7+
"strconv"
8+
"testing"
9+
10+
// Packages
11+
opts "github.com/mutablelogic/go-client"
12+
gemini "github.com/mutablelogic/go-llm/pkg/gemini"
13+
assert "github.com/stretchr/testify/assert"
14+
)
15+
16+
///////////////////////////////////////////////////////////////////////////////
17+
// TEST SET-UP
18+
19+
var (
20+
client *gemini.Client
21+
)
22+
23+
func TestMain(m *testing.M) {
24+
var verbose bool
25+
26+
// Verbose output
27+
flag.Parse()
28+
if f := flag.Lookup("test.v"); f != nil {
29+
if v, err := strconv.ParseBool(f.Value.String()); err == nil {
30+
verbose = v
31+
}
32+
}
33+
34+
// API KEY
35+
api_key := os.Getenv("GEMINI_API_KEY")
36+
if api_key == "" {
37+
log.Print("GEMINI_API_KEY not set")
38+
os.Exit(0)
39+
}
40+
41+
// Create client
42+
var err error
43+
client, err = gemini.New(api_key, opts.OptTrace(os.Stderr, verbose))
44+
if err != nil {
45+
log.Println(err)
46+
os.Exit(-1)
47+
}
48+
os.Exit(m.Run())
49+
}
50+
51+
///////////////////////////////////////////////////////////////////////////////
52+
// TESTS
53+
54+
func Test_client_001(t *testing.T) {
55+
assert := assert.New(t)
56+
assert.NotNil(client)
57+
t.Log(client)
58+
}

pkg/gemini/model.go

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
package gemini
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
7+
"github.com/mutablelogic/go-client"
8+
"github.com/mutablelogic/go-llm"
9+
)
10+
11+
///////////////////////////////////////////////////////////////////////////////
12+
// TYPES
13+
14+
type model struct {
15+
*Client `json:"-"`
16+
meta Model
17+
}
18+
19+
var _ llm.Model = (*model)(nil)
20+
21+
type Model struct {
22+
Name string `json:"name"`
23+
Version string `json:"version"`
24+
DisplayName string `json:"displayName"`
25+
Description string `json:"description"`
26+
InputTokenLimit uint64 `json:"inputTokenLimit"`
27+
OutputTokenLimit uint64 `json:"outputTokenLimit"`
28+
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
29+
Temperature float64 `json:"temperature"`
30+
TopP float64 `json:"topP"`
31+
TopK uint64 `json:"topK"`
32+
}
33+
34+
///////////////////////////////////////////////////////////////////////////////
35+
// STRINGIFY
36+
37+
func (m model) MarshalJSON() ([]byte, error) {
38+
return json.Marshal(m.meta)
39+
}
40+
41+
func (m model) String() string {
42+
data, err := json.MarshalIndent(m, "", " ")
43+
if err != nil {
44+
return err.Error()
45+
}
46+
return string(data)
47+
}
48+
49+
///////////////////////////////////////////////////////////////////////////////
50+
// PUBLIC METHODS - llm.Model implementation
51+
52+
// Return model name
53+
func (m model) Name() string {
54+
return m.meta.Name
55+
}
56+
57+
// Return the models
58+
func (gemini *Client) Models(ctx context.Context) ([]llm.Model, error) {
59+
// Cache models
60+
if gemini.cache == nil {
61+
models, err := gemini.ListModels(ctx)
62+
if err != nil {
63+
return nil, err
64+
}
65+
gemini.cache = make(map[string]llm.Model, len(models))
66+
for _, m := range models {
67+
gemini.cache[m.Name] = &model{gemini, m}
68+
}
69+
}
70+
71+
// Return models
72+
result := make([]llm.Model, 0, len(gemini.cache))
73+
for _, model := range gemini.cache {
74+
result = append(result, model)
75+
}
76+
return result, nil
77+
}
78+
79+
// Return a model by name, or nil if not found.
80+
// Panics on error.
81+
func (openai *Client) Model(ctx context.Context, name string) llm.Model {
82+
if openai.cache == nil {
83+
if _, err := openai.Models(ctx); err != nil {
84+
panic(err)
85+
}
86+
}
87+
return openai.cache[name]
88+
}
89+
90+
///////////////////////////////////////////////////////////////////////////////
91+
// PUBLIC METHODS - API
92+
93+
// ListModels returns all the models
94+
func (c *Client) ListModels(ctx context.Context) ([]Model, error) {
95+
// Response
96+
var response struct {
97+
Data []Model `json:"models"`
98+
}
99+
if err := c.DoWithContext(ctx, nil, &response, client.OptPath("models")); err != nil {
100+
return nil, err
101+
}
102+
103+
// Return success
104+
return response.Data, nil
105+
}
106+
107+
///////////////////////////////////////////////////////////////////////////////
108+
// PUBLIC METHODS - MODEL
109+
110+
// Return am empty session context object for the model,
111+
// setting session options
112+
func (m model) Context(...llm.Opt) llm.Context {
113+
return nil
114+
}
115+
116+
// Create a completion from a text prompt
117+
func (m model) Completion(context.Context, string, ...llm.Opt) (llm.Completion, error) {
118+
return nil, llm.ErrNotImplemented
119+
}
120+
121+
// Create a completion from a chat session
122+
func (m model) Chat(context.Context, []llm.Completion, ...llm.Opt) (llm.Completion, error) {
123+
return nil, llm.ErrNotImplemented
124+
}
125+
126+
// Embedding vector generation
127+
func (m model) Embedding(context.Context, string, ...llm.Opt) ([]float64, error) {
128+
return nil, llm.ErrNotImplemented
129+
}

pkg/gemini/model_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package gemini_test
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"testing"
7+
8+
// Packages
9+
assert "github.com/stretchr/testify/assert"
10+
)
11+
12+
func Test_models_001(t *testing.T) {
13+
assert := assert.New(t)
14+
15+
response, err := client.ListModels(context.TODO())
16+
assert.NoError(err)
17+
assert.NotEmpty(response)
18+
19+
data, err := json.MarshalIndent(response, "", " ")
20+
assert.NoError(err)
21+
t.Log(string(data))
22+
}

0 commit comments

Comments
 (0)