2
2
import os
3
3
import signal
4
4
import traceback
5
- from typing import Any , Dict , List , Optional
5
+ from typing import Optional
6
6
7
7
import typer
8
8
from rich import print
@@ -40,8 +40,8 @@ async def run_agent(
40
40
41
41
config , prompt = _load_agent_config (agent_path )
42
42
43
- inputs : List [ Dict [ str , Any ]] = config .get ("inputs" , [])
44
- servers : List [ Dict [ str , Any ]] = config .get ("servers" , [])
43
+ inputs = config .get ("inputs" , [])
44
+ servers = config .get ("servers" , [])
45
45
46
46
abort_event = asyncio .Event ()
47
47
exit_event = asyncio .Event ()
@@ -82,38 +82,45 @@ def _sigint_handler() -> None:
82
82
env_special_value = "${input:" + input_id + "}" # Special value to indicate env variable injection
83
83
84
84
# Check env variables that will use this input
85
- input_vars = list (
86
- {
87
- key
88
- for server in servers
89
- for key , value in server .get ("config" , {}).get ("env" , {}).items ()
90
- if value == env_special_value
91
- }
92
- )
85
+ input_vars = set ()
86
+ for server in servers :
87
+ # Check stdio's "env" and http/sse's "headers" mappings
88
+ env_or_headers = (
89
+ server ["config" ].get ("env" , {})
90
+ if server ["type" ] == "stdio"
91
+ else server ["config" ].get ("options" , {}).get ("requestInit" , {}).get ("headers" , {})
92
+ )
93
+ for key , value in env_or_headers .items ():
94
+ if env_special_value in value :
95
+ input_vars .add (key )
93
96
94
97
if not input_vars :
95
98
print (f"[yellow]Input { input_id } defined in config but not used by any server.[/yellow]" )
96
99
continue
97
100
98
101
# Prompt user for input
99
102
print (
100
- f"[blue] • { input_id } [/blue]: { description } . (default: load from { ', ' .join (input_vars )} )." ,
103
+ f"[blue] • { input_id } [/blue]: { description } . (default: load from { ', ' .join (sorted ( input_vars ) )} )." ,
101
104
end = " " ,
102
105
)
103
106
user_input = (await _async_prompt (exit_event = exit_event )).strip ()
104
107
if exit_event .is_set ():
105
108
return
106
109
107
- # Inject user input (or env variable) into servers' env
110
+ # Inject user input (or env variable) into stdio's env or http/sse's headers
108
111
for server in servers :
109
- env = server .get ("config" , {}).get ("env" , {})
110
- for key , value in env .items ():
111
- if value == env_special_value :
112
+ env_or_headers = (
113
+ server ["config" ].get ("env" , {})
114
+ if server ["type" ] == "stdio"
115
+ else server ["config" ].get ("options" , {}).get ("requestInit" , {}).get ("headers" , {})
116
+ )
117
+ for key , value in env_or_headers .items ():
118
+ if env_special_value in value :
112
119
if user_input :
113
- env [key ] = user_input
120
+ env_or_headers [key ] = env_or_headers [ key ]. replace ( env_special_value , user_input )
114
121
else :
115
122
value_from_env = os .getenv (key , "" )
116
- env [key ] = value_from_env
123
+ env_or_headers [key ] = env_or_headers [ key ]. replace ( env_special_value , value_from_env )
117
124
if value_from_env :
118
125
print (f"[green]Value successfully loaded from '{ key } '[/green]" )
119
126
else :
@@ -125,10 +132,10 @@ def _sigint_handler() -> None:
125
132
126
133
# Main agent loop
127
134
async with Agent (
128
- provider = config .get ("provider" ),
135
+ provider = config .get ("provider" ), # type: ignore[arg-type]
129
136
model = config .get ("model" ),
130
- base_url = config .get ("endpointUrl" ),
131
- servers = servers ,
137
+ base_url = config .get ("endpointUrl" ), # type: ignore[arg-type]
138
+ servers = servers , # type: ignore[arg-type]
132
139
prompt = prompt ,
133
140
) as agent :
134
141
await agent .load_tools ()
0 commit comments