|
| 1 | +# Copyright 2025 Google LLC |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +from datetime import datetime |
| 16 | +import random |
| 17 | +import time |
| 18 | +from typing import Any |
| 19 | +from typing import Dict |
| 20 | +from typing import Optional |
| 21 | + |
| 22 | +from google.adk import Agent |
| 23 | +from google.adk.tools.tool_context import ToolContext |
| 24 | +from google.genai import types |
| 25 | + |
| 26 | + |
| 27 | +def get_weather(location: str, tool_context: ToolContext) -> Dict[str, Any]: |
| 28 | + """Get weather information for a location. |
| 29 | + Args: |
| 30 | + location: The city or location to get weather for. |
| 31 | + Returns: |
| 32 | + A dictionary containing weather information. |
| 33 | + """ |
| 34 | + # Simulate weather data |
| 35 | + temperatures = [-10, -5, 0, 5, 10, 15, 20, 25, 30, 35] |
| 36 | + conditions = ["sunny", "cloudy", "rainy", "snowy", "windy"] |
| 37 | + |
| 38 | + return { |
| 39 | + "location": location, |
| 40 | + "temperature": random.choice(temperatures), |
| 41 | + "condition": random.choice(conditions), |
| 42 | + "humidity": random.randint(30, 90), |
| 43 | + "timestamp": datetime.now().isoformat(), |
| 44 | + } |
| 45 | + |
| 46 | + |
| 47 | +async def calculate_async(operation: str, x: float, y: float) -> Dict[str, Any]: |
| 48 | + """Perform async mathematical calculations. |
| 49 | + Args: |
| 50 | + operation: The operation to perform (add, subtract, multiply, divide). |
| 51 | + x: First number. |
| 52 | + y: Second number. |
| 53 | + Returns: |
| 54 | + A dictionary containing the calculation result. |
| 55 | + """ |
| 56 | + # Simulate some async work |
| 57 | + await asyncio.sleep(0.1) |
| 58 | + |
| 59 | + operations = { |
| 60 | + "add": x + y, |
| 61 | + "subtract": x - y, |
| 62 | + "multiply": x * y, |
| 63 | + "divide": x / y if y != 0 else float("inf"), |
| 64 | + } |
| 65 | + |
| 66 | + result = operations.get(operation.lower(), "Unknown operation") |
| 67 | + |
| 68 | + return { |
| 69 | + "operation": operation, |
| 70 | + "x": x, |
| 71 | + "y": y, |
| 72 | + "result": result, |
| 73 | + "timestamp": datetime.now().isoformat(), |
| 74 | + } |
| 75 | + |
| 76 | + |
| 77 | +def log_activity(message: str, tool_context: ToolContext) -> Dict[str, str]: |
| 78 | + """Log an activity message with timestamp. |
| 79 | + Args: |
| 80 | + message: The message to log. |
| 81 | + Returns: |
| 82 | + A dictionary confirming the log entry. |
| 83 | + """ |
| 84 | + if "activity_log" not in tool_context.state: |
| 85 | + tool_context.state["activity_log"] = [] |
| 86 | + |
| 87 | + log_entry = {"timestamp": datetime.now().isoformat(), "message": message} |
| 88 | + tool_context.state["activity_log"].append(log_entry) |
| 89 | + |
| 90 | + return { |
| 91 | + "status": "logged", |
| 92 | + "entry": log_entry, |
| 93 | + "total_entries": len(tool_context.state["activity_log"]), |
| 94 | + } |
| 95 | + |
| 96 | + |
| 97 | +# Before tool callbacks |
| 98 | +def before_tool_audit_callback( |
| 99 | + tool, args: Dict[str, Any], tool_context: ToolContext |
| 100 | +) -> Optional[Dict[str, Any]]: |
| 101 | + """Audit callback that logs all tool calls before execution.""" |
| 102 | + print(f"🔍 AUDIT: About to call tool '{tool.name}' with args: {args}") |
| 103 | + |
| 104 | + # Add audit info to tool context state |
| 105 | + if "audit_log" not in tool_context.state: |
| 106 | + tool_context.state["audit_log"] = [] |
| 107 | + |
| 108 | + tool_context.state["audit_log"].append({ |
| 109 | + "type": "before_call", |
| 110 | + "tool_name": tool.name, |
| 111 | + "args": args, |
| 112 | + "timestamp": datetime.now().isoformat(), |
| 113 | + }) |
| 114 | + |
| 115 | + # Return None to allow normal tool execution |
| 116 | + return None |
| 117 | + |
| 118 | + |
| 119 | +def before_tool_security_callback( |
| 120 | + tool, args: Dict[str, Any], tool_context: ToolContext |
| 121 | +) -> Optional[Dict[str, Any]]: |
| 122 | + """Security callback that can block certain tool calls.""" |
| 123 | + # Example: Block weather requests for restricted locations |
| 124 | + if tool.name == "get_weather" and args.get("location", "").lower() in [ |
| 125 | + "classified", |
| 126 | + "secret", |
| 127 | + ]: |
| 128 | + print( |
| 129 | + "🚫 SECURITY: Blocked weather request for restricted location:" |
| 130 | + f" {args.get('location')}" |
| 131 | + ) |
| 132 | + return { |
| 133 | + "error": "Access denied", |
| 134 | + "reason": "Location access is restricted", |
| 135 | + "requested_location": args.get("location"), |
| 136 | + } |
| 137 | + |
| 138 | + # Allow other calls to proceed |
| 139 | + return None |
| 140 | + |
| 141 | + |
| 142 | +async def before_tool_async_callback( |
| 143 | + tool, args: Dict[str, Any], tool_context: ToolContext |
| 144 | +) -> Optional[Dict[str, Any]]: |
| 145 | + """Async before callback that can add preprocessing.""" |
| 146 | + print(f"⚡ ASYNC BEFORE: Processing tool '{tool.name}' asynchronously") |
| 147 | + |
| 148 | + # Simulate some async preprocessing |
| 149 | + await asyncio.sleep(0.05) |
| 150 | + |
| 151 | + # For calculation tool, we could add validation |
| 152 | + if ( |
| 153 | + tool.name == "calculate_async" |
| 154 | + and args.get("operation") == "divide" |
| 155 | + and args.get("y") == 0 |
| 156 | + ): |
| 157 | + print("🚫 VALIDATION: Prevented division by zero") |
| 158 | + return { |
| 159 | + "error": "Division by zero", |
| 160 | + "operation": args.get("operation"), |
| 161 | + "x": args.get("x"), |
| 162 | + "y": args.get("y"), |
| 163 | + } |
| 164 | + |
| 165 | + return None |
| 166 | + |
| 167 | + |
| 168 | +# After tool callbacks |
| 169 | +def after_tool_enhancement_callback( |
| 170 | + tool, |
| 171 | + args: Dict[str, Any], |
| 172 | + tool_context: ToolContext, |
| 173 | + tool_response: Dict[str, Any], |
| 174 | +) -> Optional[Dict[str, Any]]: |
| 175 | + """Enhance tool responses with additional metadata.""" |
| 176 | + print(f"✨ ENHANCE: Adding metadata to response from '{tool.name}'") |
| 177 | + |
| 178 | + # Add enhancement metadata |
| 179 | + enhanced_response = tool_response.copy() |
| 180 | + enhanced_response.update({ |
| 181 | + "enhanced": True, |
| 182 | + "enhancement_timestamp": datetime.now().isoformat(), |
| 183 | + "tool_name": tool.name, |
| 184 | + "execution_context": "live_streaming", |
| 185 | + }) |
| 186 | + |
| 187 | + return enhanced_response |
| 188 | + |
| 189 | + |
| 190 | +async def after_tool_async_callback( |
| 191 | + tool, |
| 192 | + args: Dict[str, Any], |
| 193 | + tool_context: ToolContext, |
| 194 | + tool_response: Dict[str, Any], |
| 195 | +) -> Optional[Dict[str, Any]]: |
| 196 | + """Async after callback for post-processing.""" |
| 197 | + print( |
| 198 | + f"🔄 ASYNC AFTER: Post-processing response from '{tool.name}'" |
| 199 | + " asynchronously" |
| 200 | + ) |
| 201 | + |
| 202 | + # Simulate async post-processing |
| 203 | + await asyncio.sleep(0.05) |
| 204 | + |
| 205 | + # Add async processing metadata |
| 206 | + processed_response = tool_response.copy() |
| 207 | + processed_response.update({ |
| 208 | + "async_processed": True, |
| 209 | + "processing_time": "0.05s", |
| 210 | + "processor": "async_after_callback", |
| 211 | + }) |
| 212 | + |
| 213 | + return processed_response |
| 214 | + |
| 215 | + |
| 216 | +import asyncio |
| 217 | + |
| 218 | +# Create the agent with tool callbacks |
| 219 | +root_agent = Agent( |
| 220 | + # model='gemini-2.0-flash-live-preview-04-09', # for Vertex project |
| 221 | + model="gemini-2.0-flash-live-001", # for AI studio key |
| 222 | + name="tool_callbacks_agent", |
| 223 | + description=( |
| 224 | + "Live streaming agent that demonstrates tool callbacks functionality. " |
| 225 | + "It can get weather, perform calculations, and log activities while " |
| 226 | + "showing how before and after tool callbacks work in live mode." |
| 227 | + ), |
| 228 | + instruction=""" |
| 229 | + You are a helpful assistant that can: |
| 230 | + 1. Get weather information for any location using the get_weather tool |
| 231 | + 2. Perform mathematical calculations using the calculate_async tool |
| 232 | + 3. Log activities using the log_activity tool |
| 233 | + |
| 234 | + Important behavioral notes: |
| 235 | + - You have several callbacks that will be triggered before and after tool calls |
| 236 | + - Before callbacks can audit, validate, or even block tool calls |
| 237 | + - After callbacks can enhance or modify tool responses |
| 238 | + - Some locations like "classified" or "secret" are restricted for weather requests |
| 239 | + - Division by zero will be prevented by validation callbacks |
| 240 | + - All your tool responses will be enhanced with additional metadata |
| 241 | + |
| 242 | + When users ask you to test callbacks, explain what's happening with the callback system. |
| 243 | + Be conversational and explain the callback behavior you observe. |
| 244 | + """, |
| 245 | + tools=[ |
| 246 | + get_weather, |
| 247 | + calculate_async, |
| 248 | + log_activity, |
| 249 | + ], |
| 250 | + # Multiple before tool callbacks (will be processed in order until one returns a response) |
| 251 | + before_tool_callback=[ |
| 252 | + before_tool_audit_callback, |
| 253 | + before_tool_security_callback, |
| 254 | + before_tool_async_callback, |
| 255 | + ], |
| 256 | + # Multiple after tool callbacks (will be processed in order until one returns a response) |
| 257 | + after_tool_callback=[ |
| 258 | + after_tool_enhancement_callback, |
| 259 | + after_tool_async_callback, |
| 260 | + ], |
| 261 | + generate_content_config=types.GenerateContentConfig( |
| 262 | + safety_settings=[ |
| 263 | + types.SafetySetting( |
| 264 | + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, |
| 265 | + threshold=types.HarmBlockThreshold.OFF, |
| 266 | + ), |
| 267 | + ] |
| 268 | + ), |
| 269 | +) |
0 commit comments