Skip to content

Execution errors in "Porting a PyTorch model to JAX" #218

Open
@pavithraes

Description

@pavithraes

Executing the Porting a PyTorch model to JAX tutorial raises the following errors:

  1. UnderMaxVit implementation, calling MaxVit raises AttributeError: 'NoneType' object has no attribute 'value'. I think the fix is: if module.bias.value is not None: --> if module.biasis not None: ?
Traceback
AttributeError                            Traceback (most recent call last)
[/tmp/ipython-input-49-3830862216.py](https://localhost:8080/#) in <cell line: 0>()
      1 x = jnp.ones((4, 224, 224, 3))
      2 
----> 3 mod = MaxVit(
      4     input_size=(224, 224),
      5     stem_channels=64,

4 frames
[/usr/local/lib/python3.11/dist-packages/flax/nnx/object.py](https://localhost:8080/#) in __call__(cls, *args, **kwargs)
    139 
    140     def __call__(cls, *args: Any, **kwargs: Any) -> Any:
--> 141       return _graph_node_meta_call(cls, *args, **kwargs)
    142 
    143   def _object_meta_construct(cls, self, *args, **kwargs):

[/usr/local/lib/python3.11/dist-packages/flax/nnx/object.py](https://localhost:8080/#) in _graph_node_meta_call(cls, *args, **kwargs)
    148   node = cls.__new__(cls, *args, **kwargs)
    149   vars(node)['_object__state'] = ObjectState()
--> 150   cls._object_meta_construct(node, *args, **kwargs)
    151 
    152   return node

[/usr/local/lib/python3.11/dist-packages/flax/nnx/object.py](https://localhost:8080/#) in _object_meta_construct(cls, self, *args, **kwargs)
    142 
    143   def _object_meta_construct(cls, self, *args, **kwargs):
--> 144     self.__init__(*args, **kwargs)
    145 
    146 

[/tmp/ipython-input-48-3474323775.py](https://localhost:8080/#) in __init__(self, input_size, stem_channels, partition_size, block_channels, block_layers, head_dim, stochastic_depth_prob, norm_layer, activation_layer, squeeze_ratio, expansion_ratio, mlp_ratio, mlp_dropout, attention_dropout, num_classes, rngs)
    133         )
    134 
--> 135         self._init_weights(rngs)
    136 
    137     def __call__(self, x: jax.Array) -> jax.Array:

[/tmp/ipython-input-48-3474323775.py](https://localhost:8080/#) in _init_weights(self, rngs)
    148                     rngs(), module.kernel.value.shape, module.kernel.value.dtype
    149                 )
--> 150                 if module.bias.value is not None:
    151                     module.bias.value = jnp.zeros(
    152                         module.bias.value.shape, dtype=module.bias.value.dtype

AttributeError: 'NoneType' object has no attribute 'value'
  1. AssertionErrors in several tests. For example, in Conv2dNormActivation.
Traceback:
  AssertionError                            Traceback (most recent call last)
  [/tmp/ipython-input-68-247890887.py](https://localhost:8080/#) in <cell line: 0>()
        5 nnx_module = Conv2dNormActivation(32, 64, 3, 2, 1)
        6 
  ----> 7 t2f.copy_module(torch_module, nnx_module)
        8 
        9 test_modules(nnx_module, torch_module, torch.randn(4, 32, 46, 46))
  
  3 frames
  [/tmp/ipython-input-67-547272536.py](https://localhost:8080/#) in _copy_params_buffers(self, torch_nn_module, nnx_module)
       83             torch_value = getattr(torch_nn_module, torch_key)
       84             nnx_param = getattr(nnx_module, nnx_key)
  ---> 85             assert nnx_param is not None, (torch_key, nnx_key, nnx_module)
       86 
       87             if torch_value is None:
  
  AssertionError: ('bias', 'bias', Conv( # Param: 18,432 (73.7 KB)
    kernel_shape=(3, 3, 32, 64),
    kernel=Param( # 18,432 (73.7 KB)
      value=Array(shape=(3, 3, 32, 64), dtype=dtype('float32'))
    ),
    bias=None,
    in_features=32,
    out_features=64,
    kernel_size=(3, 3),
    strides=(2, 2),
    padding=((1, 1), (1, 1)),
    input_dilation=1,
    kernel_dilation=(1, 1),
    feature_group_count=1,
    use_bias=False,
    mask=None,
    dtype=None,
    param_dtype=float32,
    precision=None,
    kernel_init=<function variance_scaling.<locals>.init at 0x7ffaa4b5d3a0>,
    bias_init=<function zeros at 0x7ffaa5b03100>,
    conv_general_dilated=<function conv_general_dilated at 0x7ffaa70cc040>,
    promote_dtype=<function promote_dtype at 0x7ffaa4b5c2c0>
  ))

Very similar errors from _copy_params_buffers in tests: MBConv, MaxVitLayer, MaxVitBlock, and MaxVit

  1. The final prediction scores don't match:
Prediction for the Dog:
- PyTorch model result: ['n02113023', 'Pembroke'], score: 0.7800846099853516
- Flax model result: ['n02113023', 'Pembroke'], score: 0.0008441798854619265

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions