Skip to content

Add info about property wrappers to porting guide #347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions Libraries/MLXLMCommon/Documentation.docc/porting.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,133 @@ enum CodingKeys: String, CodingKey {

This will read `rope_theta` from the JSON file if present and fall back to a default value of `10_000`.

### Understanding Property Wrappers

MLX Swift uses property wrappers to handle automatic registration of modules and parameters. This is different from Python MLX, which uses runtime discovery.

#### Python MLX: Automatic Discovery

In Python MLX, you simply assign modules and parameters as attributes, and MLX automatically discovers them:

```python
class MyModule(nn.Module):
def __init__(self):
super().__init__()
# Just assign - MLX auto-discovers these
self.linear = nn.Linear(256, 256) # sub-module
self.weight = mx.ones([256]) # parameter
```

#### Swift MLX: Property Wrappers

In Swift MLX, you must explicitly declare modules and parameters using property wrappers:

```swift
class MyModule: Module {
@ModuleInfo var linear: Linear // sub-module
@ParameterInfo var weight: MLXArray // parameter

init() {
self._linear.wrappedValue = Linear(256, 256)
self._weight.wrappedValue = MLXArray.ones([256])
super.init()
}
}
```

#### When to Use Each

- **`@ModuleInfo`**: For neural network layers (anything with `callAsFunction`)
- Examples: `Linear`, `Attention`, `MLP`, `RMSNorm`, arrays of layers
- **`@ParameterInfo`**: For raw `MLXArray` tensors you use manually
- Examples: weight matrices, bias vectors, scaling factors

#### Weight Key Mapping

Both property wrappers support custom keys for weight loading:

```swift
@ModuleInfo(key: "self_attn") var attention: Attention
@ParameterInfo(key: "custom_weight") var weight: MLXArray
```

The `key` parameter maps to the corresponding key in the saved model weights. If no key is specified, it uses the property name.

#### Initialization Syntax

When using property wrappers, initialization uses the `wrappedValue` syntax:

```swift
// Without property wrapper:
self.weight = MLXArray.ones([dimensions])

// With property wrapper:
self._weight.wrappedValue = MLXArray.ones([dimensions])
```

This explicit registration system provides type safety and ensures all modules and parameters are properly tracked for operations like weight loading, quantization, and gradient computation.

#### Advanced Property Wrapper Patterns

Beyond the basic usage, there are several advanced patterns you'll encounter:

**Optional Modules**
```swift
@ModuleInfo(key: "lm_head") var lmHead: Linear?
@ModuleInfo(key: "text_projection") var textProjection: Linear?
```
Used when modules are conditionally created based on configuration.

**Arrays of Modules**
```swift
@ModuleInfo(key: "layers") var layers: [TransformerBlock]
@ModuleInfo(key: "down_blocks") var downBlocks: [EncoderDecoderBlock2D]
```
For dynamic numbers of repeated layers.

**Complex Module Types**
```swift
@ModuleInfo(key: "mid_blocks") var midBlocks: (ResnetBlock2D, Attention, ResnetBlock2D)
```
Tuples and other composite types are supported.

**Optional Parameters**
```swift
@ParameterInfo var bias: MLXArray?
@ModuleInfo(key: "bias") var bias: MLXArray?
```
For optional parameters that may not exist in all model variants.

**Special Case: @ModuleInfo with MLXArray**
```swift
@ModuleInfo(key: "weight") var weight: MLXArray
@ModuleInfo(key: "scales") var scales: MLXArray
```
In rare cases (like quantized layers), `@ModuleInfo` is used with `MLXArray` instead of `@ParameterInfo`. This typically occurs with specialized quantization or expert layers where the arrays are treated as sub-modules for weight loading purposes.

**Computed vs Loaded Parameters**
```swift
// Parameter loaded from weights - uses @ParameterInfo
@ParameterInfo(key: "correct_output_scale") var correctOutputScale: MLXArray

// Computed parameter - uses private underscore prefix (ignored during loading)
private let _routerInputScale: MLXArray
```

This distinction is crucial:
- `@ParameterInfo` properties are expected to exist in the weight files and will be loaded automatically
- Private properties with underscore prefix (`_`) are ignored during weight loading and must be computed during initialization
- Initialization syntax differs:
```swift
// For @ParameterInfo (loaded from weights):
self._correctOutputScale.wrappedValue = MLXArray(...)

// For private computed parameters:
self._routerInputScale = MLXArray(...)
```

Use private underscore properties when you need to compute values based on configuration or other parameters, but don't want the weight loading to fail because these "parameters" don't exist in the saved weights.

### Porting Layers without Children

Now we can begin porting the layers (Modules). Here is an example layer with no child layers (e.g. `Linear`) but with parameters (e.g. `MLXArray`):
Expand Down