Skip to content

Update links to PaliGemma 2 in fine-tuning docs #545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
"Before using PaliGemma for the first time, you must request access to the model through Kaggle by completing the following steps:\n",
"\n",
"1. Log in to [Kaggle](https://www.kaggle.com), or create a new Kaggle account if you don't already have one.\n",
"1. Go to the [PaliGemma model card](https://www.kaggle.com/models/google/paligemma/) and click **Request Access**.\n",
"1. Go to the [PaliGemma model card](https://www.kaggle.com/models/google/paligemma-2) and click **Request Access**.\n",
"1. Complete the consent form and accept the terms and conditions."
]
},
Expand Down Expand Up @@ -287,7 +287,7 @@
"source": [
"### Download the model checkpoint\n",
"\n",
"PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma/jax/paligemma-3b-pt-224).\n",
"PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma-2/jax/paligemma2-3b-pt-224).\n",
"\n",
"Download the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete."
]
Expand Down Expand Up @@ -365,7 +365,7 @@
"source": [
"# Define model\n",
"\n",
"# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property, we set it to 0.0\n",
"# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property. Set it to 0.0\n",
"# for better transfer results.\n",
"model_config = ml_collections.FrozenConfigDict({\n",
" \"llm\": {\"vocab_size\": 257_152, \"variant\": LLM_VARIANT, \"final_logits_softcap\": 0.0},\n",
Expand Down Expand Up @@ -434,7 +434,7 @@
"\n",
"# Loading all params in simultaneous - albeit much faster and more succinct -\n",
"# requires more RAM than the T4 colab runtimes have by default.\n",
"# Instead we do it param by param.\n",
"# Instead, do it param by param.\n",
"params, treedef = jax.tree.flatten(params)\n",
"sharding_leaves = jax.tree.leaves(params_sharding)\n",
"trainable_leaves = jax.tree.leaves(trainable_mask)\n",
Expand Down Expand Up @@ -708,7 +708,7 @@
" targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])\n",
"\n",
" # Compute the loss per example. i.e. the mean of per token pplx.\n",
" # Since each example has a different number of tokens we normalize it.\n",
" # Since each example has a different number of tokens, normalize it.\n",
" token_pplx = jnp.sum(logp * targets, axis=-1) # sum across vocab_size.\n",
" example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1) # sum across seq_len.\n",
" example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1) # weight by num of tokens.\n",
Expand Down
Loading