diff --git a/docs/source/JAX_Vision_transformer.ipynb b/docs/source/JAX_Vision_transformer.ipynb index 4d8d73b..1e28775 100644 --- a/docs/source/JAX_Vision_transformer.ipynb +++ b/docs/source/JAX_Vision_transformer.ipynb @@ -295,7 +295,7 @@ "\n", " # Notice the use of `flax.nnx.state`.\n", " flax_model_params = nnx.state(dst_model, nnx.Param)\n", - " flax_model_params_fstate = flax_model_params.flat_state()\n", + " flax_model_params_fstate = dict(flax_model_params.flat_state())\n", "\n", " # Mapping from Flax parameter names to TF parameter names.\n", " params_name_mapping = {\n", diff --git a/docs/source/JAX_Vision_transformer.md b/docs/source/JAX_Vision_transformer.md index b09376b..f72614d 100644 --- a/docs/source/JAX_Vision_transformer.md +++ b/docs/source/JAX_Vision_transformer.md @@ -251,7 +251,7 @@ def vit_inplace_copy_weights(*, src_model, dst_model): # Notice the use of `flax.nnx.state`. flax_model_params = nnx.state(dst_model, nnx.Param) - flax_model_params_fstate = flax_model_params.flat_state() + flax_model_params_fstate = dict(flax_model_params.flat_state()) # Mapping from Flax parameter names to TF parameter names. params_name_mapping = { diff --git a/docs/source/JAX_image_captioning.ipynb b/docs/source/JAX_image_captioning.ipynb index 1f85d2f..b975735 100644 --- a/docs/source/JAX_image_captioning.ipynb +++ b/docs/source/JAX_image_captioning.ipynb @@ -700,7 +700,7 @@ " tf_model_params_fstate = nnx.traversals.flatten_mapping(tf_model_params)\n", "\n", " flax_model_params = nnx.state(dst_model, nnx.Param)\n", - " flax_model_params_fstate = flax_model_params.flat_state()\n", + " flax_model_params_fstate = dict(flax_model_params.flat_state())\n", "\n", " src_num_params = sum([p.size for p in tf_model_params_fstate.values()])\n", " dst_num_params = sum([p.value.size for p in flax_model_params_fstate.values()])\n", @@ -1286,7 +1286,7 @@ "metadata": {}, "outputs": [], "source": [ - "for key in list(nnx.state(model, trainable_params_filter).flat_state().keys()):\n", + "for key, _ in nnx.state(model, trainable_params_filter).flat_state():\n", " assert \"encoder\" not in key" ] }, diff --git a/docs/source/JAX_image_captioning.md b/docs/source/JAX_image_captioning.md index 62b1a32..1f0f1e6 100644 --- a/docs/source/JAX_image_captioning.md +++ b/docs/source/JAX_image_captioning.md @@ -458,7 +458,7 @@ def vit_inplace_copy_weights(*, src_model, dst_model): tf_model_params_fstate = nnx.traversals.flatten_mapping(tf_model_params) flax_model_params = nnx.state(dst_model, nnx.Param) - flax_model_params_fstate = flax_model_params.flat_state() + flax_model_params_fstate = dict(flax_model_params.flat_state()) src_num_params = sum([p.size for p in tf_model_params_fstate.values()]) dst_num_params = sum([p.value.size for p in flax_model_params_fstate.values()]) @@ -913,7 +913,7 @@ model_diffstate = nnx.DiffState(0, trainable_params_filter) ``` ```{code-cell} ipython3 -for key in list(nnx.state(model, trainable_params_filter).flat_state().keys()): +for key, _ in nnx.state(model, trainable_params_filter).flat_state(): assert "encoder" not in key ``` diff --git a/docs/source/JAX_porting_PyTorch_model.ipynb b/docs/source/JAX_porting_PyTorch_model.ipynb index da41dbb..d1b651a 100644 --- a/docs/source/JAX_porting_PyTorch_model.ipynb +++ b/docs/source/JAX_porting_PyTorch_model.ipynb @@ -376,13 +376,13 @@ "To inspect NNX module's learnable parameters and buffers, we can use `nnx.state`:\n", "```python\n", "nnx_module = ...\n", - "for k, v in nnx.state(nnx_module, nnx.Param).flat_state().items():\n", + "for k, v in nnx.state(nnx_module, nnx.Param).flat_state():\n", " print(\n", " k,\n", " v.value.mean() if v.value is not None else None\n", " )\n", "\n", - "for k, v in nnx.state(nnx_module, (nnx.BatchStat, Buffer)).flat_state().items():\n", + "for k, v in nnx.state(nnx_module, (nnx.BatchStat, Buffer)).flat_state():\n", " print(\n", " k,\n", " v.value.mean() if v.value.dtype == \"float32\" else v.value.sum()\n", diff --git a/docs/source/JAX_porting_PyTorch_model.md b/docs/source/JAX_porting_PyTorch_model.md index cbd0af4..403aaad 100644 --- a/docs/source/JAX_porting_PyTorch_model.md +++ b/docs/source/JAX_porting_PyTorch_model.md @@ -180,13 +180,13 @@ class Model(nnx.Module): To inspect NNX module's learnable parameters and buffers, we can use `nnx.state`: ```python nnx_module = ... -for k, v in nnx.state(nnx_module, nnx.Param).flat_state().items(): +for k, v in nnx.state(nnx_module, nnx.Param).flat_state(): print( k, v.value.mean() if v.value is not None else None ) -for k, v in nnx.state(nnx_module, (nnx.BatchStat, Buffer)).flat_state().items(): +for k, v in nnx.state(nnx_module, (nnx.BatchStat, Buffer)).flat_state(): print( k, v.value.mean() if v.value.dtype == "float32" else v.value.sum()