diff --git a/site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb b/site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb index a0d762e69..fc6b165f6 100644 --- a/site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb +++ b/site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb @@ -105,7 +105,7 @@ "Before using PaliGemma for the first time, you must request access to the model through Kaggle by completing the following steps:\n", "\n", "1. Log in to [Kaggle](https://www.kaggle.com), or create a new Kaggle account if you don't already have one.\n", - "1. Go to the [PaliGemma model card](https://www.kaggle.com/models/google/paligemma/) and click **Request Access**.\n", + "1. Go to the [PaliGemma model card](https://www.kaggle.com/models/google/paligemma-2) and click **Request Access**.\n", "1. Complete the consent form and accept the terms and conditions." ] }, @@ -287,7 +287,7 @@ "source": [ "### Download the model checkpoint\n", "\n", - "PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma/jax/paligemma-3b-pt-224).\n", + "PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma-2/jax/paligemma2-3b-pt-224).\n", "\n", "Download the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete." ] @@ -365,7 +365,7 @@ "source": [ "# Define model\n", "\n", - "# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property, we set it to 0.0\n", + "# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property. Set it to 0.0\n", "# for better transfer results.\n", "model_config = ml_collections.FrozenConfigDict({\n", " \"llm\": {\"vocab_size\": 257_152, \"variant\": LLM_VARIANT, \"final_logits_softcap\": 0.0},\n", @@ -434,7 +434,7 @@ "\n", "# Loading all params in simultaneous - albeit much faster and more succinct -\n", "# requires more RAM than the T4 colab runtimes have by default.\n", - "# Instead we do it param by param.\n", + "# Instead, do it param by param.\n", "params, treedef = jax.tree.flatten(params)\n", "sharding_leaves = jax.tree.leaves(params_sharding)\n", "trainable_leaves = jax.tree.leaves(trainable_mask)\n", @@ -708,7 +708,7 @@ " targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])\n", "\n", " # Compute the loss per example. i.e. the mean of per token pplx.\n", - " # Since each example has a different number of tokens we normalize it.\n", + " # Since each example has a different number of tokens, normalize it.\n", " token_pplx = jnp.sum(logp * targets, axis=-1) # sum across vocab_size.\n", " example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1) # sum across seq_len.\n", " example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1) # weight by num of tokens.\n",