diff --git a/site/en/gemma/docs/paligemma/inference-with-keras.ipynb b/site/en/gemma/docs/paligemma/inference-with-keras.ipynb
index 8b44ed6d1..aa4dc5eca 100644
--- a/site/en/gemma/docs/paligemma/inference-with-keras.ipynb
+++ b/site/en/gemma/docs/paligemma/inference-with-keras.ipynb
@@ -1,43 +1,16 @@
{
"cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "3_lX1k54KKrx"
- },
- "source": [
- "Copyright 2024 Google LLC."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "Gr4W9nspKGtb"
- },
- "outputs": [],
- "source": [
- "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
- "# you may not use this file except in compliance with the License.\n",
- "# You may obtain a copy of the License at\n",
- "#\n",
- "# https://www.apache.org/licenses/LICENSE-2.0\n",
- "#\n",
- "# Unless required by applicable law or agreed to in writing, software\n",
- "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
- "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
- "# See the License for the specific language governing permissions and\n",
- "# limitations under the License."
- ]
- },
{
"cell_type": "markdown",
"metadata": {
"id": "etcMXWCUJApZ"
},
"source": [
- "# Inference with Keras\n"
+ "Project: /gemma/_project.yaml\n",
+ "Book: /gemma/_book.yaml\n",
+ " \n",
+ "# Inference with Keras\n",
+ "\n"
]
},
{
@@ -48,7 +21,7 @@
"source": [
"
\n",
"\n",
- " View on ai.google.dev \n",
+ " View on ai.google.dev \n",
" \n",
"\n",
" Run in Google Colab \n",
@@ -277,7 +250,8 @@
"* **`read_img`:** Helper method for `read_img_from_url`. This method is what actually opens the image, resizes it so that it fits in the model's constraints, and puts it into an array that can be interpreted by the model.\n",
"* **`read_img_from_url`:** Takes in an image via a valid URL. You need this method to pass the image to the model.\n",
"\n",
- "You'll use `read_img_from_url` in the next step of this notebook.\n"
+ "You'll use `read_img_from_url` in the next step of this notebook.\n",
+ "\n"
]
},
{
@@ -342,26 +316,46 @@
"\n",
" plt.show()\n",
"\n",
- "def display_segment_output(image, segment_mask, target_image_size):\n",
- " # Calculate scaling factors\n",
- " h, w = target_image_size\n",
- " x_scale = w / 64\n",
- " y_scale = h / 64\n",
- "\n",
- " # Create coordinate grids for the new image\n",
- " x_coords = np.arange(w)\n",
- " y_coords = np.arange(h)\n",
- " x_coords = (x_coords / x_scale).astype(int)\n",
- " y_coords = (y_coords / y_scale).astype(int)\n",
- " resized_array = segment_mask[y_coords[:, np.newaxis], x_coords]\n",
- " # Create a figure and axis\n",
- " fig, ax = plt.subplots()\n",
- "\n",
- " # Display the image\n",
- " ax.imshow(image)\n",
- "\n",
- " # Overlay the mask with transparency\n",
- " ax.imshow(resized_array, cmap='jet', alpha=0.5)"
+ "def display_segment_output(image, bounding_box, segment_mask, target_image_size):\n",
+ " # Initialize a full mask with the target size\n",
+ " full_mask = np.zeros(target_image_size, dtype=np.uint8)\n",
+ " target_width, target_height = target_image_size\n",
+ "\n",
+ " for bbox, mask in zip(bounding_box, segment_mask):\n",
+ " y1, x1, y2, x2 = bbox\n",
+ " x1 = int(x1 * target_width)\n",
+ " y1 = int(y1 * target_height)\n",
+ " x2 = int(x2 * target_width)\n",
+ " y2 = int(y2 * target_height)\n",
+ "\n",
+ " # Ensure mask is 2D before converting to Image\n",
+ " if mask.ndim == 3:\n",
+ " mask = mask.squeeze(axis=-1)\n",
+ " mask = Image.fromarray(mask)\n",
+ " mask = mask.resize((x2 - x1, y2 - y1), resample=Image.NEAREST)\n",
+ " mask = np.array(mask)\n",
+ " binary_mask = (mask > 0.5).astype(np.uint8)\n",
+ "\n",
+ "\n",
+ " # Place the binary mask onto the full mask\n",
+ " full_mask[y1:y2, x1:x2] = np.maximum(full_mask[y1:y2, x1:x2], binary_mask)\n",
+ " cmap = plt.get_cmap('jet')\n",
+ " colored_mask = cmap(full_mask / 1.0)\n",
+ " colored_mask = (colored_mask[:, :, :3] * 255).astype(np.uint8)\n",
+ " if isinstance(image, Image.Image):\n",
+ " image = np.array(image)\n",
+ " blended_image = image.copy()\n",
+ " mask_indices = full_mask > 0\n",
+ " alpha = 0.5\n",
+ "\n",
+ " for c in range(3):\n",
+ " blended_image[:, :, c] = np.where(mask_indices,\n",
+ " (1 - alpha) * image[:, :, c] + alpha * colored_mask[:, :, c],\n",
+ " image[:, :, c])\n",
+ "\n",
+ " fig, ax = plt.subplots()\n",
+ " ax.imshow(blended_image)\n",
+ " plt.show()"
]
},
{
@@ -374,7 +368,8 @@
"\n",
"Now you're ready to give an image and prompt to your model and have it infer the response.\n",
"\n",
- "Lets look at our test image and read it\n"
+ "Lets look at our test image and read it\n",
+ "\n"
]
},
{
@@ -393,22 +388,17 @@
},
{
"cell_type": "markdown",
- "metadata": {
- "id": "R3FE63iMqb2G"
- },
"source": [
"Here's a generation call with a single image and prompt. The prompts have to end with a `\\n`.\n",
"\n",
"We've supplied you with several example prompts — play around with it! Comment and uncomment the prompt variables to change what prompt you supply the model with."
- ]
+ ],
+ "metadata": {
+ "id": "R3FE63iMqb2G"
+ }
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "WQVmhpr6qdNH"
- },
- "outputs": [],
"source": [
"prompt = 'answer en where is the cow standing?\\n'\n",
"# prompt = 'svar no hvor står kuen?'\n",
@@ -421,7 +411,12 @@
" }\n",
")\n",
"print(output)"
- ]
+ ],
+ "metadata": {
+ "id": "WQVmhpr6qdNH"
+ },
+ "execution_count": null,
+ "outputs": []
},
{
"cell_type": "markdown",
@@ -493,20 +488,15 @@
},
{
"cell_type": "markdown",
- "metadata": {
- "id": "eUgPcUGDpsdx"
- },
"source": [
"### Parse detect output"
- ]
+ ],
+ "metadata": {
+ "id": "eUgPcUGDpsdx"
+ }
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "noRIZtqnpzNn"
- },
- "outputs": [],
"source": [
"prompt = 'detect cow\\n'\n",
"output = paligemma.generate(\n",
@@ -517,25 +507,24 @@
")\n",
"boxes, labels = parse_bbox_and_labels(output)\n",
"display_boxes(cow_image, boxes, labels, target_size)"
- ]
+ ],
+ "metadata": {
+ "id": "noRIZtqnpzNn"
+ },
+ "execution_count": null,
+ "outputs": []
},
{
"cell_type": "markdown",
- "metadata": {
- "id": "mUwNeaV0sGau"
- },
"source": [
"### Parse segment output"
- ]
+ ],
+ "metadata": {
+ "id": "mUwNeaV0sGau"
+ }
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "tuxyIMAvsM4B"
- },
- "outputs": [],
"source": [
"# @title Fetch big_vision code and install dependencies.\n",
"import os\n",
@@ -558,45 +547,46 @@
"\n",
"# Install missing dependencies. Assume jax~=0.4.25 with GPU available.\n",
"!pip3 install -q \"overrides\" \"ml_collections\" \"einops~=0.7\" \"sentencepiece\""
- ]
+ ],
+ "metadata": {
+ "cellView": "form",
+ "id": "tuxyIMAvsM4B"
+ },
+ "execution_count": null,
+ "outputs": []
},
{
"cell_type": "markdown",
- "metadata": {
- "id": "18S4uHWqutps"
- },
"source": [
"Let's take a look at another example image."
- ]
+ ],
+ "metadata": {
+ "id": "18S4uHWqutps"
+ }
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "Al6NedG4uNhp"
- },
- "outputs": [],
"source": [
"cat = read_image('https://big-vision-paligemma.hf.space/file=examples/barsik.jpg', target_size)\n",
"matplotlib.pyplot.imshow(cat)"
- ]
+ ],
+ "metadata": {
+ "id": "Al6NedG4uNhp"
+ },
+ "execution_count": null,
+ "outputs": []
},
{
"cell_type": "markdown",
- "metadata": {
- "id": "VRJYikeXu1As"
- },
"source": [
"Here is a function to help parse the segment output from PaliGemma"
- ]
+ ],
+ "metadata": {
+ "id": "VRJYikeXu1As"
+ }
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "_ryCMfGVuxyd"
- },
- "outputs": [],
"source": [
"import big_vision.evaluators.proj.paligemma.transfers.segmentation as segeval\n",
"reconstruct_masks = segeval.get_reconstruct_masks('oi')\n",
@@ -613,24 +603,24 @@
" boxes.append([fmt_box(d['y0']), fmt_box(d['x0']), fmt_box(d['y1']), fmt_box(d['x1'])])\n",
" segs.append([int(d[f's{i}']) for i in range(16)])\n",
" return np.array(boxes), np.array(reconstruct_masks(np.array(segs)))"
- ]
+ ],
+ "metadata": {
+ "id": "_ryCMfGVuxyd"
+ },
+ "execution_count": null,
+ "outputs": []
},
{
"cell_type": "markdown",
- "metadata": {
- "id": "QITD66qJvCTO"
- },
"source": [
"Query PaliGemma to segment the cat in the image"
- ]
+ ],
+ "metadata": {
+ "id": "QITD66qJvCTO"
+ }
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "fB7to-J4u5zY"
- },
- "outputs": [],
"source": [
"prompt = 'segment cat\\n'\n",
"output = paligemma.generate(\n",
@@ -639,39 +629,63 @@
" \"prompts\": prompt,\n",
" }\n",
")"
- ]
+ ],
+ "metadata": {
+ "id": "fB7to-J4u5zY"
+ },
+ "execution_count": null,
+ "outputs": []
},
{
"cell_type": "markdown",
- "metadata": {
- "id": "XZeu6-bovFvz"
- },
"source": [
"Visualize the generated mask from PaliGemma"
- ]
+ ],
+ "metadata": {
+ "id": "XZeu6-bovFvz"
+ }
},
{
"cell_type": "code",
- "execution_count": null,
+ "source": [
+ "bboxes, seg_masks = parse_segments(output)\n",
+ "display_segment_output(cat, bboxes, seg_masks, target_size)"
+ ],
"metadata": {
"id": "GcjOvoPbvAI-"
},
- "outputs": [],
- "source": [
- "_, seg_output = parse_segments(output)\n",
- "display_segment_output(cat, seg_output[0], target_size)"
- ]
+ "execution_count": null,
+ "outputs": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
- "name": "inference-with-keras.ipynb",
+ "gpuType": "T4",
+ "machine_shape": "hm",
+ "private_outputs": true,
+ "provenance": [
+ {
+ "file_id": "1Mh7SHaP1cd5XuuHLmprKoBfr9ddSVu3a",
+ "timestamp": 1715369116355
+ },
+ {
+ "file_id": "1Xo8MY-GJsjjBKS1l_hQ8AImMNw9yxqmH",
+ "timestamp": 1715208696058
+ },
+ {
+ "file_id": "1bLQYoesGy8awiA0UT0askWIjIqcivfdY",
+ "timestamp": 1715147134742
+ }
+ ],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
+ },
+ "language_info": {
+ "name": "python"
}
},
"nbformat": 4,