-
Notifications
You must be signed in to change notification settings - Fork 270
Add Gemma 3n #340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Gemma 3n #340
Conversation
I think the configs are now being created correctly, but I'm running into a lot of problems related to loading the model. I'm going to leave it here for now and wait for experts to add their input. I can't do this alone. I need help understanding what's going wrong during the model loading. You can test it by running llm-tool in Xcode. This is the current debug output:
Resolving this error will reveal many more like it, so there must be a more systemic fix. |
This is the loading code confusing model weights with computed weights. The pattern (in mlx python as well) is to name the computed parameters with a leading underscore: let _invFreq: MLXArray
let _originalInvFreq: MLXArray this matches the python code: self._inv_freq = mx.array(inv_freq, dtype=mx.float32)
self._original_inv_freq = mx.array(inv_freq, dtype=mx.float32) as it is this expects that those "weights" be present on load. |
Working on an update, stand by |
See also ml-explore/mlx-swift#251 -- this might make doing some of this easier. |
OK, I think I fixed these key issues:
I found this in the python code: self.query = NamedSequential()
self.query.add_module(
"proj",
create_conv2d(
dim,
self.num_heads * self.key_dim,
kernel_size=1,
),
) corresponding to: private class MultiQueryAttention2d: Module {
@ModuleInfo var queryProj: Conv2d
but I am not sure what that does (yet). The weights didn't have these values (it must be optional). The current failure is:
|
Wait, isn't it supported via |
Thank you! As you can see, that's something I meant to follow up on once the model is working. Claude 4 Sonnet was convinced that it's not possible to write custom Metal kernels in MLX Swift, which is why you see the incorrect comment there. LLMs are getting better at writing MLX code in Swift, but clearly they still have gaps in their knowledge. I'm hoping that by adding more Swift ports, we can fix that for future versions. |
99bdc09
to
ab32fa6
Compare
I added some complicated fixes for sanitization, which have resolved some but not all of the problems. These are probably not the right solution, since the sanitization in Python is much simpler. But the original logic that closely followed the Python implementation wasn't working. Currently it's failing when it tries to assign to |
@Blaizzy, can you help me understand what's going wrong with the weights sanitization here? Is this something that should already be handled during the conversion of the model to MLX format? I used mlx-vlm to quantize the model to 4 bits. When I try to run I'm testing the same model in this PR, since I can't run the bf16 model on my MacBook Pro with 16 GB of RAM. |
@DePasqualeOrg see Blaizzy/mlx-vlm#400 and Blaizzy/mlx-vlm#398 I wanted to see how the
Sure enough, there is no Are you able to successfully load a model? If so, which one? |
I already updated this for Blaizzy/mlx-vlm#398, but I see that today there were some significant changes in mlx-vlm. I'll wait for the Python implementation to stabilize before proceeding here. |
OK, I went back to 7a36f2eda1a304e4ef89fef874971c94352ab5d4 (398) and used mlx-community/gemma-3n-E2B-it-bf16 and that loads, so I can look at how that works. |
In mlx-community/gemma-3n-E2B-it-bf16 the
which looks like an array of arrays or an array of tuples (the latter is pretty common in vision models). That corresponds to: blocks = []
in_chs = self.conv_stem.out_chs
for stage, block_config in enumerate(gemma3n_mobilenet_def()):
block_group = []
for config in block_config:
...
However the |
OK, next issue is: @ModuleInfo var weight: MLXArray? should be: @ParameterInfo var weight: MLXArray? or simply omit the ParameterInfo since we don't need to override the key. After that the Anyway, this is a little bit tricky to implement as it does this in python: def add_module(self, name, module):
setattr(self, name, module) so it dynamically adds properties that will later be picked up by private class NamedSequential: Module, UnaryLayer {
var _items = ModuleItems()
var _names = [String]()
override init() {
}
init(_ name: String, _ module: UnaryLayer) {
super.init()
self.add(name, module)
}
func add(_ name: String, _ module: UnaryLayer) {
_items[name] = .value(.module(module))
_names.append(name)
}
override func items() -> ModuleItems {
_items
}
func callAsFunction(_ x: MLXArray) -> MLXArray {
var x = x
for name in names {
guard let module = items[name]?.module else {
fatalError("Cannot find \(name) in items")
}
x = module(x)
}
return x
}
}
|
But, seeing how it is actually used, I think this is simpler and more clear: private class ProjectionBlock: Module, UnaryLayer {
@ModuleInfo(key: "down_conv") var down: Conv2d?
@ModuleInfo var norm: RMSNormAct2d?
@ModuleInfo var proj: Conv2d
} |
Working through some more issues:
I commented out (with a comment to delete when ready) some of the code in the vision tower sanitize -- the code now matches the python code, I think. It loads the model and starts to run but fails with:
in Getting close, I think. This is with |
BTW, you may need to update your |
Thank you! It looks like the 4-bit quantized model is now loading, and I get this error: I will wait until @Blaizzy confirms that the Python implementation is stable before proceeding here. |
Heads up: mlx-vlm finally fixes audio and visual issues (according to my checks) and released https://github.com/Blaizzy/mlx-vlm/releases/tag/v0.3.0 |
It's stable @DePasqualeOrg 👌🏽 v0.3.0 fixes all Gemma3n bugs |
@xlab yap, audio was working fine the only issues where:
|
I'll wait for the Gemma team to fix the remaining issues, but also there are at least two other efforts to port this model that I'm aware of. I think we should pick one to focus on, to avoid duplicating labor. I'm happy to let others take this on if they're interested in completing the task. |
Edit: I can quantize the model using mlx-vlm, so I will be able to test this after we fix the sanitization.