Open
Description
Executing the Porting a PyTorch model to JAX tutorial raises the following errors:
- Under
MaxVit
implementation, callingMaxVit
raisesAttributeError: 'NoneType' object has no attribute 'value'
. I think the fix is:if module.bias.value is not None:
-->if module.biasis not None:
?
Traceback
AttributeError Traceback (most recent call last)
[/tmp/ipython-input-49-3830862216.py](https://localhost:8080/#) in <cell line: 0>()
1 x = jnp.ones((4, 224, 224, 3))
2
----> 3 mod = MaxVit(
4 input_size=(224, 224),
5 stem_channels=64,
4 frames
[/usr/local/lib/python3.11/dist-packages/flax/nnx/object.py](https://localhost:8080/#) in __call__(cls, *args, **kwargs)
139
140 def __call__(cls, *args: Any, **kwargs: Any) -> Any:
--> 141 return _graph_node_meta_call(cls, *args, **kwargs)
142
143 def _object_meta_construct(cls, self, *args, **kwargs):
[/usr/local/lib/python3.11/dist-packages/flax/nnx/object.py](https://localhost:8080/#) in _graph_node_meta_call(cls, *args, **kwargs)
148 node = cls.__new__(cls, *args, **kwargs)
149 vars(node)['_object__state'] = ObjectState()
--> 150 cls._object_meta_construct(node, *args, **kwargs)
151
152 return node
[/usr/local/lib/python3.11/dist-packages/flax/nnx/object.py](https://localhost:8080/#) in _object_meta_construct(cls, self, *args, **kwargs)
142
143 def _object_meta_construct(cls, self, *args, **kwargs):
--> 144 self.__init__(*args, **kwargs)
145
146
[/tmp/ipython-input-48-3474323775.py](https://localhost:8080/#) in __init__(self, input_size, stem_channels, partition_size, block_channels, block_layers, head_dim, stochastic_depth_prob, norm_layer, activation_layer, squeeze_ratio, expansion_ratio, mlp_ratio, mlp_dropout, attention_dropout, num_classes, rngs)
133 )
134
--> 135 self._init_weights(rngs)
136
137 def __call__(self, x: jax.Array) -> jax.Array:
[/tmp/ipython-input-48-3474323775.py](https://localhost:8080/#) in _init_weights(self, rngs)
148 rngs(), module.kernel.value.shape, module.kernel.value.dtype
149 )
--> 150 if module.bias.value is not None:
151 module.bias.value = jnp.zeros(
152 module.bias.value.shape, dtype=module.bias.value.dtype
AttributeError: 'NoneType' object has no attribute 'value'
AssertionError
s in several tests. For example, inConv2dNormActivation
.
Traceback:
AssertionError Traceback (most recent call last)
[/tmp/ipython-input-68-247890887.py](https://localhost:8080/#) in <cell line: 0>()
5 nnx_module = Conv2dNormActivation(32, 64, 3, 2, 1)
6
----> 7 t2f.copy_module(torch_module, nnx_module)
8
9 test_modules(nnx_module, torch_module, torch.randn(4, 32, 46, 46))
3 frames
[/tmp/ipython-input-67-547272536.py](https://localhost:8080/#) in _copy_params_buffers(self, torch_nn_module, nnx_module)
83 torch_value = getattr(torch_nn_module, torch_key)
84 nnx_param = getattr(nnx_module, nnx_key)
---> 85 assert nnx_param is not None, (torch_key, nnx_key, nnx_module)
86
87 if torch_value is None:
AssertionError: ('bias', 'bias', Conv( # Param: 18,432 (73.7 KB)
kernel_shape=(3, 3, 32, 64),
kernel=Param( # 18,432 (73.7 KB)
value=Array(shape=(3, 3, 32, 64), dtype=dtype('float32'))
),
bias=None,
in_features=32,
out_features=64,
kernel_size=(3, 3),
strides=(2, 2),
padding=((1, 1), (1, 1)),
input_dilation=1,
kernel_dilation=(1, 1),
feature_group_count=1,
use_bias=False,
mask=None,
dtype=None,
param_dtype=float32,
precision=None,
kernel_init=<function variance_scaling.<locals>.init at 0x7ffaa4b5d3a0>,
bias_init=<function zeros at 0x7ffaa5b03100>,
conv_general_dilated=<function conv_general_dilated at 0x7ffaa70cc040>,
promote_dtype=<function promote_dtype at 0x7ffaa4b5c2c0>
))
Very similar errors from _copy_params_buffers
in tests: MBConv
, MaxVitLayer
, MaxVitBlock
, and MaxVit
- The final prediction scores don't match:
Prediction for the Dog:
- PyTorch model result: ['n02113023', 'Pembroke'], score: 0.7800846099853516
- Flax model result: ['n02113023', 'Pembroke'], score: 0.0008441798854619265
Metadata
Metadata
Assignees
Labels
No labels