-
Just noticed the mlx Conv2d weight shape is different from the PyTorch Conv2d shape. I'm wondering if there's a specific reason why we want to implement it differently? I am bringing this up mainly because it's causing some issues when trying to load some Pytorch model weights, as we have to do some weight conversion or using custom Conv2d instead of the default nn.Conv2d. For example, in the clip. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 8 replies
-
Well, not the only reason but the first one is that we use NHWC vs NCHW. This means that the matrix multiplication would be with a weight of shape OhwC where hw are the kernel sizes. PyTorch's would be OChw (as it is). |
Beta Was this translation helpful? Give feedback.
Nice! A couple comments on that:
nn.Conv2d
and override just the__call__
methodMaybe even smoother instead of overriding
__call__
you could override__setattr__
and just write the weight in the right order then:That at least let's you avoid the transpose on each call to
mx.conv2d
which could come with a small perf penalty.Though it's probably not worth being so clever about, I also find just pre transforming the weights to be simple and explicit.