Skip to content

Add Gemma 3n #340

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft

Add Gemma 3n #340

wants to merge 19 commits into from

Conversation

DePasqualeOrg
Copy link
Contributor

@DePasqualeOrg DePasqualeOrg commented Jun 27, 2025

Edit: I can quantize the model using mlx-vlm, so I will be able to test this after we fix the sanitization.

@DePasqualeOrg
Copy link
Contributor Author

DePasqualeOrg commented Jun 27, 2025

I think the configs are now being created correctly, but I'm running into a lot of problems related to loading the model.

I'm going to leave it here for now and wait for experts to add their input. I can't do this alone.

I need help understanding what's going wrong during the model loading. You can test it by running llm-tool in Xcode. This is the current debug output:

Loading mlx-community/gemma-3n-E2B-it-bf16...
🔍 Gemma3n.sanitize: Starting with 1556 weights
🔍 Gemma3n.sanitize: After prefix removal, have 1556 weights
Error: Key originalInvFreq not found in Gemma3nRotaryEmbedding
Program ended with exit code: 1

Resolving this error will reveal many more like it, so there must be a more systemic fix.

@davidkoski
Copy link
Collaborator

davidkoski commented Jun 27, 2025

Error: Key originalInvFreq not found in Gemma3nRotaryEmbedding

This is the loading code confusing model weights with computed weights. The pattern (in mlx python as well) is to name the computed parameters with a leading underscore:

    let _invFreq: MLXArray
    let _originalInvFreq: MLXArray

this matches the python code:

        self._inv_freq = mx.array(inv_freq, dtype=mx.float32)
        self._original_inv_freq = mx.array(inv_freq, dtype=mx.float32)

as it is this expects that those "weights" be present on load.

@davidkoski
Copy link
Collaborator

Working on an update, stand by

@davidkoski
Copy link
Collaborator

See also ml-explore/mlx-swift#251 -- this might make doing some of this easier.

@davidkoski
Copy link
Collaborator

OK, I think I fixed these key issues:

  • computed values need to be named with a leading underscore
  • some of the keys on the modules were missing

I found this in the python code:

        self.query = NamedSequential()
        self.query.add_module(
            "proj",
            create_conv2d(
                dim,
                self.num_heads * self.key_dim,
                kernel_size=1,
            ),
        )

corresponding to:

private class MultiQueryAttention2d: Module {

    @ModuleInfo var queryProj: Conv2d

but I am not sure what that does (yet). The weights didn't have these values (it must be optional).

The current failure is:

Error: Mismatched parameter weight shape. Actual [1536, 1024], expected [1536, 640]

@davidkoski
Copy link
Collaborator

FWIW, here is how I debugged these. First I set breakpoints on the throws of the errors seen here:

    open func update(parameters: ModuleParameters, verify: VerifyUpdate) throws -> Self {

        func apply(key: String, _ item: ModuleItem, _ value: NestedItem<String, MLXArray>) throws {
            if case .none = value, !verify.contains(.allModelKeysSet) {
                return
            }

            // item: single item from `items()`
            // value: single item with matching structure from `parameters()`
            //
            // match them up and apply the MLXArrays from value -> item

            switch (item, value) {
            case (.value(.parameters(let p)), .value(let newArray)):
                if verify.contains(.all), p.shape != newArray.shape {
                    throw UpdateError.mismatchedSize(
                        key: key, expectedShape: p.shape, actualShape: newArray.shape)
                }
                p._updateInternal(newArray)

            case (.value(.parameters(let p)), .none):
                if Self.parameterIsValid(key) {
                    throw UpdateError.keyNotFound(base: describeType(self), key: key)
                } else {
                    // ignore it -- this isn't a parameter that requires update
                }

then I can look at the call stack:

image

notice the "key" value. You can also print the module value to see which module it is. E.g. in frame 3:

(lldb) po self
ConvNormAct(outChannels=768) {
  bn: RMSNormAct2d(applyAct=true, eps=1e-05) {
      act: GELU(approximation=none),
      drop: Identity,
    },
  conv: Conv2d(bias=nil, groups=768),
}

Gives this error (there are a couple of these and the keys are in a dictionary so you hit them in random order):

Error: Mismatched parameter weight shape. Actual [768, 5, 5, 1], expected [768, 5, 5, 768]

@davidkoski
Copy link
Collaborator

Looking into these mismatched sized, I am not sure how this is supposed to work -- I wonder if the python side doesn't verify sizes?

For example:

            self.dw_start = ConvNormAct(
                in_chs,
                in_chs,

The input/output channels will be the same so the conv block under that should be shape [X, .., .., X] but the safetensors has:

image

If I disable the validation of sizes in the call to update:

    try model.update(parameters: parameters, verify: [])

it makes it past that but fails here:

Error: keyNotFound(CodingKeys(stringValue: "vision_soft_tokens_per_image", intValue: nil), Swift.DecodingError.Context(codingPath: [], debugDescription: "No value associated with key CodingKeys(stringValue: \"vision_soft_tokens_per_image\", intValue: nil) (\"vision_soft_tokens_per_image\").", underlyingError: nil))

which is curious because 1) that key is present (it could be looking in a different file) and 2) the value is optional, so why the complaint?

@xlab
Copy link

xlab commented Jun 27, 2025

    // MLX Swift currently doesn't have custom Metal kernel creation capabilities like Python's
    // mx.fast.metal_kernel(). Consider optimizing with vectorized MLX operations or requesting
    // custom kernel support from the MLX Swift team for better performance.

Wait, isn't it supported via MLXFast? https://github.com/ml-explore/mlx-swift/blob/b79c74ce773440b86a81ef925ea78dd5023a16c0/Source/MLXFast/MLXFastKernel.swift#L29

Example
https://github.com/ml-explore/mlx-swift/blob/b79c74ce773440b86a81ef925ea78dd5023a16c0/Tests/MLXTests/MLXFastKernelTests.swift#L38-L51

@DePasqualeOrg
Copy link
Contributor Author

Thank you! As you can see, that's something I meant to follow up on once the model is working. Claude 4 Sonnet was convinced that it's not possible to write custom Metal kernels in MLX Swift, which is why you see the incorrect comment there.

LLMs are getting better at writing MLX code in Swift, but clearly they still have gaps in their knowledge. I'm hoping that by adding more Swift ports, we can fix that for future versions.

@DePasqualeOrg DePasqualeOrg force-pushed the gemma-3n branch 2 times, most recently from 99bdc09 to ab32fa6 Compare June 28, 2025 19:24
@DePasqualeOrg
Copy link
Contributor Author

I added some complicated fixes for sanitization, which have resolved some but not all of the problems. These are probably not the right solution, since the sanitization in Python is much simpler. But the original logic that closely followed the Python implementation wasn't working.

Currently it's failing when it tries to assign to blocks in VisionTower.

@DePasqualeOrg
Copy link
Contributor Author

DePasqualeOrg commented Jun 29, 2025

@Blaizzy, can you help me understand what's going wrong with the weights sanitization here? Is this something that should already be handled during the conversion of the model to MLX format? I used mlx-vlm to quantize the model to 4 bits.

When I try to run mlx-community/gemma-3n-E2B-it-4bit in mlx-vlm, I get this error: ValueError: [conv] Invalid input array with type uint32. Convolution currently only supports floating point types

I'm testing the same model in this PR, since I can't run the bf16 model on my MacBook Pro with 16 GB of RAM.

@davidkoski
Copy link
Collaborator

@DePasqualeOrg see Blaizzy/mlx-vlm#400 and Blaizzy/mlx-vlm#398

I wanted to see how the blocks.blocks were loaded so I tried this too but I can't load the model:

python -m mlx_vlm.generate --model mlx-community/gemma-3n-E4B-bf16 --prompt "describe these images in english" --image /Users/dkoski/Desktop/IMG_0691.jpeg

...

  File "/Users/dkoski/miniconda3/envs/mlx/lib/python3.11/site-packages/mlx/nn/layers/base.py", line 178, in load_weights
    raise ValueError(f"Received parameters not in model: {extras}.")
ValueError: Received parameters not in model: language_model.lm_head.weight.

Sure enough, there is no lm_head property in the model (I am building from source, ebafa5a789ed1a8e050b8366ae4e845dbe640b90)

Are you able to successfully load a model? If so, which one?

@DePasqualeOrg
Copy link
Contributor Author

I already updated this for Blaizzy/mlx-vlm#398, but I see that today there were some significant changes in mlx-vlm. I'll wait for the Python implementation to stabilize before proceeding here.

@davidkoski
Copy link
Collaborator

OK, I went back to 7a36f2eda1a304e4ef89fef874971c94352ab5d4 (398) and used mlx-community/gemma-3n-E2B-it-bf16 and that loads, so I can look at how that works.

@davidkoski
Copy link
Collaborator

In mlx-community/gemma-3n-E2B-it-bf16 the blocks look like this:

    "model.vision_tower.timm_model.blocks.0.0.bn1.weight": "model-00001-of-00003.safetensors",
    "model.vision_tower.timm_model.blocks.0.0.bn2.weight": "model-00001-of-00003.safetensors",
    "model.vision_tower.timm_model.blocks.0.0.conv_exp.weight": "model-00001-of-00003.safetensors",
    "model.vision_tower.timm_model.blocks.0.0.conv_pwl.weight": "model-00001-of-00003.safetensors",
    "model.vision_tower.timm_model.blocks.0.1.bn1.weight": "model-00001-of-00003.safetensors",
    "model.vision_tower.timm_model.blocks.0.1.bn2.weight": "model-00001-of-00003.safetensors",
    "model.vision_tower.timm_model.blocks.0.1.conv_exp.weight": "model-00001-of-00003.safetensors",
    "model.vision_tower.timm_model.blocks.0.1.conv_pwl.weight": "model-00001-of-00003.safetensors",
    "model.vision_tower.timm_model.blocks.0.2.bn1.weight": "model-00001-of-00003.safetensors",
    "model.vision_tower.timm_model.blocks.0.2.bn2.weight": "model-00001-of-00003.safetensors",
    "model.vision_tower.timm_model.blocks.0.2.conv_exp.weight": "model-00001-of-00003.safetensors",
    "model.vision_tower.timm_model.blocks.0.2.conv_pwl.weight": "model-00001-of-00003.safetensors",

which looks like an array of arrays or an array of tuples (the latter is pretty common in vision models).

That corresponds to:

        blocks = []
        in_chs = self.conv_stem.out_chs
        for stage, block_config in enumerate(gemma3n_mobilenet_def()):
            block_group = []
            for config in block_config:
...

blocks is an array of block_group.

However the sanitize() is converting that from blocks.0.0 to blocks.blocks.0, which I guess is ultimately VisionTower.sanitize(weights:)

@davidkoski
Copy link
Collaborator

OK, next issue is:

@ModuleInfo var weight: MLXArray?

should be:

@ParameterInfo var weight: MLXArray?

or simply omit the ParameterInfo since we don't need to override the key.

After that the NamedSequential in MultiQueryAttention2d handles output.proj.weight where it would normally be output_proj.weight -- this needs to be added (I had called this out earlier as a bit of a mystery as I couldn't yet see how it worked).

Anyway, this is a little bit tricky to implement as it does this in python:

    def add_module(self, name, module):
        setattr(self, name, module)

so it dynamically adds properties that will later be picked up by Module. We can do something like that by overriding items():

private class NamedSequential: Module, UnaryLayer {
    
    var _items = ModuleItems()
    var _names = [String]()
    
    override init() {
    }
    
    init(_ name: String, _ module: UnaryLayer) {
        super.init()
        self.add(name, module)
    }
    
    func add(_ name: String, _ module: UnaryLayer) {
        _items[name] = .value(.module(module))
        _names.append(name)
    }
    
    override func items() -> ModuleItems {
        _items
    }
    
    func callAsFunction(_ x: MLXArray) -> MLXArray {
        var x = x
        for name in names {
            guard let module = items[name]?.module else {
                fatalError("Cannot find \(name) in items")
            }
            x = module(x)
        }
        return x
    }
}

items() is the cache of the introspected values -- we are going to supply our own custom build items.

@davidkoski
Copy link
Collaborator

But, seeing how it is actually used, I think this is simpler and more clear:

private class ProjectionBlock: Module, UnaryLayer {
    @ModuleInfo(key: "down_conv") var down: Conv2d?
    @ModuleInfo var norm: RMSNormAct2d?
    @ModuleInfo var proj: Conv2d
}

@davidkoski
Copy link
Collaborator

Working through some more issues:

  • mlx-swift, conv2d and groups was missing a divide by the groups
  • this produces a slice let imgShape = img.shape.suffix(2) -- the array indices are not 0, 1
  • replaced the interpolation with Upsample
  • some config confusion -- vocabSize from the right config

I commented out (with a comment to delete when ready) some of the code in the vision tower sanitize -- the code now matches the python code, I think.

It loads the model and starts to run but fails with:

Shapes (1,682,16,48) and (1,2048,1,1) cannot be broadcast

in rmsNorm2d().

Getting close, I think. This is with mlx-community/gemma-3n-E2B-it-bf16.

@davidkoski
Copy link
Collaborator

BTW, you may need to update your mlx-swift dependency -- once we are ready here I will tag that, but for now there were a couple changes there.

@DePasqualeOrg
Copy link
Contributor Author

DePasqualeOrg commented Jul 2, 2025

Thank you! It looks like the 4-bit quantized model is now loading, and I get this error: Shapes (1,8) and (1,32) cannot be broadcast.

I will wait until @Blaizzy confirms that the Python implementation is stable before proceeding here.

@xlab
Copy link

xlab commented Jul 5, 2025

Heads up: mlx-vlm finally fixes audio and visual issues (according to my checks) and released https://github.com/Blaizzy/mlx-vlm/releases/tag/v0.3.0

@Blaizzy
Copy link

Blaizzy commented Jul 5, 2025

It's stable @DePasqualeOrg 👌🏽

v0.3.0 fixes all Gemma3n bugs

@Blaizzy
Copy link

Blaizzy commented Jul 5, 2025

@xlab yap, audio was working fine the only issues where:

  1. The multimodal merging didn't allow for audio + vision features because we were returning each separately.
  2. A few parts of the vision module convolution needed padding.
  3. The Jax conv weights were transposed whilst converting to Torch so OCR is broken unless you transpose the image's HW before processing. Note: The deepmind team is aware and they will fix the weights soon.

@DePasqualeOrg
Copy link
Contributor Author

I'll wait for the Gemma team to fix the remaining issues, but also there are at least two other efforts to port this model that I'm aware of. I think we should pick one to focus on, to avoid duplicating labor. I'm happy to let others take this on if they're interested in completing the task.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants