Skip to content

Commit b5cb859

Browse files
committed
Fixed issue with .flat_state()
1 parent 5b0e470 commit b5cb859

6 files changed

+10
-10
lines changed

docs/source/JAX_Vision_transformer.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@
295295
"\n",
296296
" # Notice the use of `flax.nnx.state`.\n",
297297
" flax_model_params = nnx.state(dst_model, nnx.Param)\n",
298-
" flax_model_params_fstate = flax_model_params.flat_state()\n",
298+
" flax_model_params_fstate = dict(flax_model_params.flat_state())\n",
299299
"\n",
300300
" # Mapping from Flax parameter names to TF parameter names.\n",
301301
" params_name_mapping = {\n",

docs/source/JAX_Vision_transformer.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def vit_inplace_copy_weights(*, src_model, dst_model):
251251
252252
# Notice the use of `flax.nnx.state`.
253253
flax_model_params = nnx.state(dst_model, nnx.Param)
254-
flax_model_params_fstate = flax_model_params.flat_state()
254+
flax_model_params_fstate = dict(flax_model_params.flat_state())
255255
256256
# Mapping from Flax parameter names to TF parameter names.
257257
params_name_mapping = {

docs/source/JAX_image_captioning.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@
700700
" tf_model_params_fstate = nnx.traversals.flatten_mapping(tf_model_params)\n",
701701
"\n",
702702
" flax_model_params = nnx.state(dst_model, nnx.Param)\n",
703-
" flax_model_params_fstate = flax_model_params.flat_state()\n",
703+
" flax_model_params_fstate = dict(flax_model_params.flat_state())\n",
704704
"\n",
705705
" src_num_params = sum([p.size for p in tf_model_params_fstate.values()])\n",
706706
" dst_num_params = sum([p.value.size for p in flax_model_params_fstate.values()])\n",
@@ -1286,7 +1286,7 @@
12861286
"metadata": {},
12871287
"outputs": [],
12881288
"source": [
1289-
"for key in list(nnx.state(model, trainable_params_filter).flat_state().keys()):\n",
1289+
"for key, _ in nnx.state(model, trainable_params_filter).flat_state():\n",
12901290
" assert \"encoder\" not in key"
12911291
]
12921292
},

docs/source/JAX_image_captioning.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def vit_inplace_copy_weights(*, src_model, dst_model):
458458
tf_model_params_fstate = nnx.traversals.flatten_mapping(tf_model_params)
459459
460460
flax_model_params = nnx.state(dst_model, nnx.Param)
461-
flax_model_params_fstate = flax_model_params.flat_state()
461+
flax_model_params_fstate = dict(flax_model_params.flat_state())
462462
463463
src_num_params = sum([p.size for p in tf_model_params_fstate.values()])
464464
dst_num_params = sum([p.value.size for p in flax_model_params_fstate.values()])
@@ -913,7 +913,7 @@ model_diffstate = nnx.DiffState(0, trainable_params_filter)
913913
```
914914

915915
```{code-cell} ipython3
916-
for key in list(nnx.state(model, trainable_params_filter).flat_state().keys()):
916+
for key, _ in nnx.state(model, trainable_params_filter).flat_state():
917917
assert "encoder" not in key
918918
```
919919

docs/source/JAX_porting_PyTorch_model.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,13 +376,13 @@
376376
"To inspect NNX module's learnable parameters and buffers, we can use `nnx.state`:\n",
377377
"```python\n",
378378
"nnx_module = ...\n",
379-
"for k, v in nnx.state(nnx_module, nnx.Param).flat_state().items():\n",
379+
"for k, v in nnx.state(nnx_module, nnx.Param).flat_state():\n",
380380
" print(\n",
381381
" k,\n",
382382
" v.value.mean() if v.value is not None else None\n",
383383
" )\n",
384384
"\n",
385-
"for k, v in nnx.state(nnx_module, (nnx.BatchStat, Buffer)).flat_state().items():\n",
385+
"for k, v in nnx.state(nnx_module, (nnx.BatchStat, Buffer)).flat_state():\n",
386386
" print(\n",
387387
" k,\n",
388388
" v.value.mean() if v.value.dtype == \"float32\" else v.value.sum()\n",

docs/source/JAX_porting_PyTorch_model.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,13 @@ class Model(nnx.Module):
180180
To inspect NNX module's learnable parameters and buffers, we can use `nnx.state`:
181181
```python
182182
nnx_module = ...
183-
for k, v in nnx.state(nnx_module, nnx.Param).flat_state().items():
183+
for k, v in nnx.state(nnx_module, nnx.Param).flat_state():
184184
print(
185185
k,
186186
v.value.mean() if v.value is not None else None
187187
)
188188

189-
for k, v in nnx.state(nnx_module, (nnx.BatchStat, Buffer)).flat_state().items():
189+
for k, v in nnx.state(nnx_module, (nnx.BatchStat, Buffer)).flat_state():
190190
print(
191191
k,
192192
v.value.mean() if v.value.dtype == "float32" else v.value.sum()

0 commit comments

Comments
 (0)