Skip to content

Commit 8fdd6fa

Browse files
feat(genai): Live API WebSocket Example (#13404)
* feat(genai): Live API WebSocket Example * feat(genai): Add TextGen socket example using Audio Input * feat(genai): Add test files * feat(genai): Update bearer token code * feat(genai): Update bearer token code * feat(genai): Update bearer token code * feat(genai): Update project env variable name * feat(genai): Update project env variable name * feat(genai): Update google genai sdk version * feat(genai): Update google genai sdk version * feat(genai): Update google genai sdk version
1 parent 1ba39ca commit 8fdd6fa

9 files changed

+643
-5
lines changed
95.4 KB
Binary file not shown.
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
import os
17+
18+
19+
def get_bearer_token() -> str:
20+
import google.auth
21+
from google.auth.transport.requests import Request
22+
23+
creds, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
24+
auth_req = Request()
25+
creds.refresh(auth_req)
26+
bearer_token = creds.token
27+
return bearer_token
28+
29+
30+
# get bearer token
31+
BEARER_TOKEN = get_bearer_token()
32+
33+
34+
async def generate_content() -> str:
35+
"""
36+
Connects to the Gemini API via WebSocket, sends a text prompt,
37+
and returns the aggregated text response.
38+
"""
39+
# [START googlegenaisdk_live_audiogen_websocket_with_txt]
40+
import base64
41+
import json
42+
import numpy as np
43+
44+
from websockets.asyncio.client import connect
45+
from scipy.io import wavfile
46+
47+
# Configuration Constants
48+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
49+
LOCATION = "us-central1"
50+
GEMINI_MODEL_NAME = "gemini-2.0-flash-live-preview-04-09"
51+
# To generate a bearer token in CLI, use:
52+
# $ gcloud auth application-default print-access-token
53+
# It's recommended to fetch this token dynamically rather than hardcoding.
54+
# BEARER_TOKEN = "ya29.a0AW4XtxhRb1s51TxLPnj..."
55+
56+
# Websocket Configuration
57+
WEBSOCKET_HOST = "us-central1-aiplatform.googleapis.com"
58+
WEBSOCKET_SERVICE_URL = (
59+
f"wss://{WEBSOCKET_HOST}/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"
60+
)
61+
62+
# Websocket Authentication
63+
headers = {
64+
"Content-Type": "application/json",
65+
"Authorization": f"Bearer {BEARER_TOKEN}",
66+
}
67+
68+
# Model Configuration
69+
model_path = (
70+
f"projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{GEMINI_MODEL_NAME}"
71+
)
72+
model_generation_config = {
73+
"response_modalities": ["AUDIO"],
74+
"speech_config": {
75+
"voice_config": {"prebuilt_voice_config": {"voice_name": "Aoede"}},
76+
"language_code": "es-ES",
77+
},
78+
}
79+
80+
async with connect(WEBSOCKET_SERVICE_URL, additional_headers=headers) as websocket_session:
81+
# 1. Send setup configuration
82+
websocket_config = {
83+
"setup": {
84+
"model": model_path,
85+
"generation_config": model_generation_config,
86+
}
87+
}
88+
await websocket_session.send(json.dumps(websocket_config))
89+
90+
# 2. Receive setup response
91+
raw_setup_response = await websocket_session.recv()
92+
setup_response = json.loads(
93+
raw_setup_response.decode("utf-8")
94+
if isinstance(raw_setup_response, bytes)
95+
else raw_setup_response
96+
)
97+
print(f"Setup Response: {setup_response}")
98+
# Example response: {'setupComplete': {}}
99+
if "setupComplete" not in setup_response:
100+
print(f"Setup failed: {setup_response}")
101+
return "Error: WebSocket setup failed."
102+
103+
# 3. Send text message
104+
text_input = "Hello? Gemini are you there?"
105+
print(f"Input: {text_input}")
106+
107+
user_message = {
108+
"client_content": {
109+
"turns": [{"role": "user", "parts": [{"text": text_input}]}],
110+
"turn_complete": True,
111+
}
112+
}
113+
await websocket_session.send(json.dumps(user_message))
114+
115+
# 4. Receive model response
116+
aggregated_response_parts = []
117+
async for raw_response_chunk in websocket_session:
118+
response_chunk = json.loads(raw_response_chunk.decode("utf-8"))
119+
120+
server_content = response_chunk.get("serverContent")
121+
if not server_content:
122+
# This might indicate an error or an unexpected message format
123+
print(f"Received non-serverContent message or empty content: {response_chunk}")
124+
break
125+
126+
# Collect audio chunks
127+
model_turn = server_content.get("modelTurn")
128+
if model_turn and "parts" in model_turn and model_turn["parts"]:
129+
for part in model_turn["parts"]:
130+
if part["inlineData"]["mimeType"] == "audio/pcm":
131+
audio_chunk = base64.b64decode(part["inlineData"]["data"])
132+
aggregated_response_parts.append(np.frombuffer(audio_chunk, dtype=np.int16))
133+
134+
# End of response
135+
if server_content.get("turnComplete"):
136+
break
137+
138+
# Save audio to a file
139+
if aggregated_response_parts:
140+
wavfile.write("output.wav", 24000, np.concatenate(aggregated_response_parts))
141+
# Example response:
142+
# Setup Response: {'setupComplete': {}}
143+
# Input: Hello? Gemini are you there?
144+
# Audio Response: Hello there. I'm here. What can I do for you today?
145+
# [END googlegenaisdk_live_audiogen_websocket_with_txt]
146+
return "output.wav"
147+
148+
149+
if __name__ == "__main__":
150+
asyncio.run(generate_content())
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
import os
17+
18+
19+
def get_bearer_token() -> str:
20+
import google.auth
21+
from google.auth.transport.requests import Request
22+
23+
creds, _ = google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
24+
auth_req = Request()
25+
creds.refresh(auth_req)
26+
bearer_token = creds.token
27+
return bearer_token
28+
29+
30+
# get bearer token
31+
BEARER_TOKEN = get_bearer_token()
32+
33+
34+
async def generate_content() -> str:
35+
"""
36+
Connects to the Gemini API via WebSocket, sends a text prompt,
37+
and returns the aggregated text response.
38+
"""
39+
# [START googlegenaisdk_live_websocket_audiotranscript_with_txt]
40+
import base64
41+
import json
42+
import numpy as np
43+
44+
from websockets.asyncio.client import connect
45+
from scipy.io import wavfile
46+
47+
# Configuration Constants
48+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
49+
LOCATION = "us-central1"
50+
GEMINI_MODEL_NAME = "gemini-2.0-flash-live-preview-04-09"
51+
# To generate a bearer token in CLI, use:
52+
# $ gcloud auth application-default print-access-token
53+
# It's recommended to fetch this token dynamically rather than hardcoding.
54+
# BEARER_TOKEN = "ya29.a0AW4XtxhRb1s51TxLPnj..."
55+
56+
# Websocket Configuration
57+
WEBSOCKET_HOST = "us-central1-aiplatform.googleapis.com"
58+
WEBSOCKET_SERVICE_URL = (
59+
f"wss://{WEBSOCKET_HOST}/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"
60+
)
61+
62+
# Websocket Authentication
63+
headers = {
64+
"Content-Type": "application/json",
65+
"Authorization": f"Bearer {BEARER_TOKEN}",
66+
}
67+
68+
# Model Configuration
69+
model_path = (
70+
f"projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{GEMINI_MODEL_NAME}"
71+
)
72+
model_generation_config = {
73+
"response_modalities": ["AUDIO"],
74+
"speech_config": {
75+
"voice_config": {"prebuilt_voice_config": {"voice_name": "Aoede"}},
76+
"language_code": "es-ES",
77+
},
78+
}
79+
80+
async with connect(WEBSOCKET_SERVICE_URL, additional_headers=headers) as websocket_session:
81+
# 1. Send setup configuration
82+
websocket_config = {
83+
"setup": {
84+
"model": model_path,
85+
"generation_config": model_generation_config,
86+
# Audio transcriptions for input and output
87+
"input_audio_transcription": {},
88+
"output_audio_transcription": {},
89+
}
90+
}
91+
await websocket_session.send(json.dumps(websocket_config))
92+
93+
# 2. Receive setup response
94+
raw_setup_response = await websocket_session.recv()
95+
setup_response = json.loads(
96+
raw_setup_response.decode("utf-8")
97+
if isinstance(raw_setup_response, bytes)
98+
else raw_setup_response
99+
)
100+
print(f"Setup Response: {setup_response}")
101+
# Expected response: {'setupComplete': {}}
102+
if "setupComplete" not in setup_response:
103+
print(f"Setup failed: {setup_response}")
104+
return "Error: WebSocket setup failed."
105+
106+
# 3. Send text message
107+
text_input = "Hello? Gemini are you there?"
108+
print(f"Input: {text_input}")
109+
110+
user_message = {
111+
"client_content": {
112+
"turns": [{"role": "user", "parts": [{"text": text_input}]}],
113+
"turn_complete": True,
114+
}
115+
}
116+
await websocket_session.send(json.dumps(user_message))
117+
118+
# 4. Receive model response
119+
aggregated_response_parts = []
120+
input_transcriptions_parts = []
121+
output_transcriptions_parts = []
122+
async for raw_response_chunk in websocket_session:
123+
response_chunk = json.loads(raw_response_chunk.decode("utf-8"))
124+
125+
server_content = response_chunk.get("serverContent")
126+
if not server_content:
127+
# This might indicate an error or an unexpected message format
128+
print(f"Received non-serverContent message or empty content: {response_chunk}")
129+
break
130+
131+
# Transcriptions
132+
if server_content.get("inputTranscription"):
133+
text = server_content.get("inputTranscription").get("text", "")
134+
input_transcriptions_parts.append(text)
135+
if server_content.get("outputTranscription"):
136+
text = server_content.get("outputTranscription").get("text", "")
137+
output_transcriptions_parts.append(text)
138+
139+
# Collect audio chunks
140+
model_turn = server_content.get("modelTurn")
141+
if model_turn and "parts" in model_turn and model_turn["parts"]:
142+
for part in model_turn["parts"]:
143+
if part["inlineData"]["mimeType"] == "audio/pcm":
144+
audio_chunk = base64.b64decode(part["inlineData"]["data"])
145+
aggregated_response_parts.append(np.frombuffer(audio_chunk, dtype=np.int16))
146+
147+
# End of response
148+
if server_content.get("turnComplete"):
149+
break
150+
151+
# Save audio to a file
152+
final_response_audio = np.concatenate(aggregated_response_parts)
153+
wavfile.write("output.wav", 24000, final_response_audio)
154+
print(f"Input transcriptions: {''.join(input_transcriptions_parts)}")
155+
print(f"Output transcriptions: {''.join(output_transcriptions_parts)}")
156+
# Example response:
157+
# Setup Response: {'setupComplete': {}}
158+
# Input: Hello? Gemini are you there?
159+
# Audio Response(output.wav): Yes, I'm here. How can I help you today?
160+
# Input transcriptions:
161+
# Output transcriptions: Yes, I'm here. How can I help you today?
162+
# [END googlegenaisdk_live_websocket_audiotranscript_with_txt]
163+
return "output.wav"
164+
165+
166+
if __name__ == "__main__":
167+
asyncio.run(generate_content())

0 commit comments

Comments
 (0)