Skip to content

Commit 0e912b0

Browse files
authored
Fixed a typo (#213)
1 parent 68efe18 commit 0e912b0

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

docs/source/JAX_for_LLM_pretraining.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@
626626
"\n",
627627
"Start training. It takes ~50 minutes on Colab.\n",
628628
"\n",
629-
"Note that for data parallel, we are sharding the training data along the `batch` axis using `jax.device_put` with `NamedeSharding`.\n",
629+
"Note that for data parallel, we are sharding the training data along the `batch` axis using `jax.device_put` with `NamedSharding`.\n",
630630
"\n",
631631
"We are also using the `jax.vmap` transformation to produce the target sequences faster."
632632
]

docs/source/JAX_for_LLM_pretraining.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def train_step(model: MiniGPT, optimizer: nnx.Optimizer, metrics: nnx.MultiMetri
464464

465465
Start training. It takes ~50 minutes on Colab.
466466

467-
Note that for data parallel, we are sharding the training data along the `batch` axis using `jax.device_put` with `NamedeSharding`.
467+
Note that for data parallel, we are sharding the training data along the `batch` axis using `jax.device_put` with `NamedSharding`.
468468

469469
We are also using the `jax.vmap` transformation to produce the target sequences faster.
470470

0 commit comments

Comments
 (0)