Skip to content

Commit 42b0d21

Browse files
Add info about property wrappers to porting guide (#347)
* Add info about property wrappers to porting guide * Computed vs Loaded Parameters
1 parent 511323d commit 42b0d21

File tree

1 file changed

+127
-0
lines changed
  • Libraries/MLXLMCommon/Documentation.docc

1 file changed

+127
-0
lines changed

Libraries/MLXLMCommon/Documentation.docc/porting.md

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,133 @@ enum CodingKeys: String, CodingKey {
112112

113113
This will read `rope_theta` from the JSON file if present and fall back to a default value of `10_000`.
114114

115+
### Understanding Property Wrappers
116+
117+
MLX Swift uses property wrappers to handle automatic registration of modules and parameters. This is different from Python MLX, which uses runtime discovery.
118+
119+
#### Python MLX: Automatic Discovery
120+
121+
In Python MLX, you simply assign modules and parameters as attributes, and MLX automatically discovers them:
122+
123+
```python
124+
class MyModule(nn.Module):
125+
def __init__(self):
126+
super().__init__()
127+
# Just assign - MLX auto-discovers these
128+
self.linear = nn.Linear(256, 256) # sub-module
129+
self.weight = mx.ones([256]) # parameter
130+
```
131+
132+
#### Swift MLX: Property Wrappers
133+
134+
In Swift MLX, you must explicitly declare modules and parameters using property wrappers:
135+
136+
```swift
137+
class MyModule: Module {
138+
@ModuleInfo var linear: Linear // sub-module
139+
@ParameterInfo var weight: MLXArray // parameter
140+
141+
init() {
142+
self._linear.wrappedValue = Linear(256, 256)
143+
self._weight.wrappedValue = MLXArray.ones([256])
144+
super.init()
145+
}
146+
}
147+
```
148+
149+
#### When to Use Each
150+
151+
- **`@ModuleInfo`**: For neural network layers (anything with `callAsFunction`)
152+
- Examples: `Linear`, `Attention`, `MLP`, `RMSNorm`, arrays of layers
153+
- **`@ParameterInfo`**: For raw `MLXArray` tensors you use manually
154+
- Examples: weight matrices, bias vectors, scaling factors
155+
156+
#### Weight Key Mapping
157+
158+
Both property wrappers support custom keys for weight loading:
159+
160+
```swift
161+
@ModuleInfo(key: "self_attn") var attention: Attention
162+
@ParameterInfo(key: "custom_weight") var weight: MLXArray
163+
```
164+
165+
The `key` parameter maps to the corresponding key in the saved model weights. If no key is specified, it uses the property name.
166+
167+
#### Initialization Syntax
168+
169+
When using property wrappers, initialization uses the `wrappedValue` syntax:
170+
171+
```swift
172+
// Without property wrapper:
173+
self.weight = MLXArray.ones([dimensions])
174+
175+
// With property wrapper:
176+
self._weight.wrappedValue = MLXArray.ones([dimensions])
177+
```
178+
179+
This explicit registration system provides type safety and ensures all modules and parameters are properly tracked for operations like weight loading, quantization, and gradient computation.
180+
181+
#### Advanced Property Wrapper Patterns
182+
183+
Beyond the basic usage, there are several advanced patterns you'll encounter:
184+
185+
**Optional Modules**
186+
```swift
187+
@ModuleInfo(key: "lm_head") var lmHead: Linear?
188+
@ModuleInfo(key: "text_projection") var textProjection: Linear?
189+
```
190+
Used when modules are conditionally created based on configuration.
191+
192+
**Arrays of Modules**
193+
```swift
194+
@ModuleInfo(key: "layers") var layers: [TransformerBlock]
195+
@ModuleInfo(key: "down_blocks") var downBlocks: [EncoderDecoderBlock2D]
196+
```
197+
For dynamic numbers of repeated layers.
198+
199+
**Complex Module Types**
200+
```swift
201+
@ModuleInfo(key: "mid_blocks") var midBlocks: (ResnetBlock2D, Attention, ResnetBlock2D)
202+
```
203+
Tuples and other composite types are supported.
204+
205+
**Optional Parameters**
206+
```swift
207+
@ParameterInfo var bias: MLXArray?
208+
@ModuleInfo(key: "bias") var bias: MLXArray?
209+
```
210+
For optional parameters that may not exist in all model variants.
211+
212+
**Special Case: @ModuleInfo with MLXArray**
213+
```swift
214+
@ModuleInfo(key: "weight") var weight: MLXArray
215+
@ModuleInfo(key: "scales") var scales: MLXArray
216+
```
217+
In rare cases (like quantized layers), `@ModuleInfo` is used with `MLXArray` instead of `@ParameterInfo`. This typically occurs with specialized quantization or expert layers where the arrays are treated as sub-modules for weight loading purposes.
218+
219+
**Computed vs Loaded Parameters**
220+
```swift
221+
// Parameter loaded from weights - uses @ParameterInfo
222+
@ParameterInfo(key: "correct_output_scale") var correctOutputScale: MLXArray
223+
224+
// Computed parameter - uses private underscore prefix (ignored during loading)
225+
private let _routerInputScale: MLXArray
226+
```
227+
228+
This distinction is crucial:
229+
- `@ParameterInfo` properties are expected to exist in the weight files and will be loaded automatically
230+
- Private properties with underscore prefix (`_`) are ignored during weight loading and must be computed during initialization
231+
- Initialization syntax differs:
232+
```swift
233+
// For @ParameterInfo (loaded from weights):
234+
self._correctOutputScale.wrappedValue = MLXArray(...)
235+
236+
// For private computed parameters:
237+
self._routerInputScale = MLXArray(...)
238+
```
239+
240+
Use private underscore properties when you need to compute values based on configuration or other parameters, but don't want the weight loading to fail because these "parameters" don't exist in the saved weights.
241+
115242
### Porting Layers without Children
116243

117244
Now we can begin porting the layers (Modules). Here is an example layer with no child layers (e.g. `Linear`) but with parameters (e.g. `MLXArray`):

0 commit comments

Comments
 (0)