From 0716c281bbcba6a1a74e0663ea2531bae309b0a1 Mon Sep 17 00:00:00 2001 From: Joe Fernandez Date: Wed, 4 Dec 2024 21:28:46 -0800 Subject: [PATCH 1/2] Updating PaliGemma notebooks --- .../paligemma/fine-tuning-paligemma.ipynb | 92 ++++-- .../docs/paligemma/inference-with-keras.ipynb | 290 +++++++++--------- 2 files changed, 215 insertions(+), 167 deletions(-) diff --git a/site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb b/site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb index b2cb645f0..4efd63bb2 100644 --- a/site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb +++ b/site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb @@ -6,8 +6,8 @@ "id": "G3MMAcssHTML" }, "source": [ - "\n", - "" + "\n", + "\n" ] }, { @@ -59,15 +59,8 @@ "\n", "View source on GitHub\n", "\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wR53lePHuiP-" - }, - "source": [ + "\n", + "\n", "This notebook shows how to fine-tune [PaliGemma](https://ai.google.dev/gemma/docs/paligemma) on a vision-language task with [JAX](https://jax.readthedocs.io/en/latest/index.html). *Fine-tuning* is a process that can improve your model's performance on specific tasks or help the model adhere to specific output requirements when instructions aren't sufficient and you have a set of examples that demonstrate the outputs you want. Gemma-based models like PaliGemma require fine-tuning to produce expected results.\n", "\n", "### What's in this notebook\n", @@ -128,7 +121,8 @@ "\n", "To generate a Kaggle API key, open your [**Settings** page in Kaggle](https://www.kaggle.com/settings) and click **Create New Token**. This triggers the download of a `kaggle.json` file containing your API credentials.\n", "\n", - "Then, in Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`.\n" + "Then, in Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`.\n", + "\n" ] }, { @@ -172,7 +166,11 @@ "# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json\n", "\n", "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n", - "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')" + "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')\n", + "\n", + "# The T4 runtime is tight on memory to finetune this model. Preallocate\n", + "# all memory ahead of time to avoid out-of-memory due to fragmentation.\n", + "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.0\"" ] }, { @@ -265,7 +263,7 @@ "tf.config.set_visible_devices([], \"GPU\")\n", "tf.config.set_visible_devices([], \"TPU\")\n", "\n", - "backend = jax.lib.xla_bridge.get_backend()\n", + "backend = jax.extend.backend.get_backend()\n", "print(f\"JAX version: {jax.__version__}\")\n", "print(f\"JAX platform: {backend.platform}\")\n", "print(f\"JAX devices: {jax.device_count()}\")" @@ -292,7 +290,7 @@ "\n", "PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma/jax/paligemma-3b-pt-224).\n", "\n", - "Download the `float16` version of the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete." + "Download the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete." ] }, { @@ -306,12 +304,19 @@ "import os\n", "import kagglehub\n", "\n", - "MODEL_PATH = \"./pt_224_128.params.f16.npz\"\n", + "# Use these for PaliGemma-2 3B 224px²\n", + "LLM_VARIANT = \"gemma2_2b\"\n", + "MODEL_PATH = \"./paligemma2-3b-pt-224.b16.npz\"\n", + "KAGGLE_HANDLE = \"google/paligemma-2/jax/paligemma2-3b-pt-224\" # Path to fetch from Kaggle.\n", + "\n", + "# Use these for PaliGemma 1:\n", + "# LLM_VARIANT = \"gemma_2b\"\n", + "# MODEL_PATH = \"./paligemma-3b-pt-224.f16.npz\"\n", + "# KAGGLE_HANDLE = \"google/paligemma/jax/paligemma-3b-pt-224\"\n", + "\n", "if not os.path.exists(MODEL_PATH):\n", " print(\"Downloading the checkpoint from Kaggle, this could take a few minutes....\")\n", - " # Note: kaggle archive contains the same checkpoint in multiple formats.\n", - " # Download only the float16 model.\n", - " MODEL_PATH = kagglehub.model_download('google/paligemma/jax/paligemma-3b-pt-224', 'paligemma-3b-pt-224.f16.npz')\n", + " MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)\n", " print(f\"Model path: {MODEL_PATH}\")\n", "\n", "TOKENIZER_PATH = \"./paligemma_tokenizer.model\"\n", @@ -360,8 +365,11 @@ "outputs": [], "source": [ "# Define model\n", + "\n", + "# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property, we set it to 0.0\n", + "# for better transfer results.\n", "model_config = ml_collections.FrozenConfigDict({\n", - " \"llm\": {\"vocab_size\": 257_152},\n", + " \"llm\": {\"vocab_size\": 257_152, \"variant\": LLM_VARIANT, \"final_logits_softcap\": 0.0},\n", " \"img\": {\"variant\": \"So400m/14\", \"pool_type\": \"none\", \"scan\": True, \"dtype_mm\": \"float16\"}\n", "})\n", "model = paligemma.Model(**model_config)\n", @@ -420,7 +428,9 @@ "\n", "@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))\n", "def maybe_cast_to_f32(params, trainable):\n", - " return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p,\n", + " # Cast others to float16, since some GPUs don't support bf16.\n", + " return jax.tree.map(lambda p, m: p.astype(jnp.float32)\n", + " if m else p.astype(jnp.float16),\n", " params, trainable)\n", "\n", "# Loading all params in simultaneous - albeit much faster and more succinct -\n", @@ -492,7 +502,7 @@ "\n", " image = tf.constant(image)\n", " image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)\n", - " return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1]\n", + " return image.numpy() / 127.5 - 1.0 # [0, 255]-\u003e[-1,1]\n", "\n", "def preprocess_tokens(prefix, suffix=None, seqlen=None):\n", " # Model has been trained to handle tokenized text composed of a prefix with\n", @@ -632,12 +642,12 @@ " return f\"data:image/jpeg;base64,{image_b64}\"\n", "\n", "def render_example(image, caption):\n", - " image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -> [0, 255]\n", + " image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -\u003e [0, 255]\n", " return f\"\"\"\n", - "
\n", - " \n", - "

{html.escape(caption)}

\n", - "
\n", + " \u003cdiv style=\"display: inline-flex; align-items: center; justify-content: center;\"\u003e\n", + " \u003cimg style=\"width:128px; height:128px;\" src=\"{render_inline(image, resize=(64,64))}\" /\u003e\n", + " \u003cp style=\"width:256px; margin:10px; font-size:small;\"\u003e{html.escape(caption)}\u003c/p\u003e\n", + " \u003c/div\u003e\n", " \"\"\"\n", "\n", "html_out = \"\"\n", @@ -754,7 +764,7 @@ " # Append to html output.\n", " for example, response in zip(examples, responses):\n", " outputs.append((example[\"image\"], response))\n", - " if num_examples and len(outputs) >= num_examples:\n", + " if num_examples and len(outputs) \u003e= num_examples:\n", " return outputs" ] }, @@ -862,14 +872,36 @@ ], "metadata": { "colab": { - "name": "fine-tuning-paligemma.ipynb", + "gpuType": "T4", + "last_runtime": { + "build_target": "//learning/grp/tools/ml_python:ml_notebook", + "kind": "private" + }, + "private_outputs": true, + "provenance": [ + { + "file_id": "17AiK8gRY7oiquQGkBH0d08PFQo3Kyx1I", + "timestamp": 1715287187925 + }, + { + "file_id": "1qZlJfPyfKRrNcz2shxQ93HnnE5Ge1LLn", + "timestamp": 1715019972450 + }, + { + "file_id": "1JFnlD2kSiTNexdPw_NYRtuW6uuSTI0kD", + "timestamp": 1714585741026 + } + ], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" + }, + "language_info": { + "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file diff --git a/site/en/gemma/docs/paligemma/inference-with-keras.ipynb b/site/en/gemma/docs/paligemma/inference-with-keras.ipynb index 32581fb4b..3aca91cf2 100644 --- a/site/en/gemma/docs/paligemma/inference-with-keras.ipynb +++ b/site/en/gemma/docs/paligemma/inference-with-keras.ipynb @@ -44,31 +44,31 @@ { "cell_type": "markdown", "metadata": { - "id": "etcMXWCUJApZ" + "id": "Q5_nIe-8gdJV" }, "source": [ - "# Inference with Keras\n" + "# Generate PaliGemma output with Keras\n", + "\n", + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", + "\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://ai.google.dev/gemma/docs/paligemma/fine-tuning-paligemma\"\u003e\u003cimg src=\"https://ai.google.dev/static/site-assets/images/docs/notebook-site-button.png\" height=\"32\" width=\"32\" /\u003eView on ai.google.dev\u003c/a\u003e\n", + "\u003c/td\u003e\n", + "\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/google/generative-ai-docs/blob/main/site/en/gemma/docs/paligemma/inference-with-keras.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + "\u003c/td\u003e\n", + "\u003ctd\u003e\n", + "\u003ca target=\"_blank\" href=\"https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/paligemma/inference-with-keras.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", + "\u003c/td\u003e\n", + "\u003c/table\u003e" ] }, { "cell_type": "markdown", "metadata": { - "id": "Q5_nIe-8gdJV" + "id": "9hhIuS9sEKHx" }, "source": [ - "\n", - "\n", - "\n", - "\n", - "
\n", - "View on ai.google.dev\n", - "\n", - "Run in Google Colab\n", - "\n", - "View source on GitHub\n", - "
\n", - "\n", - "When your AI model produces a conclusion or a prediction, it goes through a process called *inference*. This tutorial goes over how to use PaliGemma with Keras to set up a simple model that can infer information about supplied images and answer questions about them." + "PaliGemma models have *multimodal* capabilities, allowing you to generate output using both text and image input data. You can use image data with these models to provide additional context for your requests, or use the model to analyze the content of images. This tutorial shows you how to use PaliGemma with Keras to can analyze images and answer questions about them." ] }, { @@ -223,7 +223,7 @@ "outputs": [], "source": [ "import keras\n", - "import keras_nlp\n", + "import keras_hub\n", "import numpy as np\n", "import PIL\n", "import requests\n", @@ -237,26 +237,17 @@ "keras.config.set_floatx(\"bfloat16\")" ] }, - { - "cell_type": "markdown", - "metadata": { - "id": "ftjt5DiueVkL" - }, - "source": [ - "## Create your model\n", - "\n", - "Now that you've set everything up, you can download the pre-trained model and create some utility methods to help your model generate its responses." - ] - }, { "cell_type": "markdown", "metadata": { "id": "X-LE2E1uiSpP" }, "source": [ - "### Download the model checkpoint\n", + "## Load the model\n", + "\n", + "Now that you've set everything up, you can download the pre-trained model and create some utility methods to help your model generate its responses.\n", + "In this step, you download a model using `PaliGemmaCausalLM` from Keras Hub. This class helps you manage and run the causal visual language model structure of PaliGemma. A *causal visual language model* predicts the next token based on previous tokens. Keras Hub provides implementations of many popular [model architectures](https://keras.io/keras_hub/api/models/).\n", "\n", - "KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/). In this notebook, you'll create a model using `PaliGemmaCausalLM`, an end-to-end PaliGemma model for *causal visual language modeling*. A causal visual language model predicts the next token based on previous tokens.\n", "\n", "Create the model using the `from_preset` method and print its summary. This process will take about a minute to complete." ] @@ -269,7 +260,7 @@ }, "outputs": [], "source": [ - "paligemma = keras_nlp.models.PaliGemmaCausalLM.from_preset(\"pali_gemma_3b_mix_224\")\n", + "paligemma = keras_hub.models.PaliGemmaCausalLM.from_preset(\"paligemma_3b_mix_224\")\n", "paligemma.summary()" ] }, @@ -279,7 +270,7 @@ "id": "FBsWvKEvoGMe" }, "source": [ - "### Create utility methods\n", + "## Create utility methods\n", "\n", "To help you generate responses from your model, create two utility methods:\n", "\n", @@ -287,7 +278,8 @@ "* **`read_img`:** Helper method for `read_img_from_url`. This method is what actually opens the image, resizes it so that it fits in the model's constraints, and puts it into an array that can be interpreted by the model.\n", "* **`read_img_from_url`:** Takes in an image via a valid URL. You need this method to pass the image to the model.\n", "\n", - "You'll use `read_img_from_url` in the next step of this notebook.\n" + "You'll use `read_img_from_url` in the next step of this notebook.\n", + "\n" ] }, { @@ -318,8 +310,8 @@ "\n", "def parse_bbox_and_labels(detokenized_output: str):\n", " matches = re.finditer(\n", - " '\\d\\d\\d\\d)>\\d\\d\\d\\d)>\\d\\d\\d\\d)>\\d\\d\\d\\d)>'\n", - " ' (?P