Skip to content

update pytorch_gemma.ipynb to use gemma 2 #501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 31, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 91 additions & 32 deletions site/en/gemma/docs/pytorch_gemma.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,19 @@
"id": "jbza6uQdA-0P"
},
"source": [
"## Kaggle access\n",
"### 1. Set up Kaggle access for Gemma\n",
"\n",
"To login to Kaggle, you can either store your `kaggle.json` credentials file at\n",
"`~/.kaggle/kaggle.json` or run the following in a Colab environment. See the\n",
"[`kagglehub` package documentation](https://github.com/Kaggle/kagglehub#authenticate)\n",
"for more details."
"To complete this tutorial, you first need to follow the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup), which show you how to do the following:\n",
"\n",
"* Get access to Gemma on [kaggle.com](https://www.kaggle.com/models/google/gemma/).\n",
"* Select a Colab runtime with sufficient resources to run the Gemma model.\n",
"* Generate and configure a Kaggle username and API key.\n",
"\n",
"After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.\n",
"\n",
"### 2. Set environment variables\n",
"\n",
"Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`. When prompted with the \"Grant access?\" messages, agree to provide secret access."
]
},
{
Expand All @@ -87,9 +94,11 @@
},
"outputs": [],
"source": [
"import kagglehub\n",
"import os\n",
"from google.colab import userdata # `userdata` is a Colab API.\n",
"\n",
"kagglehub.login()"
"os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
"os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
]
},
{
Expand All @@ -107,7 +116,24 @@
"metadata": {
"id": "bMboT70Xop8G"
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m797.2/797.2 MB\u001b[0m \u001b[31m1.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m209.4/209.4 MB\u001b[0m \u001b[31m6.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m51.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.3/21.3 MB\u001b[0m \u001b[31m55.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"fastai 2.7.15 requires torch<2.4,>=1.10, but you have torch 2.4.0 which is incompatible.\n",
"torchaudio 2.3.1+cu121 requires torch==2.3.1, but you have torch 2.4.0 which is incompatible.\n",
"torchvision 0.18.1+cu121 requires torch==2.3.1, but you have torch 2.4.0 which is incompatible.\u001b[0m\u001b[31m\n",
"\u001b[0m"
]
}
],
"source": [
"!pip install -q -U torch immutabledict sentencepiece"
]
Expand All @@ -130,29 +156,43 @@
"outputs": [],
"source": [
"# Choose variant and machine type\n",
"VARIANT = '2b-it' #@param ['2b', '2b-it', '7b', '7b-it', '7b-quant', '7b-it-quant']\n",
"MACHINE_TYPE = 'cuda' #@param ['cuda', 'cpu']"
"VARIANT = '2b-it' #@param ['2b', '2b-it', '9b', '9b-it', '27b', '27b-it']\n",
"MACHINE_TYPE = 'cuda' #@param ['cuda', 'cpu']\n",
"\n",
"CONFIG = VARIANT[:2]\n",
"if CONFIG == '2b':\n",
" CONFIG = '2b-v2'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "viESUwjq5cAz"
"id": "ONRhkIDrE4Un"
},
"outputs": [],
"source": [
"import os\n",
"import kagglehub\n",
"\n",
"# Load model weights\n",
"weights_dir = kagglehub.model_download(f'google/gemma/pyTorch/{VARIANT}')\n",
"\n",
"weights_dir = kagglehub.model_download(f'google/gemma-2-2b/pyTorch/gemma-2-{VARIANT}')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "viESUwjq5cAz"
},
"outputs": [],
"source": [
"# Ensure that the tokenizer is present\n",
"tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')\n",
"assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'\n",
"\n",
"# Ensure that the checkpoint is present\n",
"ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')\n",
"ckpt_path = os.path.join(weights_dir, f'model.ckpt')\n",
"assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'"
]
},
Expand All @@ -171,7 +211,21 @@
"metadata": {
"id": "ww83zI9ToPso"
},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Cloning into 'gemma_pytorch'...\n",
"remote: Enumerating objects: 239, done.\u001b[K\n",
"remote: Counting objects: 100% (123/123), done.\u001b[K\n",
"remote: Compressing objects: 100% (68/68), done.\u001b[K\n",
"remote: Total 239 (delta 86), reused 58 (delta 55), pack-reused 116\u001b[K\n",
"Receiving objects: 100% (239/239), 2.18 MiB | 20.83 MiB/s, done.\n",
"Resolving deltas: 100% (135/135), done.\n"
]
}
],
"source": [
"# NOTE: The \"installation\" is just cloning the repo.\n",
"!git clone https://github.com/google/gemma_pytorch.git"
Expand All @@ -198,8 +252,12 @@
},
"outputs": [],
"source": [
"from gemma_pytorch.gemma.config import get_config_for_7b, get_config_for_2b\n",
"from gemma_pytorch.gemma.model import GemmaForCausalLM"
"from gemma.config import GemmaConfig, get_model_config\n",
"from gemma.model import GemmaForCausalLM\n",
"from gemma.tokenizer import Tokenizer\n",
"import contextlib\n",
"import os\n",
"import torch"
]
},
{
Expand All @@ -219,10 +277,8 @@
},
"outputs": [],
"source": [
"import torch\n",
"\n",
"# Set up model config.\n",
"model_config = get_config_for_2b() if \"2b\" in VARIANT else get_config_for_7b()\n",
"model_config = get_model_config(CONFIG)\n",
"model_config.tokenizer = tokenizer_path\n",
"model_config.quant = 'quant' in VARIANT\n",
"\n",
Expand Down Expand Up @@ -255,7 +311,7 @@
"- `user`: user turn\n",
"- `model`: model turn\n",
"- `<start_of_turn>`: beginning of dialogue turn\n",
"- `<end_of_turn>`: end of dialogue turn\n",
"- `<end_of_turn><eos>`: end of dialogue turn\n",
"\n",
"Read about the Gemma formatting for instruction tuning and system instructions\n",
"[here](https://ai.google.dev/gemma/docs/formatting)."
Expand All @@ -274,11 +330,11 @@
"text": [
"Chat prompt:\n",
" <start_of_turn>user\n",
"What is a good place for travel in the US?<end_of_turn>\n",
"What is a good place for travel in the US?<end_of_turn><eos>\n",
"<start_of_turn>model\n",
"California.<end_of_turn>\n",
"California.<end_of_turn><eos>\n",
"<start_of_turn>user\n",
"What can I do in California?<end_of_turn>\n",
"What can I do in California?<end_of_turn><eos>\n",
"<start_of_turn>model\n",
"\n"
]
Expand All @@ -289,10 +345,10 @@
"type": "string"
},
"text/plain": [
"\"* **Visit the Golden Gate Bridge and Alcatraz Island in San Francisco.**\\n* **Head to Yosemite National Park and marvel at nature's beauty.**\\n* **Explore the bustling metropolis of Los Angeles.**\\n* **Relax on the pristine beaches of Santa Monica or Malibu.**\\n* **Go whale watching in Monterey Bay.**\\n* **Discover the charming coastal towns of Monterey Bay and Carmel-by-the-Sea.**\\n* **Visit Disneyland and Disney California Adventure in Anaheim.**\\n*\""
"\"California is a state brimming with diverse activities! To give you a great list, tell me: \\n\\n* **What kind of trip are you looking for?** Nature, City life, Beach, Theme Parks, Food, History, something else? \\n* **What are you interested in (e.g., hiking, museums, art, nightlife, shopping)?** \\n* **What's your budget like?** \\n* **Who are you traveling with?** (family, friends, solo) \\n\\nThe more you tell me, the better recommendations I can give! 😊 \\n<end_of_turn>\""
]
},
"execution_count": 55,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -301,8 +357,8 @@
"# Generate with one request in chat mode\n",
"\n",
"# Chat templates\n",
"USER_CHAT_TEMPLATE = '<start_of_turn>user\\n{prompt}<end_of_turn>\\n'\n",
"MODEL_CHAT_TEMPLATE = '<start_of_turn>model\\n{prompt}<end_of_turn>\\n'\n",
"USER_CHAT_TEMPLATE = \"<start_of_turn>user\\n{prompt}<end_of_turn><eos>\\n\"\n",
"MODEL_CHAT_TEMPLATE = \"<start_of_turn>model\\n{prompt}<end_of_turn><eos>\\n\"\n",
"\n",
"# Sample formatted prompt\n",
"prompt = (\n",
Expand All @@ -318,7 +374,7 @@
"model.generate(\n",
" USER_CHAT_TEMPLATE.format(prompt=prompt),\n",
" device=device,\n",
" output_len=100,\n",
" output_len=128,\n",
")"
]
},
Expand All @@ -331,11 +387,14 @@
"outputs": [
{
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"['\\n\\nThe fingers dance on the keys,\\nA symphony of thoughts and dreams.\\nThe mind, a canvas yet uncouth,\\nScribbling its secrets in the night.\\n\\nThe ink, a whispered voice from deep,\\nA language ancient, never to sleep.\\nEach stroke an echo of']"
"\"\\n\\nA swirling cloud of data, raw and bold,\\nIt hums and whispers, a story untold.\\nAn LLM whispers, code into refrain,\\nCrafting words of rhyme, a lyrical strain.\\n\\nA world of pixels, logic's vibrant hue,\\nFlows through its veins, forever anew.\\nThe human touch it seeks, a gentle hand,\\nTo mold and shape, understand.\\n\\nEmotions it might learn, from snippets of prose,\\nInspiration it seeks, a yearning\""
]
},
"execution_count": 56,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -345,7 +404,7 @@
"model.generate(\n",
" 'Write a poem about an llm writing a poem.',\n",
" device=device,\n",
" output_len=60,\n",
" output_len=100,\n",
")"
]
},
Expand Down
Loading