Skip to content

Commit b2a2b11

Browse files
seanzhougooglecopybara-github
authored andcommitted
ADK changes
PiperOrigin-RevId: 761650284
1 parent 1773cda commit b2a2b11

File tree

2 files changed

+145
-0
lines changed

2 files changed

+145
-0
lines changed

src/google/adk/utils/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import re
16+
17+
from ..agents.invocation_context import InvocationContext
18+
from ..agents.readonly_context import ReadonlyContext
19+
from ..sessions.state import State
20+
21+
__all__ = [
22+
'inject_session_state',
23+
]
24+
25+
26+
async def inject_session_state(
27+
template: str,
28+
readonly_context: ReadonlyContext,
29+
) -> str:
30+
"""Populates values in the instruction template, e.g. state, artifact, etc.
31+
32+
This method is intended to be used in InstructionProvider based instruction
33+
and global_instruction which are called with readonly_context.
34+
35+
e.g.
36+
```
37+
...
38+
from google.adk.utils import instructions_utils
39+
40+
async def build_instruction(
41+
readonly_context: ReadonlyContext,
42+
) -> str:
43+
return await instructions_utils.inject_session_state(
44+
'You can inject a state variable like {var_name} or an artifact '
45+
'{artifact.file_name} into the instruction template.',
46+
readonly_context,
47+
)
48+
49+
agent = Agent(
50+
model="gemini-2.0-flash",
51+
name="agent",
52+
instruction=build_instruction,
53+
)
54+
```
55+
56+
Args:
57+
template: The instruction template.
58+
readonly_context: The read-only context
59+
60+
Returns:
61+
The instruction template with values populated.
62+
"""
63+
64+
invocation_context = readonly_context._invocation_context
65+
66+
async def _async_sub(pattern, repl_async_fn, string) -> str:
67+
result = []
68+
last_end = 0
69+
for match in re.finditer(pattern, string):
70+
result.append(string[last_end : match.start()])
71+
replacement = await repl_async_fn(match)
72+
result.append(replacement)
73+
last_end = match.end()
74+
result.append(string[last_end:])
75+
return ''.join(result)
76+
77+
async def _replace_match(match) -> str:
78+
var_name = match.group().lstrip('{').rstrip('}').strip()
79+
optional = False
80+
if var_name.endswith('?'):
81+
optional = True
82+
var_name = var_name.removesuffix('?')
83+
if var_name.startswith('artifact.'):
84+
var_name = var_name.removeprefix('artifact.')
85+
if invocation_context.artifact_service is None:
86+
raise ValueError('Artifact service is not initialized.')
87+
artifact = await invocation_context.artifact_service.load_artifact(
88+
app_name=invocation_context.session.app_name,
89+
user_id=invocation_context.session.user_id,
90+
session_id=invocation_context.session.id,
91+
filename=var_name,
92+
)
93+
if not var_name:
94+
raise KeyError(f'Artifact {var_name} not found.')
95+
return str(artifact)
96+
else:
97+
if not _is_valid_state_name(var_name):
98+
return match.group()
99+
if var_name in invocation_context.session.state:
100+
return str(invocation_context.session.state[var_name])
101+
else:
102+
if optional:
103+
return ''
104+
else:
105+
raise KeyError(f'Context variable not found: `{var_name}`.')
106+
107+
return await _async_sub(r'{+[^{}]*}+', _replace_match, template)
108+
109+
110+
def _is_valid_state_name(var_name):
111+
"""Checks if the variable name is a valid state name.
112+
113+
Valid state is either:
114+
- Valid identifier
115+
- <Valid prefix>:<Valid identifier>
116+
All the others will just return as it is.
117+
118+
Args:
119+
var_name: The variable name to check.
120+
121+
Returns:
122+
True if the variable name is a valid state name, False otherwise.
123+
"""
124+
parts = var_name.split(':')
125+
if len(parts) == 1:
126+
return var_name.isidentifier()
127+
128+
if len(parts) == 2:
129+
prefixes = [State.APP_PREFIX, State.USER_PREFIX, State.TEMP_PREFIX]
130+
if (parts[0] + ':') in prefixes:
131+
return parts[1].isidentifier()
132+
return False

0 commit comments

Comments
 (0)