Skip to content

Commit b827ecf

Browse files
committed
Updated mistral and docs
1 parent 48e4c42 commit b827ecf

File tree

10 files changed

+725
-33
lines changed

10 files changed

+725
-33
lines changed

README.md

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# go-llm
22

33
Large Language Model API interface. This is a simple API interface for large language models
4-
which run on [Ollama](https://github.com/ollama/ollama/blob/main/docs/api.md)
5-
and [Anthopic](https://docs.anthropic.com/en/api/getting-started).
4+
which run on [Ollama](https://github.com/ollama/ollama/blob/main/docs/api.md),
5+
[Anthopic](https://docs.anthropic.com/en/api/getting-started) and [Mistral](https://docs.mistral.ai/).
66

77
The module includes the ability to utilize:
88

99
* Maintaining a session of messages
1010
* Tool calling support
11+
* Creating embeddings from text
1112
* Streaming responses
1213

1314
There is a command-line tool included in the module which can be used to interact with the API.
@@ -28,7 +29,12 @@ docker run \
2829
## Programmatic Usage
2930

3031
See the documentation [here](https://pkg.go.dev/github.com/mutablelogic/go-llm)
31-
for integration into your own Go programs. To create an
32+
for integration into your own Go programs.
33+
34+
### Agent Instantiation
35+
36+
For each LLM provider, you create an agent which can be used to interact with the API.
37+
To create an
3238
[Ollama](https://pkg.go.dev/github.com/mutablelogic/go-llm/pkg/anthropic)
3339
agent,
3440

@@ -66,6 +72,25 @@ func main() {
6672
}
6773
```
6874

75+
For Mistral models, you can use:
76+
77+
```go
78+
import (
79+
"github.com/mutablelogic/go-llm/pkg/mistral"
80+
)
81+
82+
func main() {
83+
// Create a new agent
84+
agent, err := mistral.New(os.Getev("MISTRAL_API_KEY"))
85+
if err != nil {
86+
panic(err)
87+
}
88+
// ...
89+
}
90+
```
91+
92+
### Chat Sessions
93+
6994
You create a **chat session** with a model as follows,
7095

7196
```go
@@ -90,6 +115,34 @@ func session(ctx context.Context, agent llm.Agent) error {
90115
}
91116
```
92117

118+
## Options
119+
120+
You can add options to sessions, or to prompts. Different providers and models support
121+
different options.
122+
123+
| Option | Ollama | Anthropic | Mistral | OpenAI | Description |
124+
|--------|--------|-----------|---------|--------|-------------|
125+
| `llm.WithTemperature(float64)` | Yes | Yes | Yes | - | What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.7 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. |
126+
| `llm.WithTopP(float64)` | Yes | Yes | Yes | - | Nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. |
127+
| `llm.WithTopK(uint64)` | Yes | Yes | No | - | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. |
128+
| `llm.WithMaxTokens(uint64)` | - | Yes | Yes | - | The maximum number of tokens to generate in the response. |
129+
| `llm.WithStream(func(llm.ContextContent))` | Can be enabled when tools are not used | Yes | Yes | - | Stream the response to a function. |
130+
| `llm.WithToolChoice(string, string, ...)` | No | Yes | Use `auto`, `any`, `none`, `required` or a function name. Only the first argument is used. | - | The tool to use for the model. |
131+
| `llm.WithToolKit(llm.ToolKit)` | Cannot be combined with streaming | Yes | Yes | - | The set of tools to use. |
132+
| `llm.WithStopSequence(string, string, ...)` | Yes | Yes | Yes | - | Stop generation if one of these tokens is detected. |
133+
| `llm.WithSystemPrompt(string)` | No | Yes | Yes | - | Set the system prompt for the model. |
134+
| `llm.WithSeed(uint64)` | No | Yes | Yes | - | The seed to use for random sampling. If set, different calls will generate deterministic results. |
135+
| `llm.WithFormat(string)` | No | Yes | Use `json_format` or `text` | - | The format of the response. For Mistral, you must also instruct the model to produce JSON yourself with a system or a user message. |
136+
| `mistral.WithPresencePenalty(float64)` | - | - | Yes | - | Determines how much the model penalizes the repetition of words or phrases. A higher presence penalty encourages the model to use a wider variety of words and phrases, making the output more diverse and creative. |
137+
| `mistral.WithFequencyPenalty(float64)` | - | - | Yes | - | Penalizes the repetition of words based on their frequency in the generated text. A higher frequency penalty discourages the model from repeating words that have already appeared frequently in the output, promoting diversity and reducing repetition. |
138+
| `mistral.WithPrediction(string)` | - | - | Yes | - | Enable users to specify expected results, optimizing response times by leveraging known or predictable content. This approach is especially effective for updating text documents or code files with minimal changes, reducing latency while maintaining high-quality results. |
139+
| `llm.WithSafePrompt()` | - | - | Yes | - | Whether to inject a safety prompt before all conversations. |
140+
| `llm.WithNumCompletions(uint64)` | - | - | Yes | - | Number of completions to return for each request. |
141+
| `llm.WithAttachment(io.Reader)` | Yes | Yes | Yes | - | Attach a file to a user prompt. It is the responsibility of the caller to close the reader. |
142+
| `antropic.WithEphemeral()` | No | Yes | No | - | Attachments should be cached server-side |
143+
| `antropic.WithCitations()` | No | Yes | No | - | Attachments should be used in citations |
144+
| `antropic.WithUser(string)` | No | Yes | No | - | Indicate the user name for the request, for debugging |
145+
93146
## Contributing & Distribution
94147

95148
*This module is currently in development and subject to change*. Please do file

opt.go

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,17 +216,76 @@ func WithTopP(v float64) Opt {
216216

217217
// Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more
218218
// diverse answers, while a lower value (e.g. 10) will be more conservative.
219-
func WithTopK(v uint) Opt {
219+
func WithTopK(v uint64) Opt {
220220
return func(o *Opts) error {
221221
o.Set("top_k", v)
222222
return nil
223223
}
224224
}
225225

226+
// The maximum number of tokens to generate in the completion.
227+
func WithMaxTokens(v uint64) Opt {
228+
return func(o *Opts) error {
229+
o.Set("max_tokens", v)
230+
return nil
231+
}
232+
}
233+
226234
// Set system prompt
227235
func WithSystemPrompt(v string) Opt {
228236
return func(o *Opts) error {
229237
o.system = v
230238
return nil
231239
}
232240
}
241+
242+
// Set stop sequence
243+
func WithStopSequence(v ...string) Opt {
244+
return func(o *Opts) error {
245+
o.Set("stop", v)
246+
return nil
247+
}
248+
}
249+
250+
// Set random seed for deterministic behavior
251+
func WithSeed(v uint64) Opt {
252+
return func(o *Opts) error {
253+
o.Set("seed", v)
254+
return nil
255+
}
256+
}
257+
258+
// Set format
259+
func WithFormat(v any) Opt {
260+
return func(o *Opts) error {
261+
o.Set("format", v)
262+
return nil
263+
}
264+
}
265+
266+
// Set tool choices: can be auto, none, required, any or a list of tool names
267+
func WithToolChoice(v ...string) Opt {
268+
return func(o *Opts) error {
269+
o.Set("tool_choice", v)
270+
return nil
271+
}
272+
}
273+
274+
// Number of completions to return for each request
275+
func WithNumCompletions(v uint64) Opt {
276+
return func(o *Opts) error {
277+
if v < 1 || v > 8 {
278+
return ErrBadParameter.With("num_completions must be between 1 and 8")
279+
}
280+
o.Set("num_completions", v)
281+
return nil
282+
}
283+
}
284+
285+
// Inject a safety prompt before all conversations.
286+
func WithSafePrompt() Opt {
287+
return func(o *Opts) error {
288+
o.Set("safe_prompt", true)
289+
return nil
290+
}
291+
}

pkg/anthropic/opt.go

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,13 @@ type optmetadata struct {
1717
////////////////////////////////////////////////////////////////////////////////
1818
// OPTIONS
1919

20-
func WithMaxTokens(v uint) llm.Opt {
21-
return func(o *llm.Opts) error {
22-
o.Set("max_tokens", v)
23-
return nil
24-
}
25-
}
26-
2720
func WithUser(v string) llm.Opt {
2821
return func(o *llm.Opts) error {
2922
o.Set("user", v)
3023
return nil
3124
}
3225
}
3326

34-
func WithStopSequences(v ...string) llm.Opt {
35-
return func(o *llm.Opts) error {
36-
o.Set("stop", v)
37-
return nil
38-
}
39-
}
40-
4127
func WithEphemeral() llm.Opt {
4228
return func(o *llm.Opts) error {
4329
o.Set("ephemeral", true)

pkg/mistral/chat_completion.go

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package mistral
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
7+
"github.com/mutablelogic/go-client"
8+
"github.com/mutablelogic/go-llm"
9+
)
10+
11+
///////////////////////////////////////////////////////////////////////////////
12+
// TYPES
13+
14+
// Chat Completion Response
15+
type Response struct {
16+
Id string `json:"id"`
17+
Type string `json:"object"`
18+
Created uint64 `json:"created"`
19+
Model string `json:"model"`
20+
Choices []Choice `json:"choices"`
21+
Metrics `json:"usage,omitempty"`
22+
}
23+
24+
// Response variation
25+
type Choice struct {
26+
Index uint64 `json:"index"`
27+
Message MessageMeta `json:"message"`
28+
Reason string `json:"finish_reason,omitempty"`
29+
}
30+
31+
// Metrics
32+
type Metrics struct {
33+
InputTokens uint64 `json:"prompt_tokens,omitempty"`
34+
OutputTokens uint `json:"completion_tokens,omitempty"`
35+
TotalTokens uint `json:"total_tokens,omitempty"`
36+
}
37+
38+
///////////////////////////////////////////////////////////////////////////////
39+
// STRINGIFY
40+
41+
func (r Response) String() string {
42+
data, err := json.MarshalIndent(r, "", " ")
43+
if err != nil {
44+
return err.Error()
45+
}
46+
return string(data)
47+
}
48+
49+
///////////////////////////////////////////////////////////////////////////////
50+
// PUBLIC METHODS
51+
52+
type reqChatCompletion struct {
53+
Model string `json:"model"`
54+
Temperature float64 `json:"temperature,omitempty"`
55+
TopP float64 `json:"top_p,omitempty"`
56+
MaxTokens uint64 `json:"max_tokens,omitempty"`
57+
Stream bool `json:"stream,omitempty"`
58+
StopSequences []string `json:"stop,omitempty"`
59+
Seed uint64 `json:"random_seed,omitempty"`
60+
Messages []*MessageMeta `json:"messages"`
61+
Format any `json:"response_format,omitempty"`
62+
Tools []llm.Tool `json:"tools,omitempty"`
63+
ToolChoice any `json:"tool_choice,omitempty"`
64+
PresencePenalty float64 `json:"presence_penalty,omitempty"`
65+
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
66+
NumChoices uint64 `json:"n,omitempty"`
67+
Prediction *Content `json:"prediction,omitempty"`
68+
SafePrompt bool `json:"safe_prompt,omitempty"`
69+
}
70+
71+
func (mistral *Client) ChatCompletion(ctx context.Context, context llm.Context, opts ...llm.Opt) (*Response, error) {
72+
// Apply options
73+
opt, err := llm.ApplyOpts(opts...)
74+
if err != nil {
75+
return nil, err
76+
}
77+
78+
// Append the system prompt at the beginning
79+
seq := make([]*MessageMeta, 0, len(context.(*session).seq)+1)
80+
if system := opt.SystemPrompt(); system != "" {
81+
seq = append(seq, systemPrompt(system))
82+
}
83+
seq = append(seq, context.(*session).seq...)
84+
85+
// Request
86+
req, err := client.NewJSONRequest(reqChatCompletion{
87+
Model: context.(*session).model.Name(),
88+
Temperature: optTemperature(opt),
89+
TopP: optTopP(opt),
90+
MaxTokens: optMaxTokens(opt),
91+
Stream: optStream(opt),
92+
StopSequences: optStopSequences(opt),
93+
Seed: optSeed(opt),
94+
Messages: seq,
95+
Format: optFormat(opt),
96+
Tools: optTools(mistral, opt),
97+
ToolChoice: optToolChoice(opt),
98+
PresencePenalty: optPresencePenalty(opt),
99+
FrequencyPenalty: optFrequencyPenalty(opt),
100+
NumChoices: optNumCompletions(opt),
101+
Prediction: optPrediction(opt),
102+
SafePrompt: optSafePrompt(opt),
103+
})
104+
if err != nil {
105+
return nil, err
106+
}
107+
108+
// Response
109+
var response Response
110+
if err := mistral.DoWithContext(ctx, req, &response, client.OptPath("chat", "completions")); err != nil {
111+
return nil, err
112+
}
113+
114+
// Return success
115+
return &response, nil
116+
}

0 commit comments

Comments
 (0)