Skip to content

Difference in weight's shape between Conv2d and PyTorch's Conv2d #724

Answered by awni
mzbac asked this question in Q&A
Discussion options

You must be logged in to vote

Nice! A couple comments on that:

  • You could inherit from nn.Conv2d and override just the __call__ method

Maybe even smoother instead of overriding __call__ you could override __setattr__ and just write the weight in the right order then:

    def __setattr__(self, key: str, val: Any):
          if key == "weight":
              val = mx.swapaxes(val, 0, 3)
        self[key] = val

That at least let's you avoid the transpose on each call to mx.conv2d which could come with a small perf penalty.

Though it's probably not worth being so clever about, I also find just pre transforming the weights to be simple and explicit.

Replies: 1 comment 8 replies

Comment options

You must be logged in to vote
8 replies
@mzbac
Comment options

@awni
Comment options

awni Feb 22, 2024
Maintainer

@mzbac
Comment options

@awni
Comment options

awni Feb 22, 2024
Maintainer

Answer selected by mzbac
@mzbac
Comment options

@bitanath
Comment options

@awni
Comment options

awni Apr 27, 2025
Maintainer

@bitanath
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants