Skip to content

Commit d189ae3

Browse files
authored
Update README.md (#180)
* Update README.md add documentation pointers * add some initial documentation for MLXLLM - make sure the navigation works ok in swiftpackageindex
1 parent 4b67a79 commit d189ae3

File tree

10 files changed

+277
-8
lines changed

10 files changed

+277
-8
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# ``MLXLLM``
2+
3+
Example implementations of various Large Language Models (LLMs).
4+
5+
## Other MLX Libraries Packages
6+
7+
- [MLXEmbedders](MLXEmbedders)
8+
- [MLXLLM](MLXLLM)
9+
- [MLXLMCommon](MLXLMCommon)
10+
- [MLXMNIST](MLXMNIST)
11+
- [MLXVLM](MLXVLM)
12+
- [StableDiffusion](StableDiffusion)
13+
14+
## Topics
15+
16+
- <doc:adding-model>
17+
- <doc:using-model>
18+
19+
### Models
20+
21+
- ``CohereModel``
22+
- ``GemmaModel``
23+
- ``Gemma2Model``
24+
- ``InternLM2Model``
25+
- ``LlamaModel``
26+
- ``OpenELMModel``
27+
- ``PhiModel``
28+
- ``Phi3Model``
29+
- ``PhiMoEModel``
30+
- ``Qwen2Model``
31+
- ``Starcoder2Model``
32+
33+
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Adding a Model
2+
3+
If the model follows the typical LLM pattern you can add a new
4+
model in a few steps.
5+
6+
- `config.json`, `tokenizer.json`, and `tokenizer_config.json`
7+
- `*.safetensors`
8+
9+
You can follow the pattern of the models in the [Models](Models) directory
10+
and create a `.swift` file for your new model:
11+
12+
## Create a Configuration
13+
14+
Create a configuration struct to match the `config.json` (any parameters needed).
15+
16+
```swift
17+
public struct YourModelConfiguration: Codable, Sendable {
18+
public let hiddenSize: Int
19+
20+
// use this pattern for values that need defaults
21+
public let _layerNormEps: Float?
22+
public var layerNormEps: Float { _layerNormEps ?? 1e-6 }
23+
24+
enum CodingKeys: String, CodingKey {
25+
case hiddenSize = "hidden_size"
26+
case _layerNormEps = "layer_norm_eps"
27+
}
28+
}
29+
```
30+
31+
## Create the Model Class
32+
33+
Create the model class. The top-level public class should have a
34+
structure something like this:
35+
36+
```swift
37+
public class YourModel: Module, LLMModel, KVCacheDimensionProvider, LoRAModel {
38+
39+
public let kvHeads: [Int]
40+
41+
@ModuleInfo var model: YourModelInner
42+
43+
public func loraLinearLayers() -> LoRALinearLayers {
44+
// TODO: modify as needed
45+
model.layers.map { ($0.attention, ["q_proj", "v_proj"]) }
46+
}
47+
48+
public init(_ args: YourModelConfiguration) {
49+
self.kvHeads = Array(repeating: args.kvHeads, count: args.hiddenLayers)
50+
self.model = YourModelInner(args)
51+
}
52+
53+
public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
54+
// TODO: modify as needed
55+
let out = model(inputs, cache: cache)
56+
return model.embedTokens.asLinear(out)
57+
}
58+
}
59+
```
60+
61+
## Register the Model
62+
63+
In [LLMModelFactory.swift](LLMModelFactory.swift) register the model type itself
64+
(this is independent of the model id):
65+
66+
```swift
67+
public class ModelTypeRegistry: @unchecked Sendable {
68+
...
69+
private var creators: [String: @Sendable (URL) throws -> any LanguageModel] = [
70+
"yourModel": create(YourModelConfiguration.self, YourModel.init),
71+
```
72+
73+
Add a constant for the model in the `ModelRegistry` (not strictly required but useful
74+
for callers to refer to it in code):
75+
76+
```swift
77+
public class ModelRegistry: @unchecked Sendable {
78+
...
79+
static public let yourModel_4bit = ModelConfiguration(
80+
id: "mlx-community/YourModel-4bit",
81+
defaultPrompt: "What is the gravity on Mars and the moon?"
82+
)
83+
```
84+
85+
and finally add it to the all list -- this will let users find the model
86+
configuration by id:
87+
88+
```swift
89+
private static func all() -> [ModelConfiguration] {
90+
[
91+
codeLlama13b4bit,
92+
...
93+
yourModel_4bit,
94+
```
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Using a Model
2+
3+
Using a model is easy: load the weights, tokenize and evaluate.
4+
5+
## Loading a Model
6+
7+
A model is typically loaded by using a `ModelFactory` and a `ModelConfiguration`:
8+
9+
```swift
10+
// e.g. LLMModelFactory.shared
11+
let modelFactory: ModelFactory
12+
13+
// e.g. MLXLLM.ModelRegistry.llama3_8B_4bit
14+
let modelConfiguration: ModelConfiguration
15+
16+
let container = try await modelFactory.loadContainer(configuration: modelConfiguration)
17+
```
18+
19+
The `container` provides an isolation context (an `actor`) to run inference in the model.
20+
21+
Predefined `ModelConfiguration` instances are provided as static variables
22+
on the `ModelRegistry` types or they can be created:
23+
24+
```swift
25+
let modelConfiguration = ModelConfiguration(id: "mlx-community/llama3_8B_4bit")
26+
```
27+
28+
The flow inside the `ModelFactory` goes like this:
29+
30+
```swift
31+
public class LLMModelFactory: ModelFactory {
32+
33+
public func _load(
34+
hub: HubApi, configuration: ModelConfiguration,
35+
progressHandler: @Sendable @escaping (Progress) -> Void
36+
) async throws -> ModelContext {
37+
// download the weight and config using HubApi
38+
// load the base configuration
39+
// using the typeRegistry create a model (random weights)
40+
// load the weights, apply quantization as needed, update the model
41+
// calls model.sanitize() for weight preparation
42+
// load the tokenizer
43+
// (vlm) load the processor configuration, create the processor
44+
}
45+
}
46+
```
47+
48+
Callers with specialized requirements can use these individual components to manually
49+
load models, if needed.
50+
51+
## Evaluation Flow
52+
53+
- Load the Model
54+
- UserInput
55+
- LMInput
56+
- generate()
57+
- NaiveStreamingDetokenizer
58+
- TokenIterator
59+
60+
## Evaluating a Model
61+
62+
Once a model is loaded you can evaluate a prompt or series of
63+
messages. Minimally you need to prepare the user input:
64+
65+
```swift
66+
let prompt = "Describe the image in English"
67+
var input = UserInput(prompt: prompt, images: image.map { .url($0) })
68+
input.processing.resize = .init(width: 256, height: 256)
69+
```
70+
71+
This example shows adding some images and processing instructions -- if
72+
model accepts text only then these parts can be omitted. The inference
73+
calls are the same.
74+
75+
Assuming you are using a `ModelContainer` (an actor that holds
76+
a `ModelContext`, which is the bundled set of types that implement a
77+
model), the first step is to convert the `UserInput` into the
78+
`LMInput` (LanguageModel Input):
79+
80+
```swift
81+
let generateParameters: GenerateParameters
82+
let input: UserInput
83+
84+
let result = try await modelContainer.perform { [input] context in
85+
let input = try context.processor.prepare(input: input)
86+
87+
```
88+
89+
Given that `input` we can call `generate()` to produce a stream
90+
of tokens. In this example we use a `NaiveStreamingDetokenizer`
91+
to assist in converting a stream of tokens into text and print it.
92+
The stream is stopped after we hit a maximum number of tokens:
93+
94+
```
95+
var detokenizer = NaiveStreamingDetokenizer(tokenizer: context.tokenizer)
96+
97+
return try MLXLMCommon.generate(
98+
input: input, parameters: generateParameters, context: context
99+
) { tokens in
100+
101+
if let last = tokens.last {
102+
detokenizer.append(token: last)
103+
}
104+
105+
if let new = detokenizer.next() {
106+
print(new, terminator: "")
107+
fflush(stdout)
108+
}
109+
110+
if tokens.count >= maxTokens {
111+
return .stop
112+
} else {
113+
return .more
114+
}
115+
}
116+
}
117+
```

Libraries/MLXLLM/Models/Gemma2.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ private class TransformerBlock: Module {
155155
}
156156

157157
// Uses Gemma2TransformerBlock, otherwise same as GemmaModelInner
158-
public class ModelInner: Module {
158+
private class ModelInner: Module {
159159
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
160160

161161
fileprivate let layers: [TransformerBlock]
@@ -197,7 +197,7 @@ public class Gemma2Model: Module, LLMModel, KVCacheDimensionProvider {
197197
public let vocabularySize: Int
198198
public let kvHeads: [Int]
199199

200-
let model: ModelInner
200+
private let model: ModelInner
201201
let logitSoftCap: Float
202202

203203
public init(_ args: Gemma2Configuration) {

Libraries/MLXLLM/Models/Llama.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ private class LlamaModelInner: Module {
283283
}
284284
}
285285

286+
/// Model for Llama and Mistral model types.
286287
public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider {
287288

288289
public let vocabularySize: Int

Libraries/MLXLLM/Models/Phi3.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ private class TransformerBlock: Module {
151151
}
152152
}
153153

154-
public class Phi3ModelInner: Module {
154+
private class Phi3ModelInner: Module {
155155

156156
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
157157

@@ -189,7 +189,7 @@ public class Phi3Model: Module, LLMModel, KVCacheDimensionProvider {
189189
public let vocabularySize: Int
190190
public let kvHeads: [Int]
191191

192-
let model: Phi3ModelInner
192+
private let model: Phi3ModelInner
193193

194194
@ModuleInfo(key: "lm_head") var lmHead: Linear
195195

Libraries/MLXLLM/Models/Qwen2.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ private class TransformerBlock: Module {
133133
}
134134
}
135135

136-
public class Qwen2ModelInner: Module {
136+
private class Qwen2ModelInner: Module {
137137
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
138138

139139
fileprivate let layers: [TransformerBlock]
@@ -169,7 +169,7 @@ public class Qwen2Model: Module, LLMModel, KVCacheDimensionProvider {
169169
public let vocabularySize: Int
170170
public let kvHeads: [Int]
171171

172-
let model: Qwen2ModelInner
172+
private let model: Qwen2ModelInner
173173
let configuration: Qwen2Configuration
174174

175175
@ModuleInfo(key: "lm_head") var lmHead: Linear

Libraries/MLXLLM/Models/Starcoder2.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ private class TransformerBlock: Module {
116116
}
117117
}
118118

119-
public class Starcoder2ModelInner: Module {
119+
private class Starcoder2ModelInner: Module {
120120
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
121121

122122
fileprivate let layers: [TransformerBlock]
@@ -153,7 +153,7 @@ public class Starcoder2Model: Module, LLMModel, KVCacheDimensionProvider {
153153
public let kvHeads: [Int]
154154

155155
public let tieWordEmbeddings: Bool
156-
let model: Starcoder2ModelInner
156+
private let model: Starcoder2ModelInner
157157

158158
@ModuleInfo(key: "lm_head") var lmHead: Linear
159159

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# ``MLXLMCommon``
2+
3+
Common language model code.
4+
5+
## Other MLX Libraries Packages
6+
7+
- [MLXEmbedders](MLXEmbedders)
8+
- [MLXLLM](MLXLLM)
9+
- [MLXLMCommon](MLXLMCommon)
10+
- [MLXMNIST](MLXMNIST)
11+
- [MLXVLM](MLXVLM)
12+
- [StableDiffusion](StableDiffusion)
13+

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,14 @@
1+
# Documentation
2+
3+
Developers can use these examples in their own programs -- just import the swift package!
4+
5+
- [MLXLLMCommon](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxlmcommon) -- common API for LLM and VLM
6+
- [MLXLLM](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxllm) -- large language model example implementations
7+
- [MLXVLM](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxvlm) -- visual language model example implementations
8+
- [MLXEmbedders](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxembedders) -- popular Encoders / Embedding models example implementations
9+
- [StableDiffusion](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/stablediffusion) -- SDXL Turbo and Stable Diffusion mdeol example implementations
10+
- [MLXMNIST](https://swiftpackageindex.com/ml-explore/mlx-swift-examples/main/documentation/mlxmnist) -- MNIST implementation for all your digit recognition needs
11+
112
# MLX Swift Examples
213

314
Example [MLX Swift](https://github.com/ml-explore/mlx-swift) programs.

0 commit comments

Comments
 (0)