From 523b6c5082bcba70161ae68d55ed52e19b4deaa7 Mon Sep 17 00:00:00 2001 From: Chay Nabors Date: Sat, 3 May 2025 00:53:42 -0700 Subject: [PATCH 1/3] add kiro-cli crate --- crates/kiro-cli/.gitignore | 2 + crates/kiro-cli/Cargo.toml | 186 + crates/kiro-cli/build.rs | 280 ++ crates/kiro-cli/src/cli/chat/cli.rs | 25 + crates/kiro-cli/src/cli/chat/command.rs | 1093 +++++ crates/kiro-cli/src/cli/chat/consts.rs | 19 + crates/kiro-cli/src/cli/chat/context.rs | 1016 +++++ .../src/cli/chat/conversation_state.rs | 1049 +++++ crates/kiro-cli/src/cli/chat/hooks.rs | 557 +++ crates/kiro-cli/src/cli/chat/input_source.rs | 107 + crates/kiro-cli/src/cli/chat/message.rs | 384 ++ crates/kiro-cli/src/cli/chat/mod.rs | 3937 +++++++++++++++++ crates/kiro-cli/src/cli/chat/parse.rs | 762 ++++ crates/kiro-cli/src/cli/chat/parser.rs | 375 ++ crates/kiro-cli/src/cli/chat/prompt.rs | 364 ++ crates/kiro-cli/src/cli/chat/shared_writer.rs | 89 + .../kiro-cli/src/cli/chat/skim_integration.rs | 378 ++ crates/kiro-cli/src/cli/chat/token_counter.rs | 251 ++ crates/kiro-cli/src/cli/chat/tool_manager.rs | 1019 +++++ .../src/cli/chat/tools/custom_tool.rs | 241 + .../src/cli/chat/tools/execute_bash.rs | 373 ++ crates/kiro-cli/src/cli/chat/tools/fs_read.rs | 669 +++ .../kiro-cli/src/cli/chat/tools/fs_write.rs | 953 ++++ .../kiro-cli/src/cli/chat/tools/gh_issue.rs | 222 + crates/kiro-cli/src/cli/chat/tools/mod.rs | 432 ++ .../src/cli/chat/tools/tool_index.json | 176 + crates/kiro-cli/src/cli/chat/tools/use_aws.rs | 315 ++ crates/kiro-cli/src/cli/chat/util/issue.rs | 83 + crates/kiro-cli/src/cli/chat/util/mod.rs | 111 + crates/kiro-cli/src/cli/debug.rs | 109 + crates/kiro-cli/src/cli/diagnostics.rs | 68 + crates/kiro-cli/src/cli/feed.rs | 49 + crates/kiro-cli/src/cli/issue.rs | 39 + crates/kiro-cli/src/cli/mod.rs | 522 +++ crates/kiro-cli/src/cli/settings.rs | 152 + crates/kiro-cli/src/cli/telemetry.rs | 53 + crates/kiro-cli/src/cli/uninstall.rs | 174 + crates/kiro-cli/src/cli/update.rs | 58 + crates/kiro-cli/src/cli/user.rs | 471 ++ crates/kiro-cli/src/diagnostics.rs | 253 ++ .../src/fig_api_client/clients/client.rs | 208 + .../src/fig_api_client/clients/mod.rs | 9 + .../src/fig_api_client/clients/shared.rs | 65 + .../clients/streaming_client.rs | 339 ++ crates/kiro-cli/src/fig_api_client/consts.rs | 19 + .../src/fig_api_client/credentials/mod.rs | 80 + .../src/fig_api_client/customization.rs | 161 + .../kiro-cli/src/fig_api_client/endpoints.rs | 125 + crates/kiro-cli/src/fig_api_client/error.rs | 178 + .../src/fig_api_client/interceptor/mod.rs | 2 + .../src/fig_api_client/interceptor/opt_out.rs | 89 + .../fig_api_client/interceptor/session_id.rs | 82 + crates/kiro-cli/src/fig_api_client/mod.rs | 17 + crates/kiro-cli/src/fig_api_client/model.rs | 924 ++++ crates/kiro-cli/src/fig_api_client/profile.rs | 35 + crates/kiro-cli/src/fig_api_client/stage.rs | 40 + crates/kiro-cli/src/fig_auth/builder_id.rs | 708 +++ crates/kiro-cli/src/fig_auth/consts.rs | 25 + crates/kiro-cli/src/fig_auth/error.rs | 47 + crates/kiro-cli/src/fig_auth/index.html | 181 + crates/kiro-cli/src/fig_auth/mod.rs | 16 + crates/kiro-cli/src/fig_auth/pkce.rs | 627 +++ crates/kiro-cli/src/fig_auth/scope.rs | 33 + .../src/fig_auth/secret_store/linux.rs | 27 + .../src/fig_auth/secret_store/macos.rs | 80 + .../kiro-cli/src/fig_auth/secret_store/mod.rs | 102 + .../src/fig_auth/secret_store/sqlite.rs | 50 + .../src/fig_aws_common/http_client.rs | 198 + crates/kiro-cli/src/fig_aws_common/mod.rs | 36 + .../src/fig_aws_common/sdk_error_display.rs | 96 + .../user_agent_override_interceptor.rs | 227 + crates/kiro-cli/src/fig_install.rs | 119 + crates/kiro-cli/src/fig_log.rs | 313 ++ crates/kiro-cli/src/fig_os_shim/env.rs | 227 + crates/kiro-cli/src/fig_os_shim/fs.rs | 611 +++ crates/kiro-cli/src/fig_os_shim/mod.rs | 203 + crates/kiro-cli/src/fig_os_shim/platform.rs | 105 + crates/kiro-cli/src/fig_os_shim/providers.rs | 133 + crates/kiro-cli/src/fig_os_shim/sysinfo.rs | 68 + crates/kiro-cli/src/fig_settings/actions.json | 216 + crates/kiro-cli/src/fig_settings/error.rs | 70 + .../kiro-cli/src/fig_settings/keybindings.rs | 144 + crates/kiro-cli/src/fig_settings/mod.rs | 349 ++ crates/kiro-cli/src/fig_settings/settings.rs | 243 + crates/kiro-cli/src/fig_settings/sqlite.rs | 437 ++ .../sqlite_migrations/000_migration_table.sql | 5 + .../sqlite_migrations/001_history_table.sql | 13 + .../002_drop_history_in_ssh_docker.sql | 3 + .../003_improved_history_timing.sql | 3 + .../sqlite_migrations/004_state_table.sql | 4 + .../sqlite_migrations/005_auth_table.sql | 6 + crates/kiro-cli/src/fig_settings/state.rs | 202 + crates/kiro-cli/src/fig_telemetry/cognito.rs | 145 + .../kiro-cli/src/fig_telemetry/definitions.rs | 35 + crates/kiro-cli/src/fig_telemetry/endpoint.rs | 32 + crates/kiro-cli/src/fig_telemetry/event.rs | 150 + .../src/fig_telemetry/install_method.rs | 45 + crates/kiro-cli/src/fig_telemetry/mod.rs | 693 +++ crates/kiro-cli/src/fig_telemetry/util.rs | 162 + crates/kiro-cli/src/fig_telemetry_core.rs | 418 ++ crates/kiro-cli/src/fig_util/cli_context.rs | 58 + crates/kiro-cli/src/fig_util/consts.rs | 139 + crates/kiro-cli/src/fig_util/directories.rs | 304 ++ crates/kiro-cli/src/fig_util/error.rs | 25 + crates/kiro-cli/src/fig_util/manifest.rs | 343 ++ crates/kiro-cli/src/fig_util/mod.rs | 379 ++ crates/kiro-cli/src/fig_util/open.rs | 101 + crates/kiro-cli/src/fig_util/pid_file.rs | 167 + .../src/fig_util/process_info/freebsd.rs | 20 + .../src/fig_util/process_info/linux.rs | 41 + .../src/fig_util/process_info/macos.rs | 49 + .../kiro-cli/src/fig_util/process_info/mod.rs | 118 + .../src/fig_util/process_info/windows.rs | 136 + crates/kiro-cli/src/fig_util/region_check.rs | 15 + crates/kiro-cli/src/fig_util/spinner.rs | 126 + .../src/fig_util/system_info/linux.rs | 285 ++ .../kiro-cli/src/fig_util/system_info/mod.rs | 382 ++ crates/kiro-cli/src/main.rs | 102 + crates/kiro-cli/src/mcp_client/client.rs | 764 ++++ crates/kiro-cli/src/mcp_client/error.rs | 66 + .../src/mcp_client/facilitator_types.rs | 229 + crates/kiro-cli/src/mcp_client/mod.rs | 9 + crates/kiro-cli/src/mcp_client/server.rs | 293 ++ .../src/mcp_client/transport/base_protocol.rs | 108 + .../kiro-cli/src/mcp_client/transport/mod.rs | 56 + .../src/mcp_client/transport/stdio.rs | 272 ++ .../src/mcp_client/transport/websocket.rs | 0 crates/kiro-cli/src/request.rs | 188 + crates/kiro-cli/telemetry_definitions.json | 265 ++ 129 files changed, 32436 insertions(+) create mode 100644 crates/kiro-cli/.gitignore create mode 100644 crates/kiro-cli/Cargo.toml create mode 100644 crates/kiro-cli/build.rs create mode 100644 crates/kiro-cli/src/cli/chat/cli.rs create mode 100644 crates/kiro-cli/src/cli/chat/command.rs create mode 100644 crates/kiro-cli/src/cli/chat/consts.rs create mode 100644 crates/kiro-cli/src/cli/chat/context.rs create mode 100644 crates/kiro-cli/src/cli/chat/conversation_state.rs create mode 100644 crates/kiro-cli/src/cli/chat/hooks.rs create mode 100644 crates/kiro-cli/src/cli/chat/input_source.rs create mode 100644 crates/kiro-cli/src/cli/chat/message.rs create mode 100644 crates/kiro-cli/src/cli/chat/mod.rs create mode 100644 crates/kiro-cli/src/cli/chat/parse.rs create mode 100644 crates/kiro-cli/src/cli/chat/parser.rs create mode 100644 crates/kiro-cli/src/cli/chat/prompt.rs create mode 100644 crates/kiro-cli/src/cli/chat/shared_writer.rs create mode 100644 crates/kiro-cli/src/cli/chat/skim_integration.rs create mode 100644 crates/kiro-cli/src/cli/chat/token_counter.rs create mode 100644 crates/kiro-cli/src/cli/chat/tool_manager.rs create mode 100644 crates/kiro-cli/src/cli/chat/tools/custom_tool.rs create mode 100644 crates/kiro-cli/src/cli/chat/tools/execute_bash.rs create mode 100644 crates/kiro-cli/src/cli/chat/tools/fs_read.rs create mode 100644 crates/kiro-cli/src/cli/chat/tools/fs_write.rs create mode 100644 crates/kiro-cli/src/cli/chat/tools/gh_issue.rs create mode 100644 crates/kiro-cli/src/cli/chat/tools/mod.rs create mode 100644 crates/kiro-cli/src/cli/chat/tools/tool_index.json create mode 100644 crates/kiro-cli/src/cli/chat/tools/use_aws.rs create mode 100644 crates/kiro-cli/src/cli/chat/util/issue.rs create mode 100644 crates/kiro-cli/src/cli/chat/util/mod.rs create mode 100644 crates/kiro-cli/src/cli/debug.rs create mode 100644 crates/kiro-cli/src/cli/diagnostics.rs create mode 100644 crates/kiro-cli/src/cli/feed.rs create mode 100644 crates/kiro-cli/src/cli/issue.rs create mode 100644 crates/kiro-cli/src/cli/mod.rs create mode 100644 crates/kiro-cli/src/cli/settings.rs create mode 100644 crates/kiro-cli/src/cli/telemetry.rs create mode 100644 crates/kiro-cli/src/cli/uninstall.rs create mode 100644 crates/kiro-cli/src/cli/update.rs create mode 100644 crates/kiro-cli/src/cli/user.rs create mode 100644 crates/kiro-cli/src/diagnostics.rs create mode 100644 crates/kiro-cli/src/fig_api_client/clients/client.rs create mode 100644 crates/kiro-cli/src/fig_api_client/clients/mod.rs create mode 100644 crates/kiro-cli/src/fig_api_client/clients/shared.rs create mode 100644 crates/kiro-cli/src/fig_api_client/clients/streaming_client.rs create mode 100644 crates/kiro-cli/src/fig_api_client/consts.rs create mode 100644 crates/kiro-cli/src/fig_api_client/credentials/mod.rs create mode 100644 crates/kiro-cli/src/fig_api_client/customization.rs create mode 100644 crates/kiro-cli/src/fig_api_client/endpoints.rs create mode 100644 crates/kiro-cli/src/fig_api_client/error.rs create mode 100644 crates/kiro-cli/src/fig_api_client/interceptor/mod.rs create mode 100644 crates/kiro-cli/src/fig_api_client/interceptor/opt_out.rs create mode 100644 crates/kiro-cli/src/fig_api_client/interceptor/session_id.rs create mode 100644 crates/kiro-cli/src/fig_api_client/mod.rs create mode 100644 crates/kiro-cli/src/fig_api_client/model.rs create mode 100644 crates/kiro-cli/src/fig_api_client/profile.rs create mode 100644 crates/kiro-cli/src/fig_api_client/stage.rs create mode 100644 crates/kiro-cli/src/fig_auth/builder_id.rs create mode 100644 crates/kiro-cli/src/fig_auth/consts.rs create mode 100644 crates/kiro-cli/src/fig_auth/error.rs create mode 100644 crates/kiro-cli/src/fig_auth/index.html create mode 100644 crates/kiro-cli/src/fig_auth/mod.rs create mode 100644 crates/kiro-cli/src/fig_auth/pkce.rs create mode 100644 crates/kiro-cli/src/fig_auth/scope.rs create mode 100644 crates/kiro-cli/src/fig_auth/secret_store/linux.rs create mode 100644 crates/kiro-cli/src/fig_auth/secret_store/macos.rs create mode 100644 crates/kiro-cli/src/fig_auth/secret_store/mod.rs create mode 100644 crates/kiro-cli/src/fig_auth/secret_store/sqlite.rs create mode 100644 crates/kiro-cli/src/fig_aws_common/http_client.rs create mode 100644 crates/kiro-cli/src/fig_aws_common/mod.rs create mode 100644 crates/kiro-cli/src/fig_aws_common/sdk_error_display.rs create mode 100644 crates/kiro-cli/src/fig_aws_common/user_agent_override_interceptor.rs create mode 100644 crates/kiro-cli/src/fig_install.rs create mode 100644 crates/kiro-cli/src/fig_log.rs create mode 100644 crates/kiro-cli/src/fig_os_shim/env.rs create mode 100644 crates/kiro-cli/src/fig_os_shim/fs.rs create mode 100644 crates/kiro-cli/src/fig_os_shim/mod.rs create mode 100644 crates/kiro-cli/src/fig_os_shim/platform.rs create mode 100644 crates/kiro-cli/src/fig_os_shim/providers.rs create mode 100644 crates/kiro-cli/src/fig_os_shim/sysinfo.rs create mode 100644 crates/kiro-cli/src/fig_settings/actions.json create mode 100644 crates/kiro-cli/src/fig_settings/error.rs create mode 100644 crates/kiro-cli/src/fig_settings/keybindings.rs create mode 100644 crates/kiro-cli/src/fig_settings/mod.rs create mode 100644 crates/kiro-cli/src/fig_settings/settings.rs create mode 100644 crates/kiro-cli/src/fig_settings/sqlite.rs create mode 100644 crates/kiro-cli/src/fig_settings/sqlite_migrations/000_migration_table.sql create mode 100644 crates/kiro-cli/src/fig_settings/sqlite_migrations/001_history_table.sql create mode 100644 crates/kiro-cli/src/fig_settings/sqlite_migrations/002_drop_history_in_ssh_docker.sql create mode 100644 crates/kiro-cli/src/fig_settings/sqlite_migrations/003_improved_history_timing.sql create mode 100644 crates/kiro-cli/src/fig_settings/sqlite_migrations/004_state_table.sql create mode 100644 crates/kiro-cli/src/fig_settings/sqlite_migrations/005_auth_table.sql create mode 100644 crates/kiro-cli/src/fig_settings/state.rs create mode 100644 crates/kiro-cli/src/fig_telemetry/cognito.rs create mode 100644 crates/kiro-cli/src/fig_telemetry/definitions.rs create mode 100644 crates/kiro-cli/src/fig_telemetry/endpoint.rs create mode 100644 crates/kiro-cli/src/fig_telemetry/event.rs create mode 100644 crates/kiro-cli/src/fig_telemetry/install_method.rs create mode 100644 crates/kiro-cli/src/fig_telemetry/mod.rs create mode 100644 crates/kiro-cli/src/fig_telemetry/util.rs create mode 100644 crates/kiro-cli/src/fig_telemetry_core.rs create mode 100644 crates/kiro-cli/src/fig_util/cli_context.rs create mode 100644 crates/kiro-cli/src/fig_util/consts.rs create mode 100644 crates/kiro-cli/src/fig_util/directories.rs create mode 100644 crates/kiro-cli/src/fig_util/error.rs create mode 100644 crates/kiro-cli/src/fig_util/manifest.rs create mode 100644 crates/kiro-cli/src/fig_util/mod.rs create mode 100644 crates/kiro-cli/src/fig_util/open.rs create mode 100644 crates/kiro-cli/src/fig_util/pid_file.rs create mode 100644 crates/kiro-cli/src/fig_util/process_info/freebsd.rs create mode 100644 crates/kiro-cli/src/fig_util/process_info/linux.rs create mode 100644 crates/kiro-cli/src/fig_util/process_info/macos.rs create mode 100644 crates/kiro-cli/src/fig_util/process_info/mod.rs create mode 100644 crates/kiro-cli/src/fig_util/process_info/windows.rs create mode 100644 crates/kiro-cli/src/fig_util/region_check.rs create mode 100644 crates/kiro-cli/src/fig_util/spinner.rs create mode 100644 crates/kiro-cli/src/fig_util/system_info/linux.rs create mode 100644 crates/kiro-cli/src/fig_util/system_info/mod.rs create mode 100644 crates/kiro-cli/src/main.rs create mode 100644 crates/kiro-cli/src/mcp_client/client.rs create mode 100644 crates/kiro-cli/src/mcp_client/error.rs create mode 100644 crates/kiro-cli/src/mcp_client/facilitator_types.rs create mode 100644 crates/kiro-cli/src/mcp_client/mod.rs create mode 100644 crates/kiro-cli/src/mcp_client/server.rs create mode 100644 crates/kiro-cli/src/mcp_client/transport/base_protocol.rs create mode 100644 crates/kiro-cli/src/mcp_client/transport/mod.rs create mode 100644 crates/kiro-cli/src/mcp_client/transport/stdio.rs create mode 100644 crates/kiro-cli/src/mcp_client/transport/websocket.rs create mode 100644 crates/kiro-cli/src/request.rs create mode 100644 crates/kiro-cli/telemetry_definitions.json diff --git a/crates/kiro-cli/.gitignore b/crates/kiro-cli/.gitignore new file mode 100644 index 0000000000..0b0c025e2a --- /dev/null +++ b/crates/kiro-cli/.gitignore @@ -0,0 +1,2 @@ +build/ +spec.ts \ No newline at end of file diff --git a/crates/kiro-cli/Cargo.toml b/crates/kiro-cli/Cargo.toml new file mode 100644 index 0000000000..1554937843 --- /dev/null +++ b/crates/kiro-cli/Cargo.toml @@ -0,0 +1,186 @@ +[package] +name = "q_cli" +authors.workspace = true +edition.workspace = true +homepage.workspace = true +publish.workspace = true +version.workspace = true +license.workspace = true + +[lints] +workspace = true + +[features] +default = [] +wayland = ["arboard/wayland-data-control"] + +[dependencies] +amzn-codewhisperer-client = { path = "../amzn-codewhisperer-client" } +amzn-codewhisperer-streaming-client = { path = "../amzn-codewhisperer-streaming-client" } +amzn-consolas-client = { path = "../amzn-consolas-client" } +amzn-qdeveloper-streaming-client = { path = "../amzn-qdeveloper-streaming-client" } +amzn-toolkit-telemetry-client = { path = "../amzn-toolkit-telemetry-client" } +anstream = "0.6.13" +arboard = { version = "3.5.0", default-features = false } +async-trait = "0.1.87" +aws-config = "1.0.3" +aws-credential-types = "1.0.3" +aws-runtime = "1.4.4" +aws-sdk-cognitoidentity = "1.51.0" +aws-sdk-ssooidc = "1.51.0" +aws-smithy-async = "1.2.2" +aws-smithy-runtime-api = "1.6.1" +aws-smithy-types = "1.2.10" +aws-types = "1.3.0" +base64 = "0.22.1" +bitflags = "2.9.0" +bstr = "1.12.0" +bytes = "1.10.1" +camino = { version = "1.1.3", features = ["serde1"] } +cfg-if = "1.0.0" +clap = { version = "4.5.32", features = [ + "deprecated", + "derive", + "string", + "unicode", + "wrap_help", +] } +clap_complete = "4.5.46" +clap_complete_fig = "4.4.0" +color-eyre = "0.6.2" +color-print = "0.3.5" +convert_case = "0.8.0" +cookie = "0.18.1" +crossterm = { version = "0.28.1", features = ["event-stream", "events"] } +ctrlc = "3.4.6" +dialoguer = { version = "0.11.0", features = ["fuzzy-select"] } +dirs = "5.0.0" +eyre = "0.6.8" +fd-lock = "4.0.4" +futures = "0.3.26" +glob = "0.3.2" +globset = "0.4.16" +hex = "0.4.3" +http = "1.2.0" +http-body-util = "0.1.3" +hyper = { version = "1.6.0", features = ["server"] } +hyper-util = { version = "0.1.11", features = ["tokio"] } +indicatif = "0.17.11" +indoc = "2.0.6" +insta = "1.43.1" +libc = "0.2.172" +mimalloc = "0.1.46" +nix = { version = "0.29.0", features = [ + "feature", + "fs", + "ioctl", + "process", + "signal", + "term", + "user", +] } +owo-colors = "4.2.0" +parking_lot = "0.12.3" +paste = "1.0.11" +percent-encoding = "2.2.0" +r2d2 = "0.8.10" +r2d2_sqlite = "0.25.0" +rand = "0.9.0" +regex = "1.7.0" +reqwest = { version = "0.12.14", default-features = false, features = [ + "http2", + "charset", + "rustls-tls", + "rustls-tls-native-roots", + "gzip", + "json", + "socks", + "cookies", +] } +ring = "0.17.14" +rusqlite = { version = "0.32.1", features = ["bundled", "serde_json"] } +rustls = "0.23.23" +rustls-native-certs = "0.8.1" +rustls-pemfile = "2.1.0" +rustyline = { version = "15.0.0", features = [ + "custom-bindings", + "derive", + "with-file-history", +], default-features = false } +self_update = "0.42.0" +semver = { version = "1.0.26", features = ["serde"] } +serde = { version = "1.0.219", features = ["derive", "rc"] } +serde_json = "1.0.140" +sha2 = "0.10.9" +shell-color = "1.0.0" +shell-words = "1.1.0" +shellexpand = "3.0.0" +shlex = "1.3.0" +similar = "2.7.0" +skim = { version = "0.16.2" } +spinners = "4.1.0" +strip-ansi-escapes = "0.2.1" +strum = { version = "0.27.1", features = ["derive"] } +syntect = "5.2.0" +sysinfo = "0.33.1" +tempfile = "3.18.0" +thiserror = "2.0.12" +time = { version = "0.3.39", features = [ + "parsing", + "formatting", + "local-offset", + "macros", + "serde", +] } +tokio = { version = "1.44.2", features = ["full"] } +tokio-tungstenite = "0.26.2" +tokio-util = { version = "0.7.15", features = ["codec", "compat"] } +toml = "0.8.12" +tracing = { version = "0.1.40", features = ["log"] } +tracing-appender = "0.2.2" +tracing-subscriber = { version = "0.3.19", features = [ + "env-filter", + "fmt", + "parking_lot", + "time", +] } +unicode-width = "0.2.0" +url = "2.5.4" +uuid = { version = "1.15.1", features = ["v4", "serde"] } +walkdir = "2.5.0" +webpki-roots = "0.26.8" +whoami = "1.6.0" +winnow = "=0.6.2" + +[target.'cfg(unix)'.dependencies] +nix = { version = "0.29.0", features = [ + "feature", + "fs", + "ioctl", + "process", + "signal", + "term", + "user", +] } + +[target.'cfg(target_os = "macos")'.dependencies] +objc2 = "0.5.2" +objc2-app-kit = { version = "0.2.2", features = ["NSWorkspace"] } +objc2-foundation = { version = "0.2.2", features = ["NSString", "NSURL"] } +security-framework = "3.2.0" + +[dev-dependencies] +assert_cmd = "2.0" +criterion = "0.5.1" +mockito = "1.7.0" +paste = "1.0.11" +predicates = "3.0" +tracing-test = "0.2.4" + +[build-dependencies] +convert_case = "0.8.0" +prettyplease = "0.2.32" +quote = "1.0.40" +serde = { version = "1.0.219", features = ["derive", "rc"] } +serde_json = "1.0.140" +syn = "2.0.101" diff --git a/crates/kiro-cli/build.rs b/crates/kiro-cli/build.rs new file mode 100644 index 0000000000..b7320f27d5 --- /dev/null +++ b/crates/kiro-cli/build.rs @@ -0,0 +1,280 @@ +use convert_case::{ + Case, + Casing, +}; +use quote::{ + format_ident, + quote, +}; + +const DEF: &str = include_str!("./telemetry_definitions.json"); + +#[derive(Debug, Clone, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +struct TypeDef { + name: String, + r#type: Option, + allowed_values: Option>, + description: String, +} + +#[derive(Debug, Clone, serde::Deserialize)] +struct MetricDef { + name: String, + description: String, + metadata: Option>, + passive: Option, + unit: Option, +} + +#[derive(Debug, Clone, serde::Deserialize)] +struct MetricMetadata { + r#type: String, + required: Option, +} + +#[derive(Debug, Clone, serde::Deserialize)] +struct Def { + types: Vec, + metrics: Vec, +} + +fn main() { + println!("cargo:rerun-if-changed=def.json"); + + let outdir = std::env::var("OUT_DIR").unwrap(); + + let data = serde_json::from_str::(DEF).unwrap(); + + let mut out = " + #[allow(rustdoc::invalid_html_tags)] + #[allow(rustdoc::bare_urls)] + mod inner { + " + .to_string(); + + out.push_str("pub mod types {"); + for t in data.types { + let name = format_ident!("{}", t.name.to_case(Case::Pascal)); + + let rust_type = match t.allowed_values { + // enum + Some(allowed_values) => { + let mut variants = vec![]; + let mut variant_as_str = vec![]; + + for v in allowed_values { + let ident = format_ident!("{}", v.replace('.', "").to_case(Case::Pascal)); + variants.push(quote!( + #[doc = concat!("`", #v, "`")] + #ident + )); + variant_as_str.push(quote!( + #name::#ident => #v + )); + } + + let description = t.description; + + quote::quote!( + #[doc = #description] + #[derive(Debug, Clone, PartialEq)] + #[non_exhaustive] + pub enum #name { + #( + #variants, + )* + } + + impl #name { + pub fn as_str(&self) -> &'static str { + match self { + #( #variant_as_str, )* + } + } + } + + impl ::std::fmt::Display for #name { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + f.write_str(self.as_str()) + } + } + ) + .to_string() + }, + // struct + None => { + let r#type = match t.r#type.as_deref() { + Some("string") | None => quote!(::std::string::String), + Some("int") => quote!(::std::primitive::i64), + Some("double") => quote!(::std::primitive::f64), + Some("boolean") => quote!(::std::primitive::bool), + Some(other) => panic!("{}", other), + }; + let description = t.description; + + quote::quote!( + #[doc = #description] + #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] + #[serde(transparent)] + pub struct #name(pub #r#type); + + impl #name { + pub fn new(t: #r#type) -> Self { + Self(t) + } + + pub fn value(&self) -> &#r#type { + &self.0 + } + + pub fn into_value(self) -> #r#type { + self.0 + } + } + + impl ::std::fmt::Display for #name { + fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { + write!(f, "{}", self.0) + } + } + + impl From<#r#type> for #name { + fn from(t: #r#type) -> Self { + Self(t) + } + } + ) + .to_string() + }, + }; + + out.push_str(&rust_type); + } + out.push('}'); + + out.push_str("pub mod metrics {"); + for m in data.metrics.clone() { + let raw_name = m.name; + let name = format_ident!("{}", raw_name.to_case(Case::Pascal)); + let description = m.description; + + let passive = m.passive.unwrap_or_default(); + + let unit = match m.unit.map(|u| u.to_lowercase()).as_deref() { + Some("bytes") => quote!(::amzn_toolkit_telemetry_client::types::Unit::Bytes), + Some("count") => quote!(::amzn_toolkit_telemetry_client::types::Unit::Count), + Some("milliseconds") => quote!(::amzn_toolkit_telemetry_client::types::Unit::Milliseconds), + Some("percent") => quote!(::amzn_toolkit_telemetry_client::types::Unit::Percent), + Some("none") | None => quote!(::amzn_toolkit_telemetry_client::types::Unit::None), + Some(unknown) => { + panic!("unknown unit: {:?}", unknown); + }, + }; + + let metadata = m.metadata.unwrap_or_default(); + + let mut fields = Vec::new(); + for field in &metadata { + let field_name = format_ident!("{}", &field.r#type.to_case(Case::Snake)); + let ty_name = format_ident!("{}", field.r#type.to_case(Case::Pascal)); + let ty = if field.required.unwrap_or_default() { + quote!(crate::fig_telemetry::definitions::types::#ty_name) + } else { + quote!(::std::option::Option) + }; + + fields.push(quote!( + #field_name: #ty + )); + } + + let metadata_entries = metadata.iter().map(|m| { + let raw_name = &m.r#type; + let key = format_ident!("{}", m.r#type.to_case(Case::Snake)); + + let value = if m.required.unwrap_or_default() { + quote!(.value(self.#key.to_string())) + } else { + quote!(.value(self.#key.map(|v| v.to_string()).unwrap_or_default())) + }; + + quote!( + ::amzn_toolkit_telemetry_client::types::MetadataEntry::builder() + .key(#raw_name) + #value + .build() + ) + }); + + let rust_type = quote::quote!( + #[doc = #description] + #[derive(Debug, Clone, PartialEq, ::serde::Serialize, ::serde::Deserialize)] + #[serde(rename_all = "camelCase")] + pub struct #name { + /// The time that the event took place, + pub create_time: ::std::option::Option<::std::time::SystemTime>, + /// Value based on unit and call type, + pub value: ::std::option::Option, + #( pub #fields, )* + } + + impl #name { + const NAME: &'static ::std::primitive::str = #raw_name; + const PASSIVE: ::std::primitive::bool = #passive; + const UNIT: ::amzn_toolkit_telemetry_client::types::Unit = #unit; + } + + impl crate::fig_telemetry::definitions::IntoMetricDatum for #name { + fn into_metric_datum(self) -> ::amzn_toolkit_telemetry_client::types::MetricDatum { + let metadata_entries = vec![ + #( + #metadata_entries, + )* + ]; + + let epoch_timestamp = self.create_time + .map_or_else( + || ::std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as ::std::primitive::i64, + |t| t.duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as ::std::primitive::i64 + ); + + ::amzn_toolkit_telemetry_client::types::MetricDatum::builder() + .metric_name(#name::NAME) + .passive(#name::PASSIVE) + .unit(#name::UNIT) + .epoch_timestamp(epoch_timestamp) + .value(self.value.unwrap_or(1.0)) + .set_metadata(Some(metadata_entries)) + .build() + .unwrap() + } + } + ); + + out.push_str(&rust_type.to_string()); + } + out.push('}'); + + // enum of all metrics + let mut metrics = Vec::new(); + for m in data.metrics { + let name = format_ident!("{}", m.name.to_case(Case::Pascal)); + metrics.push(quote!( + #name + )); + } + out.push_str("#[derive(Debug, Clone, PartialEq, ::serde::Serialize, ::serde::Deserialize)]\n#[serde(tag = \"type\", content = \"content\")]\npub enum Metric {\n"); + for m in metrics { + out.push_str(&format!("{m}(crate::fig_telemetry::definitions::metrics::{m}),\n")); + } + out.push('}'); + + out.push_str("}\npub use inner::*;"); + + let file: syn::File = syn::parse_str(&out).unwrap(); + let pp = prettyplease::unparse(&file); + + // write an empty file to the output directory + std::fs::write(format!("{}/mod.rs", outdir), pp).unwrap(); +} diff --git a/crates/kiro-cli/src/cli/chat/cli.rs b/crates/kiro-cli/src/cli/chat/cli.rs new file mode 100644 index 0000000000..a441887b9b --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/cli.rs @@ -0,0 +1,25 @@ +use clap::Parser; + +#[derive(Debug, Clone, PartialEq, Eq, Default, Parser)] +pub struct Chat { + /// (Deprecated, use --trust-all-tools) Enabling this flag allows the model to execute + /// all commands without first accepting them. + #[arg(short, long, hide = true)] + pub accept_all: bool, + /// Print the first response to STDOUT without interactive mode. This will fail if the + /// prompt requests permissions to use a tool, unless --trust-all-tools is also used. + #[arg(long)] + pub no_interactive: bool, + /// The first question to ask + pub input: Option, + /// Context profile to use + #[arg(long = "profile")] + pub profile: Option, + /// Allows the model to use any tool to run commands without asking for confirmation. + #[arg(long)] + pub trust_all_tools: bool, + /// Trust only this set of tools. Example: trust some tools: + /// '--trust-tools=fs_read,fs_write', trust no tools: '--trust-tools=' + #[arg(long, value_delimiter = ',', value_name = "TOOL_NAMES")] + pub trust_tools: Option>, +} diff --git a/crates/kiro-cli/src/cli/chat/command.rs b/crates/kiro-cli/src/cli/chat/command.rs new file mode 100644 index 0000000000..43d07f1169 --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/command.rs @@ -0,0 +1,1093 @@ +use std::collections::HashSet; +use std::io::Write; + +use clap::{ + Parser, + Subcommand, +}; +use crossterm::style::Color; +use crossterm::{ + queue, + style, +}; +use eyre::Result; +use serde::{ + Deserialize, + Serialize, +}; + +#[derive(Debug, PartialEq, Eq)] +pub enum Command { + Ask { + prompt: String, + }, + Execute { + command: String, + }, + Clear, + Help, + Issue { + prompt: Option, + }, + Quit, + Profile { + subcommand: ProfileSubcommand, + }, + Context { + subcommand: ContextSubcommand, + }, + PromptEditor { + initial_text: Option, + }, + Compact { + prompt: Option, + show_summary: bool, + help: bool, + }, + Tools { + subcommand: Option, + }, + Prompts { + subcommand: Option, + }, + Usage, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ProfileSubcommand { + List, + Create { name: String }, + Delete { name: String }, + Set { name: String }, + Rename { old_name: String, new_name: String }, + Help, +} + +impl ProfileSubcommand { + const AVAILABLE_COMMANDS: &str = color_print::cstr! {"Available commands + help Show an explanation for the profile command + list List all available profiles + create <> Create a new profile with the specified name + delete <> Delete the specified profile + set <> Switch to the specified profile + rename <> <> Rename a profile"}; + const CREATE_USAGE: &str = "/profile create "; + const DELETE_USAGE: &str = "/profile delete "; + const RENAME_USAGE: &str = "/profile rename "; + const SET_USAGE: &str = "/profile set "; + + fn usage_msg(header: impl AsRef) -> String { + format!("{}\n\n{}", header.as_ref(), Self::AVAILABLE_COMMANDS) + } + + pub fn help_text() -> String { + color_print::cformat!( + r#" +(Beta) Profile Management + +Profiles allow you to organize and manage different sets of context files for different projects or tasks. + +{} + +Notes +• The "global" profile contains context files that are available in all profiles +• The "default" profile is used when no profile is specified +• You can switch between profiles to work on different projects +• Each profile maintains its own set of context files +"#, + Self::AVAILABLE_COMMANDS + ) + } +} + +#[derive(Parser, Debug, Clone)] +#[command(name = "hooks", disable_help_flag = true, disable_help_subcommand = true)] +struct HooksCommand { + #[command(subcommand)] + command: HooksSubcommand, +} + +#[derive(Subcommand, Debug, Clone, Eq, PartialEq)] +pub enum HooksSubcommand { + Add { + name: String, + + #[arg(long, value_parser = ["per_prompt", "conversation_start"])] + trigger: String, + + #[arg(long, value_parser = clap::value_parser!(String))] + command: String, + + #[arg(long)] + global: bool, + }, + #[command(name = "rm")] + Remove { + name: String, + + #[arg(long)] + global: bool, + }, + Enable { + name: String, + + #[arg(long)] + global: bool, + }, + Disable { + name: String, + + #[arg(long)] + global: bool, + }, + EnableAll { + #[arg(long)] + global: bool, + }, + DisableAll { + #[arg(long)] + global: bool, + }, + Help, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ContextSubcommand { + Show { + expand: bool, + }, + Add { + global: bool, + force: bool, + paths: Vec, + }, + Remove { + global: bool, + paths: Vec, + }, + Clear { + global: bool, + }, + Hooks { + subcommand: Option, + }, + Help, +} + +impl ContextSubcommand { + const ADD_USAGE: &str = "/context add [--global] [--force] [path2...]"; + const AVAILABLE_COMMANDS: &str = color_print::cstr! {"Available commands + help Show an explanation for the context command + + show [--expand] Display the context rule configuration and matched files + --expand: Print out each matched file's content + + add [--global] [--force] <> + Add context rules (filenames or glob patterns) + --global: Add to global rules (available in all profiles) + --force: Include even if matched files exceed size limits + + rm [--global] <> Remove specified rules from current profile + --global: Remove specified rules globally + + clear [--global] Remove all rules from current profile + --global: Remove global rules + + hooks View and manage context hooks"}; + const CLEAR_USAGE: &str = "/context clear [--global]"; + const HOOKS_AVAILABLE_COMMANDS: &str = color_print::cstr! {"Available subcommands + hooks help Show an explanation for context hooks commands + + hooks add [--global] <> Add a new command context hook + --global: Add to global hooks + --trigger <> When to trigger the hook, valid options: `per_prompt` or `conversation_start` + --command <> Shell command to execute + + hooks rm [--global] <> Remove an existing context hook + --global: Remove from global hooks + + hooks enable [--global] <> Enable an existing context hook + --global: Enable in global hooks + + hooks disable [--global] <> Disable an existing context hook + --global: Disable in global hooks + + hooks enable-all [--global] Enable all existing context hooks + --global: Enable all in global hooks + + hooks disable-all [--global] Disable all existing context hooks + --global: Disable all in global hooks"}; + const REMOVE_USAGE: &str = "/context rm [--global] [path2...]"; + const SHOW_USAGE: &str = "/context show [--expand]"; + + fn usage_msg(header: impl AsRef) -> String { + format!("{}\n\n{}", header.as_ref(), Self::AVAILABLE_COMMANDS) + } + + fn hooks_usage_msg(header: impl AsRef) -> String { + format!("{}\n\n{}", header.as_ref(), Self::HOOKS_AVAILABLE_COMMANDS) + } + + pub fn help_text() -> String { + color_print::cformat!( + r#" +(Beta) Context Rule Management + +Context rules determine which files are included in your Amazon Q session. +The files matched by these rules provide Amazon Q with additional information +about your project or environment. Adding relevant files helps Q generate +more accurate and helpful responses. + +In addition to files, you can specify hooks that will run commands and return +the output as context to Amazon Q. + +{} + +Notes +• You can add specific files or use glob patterns (e.g., "*.py", "src/**/*.js") +• Profile rules apply only to the current profile +• Global rules apply across all profiles +• Context is preserved between chat sessions +"#, + Self::AVAILABLE_COMMANDS + ) + } + + pub fn hooks_help_text() -> String { + color_print::cformat!( + r#" +(Beta) Context Hooks + +Use context hooks to specify shell commands to run. The output from these +commands will be appended to the prompt to Amazon Q. Hooks can be defined +in global or local profiles. + +Usage: /context hooks [SUBCOMMAND] + +Description + Show existing global or profile-specific hooks. + Alternatively, specify a subcommand to modify the hooks. + +{} + +Notes +• Hooks are executed in parallel +• 'conversation_start' hooks run on the first user prompt and are attached once to the conversation history sent to Amazon Q +• 'per_prompt' hooks run on each user prompt and are attached to the prompt, but are not stored in conversation history +"#, + Self::HOOKS_AVAILABLE_COMMANDS + ) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ToolsSubcommand { + Schema, + Trust { tool_names: HashSet }, + Untrust { tool_names: HashSet }, + TrustAll, + Reset, + ResetSingle { tool_name: String }, + Help, +} + +impl ToolsSubcommand { + const AVAILABLE_COMMANDS: &str = color_print::cstr! {"Available subcommands + help Show an explanation for the tools command + schema Show the input schema for all available tools + trust <> Trust a specific tool or tools for the session + untrust <> Revert a tool or tools to per-request confirmation + trustall Trust all tools (equivalent to deprecated /acceptall) + reset Reset all tools to default permission levels + reset <> Reset a single tool to default permission level"}; + const BASE_COMMAND: &str = color_print::cstr! {"Usage: /tools [SUBCOMMAND] + +Description + Show the current set of tools and their permission setting. + The permission setting states when user confirmation is required. Trusted tools never require confirmation. + Alternatively, specify a subcommand to modify the tool permissions."}; + + fn usage_msg(header: impl AsRef) -> String { + format!( + "{}\n\n{}\n\n{}", + header.as_ref(), + Self::BASE_COMMAND, + Self::AVAILABLE_COMMANDS + ) + } + + pub fn help_text() -> String { + color_print::cformat!( + r#" +Tool Permissions + +By default, Amazon Q will ask for your permission to use certain tools. You can control which tools you +trust so that no confirmation is required. These settings will last only for this session. + +{} + +{}"#, + Self::BASE_COMMAND, + Self::AVAILABLE_COMMANDS + ) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PromptsSubcommand { + List { search_word: Option }, + Get { get_command: PromptsGetCommand }, + Help, +} + +impl PromptsSubcommand { + const AVAILABLE_COMMANDS: &str = color_print::cstr! {"Available subcommands + help Show an explanation for the prompts command + list [search word] List available prompts from a tool or show all available prompts"}; + const BASE_COMMAND: &str = color_print::cstr! {"Usage: /prompts [SUBCOMMAND] + +Description + Show the current set of reusuable prompts from the current fleet of mcp servers."}; + + fn usage_msg(header: impl AsRef) -> String { + format!( + "{}\n\n{}\n\n{}", + header.as_ref(), + Self::BASE_COMMAND, + Self::AVAILABLE_COMMANDS + ) + } + + pub fn help_text() -> String { + color_print::cformat!( + r#" +Prompts + +Prompts are reusable templates that help you quickly access common workflows and tasks. +These templates are provided by the mcp servers you have installed and configured. + +To actually retrieve a prompt, directly start with the following command (without prepending /prompt get): + @<> [arg] Retrieve prompt specified +Or if you prefer the long way: + /prompts get <> [arg] Retrieve prompt specified + +{} + +{}"#, + Self::BASE_COMMAND, + Self::AVAILABLE_COMMANDS + ) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct PromptsGetCommand { + pub orig_input: Option, + pub params: PromptsGetParam, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct PromptsGetParam { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option>, +} + +impl Command { + // Check if input is a common single-word command that should use slash prefix + fn check_common_command(input: &str) -> Option { + let input_lower = input.trim().to_lowercase(); + match input_lower.as_str() { + "exit" | "quit" | "q" | "exit()" => { + Some("Did you mean to use the command '/quit' to exit? Type '/quit' to exit.".to_string()) + }, + "clear" | "cls" => Some( + "Did you mean to use the command '/clear' to clear the conversation? Type '/clear' to clear." + .to_string(), + ), + "help" | "?" => Some( + "Did you mean to use the command '/help' for help? Type '/help' to see available commands.".to_string(), + ), + _ => None, + } + } + + pub fn parse(input: &str, output: &mut impl Write) -> Result { + let input = input.trim(); + + // Check for common single-word commands without slash prefix + if let Some(suggestion) = Self::check_common_command(input) { + return Err(suggestion); + } + + // Check if the input starts with a literal backslash followed by a slash + // This allows users to escape the slash if they actually want to start with one + if input.starts_with("\\/") { + return Ok(Self::Ask { + prompt: input[1..].to_string(), // Remove the backslash but keep the slash + }); + } + + if let Some(command) = input.strip_prefix("/") { + let parts: Vec<&str> = command.split_whitespace().collect(); + + if parts.is_empty() { + return Err("Empty command".to_string()); + } + + return Ok(match parts[0].to_lowercase().as_str() { + "clear" => Self::Clear, + "help" => Self::Help, + "compact" => { + let mut prompt = None; + let show_summary = true; + let mut help = false; + + // Check if "help" is the first subcommand + if parts.len() > 1 && parts[1].to_lowercase() == "help" { + help = true; + } else { + let mut remaining_parts = Vec::new(); + + remaining_parts.extend_from_slice(&parts[1..]); + + // If we have remaining parts after parsing flags, join them as the prompt + if !remaining_parts.is_empty() { + prompt = Some(remaining_parts.join(" ")); + } + } + + Self::Compact { + prompt, + show_summary, + help, + } + }, + "acceptall" => { + let _ = queue!( + output, + style::SetForegroundColor(Color::Yellow), + style::Print("\n/acceptall is deprecated. Use /tools instead.\n\n"), + style::SetForegroundColor(Color::Reset) + ); + + Self::Tools { + subcommand: Some(ToolsSubcommand::TrustAll), + } + }, + "editor" => { + if parts.len() > 1 { + Self::PromptEditor { + initial_text: Some(parts[1..].join(" ")), + } + } else { + Self::PromptEditor { initial_text: None } + } + }, + "issue" => { + if parts.len() > 1 { + Self::Issue { + prompt: Some(parts[1..].join(" ")), + } + } else { + Self::Issue { prompt: None } + } + }, + "q" | "exit" | "quit" => Self::Quit, + "profile" => { + if parts.len() < 2 { + return Ok(Self::Profile { + subcommand: ProfileSubcommand::Help, + }); + } + + macro_rules! usage_err { + ($usage_str:expr) => { + return Err(format!( + "Invalid /profile arguments.\n\nUsage:\n {}", + $usage_str + )) + }; + } + + match parts[1].to_lowercase().as_str() { + "list" => Self::Profile { + subcommand: ProfileSubcommand::List, + }, + "create" => { + let name = parts.get(2); + match name { + Some(name) => Self::Profile { + subcommand: ProfileSubcommand::Create { + name: (*name).to_string(), + }, + }, + None => usage_err!(ProfileSubcommand::CREATE_USAGE), + } + }, + "delete" => { + let name = parts.get(2); + match name { + Some(name) => Self::Profile { + subcommand: ProfileSubcommand::Delete { + name: (*name).to_string(), + }, + }, + None => usage_err!(ProfileSubcommand::DELETE_USAGE), + } + }, + "rename" => { + let old_name = parts.get(2); + let new_name = parts.get(3); + match (old_name, new_name) { + (Some(old), Some(new)) => Self::Profile { + subcommand: ProfileSubcommand::Rename { + old_name: (*old).to_string(), + new_name: (*new).to_string(), + }, + }, + _ => usage_err!(ProfileSubcommand::RENAME_USAGE), + } + }, + "set" => { + let name = parts.get(2); + match name { + Some(name) => Self::Profile { + subcommand: ProfileSubcommand::Set { + name: (*name).to_string(), + }, + }, + None => usage_err!(ProfileSubcommand::SET_USAGE), + } + }, + "help" => Self::Profile { + subcommand: ProfileSubcommand::Help, + }, + other => { + return Err(ProfileSubcommand::usage_msg(format!("Unknown subcommand '{}'.", other))); + }, + } + }, + "context" => { + if parts.len() < 2 { + return Ok(Self::Context { + subcommand: ContextSubcommand::Help, + }); + } + + macro_rules! usage_err { + ($usage_str:expr) => { + return Err(format!( + "Invalid /context arguments.\n\nUsage:\n {}", + $usage_str + )) + }; + } + + match parts[1].to_lowercase().as_str() { + "show" => { + let mut expand = false; + for part in &parts[2..] { + if *part == "--expand" { + expand = true; + } else { + usage_err!(ContextSubcommand::SHOW_USAGE); + } + } + Self::Context { + subcommand: ContextSubcommand::Show { expand }, + } + }, + "add" => { + // Parse add command with paths and flags + let mut global = false; + let mut force = false; + let mut paths = Vec::new(); + + let args = match shlex::split(&parts[2..].join(" ")) { + Some(args) => args, + None => return Err("Failed to parse quoted arguments".to_string()), + }; + + for arg in &args { + if arg == "--global" { + global = true; + } else if arg == "--force" || arg == "-f" { + force = true; + } else { + paths.push(arg.to_string()); + } + } + + if paths.is_empty() { + usage_err!(ContextSubcommand::ADD_USAGE); + } + + Self::Context { + subcommand: ContextSubcommand::Add { global, force, paths }, + } + }, + "rm" => { + // Parse rm command with paths and --global flag + let mut global = false; + let mut paths = Vec::new(); + let args = match shlex::split(&parts[2..].join(" ")) { + Some(args) => args, + None => return Err("Failed to parse quoted arguments".to_string()), + }; + + for arg in &args { + if arg == "--global" { + global = true; + } else { + paths.push(arg.to_string()); + } + } + + if paths.is_empty() { + usage_err!(ContextSubcommand::REMOVE_USAGE); + } + + Self::Context { + subcommand: ContextSubcommand::Remove { global, paths }, + } + }, + "clear" => { + // Parse clear command with optional --global flag + let mut global = false; + + for part in &parts[2..] { + if *part == "--global" { + global = true; + } else { + usage_err!(ContextSubcommand::CLEAR_USAGE); + } + } + + Self::Context { + subcommand: ContextSubcommand::Clear { global }, + } + }, + "help" => Self::Context { + subcommand: ContextSubcommand::Help, + }, + "hooks" => { + if parts.get(2).is_none() { + return Ok(Self::Context { + subcommand: ContextSubcommand::Hooks { subcommand: None }, + }); + }; + + match Self::parse_hooks(&parts) { + Ok(command) => command, + Err(err) => return Err(ContextSubcommand::hooks_usage_msg(err)), + } + }, + other => { + return Err(ContextSubcommand::usage_msg(format!("Unknown subcommand '{}'.", other))); + }, + } + }, + "tools" => { + if parts.len() < 2 { + return Ok(Self::Tools { subcommand: None }); + } + + match parts[1].to_lowercase().as_str() { + "schema" => Self::Tools { + subcommand: Some(ToolsSubcommand::Schema), + }, + "trust" => { + let mut tool_names = HashSet::new(); + for part in &parts[2..] { + tool_names.insert((*part).to_string()); + } + + if tool_names.is_empty() { + let _ = queue!( + output, + style::SetForegroundColor(Color::DarkGrey), + style::Print("\nPlease use"), + style::SetForegroundColor(Color::DarkGreen), + style::Print(" /tools trust "), + style::SetForegroundColor(Color::DarkGrey), + style::Print(" to trust tools.\n\n"), + style::Print("Use "), + style::SetForegroundColor(Color::DarkGreen), + style::Print("/tools"), + style::SetForegroundColor(Color::DarkGrey), + style::Print(" to see all available tools.\n\n"), + style::SetForegroundColor(Color::Reset), + ); + } + + Self::Tools { + subcommand: Some(ToolsSubcommand::Trust { tool_names }), + } + }, + "untrust" => { + let mut tool_names = HashSet::new(); + for part in &parts[2..] { + tool_names.insert((*part).to_string()); + } + + if tool_names.is_empty() { + let _ = queue!( + output, + style::SetForegroundColor(Color::DarkGrey), + style::Print("\nPlease use"), + style::SetForegroundColor(Color::DarkGreen), + style::Print(" /tools untrust "), + style::SetForegroundColor(Color::DarkGrey), + style::Print(" to untrust tools.\n\n"), + style::Print("Use "), + style::SetForegroundColor(Color::DarkGreen), + style::Print("/tools"), + style::SetForegroundColor(Color::DarkGrey), + style::Print(" to see all available tools.\n\n"), + style::SetForegroundColor(Color::Reset), + ); + } + + Self::Tools { + subcommand: Some(ToolsSubcommand::Untrust { tool_names }), + } + }, + "trustall" => Self::Tools { + subcommand: Some(ToolsSubcommand::TrustAll), + }, + "reset" => { + let tool_name = parts.get(2); + match tool_name { + Some(tool_name) => Self::Tools { + subcommand: Some(ToolsSubcommand::ResetSingle { + tool_name: (*tool_name).to_string(), + }), + }, + None => Self::Tools { + subcommand: Some(ToolsSubcommand::Reset), + }, + } + }, + "help" => Self::Tools { + subcommand: Some(ToolsSubcommand::Help), + }, + other => { + return Err(ToolsSubcommand::usage_msg(format!("Unknown subcommand '{}'.", other))); + }, + } + }, + "prompts" => { + let subcommand = parts.get(1); + match subcommand { + Some(c) if c.to_lowercase() == "list" => Self::Prompts { + subcommand: Some(PromptsSubcommand::List { + search_word: parts.get(2).map(|v| (*v).to_string()), + }), + }, + Some(c) if c.to_lowercase() == "help" => Self::Prompts { + subcommand: Some(PromptsSubcommand::Help), + }, + Some(c) if c.to_lowercase() == "get" => { + // Need to reconstruct the input because simple splitting of + // white space might not be sufficient + let command = parts[2..].join(" "); + let get_command = parse_input_to_prompts_get_command(command.as_str())?; + let subcommand = Some(PromptsSubcommand::Get { get_command }); + Self::Prompts { subcommand } + }, + Some(other) => { + return Err(PromptsSubcommand::usage_msg(format!( + "Unknown subcommand '{}'\n", + other + ))); + }, + None => Self::Prompts { + subcommand: Some(PromptsSubcommand::List { + search_word: parts.get(2).map(|v| (*v).to_string()), + }), + }, + } + }, + "usage" => Self::Usage, + unknown_command => { + // If the command starts with a slash but isn't recognized, + // return an error instead of treating it as a prompt + return Err(format!( + "Unknown command: '/{}'. Type '/help' to see available commands.\nTo use a literal slash at the beginning of your message, escape it with a backslash (e.g., '\\//hey' for '/hey').", + unknown_command + )); + }, + }); + } + + if let Some(command) = input.strip_prefix('@') { + let get_command = parse_input_to_prompts_get_command(command)?; + let subcommand = Some(PromptsSubcommand::Get { get_command }); + return Ok(Self::Prompts { subcommand }); + } + + if let Some(command) = input.strip_prefix("!") { + return Ok(Self::Execute { + command: command.to_string(), + }); + } + + Ok(Self::Ask { + prompt: input.to_string(), + }) + } + + // NOTE: Here we use clap to parse the hooks subcommand instead of parsing manually + // like the rest of the file. + // Since the hooks subcommand has a lot of options, this makes more sense. + // Ideally, we parse everything with clap instead of trying to do it manually. + fn parse_hooks(parts: &[&str]) -> Result { + // Skip the first two parts ("/context" and "hooks") + let args = match shlex::split(&parts[1..].join(" ")) { + Some(args) => args, + None => return Err("Failed to parse arguments".to_string()), + }; + + // Parse with Clap + HooksCommand::try_parse_from(args) + .map(|hooks_command| Self::Context { + subcommand: ContextSubcommand::Hooks { + subcommand: Some(hooks_command.command), + }, + }) + .map_err(|e| e.to_string()) + } +} + +fn parse_input_to_prompts_get_command(command: &str) -> Result { + let input = shell_words::split(command).map_err(|e| format!("Error splitting command for prompts: {:?}", e))?; + let mut iter = input.into_iter(); + let prompt_name = iter.next().ok_or("Prompt name needs to be specified")?; + let args = iter.collect::>(); + let params = PromptsGetParam { + name: prompt_name, + arguments: { if args.is_empty() { None } else { Some(args) } }, + }; + let orig_input = Some(command.to_string()); + Ok(PromptsGetCommand { orig_input, params }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_command_parse() { + let mut stdout = std::io::stdout(); + + macro_rules! profile { + ($subcommand:expr) => { + Command::Profile { + subcommand: $subcommand, + } + }; + } + macro_rules! context { + ($subcommand:expr) => { + Command::Context { + subcommand: $subcommand, + } + }; + } + macro_rules! compact { + ($prompt:expr, $show_summary:expr) => { + Command::Compact { + prompt: $prompt, + show_summary: $show_summary, + help: false, + } + }; + } + let tests = &[ + ("/compact", compact!(None, true)), + ( + "/compact custom prompt", + compact!(Some("custom prompt".to_string()), true), + ), + ("/profile list", profile!(ProfileSubcommand::List)), + ( + "/profile create new_profile", + profile!(ProfileSubcommand::Create { + name: "new_profile".to_string(), + }), + ), + ( + "/profile delete p", + profile!(ProfileSubcommand::Delete { name: "p".to_string() }), + ), + ( + "/profile rename old new", + profile!(ProfileSubcommand::Rename { + old_name: "old".to_string(), + new_name: "new".to_string(), + }), + ), + ( + "/profile set p", + profile!(ProfileSubcommand::Set { name: "p".to_string() }), + ), + ( + "/profile set p", + profile!(ProfileSubcommand::Set { name: "p".to_string() }), + ), + ("/context show", context!(ContextSubcommand::Show { expand: false })), + ( + "/context show --expand", + context!(ContextSubcommand::Show { expand: true }), + ), + ( + "/context add p1 p2", + context!(ContextSubcommand::Add { + global: false, + force: false, + paths: vec!["p1".into(), "p2".into()] + }), + ), + ( + "/context add --global --force p1 p2", + context!(ContextSubcommand::Add { + global: true, + force: true, + paths: vec!["p1".into(), "p2".into()] + }), + ), + ( + "/context rm p1 p2", + context!(ContextSubcommand::Remove { + global: false, + paths: vec!["p1".into(), "p2".into()] + }), + ), + ( + "/context rm --global p1 p2", + context!(ContextSubcommand::Remove { + global: true, + paths: vec!["p1".into(), "p2".into()] + }), + ), + ("/context clear", context!(ContextSubcommand::Clear { global: false })), + ( + "/context clear --global", + context!(ContextSubcommand::Clear { global: true }), + ), + ("/issue", Command::Issue { prompt: None }), + ("/issue there was an error in the chat", Command::Issue { + prompt: Some("there was an error in the chat".to_string()), + }), + ("/issue \"there was an error in the chat\"", Command::Issue { + prompt: Some("\"there was an error in the chat\"".to_string()), + }), + ( + "/context hooks", + context!(ContextSubcommand::Hooks { subcommand: None }), + ), + ( + "/context hooks add test --trigger per_prompt --command 'echo 1' --global", + context!(ContextSubcommand::Hooks { + subcommand: Some(HooksSubcommand::Add { + name: "test".to_string(), + global: true, + trigger: "per_prompt".to_string(), + command: "echo 1".to_string() + }) + }), + ), + ( + "/context hooks rm test --global", + context!(ContextSubcommand::Hooks { + subcommand: Some(HooksSubcommand::Remove { + name: "test".to_string(), + global: true + }) + }), + ), + ( + "/context hooks enable test --global", + context!(ContextSubcommand::Hooks { + subcommand: Some(HooksSubcommand::Enable { + name: "test".to_string(), + global: true + }) + }), + ), + ( + "/context hooks disable test", + context!(ContextSubcommand::Hooks { + subcommand: Some(HooksSubcommand::Disable { + name: "test".to_string(), + global: false + }) + }), + ), + ( + "/context hooks enable-all --global", + context!(ContextSubcommand::Hooks { + subcommand: Some(HooksSubcommand::EnableAll { global: true }) + }), + ), + ( + "/context hooks disable-all", + context!(ContextSubcommand::Hooks { + subcommand: Some(HooksSubcommand::DisableAll { global: false }) + }), + ), + ( + "/context hooks help", + context!(ContextSubcommand::Hooks { + subcommand: Some(HooksSubcommand::Help) + }), + ), + ]; + + for (input, parsed) in tests { + assert_eq!(&Command::parse(input, &mut stdout).unwrap(), parsed, "{}", input); + } + } + + #[test] + fn test_common_command_suggestions() { + let mut stdout = std::io::stdout(); + let test_cases = vec![ + ( + "exit", + "Did you mean to use the command '/quit' to exit? Type '/quit' to exit.", + ), + ( + "quit", + "Did you mean to use the command '/quit' to exit? Type '/quit' to exit.", + ), + ( + "q", + "Did you mean to use the command '/quit' to exit? Type '/quit' to exit.", + ), + ( + "clear", + "Did you mean to use the command '/clear' to clear the conversation? Type '/clear' to clear.", + ), + ( + "cls", + "Did you mean to use the command '/clear' to clear the conversation? Type '/clear' to clear.", + ), + ( + "help", + "Did you mean to use the command '/help' for help? Type '/help' to see available commands.", + ), + ( + "?", + "Did you mean to use the command '/help' for help? Type '/help' to see available commands.", + ), + ]; + + for (input, expected_message) in test_cases { + let result = Command::parse(input, &mut stdout); + assert!(result.is_err(), "Expected error for input: {}", input); + assert_eq!(result.unwrap_err(), expected_message); + } + } +} diff --git a/crates/kiro-cli/src/cli/chat/consts.rs b/crates/kiro-cli/src/cli/chat/consts.rs new file mode 100644 index 0000000000..6850f7efab --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/consts.rs @@ -0,0 +1,19 @@ +use super::token_counter::TokenCounter; + +// These limits are the internal undocumented values from the service for each item + +pub const MAX_CURRENT_WORKING_DIRECTORY_LEN: usize = 256; + +/// Limit to send the number of messages as part of chat. +pub const MAX_CONVERSATION_STATE_HISTORY_LEN: usize = 250; + +pub const MAX_TOOL_RESPONSE_SIZE: usize = 800_000; + +/// TODO: Use this to gracefully handle user message sizes. +#[allow(dead_code)] +pub const MAX_USER_MESSAGE_SIZE: usize = 600_000; + +/// In tokens +pub const CONTEXT_WINDOW_SIZE: usize = 200_000; + +pub const MAX_CHARS: usize = TokenCounter::token_to_chars(CONTEXT_WINDOW_SIZE); // Character-based warning threshold diff --git a/crates/kiro-cli/src/cli/chat/context.rs b/crates/kiro-cli/src/cli/chat/context.rs new file mode 100644 index 0000000000..756d055740 --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/context.rs @@ -0,0 +1,1016 @@ +use std::collections::HashMap; +use std::io::Write; +use std::path::{ + Path, + PathBuf, +}; +use std::sync::Arc; + +use eyre::{ + Result, + eyre, +}; +use glob::glob; +use regex::Regex; +use serde::{ + Deserialize, + Serialize, +}; +use tracing::debug; + +use super::hooks::{ + Hook, + HookExecutor, +}; +use crate::fig_os_shim::Context; +use crate::fig_util::directories; + +pub const AMAZONQ_FILENAME: &str = "AmazonQ.md"; + +/// Configuration for context files, containing paths to include in the context. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(default)] +pub struct ContextConfig { + /// List of file paths or glob patterns to include in the context. + pub paths: Vec, + + /// Map of Hook Name to [`Hook`]. The hook name serves as the hook's ID. + pub hooks: HashMap, +} + +#[allow(dead_code)] +/// Manager for context files and profiles. +#[derive(Debug, Clone)] +pub struct ContextManager { + ctx: Arc, + + /// Global context configuration that applies to all profiles. + pub global_config: ContextConfig, + + /// Name of the current active profile. + pub current_profile: String, + + /// Context configuration for the current profile. + pub profile_config: ContextConfig, + + pub hook_executor: HookExecutor, +} + +#[allow(dead_code)] +impl ContextManager { + /// Create a new ContextManager with default settings. + /// + /// This will: + /// 1. Create the necessary directories if they don't exist + /// 2. Load the global configuration + /// 3. Load the default profile configuration + /// + /// # Returns + /// A Result containing the new ContextManager or an error + pub async fn new(ctx: Arc) -> Result { + let profiles_dir = directories::chat_profiles_dir(&ctx)?; + + ctx.fs().create_dir_all(&profiles_dir).await?; + + let global_config = load_global_config(&ctx).await?; + let current_profile = "default".to_string(); + let profile_config = load_profile_config(&ctx, ¤t_profile).await?; + + Ok(Self { + ctx, + global_config, + current_profile, + profile_config, + hook_executor: HookExecutor::new(), + }) + } + + /// Save the current configuration to disk. + /// + /// # Arguments + /// * `global` - If true, save the global configuration; otherwise, save the current profile + /// configuration + /// + /// # Returns + /// A Result indicating success or an error + async fn save_config(&self, global: bool) -> Result<()> { + if global { + let global_path = directories::chat_global_context_path(&self.ctx)?; + let contents = serde_json::to_string_pretty(&self.global_config) + .map_err(|e| eyre!("Failed to serialize global configuration: {}", e))?; + + self.ctx.fs().write(&global_path, contents).await?; + } else { + let profile_path = profile_context_path(&self.ctx, &self.current_profile)?; + if let Some(parent) = profile_path.parent() { + self.ctx.fs().create_dir_all(parent).await?; + } + let contents = serde_json::to_string_pretty(&self.profile_config) + .map_err(|e| eyre!("Failed to serialize profile configuration: {}", e))?; + + self.ctx.fs().write(&profile_path, contents).await?; + } + + Ok(()) + } + + /// Add paths to the context configuration. + /// + /// # Arguments + /// * `paths` - List of paths to add + /// * `global` - If true, add to global configuration; otherwise, add to current profile + /// configuration + /// * `force` - If true, skip validation that the path exists + /// + /// # Returns + /// A Result indicating success or an error + pub async fn add_paths(&mut self, paths: Vec, global: bool, force: bool) -> Result<()> { + let mut all_paths = self.global_config.paths.clone(); + all_paths.append(&mut self.profile_config.paths.clone()); + + // Validate paths exist before adding them + if !force { + let mut context_files = Vec::new(); + + // Check each path to make sure it exists or matches at least one file + for path in &paths { + // We're using a temporary context_files vector just for validation + // Pass is_validation=true to ensure we error if glob patterns don't match any files + match process_path(&self.ctx, path, &mut context_files, false, true).await { + Ok(_) => {}, // Path is valid + Err(e) => return Err(eyre!("Invalid path '{}': {}. Use --force to add anyway.", path, e)), + } + } + } + + // Add each path, checking for duplicates + for path in paths { + if all_paths.contains(&path) { + return Err(eyre!("Rule '{}' already exists.", path)); + } + if global { + self.global_config.paths.push(path); + } else { + self.profile_config.paths.push(path); + } + } + + // Save the updated configuration + self.save_config(global).await?; + + Ok(()) + } + + /// Remove paths from the context configuration. + /// + /// # Arguments + /// * `paths` - List of paths to remove + /// * `global` - If true, remove from global configuration; otherwise, remove from current + /// profile configuration + /// + /// # Returns + /// A Result indicating success or an error + pub async fn remove_paths(&mut self, paths: Vec, global: bool) -> Result<()> { + // Get reference to the appropriate config + let config = self.get_config_mut(global); + + // Track if any paths were removed + let mut removed_any = false; + + // Remove each path if it exists + for path in paths { + let original_len = config.paths.len(); + config.paths.retain(|p| p != &path); + + if config.paths.len() < original_len { + removed_any = true; + } + } + + if !removed_any { + return Err(eyre!("None of the specified paths were found in the context")); + } + + // Save the updated configuration + self.save_config(global).await?; + + Ok(()) + } + + /// List all available profiles. + /// + /// # Returns + /// A Result containing a vector of profile names, with "default" always first + pub async fn list_profiles(&self) -> Result> { + let mut profiles = Vec::new(); + + // Always include default profile + profiles.push("default".to_string()); + + // Read profile directory and extract profile names + let profiles_dir = directories::chat_profiles_dir(&self.ctx)?; + if profiles_dir.exists() { + let mut read_dir = self.ctx.fs().read_dir(&profiles_dir).await?; + while let Some(entry) = read_dir.next_entry().await? { + let path = entry.path(); + if let (true, Some(name)) = (path.is_dir(), path.file_name()) { + if name != "default" { + profiles.push(name.to_string_lossy().to_string()); + } + } + } + } + + // Sort non-default profiles alphabetically + if profiles.len() > 1 { + profiles[1..].sort(); + } + + Ok(profiles) + } + + /// List all available profiles using blocking operations. + /// + /// Similar to list_profiles but uses synchronous filesystem operations. + /// + /// # Returns + /// A Result containing a vector of profile names, with "default" always first + pub fn list_profiles_blocking(&self) -> Result> { + let mut profiles = Vec::new(); + + // Always include default profile + profiles.push("default".to_string()); + + // Read profile directory and extract profile names + let profiles_dir = directories::chat_profiles_dir(&self.ctx)?; + if profiles_dir.exists() { + for entry in std::fs::read_dir(profiles_dir)? { + let entry = entry?; + let path = entry.path(); + if let (true, Some(name)) = (path.is_dir(), path.file_name()) { + if name != "default" { + profiles.push(name.to_string_lossy().to_string()); + } + } + } + } + + // Sort non-default profiles alphabetically + if profiles.len() > 1 { + profiles[1..].sort(); + } + + Ok(profiles) + } + + /// Clear all paths from the context configuration. + /// + /// # Arguments + /// * `global` - If true, clear global configuration; otherwise, clear current profile + /// configuration + /// + /// # Returns + /// A Result indicating success or an error + pub async fn clear(&mut self, global: bool) -> Result<()> { + // Clear the appropriate config + if global { + self.global_config.paths.clear(); + } else { + self.profile_config.paths.clear(); + } + + // Save the updated configuration + self.save_config(global).await?; + + Ok(()) + } + + /// Create a new profile. + /// + /// # Arguments + /// * `name` - Name of the profile to create + /// + /// # Returns + /// A Result indicating success or an error + pub async fn create_profile(&self, name: &str) -> Result<()> { + validate_profile_name(name)?; + + // Check if profile already exists + let profile_path = profile_context_path(&self.ctx, name)?; + if profile_path.exists() { + return Err(eyre!("Profile '{}' already exists", name)); + } + + // Create empty profile configuration + let config = ContextConfig::default(); + let contents = serde_json::to_string_pretty(&config) + .map_err(|e| eyre!("Failed to serialize profile configuration: {}", e))?; + + // Create the file + if let Some(parent) = profile_path.parent() { + self.ctx.fs().create_dir_all(parent).await?; + } + self.ctx.fs().write(&profile_path, contents).await?; + + Ok(()) + } + + /// Delete a profile. + /// + /// # Arguments + /// * `name` - Name of the profile to delete + /// + /// # Returns + /// A Result indicating success or an error + pub async fn delete_profile(&self, name: &str) -> Result<()> { + if name == "default" { + return Err(eyre!("Cannot delete the default profile")); + } else if name == self.current_profile { + return Err(eyre!( + "Cannot delete the active profile. Switch to another profile first" + )); + } + + let profile_path = profile_dir_path(&self.ctx, name)?; + if !profile_path.exists() { + return Err(eyre!("Profile '{}' does not exist", name)); + } + + self.ctx.fs().remove_dir_all(&profile_path).await?; + + Ok(()) + } + + /// Rename a profile. + /// + /// # Arguments + /// * `old_name` - Current name of the profile + /// * `new_name` - New name for the profile + /// + /// # Returns + /// A Result indicating success or an error + pub async fn rename_profile(&mut self, old_name: &str, new_name: &str) -> Result<()> { + // Validate profile names + if old_name == "default" { + return Err(eyre!("Cannot rename the default profile")); + } + if new_name == "default" { + return Err(eyre!("Cannot rename to 'default' as it's a reserved profile name")); + } + + validate_profile_name(new_name)?; + + let old_profile_path = profile_dir_path(&self.ctx, old_name)?; + if !old_profile_path.exists() { + return Err(eyre!("Profile '{}' not found", old_name)); + } + + let new_profile_path = profile_dir_path(&self.ctx, new_name)?; + if new_profile_path.exists() { + return Err(eyre!("Profile '{}' already exists", new_name)); + } + + self.ctx.fs().rename(&old_profile_path, &new_profile_path).await?; + + // If the current profile is being renamed, update the current_profile field + if self.current_profile == old_name { + self.current_profile = new_name.to_string(); + self.profile_config = load_profile_config(&self.ctx, new_name).await?; + } + + Ok(()) + } + + /// Switch to a different profile. + /// + /// # Arguments + /// * `name` - Name of the profile to switch to + /// + /// # Returns + /// A Result indicating success or an error + pub async fn switch_profile(&mut self, name: &str) -> Result<()> { + validate_profile_name(name)?; + self.hook_executor.profile_cache.clear(); + + // Special handling for default profile - it always exists + if name == "default" { + // Load the default profile configuration + let profile_config = load_profile_config(&self.ctx, name).await?; + + // Update the current profile + self.current_profile = name.to_string(); + self.profile_config = profile_config; + + return Ok(()); + } + + // Check if profile exists + let profile_path = profile_context_path(&self.ctx, name)?; + if !profile_path.exists() { + return Err(eyre!("Profile '{}' does not exist. Use 'create' to create it", name)); + } + + // Update the current profile + self.current_profile = name.to_string(); + self.profile_config = load_profile_config(&self.ctx, name).await?; + + Ok(()) + } + + /// Get all context files (global + profile-specific). + /// + /// This method: + /// 1. Processes all paths in the global and profile configurations + /// 2. Expands glob patterns to include matching files + /// 3. Reads the content of each file + /// 4. Returns a vector of (filename, content) pairs + /// + /// # Arguments + /// * `force` - If true, include paths that don't exist yet + /// + /// # Returns + /// A Result containing a vector of (filename, content) pairs or an error + pub async fn get_context_files(&self, force: bool) -> Result> { + let mut context_files = Vec::new(); + + self.collect_context_files(&self.global_config.paths, &mut context_files, force) + .await?; + self.collect_context_files(&self.profile_config.paths, &mut context_files, force) + .await?; + + context_files.sort_by(|a, b| a.0.cmp(&b.0)); + context_files.dedup_by(|a, b| a.0 == b.0); + + Ok(context_files) + } + + pub async fn get_context_files_by_path(&self, force: bool, path: &str) -> Result> { + let mut context_files = Vec::new(); + process_path(&self.ctx, path, &mut context_files, force, true).await?; + Ok(context_files) + } + + /// Get all context files from the global configuration. + pub async fn get_global_context_files(&self, force: bool) -> Result> { + let mut context_files = Vec::new(); + + self.collect_context_files(&self.global_config.paths, &mut context_files, force) + .await?; + + Ok(context_files) + } + + /// Get all context files from the current profile configuration. + pub async fn get_current_profile_context_files(&self, force: bool) -> Result> { + let mut context_files = Vec::new(); + + self.collect_context_files(&self.profile_config.paths, &mut context_files, force) + .await?; + + Ok(context_files) + } + + async fn collect_context_files( + &self, + paths: &[String], + context_files: &mut Vec<(String, String)>, + force: bool, + ) -> Result<()> { + for path in paths { + // Use is_validation=false to handle non-matching globs gracefully + process_path(&self.ctx, path, context_files, force, false).await?; + } + Ok(()) + } + + fn get_config_mut(&mut self, global: bool) -> &mut ContextConfig { + if global { + &mut self.global_config + } else { + &mut self.profile_config + } + } + + /// Add hooks to the context config. If another hook with the same name already exists, throw an + /// error. + /// + /// # Arguments + /// * `hook` - name of the hook to delete + /// * `global` - If true, the add to the global config. If false, add to the current profile + /// config. + /// * `conversation_start` - If true, add the hook to conversation_start. Otherwise, it will be + /// added to per_prompt. + pub async fn add_hook(&mut self, name: String, hook: Hook, global: bool) -> Result<()> { + let config = self.get_config_mut(global); + + if config.hooks.contains_key(&name) { + return Err(eyre!("name already exists.")); + } + + config.hooks.insert(name, hook); + self.save_config(global).await + } + + /// Delete hook(s) by name + /// # Arguments + /// * `name` - name of the hook to delete + /// * `global` - If true, the delete from the global config. If false, delete from the current + /// profile config + pub async fn remove_hook(&mut self, name: &str, global: bool) -> Result<()> { + let config = self.get_config_mut(global); + + if !config.hooks.contains_key(name) { + return Err(eyre!("does not exist.")); + } + + config.hooks.remove(name); + + self.save_config(global).await + } + + /// Sets the "disabled" field on any [`Hook`] with the given name + /// # Arguments + /// * `disable` - Set "disabled" field to this value + pub async fn set_hook_disabled(&mut self, name: &str, global: bool, disable: bool) -> Result<()> { + let config = self.get_config_mut(global); + + if !config.hooks.contains_key(name) { + return Err(eyre!("does not exist.")); + } + + if let Some(hook) = config.hooks.get_mut(name) { + hook.disabled = disable; + } + + self.save_config(global).await + } + + /// Sets the "disabled" field on all [`Hook`]s + /// # Arguments + /// * `disable` - Set all "disabled" fields to this value + pub async fn set_all_hooks_disabled(&mut self, global: bool, disable: bool) -> Result<()> { + let config = self.get_config_mut(global); + + config.hooks.iter_mut().for_each(|(_, h)| h.disabled = disable); + + self.save_config(global).await + } + + /// Run all the currently enabled hooks from both the global and profile contexts. + /// Skipped hooks (disabled) will not appear in the output. + /// # Arguments + /// * `updates` - output stream to write hook run status to if Some, else do nothing if None + /// # Returns + /// A vector containing pairs of a [`Hook`] definition and its execution output + pub async fn run_hooks(&mut self, updates: Option<&mut impl Write>) -> Vec<(Hook, String)> { + let mut hooks: Vec<&Hook> = Vec::new(); + + // Set internal hook states + let configs = [ + (&mut self.global_config.hooks, true), + (&mut self.profile_config.hooks, false), + ]; + + for (hook_list, is_global) in configs { + hooks.extend(hook_list.iter_mut().map(|(name, h)| { + h.name = name.to_string(); + h.is_global = is_global; + &*h + })); + } + + self.hook_executor.run_hooks(hooks, updates).await + } +} + +fn profile_dir_path(ctx: &Context, profile_name: &str) -> Result { + Ok(directories::chat_profiles_dir(ctx)?.join(profile_name)) +} + +/// Path to the context config file for `profile_name`. +pub fn profile_context_path(ctx: &Context, profile_name: &str) -> Result { + Ok(directories::chat_profiles_dir(ctx)? + .join(profile_name) + .join("context.json")) +} + +/// Load the global context configuration. +/// +/// If the global configuration file doesn't exist, returns a default configuration. +async fn load_global_config(ctx: &Context) -> Result { + let global_path = directories::chat_global_context_path(&ctx)?; + debug!(?global_path, "loading profile config"); + if ctx.fs().exists(&global_path) { + let contents = ctx.fs().read_to_string(&global_path).await?; + let config: ContextConfig = + serde_json::from_str(&contents).map_err(|e| eyre!("Failed to parse global configuration: {}", e))?; + Ok(config) + } else { + // Return default global configuration with predefined paths + Ok(ContextConfig { + paths: vec![ + ".amazonq/rules/**/*.md".to_string(), + "README.md".to_string(), + AMAZONQ_FILENAME.to_string(), + ], + hooks: HashMap::new(), + }) + } +} + +/// Load a profile's context configuration. +/// +/// If the profile configuration file doesn't exist, creates a default configuration. +async fn load_profile_config(ctx: &Context, profile_name: &str) -> Result { + let profile_path = profile_context_path(ctx, profile_name)?; + debug!(?profile_path, "loading profile config"); + if ctx.fs().exists(&profile_path) { + let contents = ctx.fs().read_to_string(&profile_path).await?; + let config: ContextConfig = + serde_json::from_str(&contents).map_err(|e| eyre!("Failed to parse profile configuration: {}", e))?; + Ok(config) + } else { + // Return empty configuration for new profiles + Ok(ContextConfig::default()) + } +} + +/// Process a path, handling glob patterns and file types. +/// +/// This method: +/// 1. Expands the path (handling ~ for home directory) +/// 2. If the path contains glob patterns, expands them +/// 3. For each resulting path, adds the file to the context collection +/// 4. Handles directories by including all files in the directory (non-recursive) +/// 5. With force=true, includes paths that don't exist yet +/// +/// # Arguments +/// * `path` - The path to process +/// * `context_files` - The collection to add files to +/// * `force` - If true, include paths that don't exist yet +/// * `is_validation` - If true, error when glob patterns don't match; if false, silently skip +/// +/// # Returns +/// A Result indicating success or an error +async fn process_path( + ctx: &Context, + path: &str, + context_files: &mut Vec<(String, String)>, + force: bool, + is_validation: bool, +) -> Result<()> { + // Expand ~ to home directory + let expanded_path = if path.starts_with('~') { + if let Some(home_dir) = ctx.env().home() { + home_dir.join(&path[2..]).to_string_lossy().to_string() + } else { + return Err(eyre!("Could not determine home directory")); + } + } else { + path.to_string() + }; + + // Handle absolute, relative paths, and glob patterns + let full_path = if expanded_path.starts_with('/') { + expanded_path + } else { + ctx.env() + .current_dir()? + .join(&expanded_path) + .to_string_lossy() + .to_string() + }; + + // Required in chroot testing scenarios so that we can use `Path::exists`. + let full_path = ctx.fs().chroot_path_str(full_path); + + // Check if the path contains glob patterns + if full_path.contains('*') || full_path.contains('?') || full_path.contains('[') { + // Expand glob pattern + match glob(&full_path) { + Ok(entries) => { + let mut found_any = false; + + for entry in entries { + match entry { + Ok(path) => { + if path.is_file() { + add_file_to_context(ctx, &path, context_files).await?; + found_any = true; + } + }, + Err(e) => return Err(eyre!("Glob error: {}", e)), + } + } + + if !found_any && !force && is_validation { + // When validating paths (e.g., for /context add), error if no files match + return Err(eyre!("No files found matching glob pattern '{}'", full_path)); + } + // When just showing expanded files (e.g., for /context show --expand), + // silently skip non-matching patterns (don't add anything to context_files) + }, + Err(e) => return Err(eyre!("Invalid glob pattern '{}': {}", full_path, e)), + } + } else { + // Regular path + let path = Path::new(&full_path); + if path.exists() { + if path.is_file() { + add_file_to_context(ctx, path, context_files).await?; + } else if path.is_dir() { + // For directories, add all files in the directory (non-recursive) + let mut read_dir = ctx.fs().read_dir(path).await?; + while let Some(entry) = read_dir.next_entry().await? { + let path = entry.path(); + if path.is_file() { + add_file_to_context(ctx, &path, context_files).await?; + } + } + } + } else if !force && is_validation { + // When validating paths (e.g., for /context add), error if the path doesn't exist + return Err(eyre!("Path '{}' does not exist", full_path)); + } else if force { + // When using --force, we'll add the path even though it doesn't exist + // This allows users to add paths that will exist in the future + context_files.push((full_path.clone(), format!("(Path '{}' does not exist yet)", full_path))); + } + // When just showing expanded files (e.g., for /context show --expand), + // silently skip non-existent paths if is_validation is false + } + + Ok(()) +} + +/// Add a file to the context collection. +/// +/// This method: +/// 1. Reads the content of the file +/// 2. Adds the (filename, content) pair to the context collection +/// +/// # Arguments +/// * `path` - The path to the file +/// * `context_files` - The collection to add the file to +/// +/// # Returns +/// A Result indicating success or an error +async fn add_file_to_context(ctx: &Context, path: &Path, context_files: &mut Vec<(String, String)>) -> Result<()> { + let filename = path.to_string_lossy().to_string(); + let content = ctx.fs().read_to_string(path).await?; + context_files.push((filename, content)); + Ok(()) +} + +/// Validate a profile name. +/// +/// Profile names can only contain alphanumeric characters, hyphens, and underscores. +/// +/// # Arguments +/// * `name` - Name to validate +/// +/// # Returns +/// A Result indicating if the name is valid +fn validate_profile_name(name: &str) -> Result<()> { + // Check if name is empty + if name.is_empty() { + return Err(eyre!("Profile name cannot be empty")); + } + + // Check if name contains only allowed characters and starts with an alphanumeric character + let re = Regex::new(r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$").unwrap(); + if !re.is_match(name) { + return Err(eyre!( + "Profile name must start with an alphanumeric character and can only contain alphanumeric characters, hyphens, and underscores" + )); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use std::io::Stdout; + + use super::super::hooks::HookTrigger; + use super::*; + + // Helper function to create a test ContextManager with Context + pub async fn create_test_context_manager() -> Result { + let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); + let manager = ContextManager::new(ctx).await?; + Ok(manager) + } + + #[tokio::test] + async fn test_validate_profile_name() { + // Test valid names + assert!(validate_profile_name("valid").is_ok()); + assert!(validate_profile_name("valid-name").is_ok()); + assert!(validate_profile_name("valid_name").is_ok()); + assert!(validate_profile_name("valid123").is_ok()); + assert!(validate_profile_name("1valid").is_ok()); + assert!(validate_profile_name("9test").is_ok()); + + // Test invalid names + assert!(validate_profile_name("").is_err()); + assert!(validate_profile_name("invalid/name").is_err()); + assert!(validate_profile_name("invalid.name").is_err()); + assert!(validate_profile_name("invalid name").is_err()); + assert!(validate_profile_name("_invalid").is_err()); + assert!(validate_profile_name("-invalid").is_err()); + } + + #[tokio::test] + async fn test_profile_ops() -> Result<()> { + let mut manager = create_test_context_manager().await?; + let ctx = Arc::clone(&manager.ctx); + + assert_eq!(manager.current_profile, "default"); + + // Create ops + manager.create_profile("test_profile").await?; + assert!(profile_context_path(&ctx, "test_profile")?.exists()); + assert!(manager.create_profile("test_profile").await.is_err()); + manager.create_profile("alt").await?; + + // Listing + let profiles = manager.list_profiles().await?; + assert!(profiles.contains(&"default".to_string())); + assert!(profiles.contains(&"test_profile".to_string())); + assert!(profiles.contains(&"alt".to_string())); + + // Switching + manager.switch_profile("test_profile").await?; + assert!(manager.switch_profile("notexists").await.is_err()); + + // Renaming + manager.rename_profile("alt", "renamed").await?; + assert!(!profile_context_path(&ctx, "alt")?.exists()); + assert!(profile_context_path(&ctx, "renamed")?.exists()); + + // Delete ops + assert!(manager.delete_profile("test_profile").await.is_err()); + manager.switch_profile("default").await?; + manager.delete_profile("test_profile").await?; + assert!(!profile_context_path(&ctx, "test_profile")?.exists()); + assert!(manager.delete_profile("test_profile").await.is_err()); + assert!(manager.delete_profile("default").await.is_err()); + + Ok(()) + } + + #[tokio::test] + async fn test_path_ops() -> Result<()> { + let mut manager = create_test_context_manager().await?; + let ctx = Arc::clone(&manager.ctx); + + // Create some test files for matching. + ctx.fs().create_dir_all("test").await?; + ctx.fs().write("test/p1.md", "p1").await?; + ctx.fs().write("test/p2.md", "p2").await?; + + assert!( + manager.get_context_files(false).await?.is_empty(), + "no files should be returned for an empty profile when force is false" + ); + assert_eq!( + manager.get_context_files(true).await?.len(), + 2, + "default non-glob global files should be included when force is true" + ); + + manager.add_paths(vec!["test/*.md".to_string()], false, false).await?; + let files = manager.get_context_files(false).await?; + assert!(files[0].0.ends_with("p1.md")); + assert_eq!(files[0].1, "p1"); + assert!(files[1].0.ends_with("p2.md")); + assert_eq!(files[1].1, "p2"); + + assert!( + manager + .add_paths(vec!["test/*.txt".to_string()], false, false) + .await + .is_err(), + "adding a glob with no matching and without force should fail" + ); + + Ok(()) + } + + #[tokio::test] + async fn test_add_hook() -> Result<()> { + let mut manager = create_test_context_manager().await?; + let hook = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); + + // Test adding hook to profile config + manager.add_hook("test_hook".to_string(), hook.clone(), false).await?; + assert!(manager.profile_config.hooks.contains_key("test_hook")); + + // Test adding hook to global config + manager.add_hook("global_hook".to_string(), hook.clone(), true).await?; + assert!(manager.global_config.hooks.contains_key("global_hook")); + + // Test adding duplicate hook name + assert!(manager.add_hook("test_hook".to_string(), hook, false).await.is_err()); + + Ok(()) + } + + #[tokio::test] + async fn test_remove_hook() -> Result<()> { + let mut manager = create_test_context_manager().await?; + let hook = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); + + manager.add_hook("test_hook".to_string(), hook, false).await?; + + // Test removing existing hook + manager.remove_hook("test_hook", false).await?; + assert!(!manager.profile_config.hooks.contains_key("test_hook")); + + // Test removing non-existent hook + assert!(manager.remove_hook("test_hook", false).await.is_err()); + + Ok(()) + } + + #[tokio::test] + async fn test_set_hook_disabled() -> Result<()> { + let mut manager = create_test_context_manager().await?; + let hook = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); + + manager.add_hook("test_hook".to_string(), hook, false).await?; + + // Test disabling hook + manager.set_hook_disabled("test_hook", false, true).await?; + assert!(manager.profile_config.hooks.get("test_hook").unwrap().disabled); + + // Test enabling hook + manager.set_hook_disabled("test_hook", false, false).await?; + assert!(!manager.profile_config.hooks.get("test_hook").unwrap().disabled); + + // Test with non-existent hook + assert!(manager.set_hook_disabled("nonexistent", false, true).await.is_err()); + + Ok(()) + } + + #[tokio::test] + async fn test_set_all_hooks_disabled() -> Result<()> { + let mut manager = create_test_context_manager().await?; + let hook1 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); + let hook2 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); + + manager.add_hook("hook1".to_string(), hook1, false).await?; + manager.add_hook("hook2".to_string(), hook2, false).await?; + + // Test disabling all hooks + manager.set_all_hooks_disabled(false, true).await?; + assert!(manager.profile_config.hooks.values().all(|h| h.disabled)); + + // Test enabling all hooks + manager.set_all_hooks_disabled(false, false).await?; + assert!(manager.profile_config.hooks.values().all(|h| !h.disabled)); + + Ok(()) + } + + #[tokio::test] + async fn test_run_hooks() -> Result<()> { + let mut manager = create_test_context_manager().await?; + let hook1 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); + let hook2 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); + + manager.add_hook("hook1".to_string(), hook1, false).await?; + manager.add_hook("hook2".to_string(), hook2, false).await?; + + // Run the hooks + let results = manager.run_hooks(None::<&mut Stdout>).await; + assert_eq!(results.len(), 2); // Should include both hooks + + Ok(()) + } + + #[tokio::test] + async fn test_hooks_across_profiles() -> Result<()> { + let mut manager = create_test_context_manager().await?; + let hook1 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); + let hook2 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); + + manager.add_hook("profile_hook".to_string(), hook1, false).await?; + manager.add_hook("global_hook".to_string(), hook2, true).await?; + + let results = manager.run_hooks(None::<&mut Stdout>).await; + assert_eq!(results.len(), 2); // Should include both hooks + + // Create and switch to a new profile + manager.create_profile("test_profile").await?; + manager.switch_profile("test_profile").await?; + + let results = manager.run_hooks(None::<&mut Stdout>).await; + assert_eq!(results.len(), 1); // Should include global hook + assert_eq!(results[0].0.name, "global_hook"); + + Ok(()) + } +} diff --git a/crates/kiro-cli/src/cli/chat/conversation_state.rs b/crates/kiro-cli/src/cli/chat/conversation_state.rs new file mode 100644 index 0000000000..0a6b151e05 --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/conversation_state.rs @@ -0,0 +1,1049 @@ +use std::collections::{ + HashMap, + VecDeque, +}; +use std::sync::Arc; + +use tracing::{ + debug, + error, + warn, +}; + +use super::consts::{ + MAX_CHARS, + MAX_CONVERSATION_STATE_HISTORY_LEN, +}; +use super::context::ContextManager; +use super::hooks::{ + Hook, + HookTrigger, +}; +use super::message::{ + AssistantMessage, + ToolUseResult, + ToolUseResultBlock, + UserMessage, + UserMessageContent, + build_env_state, +}; +use super::token_counter::{ + CharCount, + CharCounter, +}; +use super::tools::{ + InputSchema, + QueuedTool, + ToolOrigin, + ToolSpec, + serde_value_to_document, +}; +use crate::cli::chat::shared_writer::SharedWriter; +use crate::fig_api_client::model::{ + AssistantResponseMessage, + ChatMessage, + ConversationState as FigConversationState, + Tool, + ToolInputSchema, + ToolResult, + ToolResultContentBlock, + ToolSpecification, + ToolUse, + UserInputMessage, + UserInputMessageContext, +}; +use crate::fig_os_shim::Context; +use crate::mcp_client::Prompt; + +const CONTEXT_ENTRY_START_HEADER: &str = "--- CONTEXT ENTRY BEGIN ---\n"; +const CONTEXT_ENTRY_END_HEADER: &str = "--- CONTEXT ENTRY END ---\n\n"; + +/// Tracks state related to an ongoing conversation. +#[derive(Debug, Clone)] +pub struct ConversationState { + /// Randomly generated on creation. + conversation_id: String, + /// The next user message to be sent as part of the conversation. Required to be [Some] before + /// calling [Self::as_sendable_conversation_state]. + next_message: Option, + history: VecDeque<(UserMessage, AssistantMessage)>, + /// The range in the history sendable to the backend (start inclusive, end exclusive). + valid_history_range: (usize, usize), + /// Similar to history in that stores user and assistant responses, except that it is not used + /// in message requests. Instead, the responses are expected to be in human-readable format, + /// e.g user messages prefixed with '> '. Should also be used to store errors posted in the + /// chat. + pub transcript: VecDeque, + pub tools: HashMap>, + /// Context manager for handling sticky context files + pub context_manager: Option, + /// Cached value representing the length of the user context message. + context_message_length: Option, + /// Stores the latest conversation summary created by /compact + latest_summary: Option, + updates: Option, +} + +impl ConversationState { + pub async fn new( + ctx: Arc, + conversation_id: &str, + tool_config: HashMap, + profile: Option, + updates: Option, + ) -> Self { + // Initialize context manager + let context_manager = match ContextManager::new(ctx).await { + Ok(mut manager) => { + // Switch to specified profile if provided + if let Some(profile_name) = profile { + if let Err(e) = manager.switch_profile(&profile_name).await { + warn!("Failed to switch to profile {}: {}", profile_name, e); + } + } + Some(manager) + }, + Err(e) => { + warn!("Failed to initialize context manager: {}", e); + None + }, + }; + + Self { + conversation_id: conversation_id.to_string(), + next_message: None, + history: VecDeque::new(), + valid_history_range: Default::default(), + transcript: VecDeque::with_capacity(MAX_CONVERSATION_STATE_HISTORY_LEN), + tools: tool_config + .into_values() + .fold(HashMap::>::new(), |mut acc, v| { + let tool = Tool::ToolSpecification(ToolSpecification { + name: v.name, + description: v.description, + input_schema: v.input_schema.into(), + }); + acc.entry(v.tool_origin) + .and_modify(|tools| tools.push(tool.clone())) + .or_insert(vec![tool]); + acc + }), + context_manager, + context_message_length: None, + latest_summary: None, + updates, + } + } + + pub fn history(&self) -> &VecDeque<(UserMessage, AssistantMessage)> { + &self.history + } + + /// Clears the conversation history and optionally the summary. + pub fn clear(&mut self, preserve_summary: bool) { + self.next_message = None; + self.history.clear(); + if !preserve_summary { + self.latest_summary = None; + } + } + + /// Appends a collection prompts into history and returns the last message in the collection. + /// It asserts that the collection ends with a prompt that assumes the role of user. + pub fn append_prompts(&mut self, mut prompts: VecDeque) -> Option { + debug_assert!(self.next_message.is_none(), "next_message should not exist"); + debug_assert!(prompts.back().is_some_and(|p| p.role == crate::mcp_client::Role::User)); + let last_msg = prompts.pop_back()?; + let (mut candidate_user, mut candidate_asst) = (None::, None::); + while let Some(prompt) = prompts.pop_front() { + let Prompt { role, content } = prompt; + match role { + crate::mcp_client::Role::User => { + let user_msg = UserMessage::new_prompt(content.to_string()); + candidate_user.replace(user_msg); + }, + crate::mcp_client::Role::Assistant => { + let assistant_msg = AssistantMessage::new_response(None, content.into()); + candidate_asst.replace(assistant_msg); + }, + } + if candidate_asst.is_some() && candidate_user.is_some() { + let asst = candidate_asst.take().unwrap(); + let user = candidate_user.take().unwrap(); + self.append_assistant_transcript(&asst); + self.history.push_back((user, asst)); + } + } + Some(last_msg.content.to_string()) + } + + pub fn next_user_message(&self) -> Option<&UserMessage> { + self.next_message.as_ref() + } + + pub fn reset_next_user_message(&mut self) { + self.next_message = None; + } + + pub async fn set_next_user_message(&mut self, input: String) { + debug_assert!(self.next_message.is_none(), "next_message should not exist"); + if let Some(next_message) = self.next_message.as_ref() { + warn!(?next_message, "next_message should not exist"); + } + + let input = if input.is_empty() { + warn!("input must not be empty when adding new messages"); + "Empty prompt".to_string() + } else { + input + }; + + let msg = UserMessage::new_prompt(input); + self.next_message = Some(msg); + } + + /// Sets the response message according to the currently set [Self::next_message]. + pub fn push_assistant_message(&mut self, message: AssistantMessage) { + debug_assert!(self.next_message.is_some(), "next_message should exist"); + let next_user_message = self.next_message.take().expect("next user message should exist"); + + self.append_assistant_transcript(&message); + self.history.push_back((next_user_message, message)); + } + + /// Returns the conversation id. + pub fn conversation_id(&self) -> &str { + self.conversation_id.as_ref() + } + + /// Returns the message id associated with the last assistant message, if present. + /// + /// This is equivalent to `utterance_id` in the Q API. + pub fn message_id(&self) -> Option<&str> { + self.history.back().and_then(|(_, msg)| msg.message_id()) + } + + /// Updates the history so that, when non-empty, the following invariants are in place: + /// 1. The history length is `<= MAX_CONVERSATION_STATE_HISTORY_LEN`. Oldest messages are + /// dropped. + /// 2. The first message is from the user, and does not contain tool results. Oldest messages + /// are dropped. + /// 3. If the last message from the assistant contains tool results, and a next user message is + /// set without tool results, then the user message will have "cancelled" tool results. + pub fn enforce_conversation_invariants(&mut self) { + // First set the valid range as the entire history - this will be truncated as necessary + // later below. + self.valid_history_range = (0, self.history.len()); + + // Trim the conversation history by finding the second oldest message from the user without + // tool results - this will be the new oldest message in the history. + // + // Note that we reserve extra slots for [ConversationState::context_messages]. + if (self.history.len() * 2) > MAX_CONVERSATION_STATE_HISTORY_LEN - 6 { + match self + .history + .iter() + .enumerate() + .skip(1) + .find(|(_, (m, _))| -> bool { !m.has_tool_use_results() }) + .map(|v| v.0) + { + Some(i) => { + debug!("removing the first {i} user/assistant response pairs in the history"); + self.valid_history_range.0 = i; + }, + None => { + debug!("no valid starting user message found in the history, clearing"); + self.valid_history_range = (0, 0); + // Edge case: if the next message contains tool results, then we have to just + // abandon them. + if self.next_message.as_ref().is_some_and(|m| m.has_tool_use_results()) { + debug!("abandoning tool results"); + self.next_message = Some(UserMessage::new_prompt( + "The conversation history has overflowed, clearing state".to_string(), + )); + } + }, + } + } + + // If the last message from the assistant contains tool uses AND next_message is set, we need to + // ensure that next_message contains tool results. + if let (Some((_, AssistantMessage::ToolUse { tool_uses, .. })), Some(user_msg)) = ( + self.history + .range(self.valid_history_range.0..self.valid_history_range.1) + .last(), + &mut self.next_message, + ) { + if !user_msg.has_tool_use_results() { + debug!( + "last assistant message contains tool uses, but next message is set and does not contain tool results. setting tool results as cancelled" + ); + *user_msg = UserMessage::new_cancelled_tool_uses( + user_msg.prompt().map(|p| p.to_string()), + tool_uses.iter().map(|t| t.id.as_str()), + ); + } + } + } + + pub fn add_tool_results(&mut self, tool_results: Vec) { + debug_assert!(self.next_message.is_none()); + self.next_message = Some(UserMessage::new_tool_use_results(tool_results)); + } + + /// Sets the next user message with "cancelled" tool results. + pub fn abandon_tool_use(&mut self, tools_to_be_abandoned: Vec, deny_input: String) { + self.next_message = Some(UserMessage::new_cancelled_tool_uses( + Some(deny_input), + tools_to_be_abandoned.iter().map(|t| t.id.as_str()), + )); + } + + /// Returns a [FigConversationState] capable of being sent by [fig_api_client::StreamingClient]. + /// + /// Params: + /// - `run_hooks` - whether hooks should be executed and included as context + pub async fn as_sendable_conversation_state(&mut self, run_hooks: bool) -> FigConversationState { + debug_assert!(self.next_message.is_some()); + self.enforce_conversation_invariants(); + self.history.drain(self.valid_history_range.1..); + self.history.drain(..self.valid_history_range.0); + + self.backend_conversation_state(run_hooks, false) + .await + .into_fig_conversation_state() + .expect("unable to construct conversation state") + } + + /// Returns a conversation state representation which reflects the exact conversation to send + /// back to the model. + pub async fn backend_conversation_state(&mut self, run_hooks: bool, quiet: bool) -> BackendConversationState<'_> { + self.enforce_conversation_invariants(); + + // Run hooks and add to conversation start and next user message. + let mut conversation_start_context = None; + if let (true, Some(cm)) = (run_hooks, self.context_manager.as_mut()) { + let mut null_writer = SharedWriter::null(); + let updates = if quiet { + None + } else { + Some(self.updates.as_mut().unwrap_or(&mut null_writer)) + }; + + let hook_results = cm.run_hooks(updates).await; + conversation_start_context = Some(format_hook_context(hook_results.iter(), HookTrigger::ConversationStart)); + + // add per prompt content to next_user_message if available + if let Some(next_message) = self.next_message.as_mut() { + next_message.additional_context = format_hook_context(hook_results.iter(), HookTrigger::PerPrompt); + } + } + + let context_messages = self.context_messages(conversation_start_context).await; + + BackendConversationState { + conversation_id: self.conversation_id.as_str(), + next_user_message: self.next_message.as_ref(), + history: self + .history + .range(self.valid_history_range.0..self.valid_history_range.1), + context_messages, + tools: &self.tools, + } + } + + /// Returns a [FigConversationState] capable of replacing the history of the current + /// conversation with a summary generated by the model. + pub async fn create_summary_request(&mut self, custom_prompt: Option>) -> FigConversationState { + let summary_content = match custom_prompt { + Some(custom_prompt) => { + // Make the custom instructions much more prominent and directive + format!( + "[SYSTEM NOTE: This is an automated summarization request, not from the user]\n\n\ + FORMAT REQUIREMENTS: Create a structured, concise summary in bullet-point format. DO NOT respond conversationally. DO NOT address the user directly.\n\n\ + IMPORTANT CUSTOM INSTRUCTION: {}\n\n\ + Your task is to create a structured summary document containing:\n\ + 1) A bullet-point list of key topics/questions covered\n\ + 2) Bullet points for all significant tools executed and their results\n\ + 3) Bullet points for any code or technical information shared\n\ + 4) A section of key insights gained\n\n\ + FORMAT THE SUMMARY IN THIRD PERSON, NOT AS A DIRECT RESPONSE. Example format:\n\n\ + ## CONVERSATION SUMMARY\n\ + * Topic 1: Key information\n\ + * Topic 2: Key information\n\n\ + ## TOOLS EXECUTED\n\ + * Tool X: Result Y\n\n\ + Remember this is a DOCUMENT not a chat response. The custom instruction above modifies what to prioritize.\n\ + FILTER OUT CHAT CONVENTIONS (greetings, offers to help, etc).", + custom_prompt.as_ref() + ) + }, + None => { + // Default prompt + "[SYSTEM NOTE: This is an automated summarization request, not from the user]\n\n\ + FORMAT REQUIREMENTS: Create a structured, concise summary in bullet-point format. DO NOT respond conversationally. DO NOT address the user directly.\n\n\ + Your task is to create a structured summary document containing:\n\ + 1) A bullet-point list of key topics/questions covered\n\ + 2) Bullet points for all significant tools executed and their results\n\ + 3) Bullet points for any code or technical information shared\n\ + 4) A section of key insights gained\n\n\ + FORMAT THE SUMMARY IN THIRD PERSON, NOT AS A DIRECT RESPONSE. Example format:\n\n\ + ## CONVERSATION SUMMARY\n\ + * Topic 1: Key information\n\ + * Topic 2: Key information\n\n\ + ## TOOLS EXECUTED\n\ + * Tool X: Result Y\n\n\ + Remember this is a DOCUMENT not a chat response.\n\ + FILTER OUT CHAT CONVENTIONS (greetings, offers to help, etc).".to_string() + }, + }; + + let conv_state = self.backend_conversation_state(false, true).await; + + // Include everything but the last message in the history. + let history_len = conv_state.history.len(); + let history = if history_len < 2 { + vec![] + } else { + flatten_history(conv_state.history.take(history_len.saturating_sub(1))) + }; + + let mut summary_message = UserInputMessage { + content: summary_content, + user_input_message_context: None, + user_intent: None, + }; + + // If the last message contains tool uses, then add cancelled tool results to the summary + // message. + if let Some(ChatMessage::AssistantResponseMessage(AssistantResponseMessage { + tool_uses: Some(tool_uses), + .. + })) = history.last() + { + self.set_cancelled_tool_results(&mut summary_message, tool_uses); + } + + FigConversationState { + conversation_id: Some(self.conversation_id.clone()), + user_input_message: summary_message, + history: Some(history), + } + } + + pub fn replace_history_with_summary(&mut self, summary: String) { + self.history.drain(..(self.history.len().saturating_sub(1))); + self.latest_summary = Some(summary); + // If the last message contains tool results, then we add the results to the content field + // instead. This is required to avoid validation errors. + // TODO: this can break since the max user content size is less than the max tool response + // size! Alternative could be to set the last tool use as part of the context messages. + if let Some((user, _)) = self.history.back_mut() { + if let Some(tool_results) = user.tool_use_results() { + let tool_content: Vec = tool_results + .iter() + .flat_map(|tr| { + tr.content.iter().map(|c| match c { + ToolUseResultBlock::Json(document) => serde_json::to_string(&document) + .map_err(|err| error!(?err, "failed to serialize tool result")) + .unwrap_or_default(), + ToolUseResultBlock::Text(s) => s.clone(), + }) + }) + .collect::<_>(); + let mut tool_content = tool_content.join(" "); + if tool_content.is_empty() { + // To avoid validation errors with empty content, we need to make sure + // something is set. + tool_content.push_str(""); + } + user.content = UserMessageContent::Prompt { prompt: tool_content }; + } + } + } + + pub fn current_profile(&self) -> Option<&str> { + if let Some(cm) = self.context_manager.as_ref() { + Some(cm.current_profile.as_str()) + } else { + None + } + } + + /// Returns pairs of user and assistant messages to include as context in the message history + /// including both summaries and context files if available. + /// + /// TODO: + /// - Either add support for multiple context messages if the context is too large to fit inside + /// a single user message, or handle this case more gracefully. For now, always return 2 + /// messages. + /// - Cache this return for some period of time. + async fn context_messages( + &mut self, + conversation_start_context: Option, + ) -> Option> { + let mut context_content = String::new(); + + if let Some(summary) = &self.latest_summary { + context_content.push_str(CONTEXT_ENTRY_START_HEADER); + context_content.push_str("This summary contains ALL relevant information from our previous conversation including tool uses, results, code analysis, and file operations. YOU MUST reference this information when answering questions and explicitly acknowledge specific details from the summary when they're relevant to the current question.\n\n"); + context_content.push_str("SUMMARY CONTENT:\n"); + context_content.push_str(summary); + context_content.push('\n'); + context_content.push_str(CONTEXT_ENTRY_END_HEADER); + } + + // Add context files if available + if let Some(context_manager) = self.context_manager.as_mut() { + match context_manager.get_context_files(true).await { + Ok(files) => { + if !files.is_empty() { + context_content.push_str(CONTEXT_ENTRY_START_HEADER); + for (filename, content) in files { + context_content.push_str(&format!("[{}]\n{}\n", filename, content)); + } + context_content.push_str(CONTEXT_ENTRY_END_HEADER); + } + }, + Err(e) => { + warn!("Failed to get context files: {}", e); + }, + } + } + + if let Some(context) = conversation_start_context { + context_content.push_str(&context); + } + + if !context_content.is_empty() { + self.context_message_length = Some(context_content.len()); + let user_msg = UserMessage::new_prompt(context_content); + let assistant_msg = AssistantMessage::new_response(None, "I will fully incorporate this information when generating my responses, and explicitly acknowledge relevant parts of the summary when answering questions.".into()); + Some(vec![(user_msg, assistant_msg)]) + } else { + None + } + } + + /// The length of the user message used as context, if any. + pub fn context_message_length(&self) -> Option { + self.context_message_length + } + + /// Calculate the total character count in the conversation + pub async fn calculate_char_count(&mut self) -> CharCount { + self.backend_conversation_state(false, true).await.char_count() + } + + /// Get the current token warning level + pub async fn get_token_warning_level(&mut self) -> TokenWarningLevel { + let total_chars = self.calculate_char_count().await; + + if *total_chars >= MAX_CHARS { + TokenWarningLevel::Critical + } else { + TokenWarningLevel::None + } + } + + pub fn append_user_transcript(&mut self, message: &str) { + self.append_transcript(format!("> {}", message.replace("\n", "> \n"))); + } + + pub fn append_assistant_transcript(&mut self, message: &AssistantMessage) { + let tool_uses = message.tool_uses().map_or("none".to_string(), |tools| { + tools.iter().map(|tool| tool.name.clone()).collect::>().join(",") + }); + self.append_transcript(format!("{}\n[Tool uses: {tool_uses}]", message.content())); + } + + pub fn append_transcript(&mut self, message: String) { + if self.transcript.len() >= MAX_CONVERSATION_STATE_HISTORY_LEN { + self.transcript.pop_front(); + } + self.transcript.push_back(message); + } + + /// Mutates `msg` so that it will contain an appropriate [UserInputMessageContext] that + /// contains "cancelled" tool results for `tool_uses`. + fn set_cancelled_tool_results(&self, msg: &mut UserInputMessage, tool_uses: &[ToolUse]) { + match msg.user_input_message_context.as_mut() { + Some(ctx) => { + if ctx.tool_results.as_ref().is_none_or(|r| r.is_empty()) { + debug!( + "last assistant message contains tool uses, but next message is set and does not contain tool results. setting tool results as cancelled" + ); + ctx.tool_results = Some( + tool_uses + .iter() + .map(|tool_use| ToolResult { + tool_use_id: tool_use.tool_use_id.clone(), + content: vec![ToolResultContentBlock::Text( + "Tool use was cancelled by the user".to_string(), + )], + status: crate::fig_api_client::model::ToolResultStatus::Error, + }) + .collect::>(), + ); + } + }, + None => { + debug!( + "last assistant message contains tool uses, but next message is set and does not contain tool results. setting tool results as cancelled" + ); + let tool_results = tool_uses + .iter() + .map(|tool_use| ToolResult { + tool_use_id: tool_use.tool_use_id.clone(), + content: vec![ToolResultContentBlock::Text( + "Tool use was cancelled by the user".to_string(), + )], + status: crate::fig_api_client::model::ToolResultStatus::Error, + }) + .collect::>(); + let user_input_message_context = UserInputMessageContext { + env_state: Some(build_env_state()), + tool_results: Some(tool_results), + tools: if self.tools.is_empty() { + None + } else { + Some(self.tools.values().flatten().cloned().collect::>()) + }, + ..Default::default() + }; + msg.user_input_message_context = Some(user_input_message_context); + }, + } + } +} + +/// Represents a conversation state that can be converted into a [FigConversationState] (the type +/// used by the API client). Represents borrowed data, and reflects an exact [FigConversationState] +/// that can be generated from [ConversationState] at any point in time. +/// +/// This is intended to provide us ways to accurately assess the exact state that is sent to the +/// model without having to needlessly clone and mutate [ConversationState] in strange ways. +pub type BackendConversationState<'a> = BackendConversationStateImpl< + 'a, + std::collections::vec_deque::Iter<'a, (UserMessage, AssistantMessage)>, + Option>, +>; + +/// See [BackendConversationState] +#[derive(Debug, Clone)] +pub struct BackendConversationStateImpl<'a, T, U> { + pub conversation_id: &'a str, + pub next_user_message: Option<&'a UserMessage>, + pub history: T, + pub context_messages: U, + pub tools: &'a HashMap>, +} + +impl + BackendConversationStateImpl< + '_, + std::collections::vec_deque::Iter<'_, (UserMessage, AssistantMessage)>, + Option>, + > +{ + fn into_fig_conversation_state(self) -> eyre::Result { + let history = flatten_history(self.context_messages.unwrap_or_default().iter().chain(self.history)); + let mut user_input_message: UserInputMessage = self + .next_user_message + .cloned() + .map(UserMessage::into_user_input_message) + .ok_or(eyre::eyre!("next user message is not set"))?; + if let Some(ctx) = user_input_message.user_input_message_context.as_mut() { + ctx.tools = Some(self.tools.values().flatten().cloned().collect::>()); + } + + Ok(FigConversationState { + conversation_id: Some(self.conversation_id.to_string()), + user_input_message, + history: Some(history), + }) + } + + pub fn calculate_conversation_size(&self) -> ConversationSize { + let mut user_chars = 0; + let mut assistant_chars = 0; + let mut context_chars = 0; + + // Count the chars used by the messages in the history. + // this clone is cheap + let history = self.history.clone(); + for (user, assistant) in history { + user_chars += *user.char_count(); + assistant_chars += *assistant.char_count(); + } + + // Add any chars from context messages, if available. + context_chars += self + .context_messages + .as_ref() + .map(|v| { + v.iter().fold(0, |acc, (user, assistant)| { + acc + *user.char_count() + *assistant.char_count() + }) + }) + .unwrap_or_default(); + + ConversationSize { + context_messages: context_chars.into(), + user_messages: user_chars.into(), + assistant_messages: assistant_chars.into(), + } + } +} + +/// Reflects a detailed accounting of the context window utilization for a given conversation. +#[derive(Debug, Clone, Copy)] +pub struct ConversationSize { + pub context_messages: CharCount, + pub user_messages: CharCount, + pub assistant_messages: CharCount, +} + +/// Converts a list of user/assistant message pairs into a flattened list of ChatMessage. +fn flatten_history<'a, T>(history: T) -> Vec +where + T: Iterator, +{ + history.fold(Vec::new(), |mut acc, (user, assistant)| { + acc.push(ChatMessage::UserInputMessage(user.clone().into_history_entry())); + acc.push(ChatMessage::AssistantResponseMessage(assistant.clone().into())); + acc + }) +} + +/// Character count warning levels for conversation size +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TokenWarningLevel { + /// No warning, conversation is within normal limits + None, + /// Critical level - at single warning threshold (600K characters) + Critical, +} + +impl From for ToolInputSchema { + fn from(value: InputSchema) -> Self { + Self { + json: Some(serde_value_to_document(value.0)), + } + } +} + +fn format_hook_context<'a>(hook_results: impl IntoIterator, trigger: HookTrigger) -> String { + let mut context_content = String::new(); + + context_content.push_str(CONTEXT_ENTRY_START_HEADER); + context_content.push_str("This section (like others) contains important information that I want you to use in your responses. I have gathered this context from valuable programmatic script hooks. You must follow any requests and consider all of the information in this section"); + if trigger == HookTrigger::ConversationStart { + context_content.push_str(" for the entire conversation"); + } + context_content.push_str("\n\n"); + + for (hook, output) in hook_results.into_iter().filter(|(h, _)| h.trigger == trigger) { + context_content.push_str(&format!("'{}': {output}\n\n", &hook.name)); + } + context_content.push_str(CONTEXT_ENTRY_END_HEADER); + context_content +} + +#[cfg(test)] +mod tests { + use super::super::context::{ + AMAZONQ_FILENAME, + profile_context_path, + }; + use super::super::message::AssistantToolUse; + use super::*; + use crate::cli::chat::tool_manager::ToolManager; + use crate::fig_api_client::model::{ + AssistantResponseMessage, + ToolResultStatus, + }; + + fn assert_conversation_state_invariants(state: FigConversationState, assertion_iteration: usize) { + if let Some(Some(msg)) = state.history.as_ref().map(|h| h.first()) { + assert!( + matches!(msg, ChatMessage::UserInputMessage(_)), + "{assertion_iteration}: First message in the history must be from the user, instead found: {:?}", + msg + ); + } + if let Some(Some(msg)) = state.history.as_ref().map(|h| h.last()) { + assert!( + matches!(msg, ChatMessage::AssistantResponseMessage(_)), + "{assertion_iteration}: Last message in the history must be from the assistant, instead found: {:?}", + msg + ); + // If the last message from the assistant contains tool uses, then the next user + // message must contain tool results. + match (state.user_input_message.user_input_message_context.as_ref(), msg) { + ( + Some(ctx), + ChatMessage::AssistantResponseMessage(AssistantResponseMessage { + tool_uses: Some(tool_uses), + .. + }), + ) if !tool_uses.is_empty() => { + assert!( + ctx.tool_results.as_ref().is_some_and(|r| !r.is_empty()), + "The user input message must contain tool results when the last assistant message contains tool uses" + ); + }, + _ => {}, + } + } + + if let Some(history) = state.history.as_ref() { + for (i, msg) in history.iter().enumerate() { + // User message checks. + if let ChatMessage::UserInputMessage(user) = msg { + assert!( + user.user_input_message_context + .as_ref() + .is_none_or(|ctx| ctx.tools.is_none()), + "the tool specification should be empty for all user messages in the history" + ); + + // Check that messages with tool results are immediately preceded by an + // assistant message with tool uses. + if user + .user_input_message_context + .as_ref() + .is_some_and(|ctx| ctx.tool_results.as_ref().is_some_and(|r| !r.is_empty())) + { + match history.get(i.checked_sub(1).unwrap_or_else(|| { + panic!( + "{assertion_iteration}: first message in the history should not contain tool results" + ) + })) { + Some(ChatMessage::AssistantResponseMessage(assistant)) => { + assert!(assistant.tool_uses.is_some()); + }, + _ => panic!( + "expected an assistant response message with tool uses at index: {}", + i - 1 + ), + } + } + } + } + } + + let actual_history_len = state.history.unwrap_or_default().len(); + assert!( + actual_history_len <= MAX_CONVERSATION_STATE_HISTORY_LEN, + "history should not extend past the max limit of {}, instead found length {}", + MAX_CONVERSATION_STATE_HISTORY_LEN, + actual_history_len + ); + + let ctx = state + .user_input_message + .user_input_message_context + .as_ref() + .expect("user input message context must exist"); + assert!( + ctx.tools.is_some(), + "Currently, the tool spec must be included in the next user message" + ); + } + + #[tokio::test] + async fn test_conversation_state_history_handling_truncation() { + let mut tool_manager = ToolManager::default(); + let mut conversation_state = ConversationState::new( + Context::new_fake(), + "fake_conv_id", + tool_manager.load_tools().await.unwrap(), + None, + None, + ) + .await; + + // First, build a large conversation history. We need to ensure that the order is always + // User -> Assistant -> User -> Assistant ...and so on. + conversation_state.set_next_user_message("start".to_string()).await; + for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { + let s = conversation_state.as_sendable_conversation_state(true).await; + assert_conversation_state_invariants(s, i); + conversation_state.push_assistant_message(AssistantMessage::new_response(None, i.to_string())); + conversation_state.set_next_user_message(i.to_string()).await; + } + } + + #[tokio::test] + async fn test_conversation_state_history_handling_with_tool_results() { + // Build a long conversation history of tool use results. + let mut tool_manager = ToolManager::default(); + let mut conversation_state = ConversationState::new( + Context::new_fake(), + "fake_conv_id", + tool_manager.load_tools().await.unwrap(), + None, + None, + ) + .await; + conversation_state.set_next_user_message("start".to_string()).await; + for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { + let s = conversation_state.as_sendable_conversation_state(true).await; + assert_conversation_state_invariants(s, i); + + conversation_state.push_assistant_message(AssistantMessage::new_tool_use(None, i.to_string(), vec![ + AssistantToolUse { + id: "tool_id".to_string(), + name: "tool name".to_string(), + args: serde_json::Value::Null, + }, + ])); + conversation_state.add_tool_results(vec![ToolUseResult { + tool_use_id: "tool_id".to_string(), + content: vec![], + status: ToolResultStatus::Success, + }]); + } + + // Build a long conversation history of user messages mixed in with tool results. + let mut conversation_state = ConversationState::new( + Context::new_fake(), + "fake_conv_id", + tool_manager.load_tools().await.unwrap(), + None, + None, + ) + .await; + conversation_state.set_next_user_message("start".to_string()).await; + for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { + let s = conversation_state.as_sendable_conversation_state(true).await; + assert_conversation_state_invariants(s, i); + if i % 3 == 0 { + conversation_state.push_assistant_message(AssistantMessage::new_tool_use(None, i.to_string(), vec![ + AssistantToolUse { + id: "tool_id".to_string(), + name: "tool name".to_string(), + args: serde_json::Value::Null, + }, + ])); + conversation_state.add_tool_results(vec![ToolUseResult { + tool_use_id: "tool_id".to_string(), + content: vec![], + status: ToolResultStatus::Success, + }]); + } else { + conversation_state.push_assistant_message(AssistantMessage::new_response(None, i.to_string())); + conversation_state.set_next_user_message(i.to_string()).await; + } + } + } + + #[tokio::test] + async fn test_conversation_state_with_context_files() { + let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); + ctx.fs().write(AMAZONQ_FILENAME, "test context").await.unwrap(); + + let mut tool_manager = ToolManager::default(); + let mut conversation_state = ConversationState::new( + ctx, + "fake_conv_id", + tool_manager.load_tools().await.unwrap(), + None, + None, + ) + .await; + + // First, build a large conversation history. We need to ensure that the order is always + // User -> Assistant -> User -> Assistant ...and so on. + conversation_state.set_next_user_message("start".to_string()).await; + for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { + let s = conversation_state.as_sendable_conversation_state(true).await; + + // Ensure that the first two messages are the fake context messages. + let hist = s.history.as_ref().unwrap(); + let user = &hist[0]; + let assistant = &hist[1]; + match (user, assistant) { + (ChatMessage::UserInputMessage(user), ChatMessage::AssistantResponseMessage(_)) => { + assert!( + user.content.contains("test context"), + "expected context message to contain context file, instead found: {}", + user.content + ); + }, + _ => panic!("Expected the first two messages to be from the user and the assistant"), + } + + assert_conversation_state_invariants(s, i); + + conversation_state.push_assistant_message(AssistantMessage::new_response(None, i.to_string())); + conversation_state.set_next_user_message(i.to_string()).await; + } + } + + #[tokio::test] + async fn test_conversation_state_additional_context() { + tracing_subscriber::fmt::try_init().ok(); + + let mut tool_manager = ToolManager::default(); + let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); + let conversation_start_context = "conversation start context"; + let prompt_context = "prompt context"; + let config = serde_json::json!({ + "hooks": { + "test_per_prompt": { + "trigger": "per_prompt", + "type": "inline", + "command": format!("echo {}", prompt_context) + }, + "test_conversation_start": { + "trigger": "conversation_start", + "type": "inline", + "command": format!("echo {}", conversation_start_context) + } + } + }); + let config_path = profile_context_path(&ctx, "default").unwrap(); + ctx.fs().create_dir_all(config_path.parent().unwrap()).await.unwrap(); + ctx.fs() + .write(&config_path, serde_json::to_string(&config).unwrap()) + .await + .unwrap(); + let mut conversation_state = ConversationState::new( + ctx, + "fake_conv_id", + tool_manager.load_tools().await.unwrap(), + None, + Some(SharedWriter::stdout()), + ) + .await; + + // Simulate conversation flow + conversation_state.set_next_user_message("start".to_string()).await; + for i in 0..=5 { + let s = conversation_state.as_sendable_conversation_state(true).await; + let hist = s.history.as_ref().unwrap(); + #[allow(clippy::match_wildcard_for_single_variants)] + match &hist[0] { + ChatMessage::UserInputMessage(user) => { + assert!( + user.content.contains(conversation_start_context), + "expected to contain '{conversation_start_context}', instead found: {}", + user.content + ); + }, + _ => panic!("Expected user message."), + } + assert!( + s.user_input_message.content.contains(prompt_context), + "expected to contain '{prompt_context}', instead found: {}", + s.user_input_message.content + ); + + conversation_state.push_assistant_message(AssistantMessage::new_response(None, i.to_string())); + conversation_state.set_next_user_message(i.to_string()).await; + } + } +} diff --git a/crates/kiro-cli/src/cli/chat/hooks.rs b/crates/kiro-cli/src/cli/chat/hooks.rs new file mode 100644 index 0000000000..036195ba17 --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/hooks.rs @@ -0,0 +1,557 @@ +use std::collections::HashMap; +use std::io::Write; +use std::process::Stdio; +use std::time::{ + Duration, + Instant, +}; + +use bstr::ByteSlice; +use crossterm::style::{ + Color, + Stylize, +}; +use crossterm::{ + cursor, + execute, + queue, + style, + terminal, +}; +use eyre::{ + Result, + eyre, +}; +use futures::stream::{ + FuturesUnordered, + StreamExt, +}; +use serde::{ + Deserialize, + Serialize, +}; +use spinners::{ + Spinner, + Spinners, +}; + +use super::util::truncate_safe; + +const DEFAULT_TIMEOUT_MS: u64 = 30_000; +const DEFAULT_MAX_OUTPUT_SIZE: usize = 1024 * 10; +const DEFAULT_CACHE_TTL_SECONDS: u64 = 0; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Hook { + pub trigger: HookTrigger, + + pub r#type: HookType, + + #[serde(default = "Hook::default_disabled")] + pub disabled: bool, + + /// Max time the hook can run before it throws a timeout error + #[serde(default = "Hook::default_timeout_ms")] + pub timeout_ms: u64, + + /// Max output size of the hook before it is truncated + #[serde(default = "Hook::default_max_output_size")] + pub max_output_size: usize, + + /// How long the hook output is cached before it will be executed again + #[serde(default = "Hook::default_cache_ttl_seconds")] + pub cache_ttl_seconds: u64, + + // Type-specific fields + /// The bash command to execute + pub command: Option, // For inline hooks + + // Internal data + #[serde(skip)] + pub name: String, + #[serde(skip)] + pub is_global: bool, +} + +impl Hook { + pub fn new_inline_hook(trigger: HookTrigger, command: String) -> Self { + Self { + trigger, + r#type: HookType::Inline, + disabled: Self::default_disabled(), + timeout_ms: Self::default_timeout_ms(), + max_output_size: Self::default_max_output_size(), + cache_ttl_seconds: Self::default_cache_ttl_seconds(), + command: Some(command), + is_global: false, + name: "new hook".to_string(), + } + } + + fn default_disabled() -> bool { + false + } + + fn default_timeout_ms() -> u64 { + DEFAULT_TIMEOUT_MS + } + + fn default_max_output_size() -> usize { + DEFAULT_MAX_OUTPUT_SIZE + } + + fn default_cache_ttl_seconds() -> u64 { + DEFAULT_CACHE_TTL_SECONDS + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum HookType { + // Execute an inline shell command + Inline, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum HookTrigger { + ConversationStart, + PerPrompt, +} + +#[derive(Debug, Clone)] +pub struct CachedHook { + output: String, + expiry: Option, +} + +/// Maps a hook name to a [`CachedHook`] +#[derive(Debug, Clone, Default)] +pub struct HookExecutor { + pub global_cache: HashMap, + pub profile_cache: HashMap, +} + +impl HookExecutor { + pub fn new() -> Self { + Self { + global_cache: HashMap::new(), + profile_cache: HashMap::new(), + } + } + + /// Run and cache [`Hook`]s. Any hooks that are already cached will be returned without + /// executing. Hooks that fail to execute will not be returned. + /// + /// If `updates` is `Some`, progress on hook execution will be written to it. + /// Errors encountered with write operations to `updates` are ignored. + /// + /// Note: [`HookTrigger::ConversationStart`] hooks never leave the cache. + pub async fn run_hooks(&mut self, hooks: Vec<&Hook>, mut updates: Option<&mut impl Write>) -> Vec<(Hook, String)> { + let mut results = Vec::with_capacity(hooks.len()); + let mut futures = FuturesUnordered::new(); + + // Start all hook future OR fetch from cache if available + // Why enumerate? We want to return the hook results in the order of hooks that we received, + // however, for output display we want to process hooks as they complete rather than the + // order they were started in. The index will be used later to sort them back to output order. + for (index, hook) in hooks.into_iter().enumerate() { + if hook.disabled { + continue; + } + + if let Some(cached) = self.get_cache(hook) { + results.push((index, (hook.clone(), cached.clone()))); + continue; + } + let future = self.execute_hook(hook); + futures.push(async move { (index, future.await) }); + } + + // Start caching the results added after whats already their (they are from the cache already) + let start_cache_index = results.len(); + + let mut succeeded = 0; + let total = futures.len(); + + let mut spinner = None; + let spinner_text = |complete: usize, total: usize| { + format!( + "{} of {} hooks finished", + complete.to_string().blue(), + total.to_string().blue(), + ) + }; + if total != 0 && updates.is_some() { + spinner = Some(Spinner::new(Spinners::Dots12, spinner_text(succeeded, total))); + } + + // Process results as they complete + let start_time = Instant::now(); + while let Some((index, (hook, result, duration))) = futures.next().await { + // If output is enabled, handle that first + if let Some(updates) = updates.as_deref_mut() { + if let Some(spinner) = spinner.as_mut() { + spinner.stop(); + + // Erase the spinner + let _ = execute!( + updates, + cursor::MoveToColumn(0), + terminal::Clear(terminal::ClearType::CurrentLine), + cursor::Hide, + ); + } + match &result { + Ok(_) => { + let _ = queue!( + updates, + style::SetForegroundColor(style::Color::Green), + style::Print("✓ "), + style::SetForegroundColor(style::Color::Blue), + style::Print(&hook.name), + style::ResetColor, + style::Print(" finished in "), + style::SetForegroundColor(style::Color::Yellow), + style::Print(format!("{:.2} s\n", duration.as_secs_f32())), + style::ResetColor, + ); + }, + Err(e) => { + let _ = queue!( + updates, + style::SetForegroundColor(style::Color::Red), + style::Print("✗ "), + style::SetForegroundColor(style::Color::Blue), + style::Print(&hook.name), + style::ResetColor, + style::Print(" failed after "), + style::SetForegroundColor(style::Color::Yellow), + style::Print(format!("{:.2} s", duration.as_secs_f32())), + style::ResetColor, + style::Print(format!(": {}\n", e)), + ); + }, + } + } + + // Process results regardless of output enabled + if let Ok(output) = result { + succeeded += 1; + results.push((index, (hook.clone(), output))); + } + + // Display ending summary or add a new spinner + if let Some(updates) = updates.as_deref_mut() { + // The futures set size decreases each time we process one + if futures.is_empty() { + let symbol = if total == succeeded { + "✓".to_string().green() + } else { + "✗".to_string().red() + }; + + let _ = queue!( + updates, + style::SetForegroundColor(Color::Blue), + style::Print(format!("{symbol} {} in ", spinner_text(succeeded, total))), + style::SetForegroundColor(style::Color::Yellow), + style::Print(format!("{:.2} s\n", start_time.elapsed().as_secs_f32())), + style::ResetColor, + ); + } else { + spinner = Some(Spinner::new(Spinners::Dots, spinner_text(succeeded, total))); + } + } + } + drop(futures); + + // Fill cache with executed results, skipping what was already from cache + results.iter().skip(start_cache_index).for_each(|(_, (hook, output))| { + let expiry = match hook.trigger { + HookTrigger::ConversationStart => None, + HookTrigger::PerPrompt => Some(Instant::now() + Duration::from_secs(hook.cache_ttl_seconds)), + }; + self.insert_cache(hook, CachedHook { + output: output.clone(), + expiry, + }); + }); + + // Return back to order at request start + results.sort_by_key(|(idx, _)| *idx); + results.into_iter().map(|(_, r)| r).collect() + } + + async fn execute_hook<'a>(&self, hook: &'a Hook) -> (&'a Hook, Result, Duration) { + let start_time = Instant::now(); + let result = match hook.r#type { + HookType::Inline => self.execute_inline_hook(hook).await, + }; + + (hook, result, start_time.elapsed()) + } + + async fn execute_inline_hook(&self, hook: &Hook) -> Result { + let command = hook.command.as_ref().ok_or_else(|| eyre!("no command specified"))?; + + let command_future = tokio::process::Command::new("bash") + .arg("-c") + .arg(command) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .output(); + let timeout = Duration::from_millis(hook.timeout_ms); + + // Run with timeout + match tokio::time::timeout(timeout, command_future).await { + Ok(result) => { + let result = result?; + if result.status.success() { + let stdout = result.stdout.to_str_lossy(); + let stdout = format!( + "{}{}", + truncate_safe(&stdout, hook.max_output_size), + if stdout.len() > hook.max_output_size { + " ... truncated" + } else { + "" + } + ); + Ok(stdout) + } else { + Err(eyre!("command returned non-zero exit code: {}", result.status)) + } + }, + Err(_) => Err(eyre!("command timed out after {} ms", timeout.as_millis())), + } + } + + /// Will return a cached hook's output if it exists and isn't expired. + fn get_cache(&self, hook: &Hook) -> Option { + let cache = if hook.is_global { + &self.global_cache + } else { + &self.profile_cache + }; + + cache.get(&hook.name).and_then(|o| { + if let Some(expiry) = o.expiry { + if Instant::now() < expiry { + Some(o.output.clone()) + } else { + None + } + } else { + Some(o.output.clone()) + } + }) + } + + fn insert_cache(&mut self, hook: &Hook, hook_output: CachedHook) { + let cache = if hook.is_global { + &mut self.global_cache + } else { + &mut self.profile_cache + }; + + cache.insert(hook.name.clone(), hook_output); + } +} + +#[cfg(test)] +mod tests { + use std::io::Stdout; + use std::time::Duration; + + use tokio::time::sleep; + + use super::*; + + #[test] + fn test_hook_creation() { + let command = "echo 'hello'"; + let hook = Hook::new_inline_hook(HookTrigger::PerPrompt, command.to_string()); + + assert_eq!(hook.r#type, HookType::Inline); + assert!(!hook.disabled); + assert_eq!(hook.timeout_ms, DEFAULT_TIMEOUT_MS); + assert_eq!(hook.max_output_size, DEFAULT_MAX_OUTPUT_SIZE); + assert_eq!(hook.cache_ttl_seconds, DEFAULT_CACHE_TTL_SECONDS); + assert_eq!(hook.command, Some(command.to_string())); + assert_eq!(hook.trigger, HookTrigger::PerPrompt); + assert!(!hook.is_global); + } + + #[tokio::test] + async fn test_hook_executor_cached_conversation_start() { + let mut executor = HookExecutor::new(); + let mut hook1 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo 'test1'".to_string()); + hook1.is_global = true; + + let mut hook2 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo 'test2'".to_string()); + hook2.is_global = false; + + // First execution should run the command + let mut output = Vec::new(); + let results = executor.run_hooks(vec![&hook1, &hook2], Some(&mut output)).await; + + assert_eq!(results.len(), 2); + assert!(results[0].1.contains("test1")); + assert!(results[1].1.contains("test2")); + assert!(!output.is_empty()); + + // Second execution should use cache + let mut output = Vec::new(); + let results = executor.run_hooks(vec![&hook1, &hook2], Some(&mut output)).await; + + assert_eq!(results.len(), 2); + assert!(results[0].1.contains("test1")); + assert!(results[1].1.contains("test2")); + assert!(output.is_empty()); // Should not have run the hook, so no output. + } + + #[tokio::test] + async fn test_hook_executor_cached_per_prompt() { + let mut executor = HookExecutor::new(); + let mut hook1 = Hook::new_inline_hook(HookTrigger::PerPrompt, "echo 'test1'".to_string()); + hook1.is_global = true; + hook1.cache_ttl_seconds = 60; + + let mut hook2 = Hook::new_inline_hook(HookTrigger::PerPrompt, "echo 'test2'".to_string()); + hook2.is_global = false; + hook2.cache_ttl_seconds = 60; + + // First execution should run the command + let mut output = Vec::new(); + let results = executor.run_hooks(vec![&hook1, &hook2], Some(&mut output)).await; + + assert_eq!(results.len(), 2); + assert!(results[0].1.contains("test1")); + assert!(results[1].1.contains("test2")); + assert!(!output.is_empty()); + + // Second execution should use cache + let mut output = Vec::new(); + let results = executor.run_hooks(vec![&hook1, &hook2], Some(&mut output)).await; + + assert_eq!(results.len(), 2); + assert!(results[0].1.contains("test1")); + assert!(results[1].1.contains("test2")); + assert!(output.is_empty()); // Should not have run the hook, so no output. + } + + #[tokio::test] + async fn test_hook_executor_not_cached_per_prompt() { + let mut executor = HookExecutor::new(); + let mut hook1 = Hook::new_inline_hook(HookTrigger::PerPrompt, "echo 'test1'".to_string()); + hook1.is_global = true; + + let mut hook2 = Hook::new_inline_hook(HookTrigger::PerPrompt, "echo 'test2'".to_string()); + hook2.is_global = false; + + // First execution should run the command + let mut output = Vec::new(); + let results = executor.run_hooks(vec![&hook1, &hook2], Some(&mut output)).await; + + assert_eq!(results.len(), 2); + assert!(results[0].1.contains("test1")); + assert!(results[1].1.contains("test2")); + assert!(!output.is_empty()); + + // Second execution should use cache + let mut output = Vec::new(); + let results = executor.run_hooks(vec![&hook1, &hook2], Some(&mut output)).await; + + assert_eq!(results.len(), 2); + assert!(results[0].1.contains("test1")); + assert!(results[1].1.contains("test2")); + assert!(!output.is_empty()); + } + + #[tokio::test] + async fn test_hook_timeout() { + let mut executor = HookExecutor::new(); + let mut hook = Hook::new_inline_hook(HookTrigger::PerPrompt, "sleep 2".to_string()); + hook.timeout_ms = 100; // Set very short timeout + + let results = executor.run_hooks(vec![&hook], None::<&mut Stdout>).await; + + assert_eq!(results.len(), 0); // Should fail due to timeout + } + + #[tokio::test] + async fn test_disabled_hook() { + let mut executor = HookExecutor::new(); + let mut hook = Hook::new_inline_hook(HookTrigger::PerPrompt, "echo 'test'".to_string()); + hook.disabled = true; + + let results = executor.run_hooks(vec![&hook], None::<&mut Stdout>).await; + + assert_eq!(results.len(), 0); // Disabled hook should not run + } + + #[tokio::test] + async fn test_cache_expiration() { + let mut executor = HookExecutor::new(); + let mut hook = Hook::new_inline_hook(HookTrigger::PerPrompt, "echo 'test'".to_string()); + hook.cache_ttl_seconds = 1; + + // First execution + let results1 = executor.run_hooks(vec![&hook], None::<&mut Stdout>).await; + assert_eq!(results1.len(), 1); + + // Wait for cache to expire + sleep(Duration::from_millis(1001)).await; + + // Second execution should run command again + let results2 = executor.run_hooks(vec![&hook], None::<&mut Stdout>).await; + assert_eq!(results2.len(), 1); + } + + #[test] + fn test_hook_cache_storage() { + let mut executor: HookExecutor = HookExecutor::new(); + let hook = Hook::new_inline_hook(HookTrigger::PerPrompt, "".to_string()); + + let cached_hook = CachedHook { + output: "test output".to_string(), + expiry: None, + }; + + executor.insert_cache(&hook, cached_hook.clone()); + + assert_eq!(executor.get_cache(&hook), Some("test output".to_string())); + } + + #[test] + fn test_hook_cache_storage_expired() { + let mut executor: HookExecutor = HookExecutor::new(); + let hook = Hook::new_inline_hook(HookTrigger::PerPrompt, "".to_string()); + + let cached_hook = CachedHook { + output: "test output".to_string(), + expiry: Some(Instant::now()), + }; + + executor.insert_cache(&hook, cached_hook.clone()); + + // Item should not return since it is expired + assert_eq!(executor.get_cache(&hook), None); + } + + #[tokio::test] + async fn test_max_output_size() { + let mut executor = HookExecutor::new(); + let mut hook = Hook::new_inline_hook( + HookTrigger::PerPrompt, + "for i in {1..1000}; do echo $i; done".to_string(), + ); + hook.max_output_size = 100; + + let results = executor.run_hooks(vec![&hook], None::<&mut Stdout>).await; + + assert!(results[0].1.len() <= hook.max_output_size + " ... truncated".len()); + } +} diff --git a/crates/kiro-cli/src/cli/chat/input_source.rs b/crates/kiro-cli/src/cli/chat/input_source.rs new file mode 100644 index 0000000000..4e43bb2d13 --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/input_source.rs @@ -0,0 +1,107 @@ +use std::sync::Arc; + +use eyre::Result; +use rustyline::error::ReadlineError; +use rustyline::{ + EventHandler, + KeyEvent, +}; + +use super::context::ContextManager; +use super::prompt::rl; +use super::skim_integration::SkimCommandSelector; + +#[derive(Debug)] +pub struct InputSource(inner::Inner); + +mod inner { + use rustyline::Editor; + use rustyline::history::FileHistory; + + use super::super::prompt::ChatHelper; + + #[derive(Debug)] + pub enum Inner { + Readline(Editor), + #[allow(dead_code)] + Mock { + index: usize, + lines: Vec, + }, + } +} + +impl InputSource { + pub fn new( + sender: std::sync::mpsc::Sender>, + receiver: std::sync::mpsc::Receiver>, + ) -> Result { + Ok(Self(inner::Inner::Readline(rl(sender, receiver)?))) + } + + pub fn put_skim_command_selector(&mut self, context_manager: Arc, tool_names: Vec) { + if let inner::Inner::Readline(rl) = &mut self.0 { + let key_char = match crate::fig_settings::settings::get_string_opt("chat.skimCommandKey").as_deref() { + Some(key) if key.len() == 1 => key.chars().next().unwrap_or('s'), + _ => 's', // Default to 's' if setting is missing or invalid + }; + rl.bind_sequence( + KeyEvent::ctrl(key_char), + EventHandler::Conditional(Box::new(SkimCommandSelector::new(context_manager, tool_names))), + ); + } + } + + #[allow(dead_code)] + pub fn new_mock(lines: Vec) -> Self { + Self(inner::Inner::Mock { index: 0, lines }) + } + + pub fn read_line(&mut self, prompt: Option<&str>) -> Result, ReadlineError> { + match &mut self.0 { + inner::Inner::Readline(rl) => { + let prompt = prompt.unwrap_or_default(); + let curr_line = rl.readline(prompt); + match curr_line { + Ok(line) => { + let _ = rl.add_history_entry(line.as_str()); + Ok(Some(line)) + }, + Err(ReadlineError::Interrupted | ReadlineError::Eof) => Ok(None), + Err(err) => Err(err), + } + }, + inner::Inner::Mock { index, lines } => { + *index += 1; + Ok(lines.get(*index - 1).cloned()) + }, + } + } + + // We're keeping this method for potential future use + #[allow(dead_code)] + pub fn set_buffer(&mut self, content: &str) { + if let inner::Inner::Readline(rl) = &mut self.0 { + // Add to history so user can access it with up arrow + let _ = rl.add_history_entry(content); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mock_input_source() { + let l1 = "Hello,".to_string(); + let l2 = "Line 2".to_string(); + let l3 = "World!".to_string(); + let mut input = InputSource::new_mock(vec![l1.clone(), l2.clone(), l3.clone()]); + + assert_eq!(input.read_line(None).unwrap().unwrap(), l1); + assert_eq!(input.read_line(None).unwrap().unwrap(), l2); + assert_eq!(input.read_line(None).unwrap().unwrap(), l3); + assert!(input.read_line(None).unwrap().is_none()); + } +} diff --git a/crates/kiro-cli/src/cli/chat/message.rs b/crates/kiro-cli/src/cli/chat/message.rs new file mode 100644 index 0000000000..512a8fb276 --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/message.rs @@ -0,0 +1,384 @@ +use std::env; + +use serde::{ + Deserialize, + Serialize, +}; +use tracing::error; + +use super::consts::MAX_CURRENT_WORKING_DIRECTORY_LEN; +use super::tools::{ + InvokeOutput, + OutputKind, + document_to_serde_value, + serde_value_to_document, +}; +use super::util::truncate_safe; +use crate::fig_api_client::model::{ + AssistantResponseMessage, + EnvState, + ToolResult, + ToolResultContentBlock, + ToolResultStatus, + ToolUse, + UserInputMessage, + UserInputMessageContext, +}; + +const USER_ENTRY_START_HEADER: &str = "--- USER MESSAGE BEGIN ---\n"; +const USER_ENTRY_END_HEADER: &str = "--- USER MESSAGE END ---\n\n"; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UserMessage { + pub additional_context: String, + pub env_context: UserEnvContext, + pub content: UserMessageContent, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum UserMessageContent { + Prompt { + /// The original prompt as input by the user. + prompt: String, + }, + CancelledToolUses { + /// The original prompt as input by the user, if any. + prompt: Option, + tool_use_results: Vec, + }, + ToolUseResults { + tool_use_results: Vec, + }, +} + +impl UserMessage { + /// Creates a new [UserMessage::Prompt], automatically detecting and adding the user's + /// environment [UserEnvContext]. + pub fn new_prompt(prompt: String) -> Self { + Self { + additional_context: String::new(), + env_context: UserEnvContext::generate_new(), + content: UserMessageContent::Prompt { prompt }, + } + } + + pub fn new_cancelled_tool_uses<'a>(prompt: Option, tool_use_ids: impl Iterator) -> Self { + Self { + additional_context: String::new(), + env_context: UserEnvContext::generate_new(), + content: UserMessageContent::CancelledToolUses { + prompt, + tool_use_results: tool_use_ids + .map(|id| ToolUseResult { + tool_use_id: id.to_string(), + content: vec![ToolUseResultBlock::Text( + "Tool use was cancelled by the user".to_string(), + )], + status: ToolResultStatus::Error, + }) + .collect(), + }, + } + } + + pub fn new_tool_use_results(results: Vec) -> Self { + Self { + additional_context: String::new(), + env_context: UserEnvContext::generate_new(), + content: UserMessageContent::ToolUseResults { + tool_use_results: results, + }, + } + } + + /// Converts this message into a [UserInputMessage] to be stored in the history of + /// [fig_api_client::model::ConversationState]. + pub fn into_history_entry(self) -> UserInputMessage { + UserInputMessage { + content: self.prompt().unwrap_or_default().to_string(), + user_input_message_context: Some(UserInputMessageContext { + env_state: self.env_context.env_state, + tool_results: match self.content { + UserMessageContent::CancelledToolUses { tool_use_results, .. } + | UserMessageContent::ToolUseResults { tool_use_results } => { + Some(tool_use_results.into_iter().map(Into::into).collect()) + }, + UserMessageContent::Prompt { .. } => None, + }, + tools: None, + ..Default::default() + }), + user_intent: None, + } + } + + /// Converts this message into a [UserInputMessage] to be sent as + /// [FigConversationState::user_input_message]. + pub fn into_user_input_message(self) -> UserInputMessage { + let formatted_prompt = match self.prompt() { + Some(prompt) if !prompt.is_empty() => { + format!("{}{}{}", USER_ENTRY_START_HEADER, prompt, USER_ENTRY_END_HEADER) + }, + _ => String::new(), + }; + UserInputMessage { + content: format!("{} {}", self.additional_context, formatted_prompt) + .trim() + .to_string(), + user_input_message_context: Some(UserInputMessageContext { + env_state: self.env_context.env_state, + tool_results: match self.content { + UserMessageContent::CancelledToolUses { tool_use_results, .. } + | UserMessageContent::ToolUseResults { tool_use_results } => { + Some(tool_use_results.into_iter().map(Into::into).collect()) + }, + UserMessageContent::Prompt { .. } => None, + }, + tools: None, + ..Default::default() + }), + user_intent: None, + } + } + + pub fn has_tool_use_results(&self) -> bool { + match self.content() { + UserMessageContent::CancelledToolUses { .. } | UserMessageContent::ToolUseResults { .. } => true, + UserMessageContent::Prompt { .. } => false, + } + } + + pub fn tool_use_results(&self) -> Option<&[ToolUseResult]> { + match self.content() { + UserMessageContent::Prompt { .. } => None, + UserMessageContent::CancelledToolUses { tool_use_results, .. } => Some(tool_use_results.as_slice()), + UserMessageContent::ToolUseResults { tool_use_results } => Some(tool_use_results.as_slice()), + } + } + + pub fn additional_context(&self) -> &str { + &self.additional_context + } + + pub fn content(&self) -> &UserMessageContent { + &self.content + } + + pub fn prompt(&self) -> Option<&str> { + match self.content() { + UserMessageContent::Prompt { prompt } => Some(prompt.as_str()), + UserMessageContent::CancelledToolUses { prompt, .. } => prompt.as_ref().map(|s| s.as_str()), + UserMessageContent::ToolUseResults { .. } => None, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolUseResult { + /// The ID for the tool request. + pub tool_use_id: String, + /// Content of the tool result. + pub content: Vec, + /// Status of the tool result. + pub status: ToolResultStatus, +} + +impl From for ToolUseResult { + fn from(value: ToolResult) -> Self { + Self { + tool_use_id: value.tool_use_id, + content: value.content.into_iter().map(Into::into).collect(), + status: value.status, + } + } +} + +impl From for ToolResult { + fn from(value: ToolUseResult) -> Self { + Self { + tool_use_id: value.tool_use_id, + content: value.content.into_iter().map(Into::into).collect(), + status: value.status, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ToolUseResultBlock { + Json(serde_json::Value), + Text(String), +} + +impl From for ToolResultContentBlock { + fn from(value: ToolUseResultBlock) -> Self { + match value { + ToolUseResultBlock::Json(v) => Self::Json(serde_value_to_document(v)), + ToolUseResultBlock::Text(s) => Self::Text(s), + } + } +} + +impl From for ToolUseResultBlock { + fn from(value: ToolResultContentBlock) -> Self { + match value { + ToolResultContentBlock::Json(v) => Self::Json(document_to_serde_value(v)), + ToolResultContentBlock::Text(s) => Self::Text(s), + } + } +} + +impl From for ToolUseResultBlock { + fn from(value: InvokeOutput) -> Self { + match value.output { + OutputKind::Text(text) => Self::Text(text), + OutputKind::Json(value) => Self::Json(value), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UserEnvContext { + env_state: Option, +} + +impl UserEnvContext { + pub fn generate_new() -> Self { + Self { + env_state: Some(build_env_state()), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum AssistantMessage { + /// Normal response containing no tool uses. + Response { + message_id: Option, + content: String, + }, + /// An assistant message containing tool uses. + ToolUse { + message_id: Option, + content: String, + tool_uses: Vec, + }, +} + +impl AssistantMessage { + pub fn new_response(message_id: Option, content: String) -> Self { + Self::Response { message_id, content } + } + + pub fn new_tool_use(message_id: Option, content: String, tool_uses: Vec) -> Self { + Self::ToolUse { + message_id, + content, + tool_uses, + } + } + + pub fn message_id(&self) -> Option<&str> { + match self { + AssistantMessage::Response { message_id, .. } => message_id.as_ref().map(|s| s.as_str()), + AssistantMessage::ToolUse { message_id, .. } => message_id.as_ref().map(|s| s.as_str()), + } + } + + pub fn content(&self) -> &str { + match self { + AssistantMessage::Response { content, .. } => content.as_str(), + AssistantMessage::ToolUse { content, .. } => content.as_str(), + } + } + + pub fn tool_uses(&self) -> Option<&[AssistantToolUse]> { + match self { + AssistantMessage::ToolUse { tool_uses, .. } => Some(tool_uses.as_slice()), + AssistantMessage::Response { .. } => None, + } + } +} + +impl From for AssistantResponseMessage { + fn from(value: AssistantMessage) -> Self { + let (message_id, content, tool_uses) = match value { + AssistantMessage::Response { message_id, content } => (message_id, content, None), + AssistantMessage::ToolUse { + message_id, + content, + tool_uses, + } => ( + message_id, + content, + Some(tool_uses.into_iter().map(Into::into).collect()), + ), + }; + Self { + message_id, + content, + tool_uses, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AssistantToolUse { + /// The ID for the tool request. + pub id: String, + /// The name for the tool. + pub name: String, + /// The input to pass to the tool. + pub args: serde_json::Value, +} + +impl From for ToolUse { + fn from(value: AssistantToolUse) -> Self { + Self { + tool_use_id: value.id, + name: value.name, + input: serde_value_to_document(value.args), + } + } +} + +impl From for AssistantToolUse { + fn from(value: ToolUse) -> Self { + Self { + id: value.tool_use_id, + name: value.name, + args: document_to_serde_value(value.input), + } + } +} + +pub fn build_env_state() -> EnvState { + let mut env_state = EnvState { + operating_system: Some(env::consts::OS.into()), + ..Default::default() + }; + + match env::current_dir() { + Ok(current_dir) => { + env_state.current_working_directory = + Some(truncate_safe(¤t_dir.to_string_lossy(), MAX_CURRENT_WORKING_DIRECTORY_LEN).into()); + }, + Err(err) => { + error!(?err, "Attempted to fetch the CWD but it did not exist."); + }, + } + + env_state +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_env_state() { + let env_state = build_env_state(); + assert!(env_state.current_working_directory.is_some()); + assert!(env_state.operating_system.as_ref().is_some_and(|os| !os.is_empty())); + println!("{env_state:?}"); + } +} diff --git a/crates/kiro-cli/src/cli/chat/mod.rs b/crates/kiro-cli/src/cli/chat/mod.rs new file mode 100644 index 0000000000..6e4c7beb7b --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/mod.rs @@ -0,0 +1,3937 @@ +pub mod cli; +mod command; +mod consts; +mod context; +mod conversation_state; +mod hooks; +mod input_source; +mod message; +mod parse; +mod parser; +mod prompt; +mod shared_writer; +mod skim_integration; +mod token_counter; +mod tool_manager; +mod tools; +pub mod util; + +use std::borrow::Cow; +use std::collections::{ + HashMap, + HashSet, + VecDeque, +}; +use std::io::{ + IsTerminal, + Read, + Write, +}; +use std::process::{ + Command as ProcessCommand, + ExitCode, +}; +use std::sync::Arc; +use std::time::Duration; +use std::{ + env, + fs, +}; + +use command::{ + Command, + PromptsSubcommand, + ToolsSubcommand, +}; +use consts::CONTEXT_WINDOW_SIZE; +use context::ContextManager; +use conversation_state::{ + ConversationState, + TokenWarningLevel, +}; +use crossterm::style::{ + Attribute, + Color, + Stylize, +}; +use crossterm::terminal::ClearType; +use crossterm::{ + cursor, + execute, + queue, + style, + terminal, +}; +use eyre::{ + ErrReport, + Result, + bail, +}; +use hooks::{ + Hook, + HookTrigger, +}; +use message::{ + AssistantMessage, + AssistantToolUse, + ToolUseResult, + ToolUseResultBlock, +}; +use rand::distr::{ + Alphanumeric, + SampleString, +}; +use shared_writer::SharedWriter; + +use crate::fig_api_client::StreamingClient; +use crate::fig_api_client::clients::SendMessageOutput; +use crate::fig_api_client::model::{ + ChatResponseStream, + Tool as FigTool, + ToolResultStatus, +}; +use crate::fig_os_shim::Context; +use crate::fig_settings::{ + Settings, + State, +}; +use crate::fig_util::CLI_BINARY_NAME; + +/// Help text for the compact command +fn compact_help_text() -> String { + color_print::cformat!( + r#" +Conversation Compaction + +The /compact command summarizes the conversation history to free up context space +while preserving essential information. This is useful for long-running conversations +that may eventually reach memory constraints. + +Usage + /compact Summarize the conversation and clear history + /compact [prompt] Provide custom guidance for summarization + +When to use +• When you see the memory constraint warning message +• When a conversation has been running for a long time +• Before starting a new topic within the same session +• After completing complex tool operations + +How it works +• Creates an AI-generated summary of your conversation +• Retains key information, code, and tool executions in the summary +• Clears the conversation history to free up space +• The assistant will reference the summary context in future responses +"# + ) +} +use input_source::InputSource; +use parse::{ + ParseState, + interpret_markdown, +}; +use parser::{ + RecvErrorKind, + ResponseParser, +}; +use regex::Regex; +use serde_json::Map; +use spinners::{ + Spinner, + Spinners, +}; +use thiserror::Error; +use token_counter::{ + TokenCount, + TokenCounter, +}; +use tokio::signal::unix::{ + SignalKind, + signal, +}; +use tool_manager::{ + GetPromptError, + McpServerConfig, + PromptBundle, + ToolManager, + ToolManagerBuilder, +}; +use tools::gh_issue::GhIssueContext; +use tools::{ + QueuedTool, + Tool, + ToolPermissions, + ToolSpec, +}; +use tracing::{ + debug, + error, + info, + trace, + warn, +}; +use unicode_width::UnicodeWidthStr; +use util::{ + animate_output, + play_notification_bell, + region_check, +}; +use uuid::Uuid; +use winnow::Partial; +use winnow::stream::Offset; + +use crate::mcp_client::{ + Prompt, + PromptGetResult, +}; + +const WELCOME_TEXT: &str = color_print::cstr! {" +Welcome to + + █████╗ ███╗ ███╗ █████╗ ███████╗ ██████╗ ███╗ ██╗ ██████╗ +██╔══██╗████╗ ████║██╔══██╗╚══███╔╝██╔═══██╗████╗ ██║ ██╔═══██╗ +███████║██╔████╔██║███████║ ███╔╝ ██║ ██║██╔██╗ ██║ ██║ ██║ +██╔══██║██║╚██╔╝██║██╔══██║ ███╔╝ ██║ ██║██║╚██╗██║ ██║▄▄ ██║ +██║ ██║██║ ╚═╝ ██║██║ ██║███████╗╚██████╔╝██║ ╚████║ ╚██████╔╝ +╚═╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚═╝╚══════╝ ╚═════╝ ╚═╝ ╚═══╝ ╚══▀▀═╝ + +"}; + +const SMALL_SCREEN_WECLOME_TEXT: &str = color_print::cstr! {" +Welcome to Amazon Q! +"}; + +const ROTATING_TIPS: [&str; 9] = [ + color_print::cstr! {"Get notified whenever Q CLI finishes responding. Just run q settings chat.enableNotifications true"}, + color_print::cstr! {"You can use /editor to edit your prompt with a vim-like experience"}, + color_print::cstr! {"You can execute bash commands by typing ! followed by the command"}, + color_print::cstr! {"Q can use tools without asking for confirmation every time. Give /tools trust a try"}, + color_print::cstr! {"You can programmatically inject context to your prompts by using hooks. Check out /context hooks help"}, + color_print::cstr! {"You can use /compact to replace the conversation history with its summary to free up the context space"}, + color_print::cstr! {"/usage shows you a visual breakdown of your current context window usage"}, + color_print::cstr! {"If you want to file an issue to the Q CLI team, just tell me, or run q issue"}, + color_print::cstr! {"You can enable custom tools with MCP servers. Learn more with /help"}, +]; + +const GREETING_BREAK_POINT: usize = 67; + +const POPULAR_SHORTCUTS: &str = color_print::cstr! {" + +/help all commands ctrl + j new lines ctrl + s fuzzy search +"}; + +const SMALL_SCREEN_POPULAR_SHORTCUTS: &str = color_print::cstr! {" + +/help all commands +ctrl + j new lines +ctrl + s fuzzy search + +"}; +const HELP_TEXT: &str = color_print::cstr! {" + +q (Amazon Q Chat) + +Commands: +/clear Clear the conversation history +/issue Report an issue or make a feature request +/editor Open $EDITOR (defaults to vi) to compose a prompt +/help Show this help dialogue +/quit Quit the application +/compact Summarize the conversation to free up context space + help Show help for the compact command + [prompt] Optional custom prompt to guide summarization +/tools View and manage tools and permissions + help Show an explanation for the trust command + trust Trust a specific tool or tools for the session + untrust Revert a tool or tools to per-request confirmation + trustall Trust all tools (equivalent to deprecated /acceptall) + reset Reset all tools to default permission levels +/profile Manage profiles + help Show profile help + list List profiles + set Set the current profile + create Create a new profile + delete Delete a profile + rename Rename a profile +/prompts View and retrieve prompts + help Show prompts help + list List or search available prompts + get Retrieve and send a prompt +/context Manage context files and hooks for the chat session + help Show context help + show Display current context rules configuration [--expand] + add Add file(s) to context [--global] [--force] + rm Remove file(s) from context [--global] + clear Clear all files from current context [--global] + hooks View and manage context hooks +/usage Show current session's context window usage + +MCP: +You can now configure the Amazon Q CLI to use MCP servers. \nLearn how: https://docs.aws.amazon.com/en_us/amazonq/latest/qdeveloper-ug/command-line-mcp.html + +Tips: +!{command} Quickly execute a command in your current session +Ctrl(^) + j Insert new-line to provide multi-line prompt. Alternatively, [Alt(⌥) + Enter(⏎)] +Ctrl(^) + s Fuzzy search commands and context files. Use Tab to select multiple items. + Change the keybind to ctrl+x with: q settings chat.skimCommandKey x (where x is any key) + +"}; + +const RESPONSE_TIMEOUT_CONTENT: &str = "Response timed out - message took too long to generate"; +const TRUST_ALL_TEXT: &str = color_print::cstr! {"All tools are now trusted (!). Amazon Q will execute tools without asking for confirmation.\ +\nAgents can sometimes do unexpected things so understand the risks. +\nLearn more at https://docs.aws.amazon.com/amazonq/latest/qdeveloper-ug/command-line-chat-security.html#command-line-chat-trustall-safety"}; + +const TOOL_BULLET: &str = " ● "; +const CONTINUATION_LINE: &str = " ⋮ "; + +pub async fn launch_chat(args: cli::Chat) -> Result { + let trust_tools = args.trust_tools.map(|mut tools| { + if tools.len() == 1 && tools[0].is_empty() { + tools.pop(); + } + tools + }); + chat( + args.input, + args.no_interactive, + args.accept_all, + args.profile, + args.trust_all_tools, + trust_tools, + ) + .await +} + +pub async fn chat( + input: Option, + no_interactive: bool, + accept_all: bool, + profile: Option, + trust_all_tools: bool, + trust_tools: Option>, +) -> Result { + if !crate::fig_util::system_info::in_cloudshell() && !crate::fig_auth::is_logged_in().await { + bail!( + "You are not logged in, please log in with {}", + format!("{CLI_BINARY_NAME} login",).bold() + ); + } + + region_check("chat")?; + + let ctx = Context::new(); + + let stdin = std::io::stdin(); + // no_interactive flag or part of a pipe + let interactive = !no_interactive && stdin.is_terminal(); + let input = if !interactive && !stdin.is_terminal() { + // append to input string any extra info that was provided, e.g. via pipe + let mut input = input.unwrap_or_default(); + stdin.lock().read_to_string(&mut input)?; + Some(input) + } else { + input + }; + + let mut output = match interactive { + true => SharedWriter::stderr(), + false => SharedWriter::stdout(), + }; + + let client = match ctx.env().get("Q_MOCK_CHAT_RESPONSE") { + Ok(json) => create_stream(serde_json::from_str(std::fs::read_to_string(json)?.as_str())?), + _ => StreamingClient::new().await?, + }; + + let mcp_server_configs = match McpServerConfig::load_config(&mut output).await { + Ok(config) => { + execute!( + output, + style::Print( + "To learn more about MCP safety, see https://docs.aws.amazon.com/amazonq/latest/qdeveloper-ug/command-line-mcp-security.html\n" + ) + )?; + config + }, + Err(e) => { + warn!("No mcp server config loaded: {}", e); + McpServerConfig::default() + }, + }; + + // If profile is specified, verify it exists before starting the chat + if let Some(ref profile_name) = profile { + // Create a temporary context manager to check if the profile exists + match ContextManager::new(Arc::clone(&ctx)).await { + Ok(context_manager) => { + let profiles = context_manager.list_profiles().await?; + if !profiles.contains(profile_name) { + bail!( + "Profile '{}' does not exist. Available profiles: {}", + profile_name, + profiles.join(", ") + ); + } + }, + Err(e) => { + warn!("Failed to initialize context manager to verify profile: {}", e); + // Continue without verification if context manager can't be initialized + }, + } + } + + let conversation_id = Alphanumeric.sample_string(&mut rand::rng(), 9); + info!(?conversation_id, "Generated new conversation id"); + let (prompt_request_sender, prompt_request_receiver) = std::sync::mpsc::channel::>(); + let (prompt_response_sender, prompt_response_receiver) = std::sync::mpsc::channel::>(); + let mut tool_manager = ToolManagerBuilder::default() + .mcp_server_config(mcp_server_configs) + .prompt_list_sender(prompt_response_sender) + .prompt_list_receiver(prompt_request_receiver) + .conversation_id(&conversation_id) + .build()?; + let tool_config = tool_manager.load_tools().await?; + let mut tool_permissions = ToolPermissions::new(tool_config.len()); + if accept_all || trust_all_tools { + for tool in tool_config.values() { + tool_permissions.trust_tool(&tool.name); + } + + // Deprecation notice for --accept-all users + if accept_all && interactive { + queue!( + output, + style::SetForegroundColor(Color::Yellow), + style::Print("\n--accept-all, -a is deprecated. Use --trust-all-tools instead."), + style::SetForegroundColor(Color::Reset), + )?; + } + } else if let Some(trusted) = trust_tools.map(|vec| vec.into_iter().collect::>()) { + // --trust-all-tools takes precedence over --trust-tools=... + for tool in tool_config.values() { + if trusted.contains(&tool.name) { + tool_permissions.trust_tool(&tool.name); + } else { + tool_permissions.untrust_tool(&tool.name); + } + } + } + + let mut chat = ChatContext::new( + ctx, + &conversation_id, + Settings::new(), + State::new(), + output, + input, + InputSource::new(prompt_request_sender, prompt_response_receiver)?, + interactive, + client, + || terminal::window_size().map(|s| s.columns.into()).ok(), + tool_manager, + profile, + tool_config, + tool_permissions, + ) + .await?; + + let result = chat.try_chat().await.map(|_| ExitCode::SUCCESS); + drop(chat); // Explicit drop for clarity + + result +} + +/// Enum used to denote the origin of a tool use event +enum ToolUseStatus { + /// Variant denotes that the tool use event associated with chat context is a direct result of + /// a user request + Idle, + /// Variant denotes that the tool use event associated with the chat context is a result of a + /// retry for one or more previously attempted tool use. The tuple is the utterance id + /// associated with the original user request that necessitated the tool use + RetryInProgress(String), +} + +#[derive(Debug, Error)] +pub enum ChatError { + #[error("{0}")] + Client(#[from] crate::fig_api_client::Error), + #[error("{0}")] + ResponseStream(#[from] parser::RecvError), + #[error("{0}")] + Std(#[from] std::io::Error), + #[error("{0}")] + Readline(#[from] rustyline::error::ReadlineError), + #[error("{0}")] + Custom(Cow<'static, str>), + #[error("interrupted")] + Interrupted { tool_uses: Option> }, + #[error( + "Tool approval required but --no-interactive was specified. Use --trust-all-tools to automatically approve tools." + )] + NonInteractiveToolApproval, + #[error(transparent)] + GetPromptError(#[from] GetPromptError), +} + +pub struct ChatContext { + ctx: Arc, + settings: Settings, + /// The [State] to use for the chat context. + state: State, + /// The [Write] destination for printing conversation text. + output: SharedWriter, + initial_input: Option, + input_source: InputSource, + interactive: bool, + /// The client to use to interact with the model. + client: StreamingClient, + /// Width of the terminal, required for [ParseState]. + terminal_width_provider: fn() -> Option, + spinner: Option, + /// [ConversationState]. + conversation_state: ConversationState, + /// State to track tools that need confirmation. + tool_permissions: ToolPermissions, + /// Telemetry events to be sent as part of the conversation. + tool_use_telemetry_events: HashMap, + /// State used to keep track of tool use relation + tool_use_status: ToolUseStatus, + /// Abstraction that consolidates custom tools with native ones + tool_manager: ToolManager, + /// Any failed requests that could be useful for error report/debugging + failed_request_ids: Vec, + /// Pending prompts to be sent + pending_prompts: VecDeque, +} + +impl ChatContext { + #[allow(clippy::too_many_arguments)] + pub async fn new( + ctx: Arc, + conversation_id: &str, + settings: Settings, + state: State, + output: SharedWriter, + input: Option, + input_source: InputSource, + interactive: bool, + client: StreamingClient, + terminal_width_provider: fn() -> Option, + tool_manager: ToolManager, + profile: Option, + tool_config: HashMap, + tool_permissions: ToolPermissions, + ) -> Result { + let ctx_clone = Arc::clone(&ctx); + let output_clone = output.clone(); + let conversation_state = + ConversationState::new(ctx_clone, conversation_id, tool_config, profile, Some(output_clone)).await; + Ok(Self { + ctx, + settings, + state, + output, + initial_input: input, + input_source, + interactive, + client, + terminal_width_provider, + spinner: None, + tool_permissions, + conversation_state, + tool_use_telemetry_events: HashMap::new(), + tool_use_status: ToolUseStatus::Idle, + tool_manager, + failed_request_ids: Vec::new(), + pending_prompts: VecDeque::new(), + }) + } +} + +impl Drop for ChatContext { + fn drop(&mut self) { + if let Some(spinner) = &mut self.spinner { + spinner.stop(); + } + + if self.interactive { + queue!( + self.output, + cursor::MoveToColumn(0), + style::SetAttribute(Attribute::Reset), + style::ResetColor, + cursor::Show + ) + .ok(); + } + + self.output.flush().ok(); + } +} + +/// The chat execution state. +/// +/// Intended to provide more robust handling around state transitions while dealing with, e.g., +/// tool validation, execution, response stream handling, etc. +#[derive(Debug)] +enum ChatState { + /// Prompt the user with `tool_uses`, if available. + PromptUser { + /// Tool uses to present to the user. + tool_uses: Option>, + /// Tracks the next tool in tool_uses that needs user acceptance. + pending_tool_index: Option, + /// Used to avoid displaying the tool info at inappropriate times, e.g. after clear or help + /// commands. + skip_printing_tools: bool, + }, + /// Handle the user input, depending on if any tools require execution. + HandleInput { + input: String, + tool_uses: Option>, + pending_tool_index: Option, + }, + /// Validate the list of tool uses provided by the model. + ValidateTools(Vec), + /// Execute the list of tools. + ExecuteTools(Vec), + /// Consume the response stream and display to the user. + HandleResponseStream(SendMessageOutput), + /// Compact the chat history. + CompactHistory { + tool_uses: Option>, + pending_tool_index: Option, + /// Custom prompt to include as part of history compaction. + prompt: Option, + /// Whether or not the summary should be shown on compact success. + show_summary: bool, + /// Whether or not to show the /compact help text. + help: bool, + }, + /// Exit the chat. + Exit, +} + +impl Default for ChatState { + fn default() -> Self { + Self::PromptUser { + tool_uses: None, + pending_tool_index: None, + skip_printing_tools: false, + } + } +} + +impl ChatContext { + /// Opens the user's preferred editor to compose a prompt + fn open_editor(initial_text: Option) -> Result { + // Create a temporary file with a unique name + let temp_dir = std::env::temp_dir(); + let file_name = format!("q_prompt_{}.md", Uuid::new_v4()); + let temp_file_path = temp_dir.join(file_name); + + // Get the editor from environment variable or use a default + let editor_cmd = env::var("EDITOR").unwrap_or_else(|_| "vi".to_string()); + + // Parse the editor command to handle arguments + let mut parts = + shlex::split(&editor_cmd).ok_or_else(|| ChatError::Custom("Failed to parse EDITOR command".into()))?; + + if parts.is_empty() { + return Err(ChatError::Custom("EDITOR environment variable is empty".into())); + } + + let editor_bin = parts.remove(0); + + // Write initial content to the file if provided + let initial_content = initial_text.unwrap_or_default(); + fs::write(&temp_file_path, &initial_content) + .map_err(|e| ChatError::Custom(format!("Failed to create temporary file: {}", e).into()))?; + + // Open the editor with the parsed command and arguments + let mut cmd = ProcessCommand::new(editor_bin); + // Add any arguments that were part of the EDITOR variable + for arg in parts { + cmd.arg(arg); + } + // Add the file path as the last argument + let status = cmd + .arg(&temp_file_path) + .status() + .map_err(|e| ChatError::Custom(format!("Failed to open editor: {}", e).into()))?; + + if !status.success() { + return Err(ChatError::Custom("Editor exited with non-zero status".into())); + } + + // Read the content back + let content = fs::read_to_string(&temp_file_path) + .map_err(|e| ChatError::Custom(format!("Failed to read temporary file: {}", e).into()))?; + + // Clean up the temporary file + let _ = fs::remove_file(&temp_file_path); + + Ok(content.trim().to_string()) + } + + fn draw_tip_box(&mut self, text: &str) -> Result<()> { + let box_width = GREETING_BREAK_POINT; + let inner_width = box_width - 4; // account for │ and padding + + // wrap the single line into multiple lines respecting inner width + // Manually wrap the text by splitting at word boundaries + let mut wrapped_lines = Vec::new(); + let mut line = String::new(); + + for word in text.split_whitespace() { + if line.len() + word.len() < inner_width { + if !line.is_empty() { + line.push(' '); + } + line.push_str(word); + } else { + // Here we need to account for words that are too long as well + if word.len() >= inner_width { + let mut start = 0_usize; + for (i, _) in word.chars().enumerate() { + if i - start >= inner_width { + wrapped_lines.push(word[start..i].to_string()); + start = i; + } + } + wrapped_lines.push(word[start..].to_string()); + line = String::new(); + } else { + wrapped_lines.push(line); + line = word.to_string(); + } + } + } + + if !line.is_empty() { + wrapped_lines.push(line); + } + + // ───── Did you know? ───── + let label = " Did you know? "; + let side_len = (box_width.saturating_sub(label.len())) / 2; + let top_border = format!( + "╭{}{}{}╮", + "─".repeat(side_len - 1), + label, + "─".repeat(box_width - side_len - label.len() - 1) + ); + + // Build output + execute!( + self.output, + terminal::Clear(ClearType::CurrentLine), + cursor::MoveToColumn(0), + style::Print(format!("{top_border}\n")), + )?; + + // Top vertical padding + execute!( + self.output, + style::Print(format!("│{: Result<()> { + let is_small_screen = self.terminal_width() < GREETING_BREAK_POINT; + if self.interactive && self.settings.get_bool_or("chat.greeting.enabled", true) { + execute!( + self.output, + style::Print(if is_small_screen { + SMALL_SCREEN_WECLOME_TEXT + } else { + WELCOME_TEXT + }), + style::Print("\n\n"), + )?; + + let current_tip_index = + (self.state.get_int_or("chat.greeting.rotating_tips_current_index", 0) as usize) % ROTATING_TIPS.len(); + + let tip = ROTATING_TIPS[current_tip_index]; + if is_small_screen { + // If the screen is small, print the tip in a single line + execute!( + self.output, + style::Print("💡 ".to_string()), + style::Print(tip), + style::Print("\n") + )?; + } else { + self.draw_tip_box(tip)?; + } + + execute!( + self.output, + style::Print(if is_small_screen { + SMALL_SCREEN_POPULAR_SHORTCUTS + } else { + POPULAR_SHORTCUTS + }), + style::Print( + "━" + .repeat(if is_small_screen { 0 } else { GREETING_BREAK_POINT }) + .dark_grey() + ) + )?; + execute!(self.output, style::Print("\n"), style::SetForegroundColor(Color::Reset))?; + + // update the current tip index + let next_tip_index = (current_tip_index + 1) % ROTATING_TIPS.len(); + self.state + .set_value("chat.greeting.rotating_tips_current_index", next_tip_index)?; + } + + if self.interactive && self.all_tools_trusted() { + queue!( + self.output, + style::Print(format!( + "{}{TRUST_ALL_TEXT}\n\n", + if !is_small_screen { "\n" } else { "" } + )) + )?; + } + self.output.flush()?; + + let mut ctrl_c_stream = signal(SignalKind::interrupt())?; + + let mut next_state = Some(ChatState::PromptUser { + tool_uses: None, + pending_tool_index: None, + skip_printing_tools: true, + }); + + if let Some(user_input) = self.initial_input.take() { + if self.interactive { + execute!( + self.output, + style::SetAttribute(Attribute::Reset), + style::SetForegroundColor(Color::Magenta), + style::Print("> "), + style::SetAttribute(Attribute::Reset), + style::Print(&user_input), + style::Print("\n") + )?; + } + next_state = Some(ChatState::HandleInput { + input: user_input, + tool_uses: None, + pending_tool_index: None, + }); + } + + loop { + debug_assert!(next_state.is_some()); + let chat_state = next_state.take().unwrap_or_default(); + debug!(?chat_state, "changing to state"); + + let result = match chat_state { + ChatState::PromptUser { + tool_uses, + pending_tool_index, + skip_printing_tools, + } => { + // Cannot prompt in non-interactive mode no matter what. + if !self.interactive { + return Ok(()); + } + self.prompt_user(tool_uses, pending_tool_index, skip_printing_tools) + .await + }, + ChatState::HandleInput { + input, + tool_uses, + pending_tool_index, + } => { + let tool_uses_clone = tool_uses.clone(); + tokio::select! { + res = self.handle_input(input, tool_uses, pending_tool_index) => res, + Some(_) = ctrl_c_stream.recv() => Err(ChatError::Interrupted { tool_uses: tool_uses_clone }) + } + }, + ChatState::CompactHistory { + tool_uses, + pending_tool_index, + prompt, + show_summary, + help, + } => { + let tool_uses_clone = tool_uses.clone(); + tokio::select! { + res = self.compact_history(tool_uses, pending_tool_index, prompt, show_summary, help) => res, + Some(_) = ctrl_c_stream.recv() => Err(ChatError::Interrupted { tool_uses: tool_uses_clone }) + } + }, + ChatState::ExecuteTools(tool_uses) => { + let tool_uses_clone = tool_uses.clone(); + tokio::select! { + res = self.tool_use_execute(tool_uses) => res, + Some(_) = ctrl_c_stream.recv() => Err(ChatError::Interrupted { tool_uses: Some(tool_uses_clone) }) + } + }, + ChatState::ValidateTools(tool_uses) => { + tokio::select! { + res = self.validate_tools(tool_uses) => res, + Some(_) = ctrl_c_stream.recv() => Err(ChatError::Interrupted { tool_uses: None }) + } + }, + ChatState::HandleResponseStream(response) => tokio::select! { + res = self.handle_response(response) => res, + Some(_) = ctrl_c_stream.recv() => Err(ChatError::Interrupted { tool_uses: None }) + }, + ChatState::Exit => return Ok(()), + }; + + next_state = Some(self.handle_state_execution_result(result).await?); + } + } + + /// Handles the result of processing a [ChatState], returning the next [ChatState] to change + /// to. + async fn handle_state_execution_result( + &mut self, + result: Result, + ) -> Result { + // Remove non-ASCII and ANSI characters. + let re = Regex::new(r"((\x9B|\x1B\[)[0-?]*[ -\/]*[@-~])|([^\x00-\x7F]+)").unwrap(); + match result { + Ok(state) => Ok(state), + Err(e) => { + macro_rules! print_err { + ($prepend_msg:expr, $err:expr) => {{ + queue!( + self.output, + style::SetAttribute(Attribute::Bold), + style::SetForegroundColor(Color::Red), + )?; + + let report = eyre::Report::from($err); + + let text = re + .replace_all(&format!("{}: {:?}\n", $prepend_msg, report), "") + .into_owned(); + + queue!(self.output, style::Print(&text),)?; + self.conversation_state.append_transcript(text); + + execute!( + self.output, + style::SetAttribute(Attribute::Reset), + style::SetForegroundColor(Color::Reset), + )?; + }}; + } + + macro_rules! print_default_error { + ($err:expr) => { + print_err!("Amazon Q is having trouble responding right now", $err); + }; + } + + error!(?e, "An error occurred processing the current state"); + if self.interactive && self.spinner.is_some() { + drop(self.spinner.take()); + queue!( + self.output, + terminal::Clear(terminal::ClearType::CurrentLine), + cursor::MoveToColumn(0), + )?; + } + match e { + ChatError::Interrupted { tool_uses: inter } => { + execute!(self.output, style::Print("\n\n"))?; + // If there was an interrupt during tool execution, then we add fake + // messages to "reset" the chat state. + match inter { + Some(tool_uses) if !tool_uses.is_empty() => { + self.conversation_state.abandon_tool_use( + tool_uses, + "The user interrupted the tool execution.".to_string(), + ); + let _ = self.conversation_state.as_sendable_conversation_state(false).await; + self.conversation_state + .push_assistant_message(AssistantMessage::new_response( + None, + "Tool uses were interrupted, waiting for the next user prompt".to_string(), + )); + }, + _ => (), + } + }, + ChatError::Client(err) => match err { + // Errors from attempting to send too large of a conversation history. In + // this case, attempt to automatically compact the history for the user. + crate::fig_api_client::Error::ContextWindowOverflow => { + let history_too_small = self + .conversation_state + .backend_conversation_state(false, true) + .await + .history + .len() + < 2; + if history_too_small { + print_err!( + "Your conversation is too large - try reducing the size of + the context being passed", + err + ); + return Ok(ChatState::PromptUser { + tool_uses: None, + pending_tool_index: None, + skip_printing_tools: false, + }); + } + + return Ok(ChatState::CompactHistory { + tool_uses: None, + pending_tool_index: None, + prompt: None, + show_summary: false, + help: false, + }); + }, + crate::fig_api_client::Error::QuotaBreach(msg) => { + print_err!(msg, err); + }, + _ => { + print_default_error!(err); + }, + }, + _ => { + print_default_error!(e); + }, + } + self.conversation_state.enforce_conversation_invariants(); + self.conversation_state.reset_next_user_message(); + Ok(ChatState::PromptUser { + tool_uses: None, + pending_tool_index: None, + skip_printing_tools: false, + }) + }, + } + } + + /// Compacts the conversation history, replacing the history with a summary generated by the + /// model. + /// + /// The last two user messages in the history are not included in the compaction process. + async fn compact_history( + &mut self, + tool_uses: Option>, + pending_tool_index: Option, + custom_prompt: Option, + show_summary: bool, + help: bool, + ) -> Result { + let hist = self.conversation_state.history(); + debug!(?hist, "compacting history"); + + // If help flag is set, show compact command help + if help { + execute!( + self.output, + style::Print("\n"), + style::Print(compact_help_text()), + style::Print("\n") + )?; + + return Ok(ChatState::PromptUser { + tool_uses, + pending_tool_index, + skip_printing_tools: true, + }); + } + + if self.conversation_state.history().len() < 2 { + execute!( + self.output, + style::SetForegroundColor(Color::Yellow), + style::Print("\nConversation too short to compact.\n\n"), + style::SetForegroundColor(Color::Reset) + )?; + + return Ok(ChatState::PromptUser { + tool_uses, + pending_tool_index, + skip_printing_tools: true, + }); + } + + // Send a request for summarizing the history. + let summary_state = self + .conversation_state + .create_summary_request(custom_prompt.as_ref()) + .await; + if self.interactive { + execute!(self.output, cursor::Hide, style::Print("\n"))?; + self.spinner = Some(Spinner::new(Spinners::Dots, "Creating summary...".to_string())); + } + let response = self.client.send_message(summary_state).await; + + // TODO(brandonskiser): This is a temporary hotfix for failing compaction. We should instead + // retry except with less context included. + let response = match response { + Ok(res) => res, + Err(e) => match e { + crate::fig_api_client::Error::ContextWindowOverflow => { + self.conversation_state.clear(true); + if self.interactive { + self.spinner.take(); + execute!( + self.output, + terminal::Clear(terminal::ClearType::CurrentLine), + cursor::MoveToColumn(0), + style::SetForegroundColor(Color::Yellow), + style::Print( + "The context window usage has overflowed. Clearing the conversation history.\n\n" + ), + style::SetAttribute(Attribute::Reset) + )?; + } + return Ok(ChatState::PromptUser { + tool_uses, + pending_tool_index, + skip_printing_tools: true, + }); + }, + e => return Err(e.into()), + }, + }; + + let summary = { + let mut parser = ResponseParser::new(response); + loop { + match parser.recv().await { + Ok(parser::ResponseEvent::EndStream { message }) => { + break message.content().to_string(); + }, + Ok(_) => (), + Err(err) => { + if let Some(request_id) = &err.request_id { + self.failed_request_ids.push(request_id.clone()); + }; + return Err(err.into()); + }, + } + } + }; + + if self.interactive && self.spinner.is_some() { + drop(self.spinner.take()); + queue!( + self.output, + terminal::Clear(terminal::ClearType::CurrentLine), + cursor::MoveToColumn(0), + cursor::Show + )?; + } + + if let Some(message_id) = self.conversation_state.message_id() { + crate::fig_telemetry::send_chat_added_message( + self.conversation_state.conversation_id().to_owned(), + message_id.to_owned(), + self.conversation_state.context_message_length(), + ) + .await; + } + + self.conversation_state.replace_history_with_summary(summary.clone()); + + // Print output to the user. + { + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print("✔ Conversation history has been compacted successfully!\n\n"), + style::SetForegroundColor(Color::DarkGrey) + )?; + + let mut output = Vec::new(); + if let Some(custom_prompt) = &custom_prompt { + execute!( + output, + style::Print(format!("• Custom prompt applied: {}\n", custom_prompt)) + )?; + } + animate_output(&mut self.output, &output)?; + + // Display the summary if the show_summary flag is set + if show_summary { + // Add a border around the summary for better visual separation + let terminal_width = self.terminal_width(); + let border = "═".repeat(terminal_width.min(80)); + execute!( + self.output, + style::Print("\n"), + style::SetForegroundColor(Color::Cyan), + style::Print(&border), + style::Print("\n"), + style::SetAttribute(Attribute::Bold), + style::Print(" CONVERSATION SUMMARY"), + style::Print("\n"), + style::Print(&border), + style::SetAttribute(Attribute::Reset), + style::Print("\n\n"), + )?; + + execute!( + output, + style::Print(&summary), + style::Print("\n\n"), + style::SetForegroundColor(Color::Cyan), + style::Print("The conversation history has been replaced with this summary.\n"), + style::Print("It contains all important details from previous interactions.\n"), + )?; + animate_output(&mut self.output, &output)?; + + execute!( + self.output, + style::Print(&border), + style::Print("\n\n"), + style::SetForegroundColor(Color::Reset) + )?; + } + } + + // If a next message is set, then retry the request. + if self.conversation_state.next_user_message().is_some() { + Ok(ChatState::HandleResponseStream( + self.client + .send_message(self.conversation_state.as_sendable_conversation_state(false).await) + .await?, + )) + } else { + // Otherwise, return back to the prompt for any pending tool uses. + Ok(ChatState::PromptUser { + tool_uses, + pending_tool_index, + skip_printing_tools: true, + }) + } + } + + /// Read input from the user. + async fn prompt_user( + &mut self, + mut tool_uses: Option>, + pending_tool_index: Option, + skip_printing_tools: bool, + ) -> Result { + execute!(self.output, cursor::Show)?; + let tool_uses = tool_uses.take().unwrap_or_default(); + + // Check token usage and display warnings if needed + if pending_tool_index.is_none() { + // Only display warnings when not waiting for tool approval + if let Err(e) = self.display_char_warnings().await { + warn!("Failed to display character limit warnings: {}", e); + } + } + + let show_tool_use_confirmation_dialog = !skip_printing_tools && pending_tool_index.is_some(); + if show_tool_use_confirmation_dialog { + execute!( + self.output, + style::SetForegroundColor(Color::DarkGrey), + style::Print("\nAllow this action? Use '"), + style::SetForegroundColor(Color::Green), + style::Print("t"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("' to trust (always allow) this tool for the session. ["), + style::SetForegroundColor(Color::Green), + style::Print("y"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("/"), + style::SetForegroundColor(Color::Green), + style::Print("n"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("/"), + style::SetForegroundColor(Color::Green), + style::Print("t"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("]:\n\n"), + style::SetForegroundColor(Color::Reset), + )?; + } + + // Do this here so that the skim integration sees an updated view of the context *during the current + // q session*. (e.g., if I add files to context, that won't show up for skim for the current + // q session unless we do this in prompt_user... unless you can find a better way) + if let Some(ref context_manager) = self.conversation_state.context_manager { + let tool_names = self.tool_manager.tn_map.keys().cloned().collect::>(); + self.input_source + .put_skim_command_selector(Arc::new(context_manager.clone()), tool_names); + } + execute!( + self.output, + style::SetForegroundColor(Color::Reset), + style::SetAttribute(Attribute::Reset) + )?; + let user_input = match self.read_user_input(&self.generate_tool_trust_prompt(), false) { + Some(input) => input, + None => return Ok(ChatState::Exit), + }; + + self.conversation_state.append_user_transcript(&user_input); + Ok(ChatState::HandleInput { + input: user_input, + tool_uses: Some(tool_uses), + pending_tool_index, + }) + } + + async fn handle_input( + &mut self, + mut user_input: String, + tool_uses: Option>, + pending_tool_index: Option, + ) -> Result { + let command_result = Command::parse(&user_input, &mut self.output); + + if let Err(error_message) = &command_result { + // Display error message for command parsing errors + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!("\nError: {}\n\n", error_message)), + style::SetForegroundColor(Color::Reset) + )?; + + return Ok(ChatState::PromptUser { + tool_uses, + pending_tool_index, + skip_printing_tools: true, + }); + } + + let command = command_result.unwrap(); + let mut tool_uses: Vec = tool_uses.unwrap_or_default(); + + Ok(match command { + Command::Ask { prompt } => { + // Check for a pending tool approval + if let Some(index) = pending_tool_index { + let tool_use = &mut tool_uses[index]; + + let is_trust = ["t", "T"].contains(&prompt.as_str()); + if ["y", "Y"].contains(&prompt.as_str()) || is_trust { + if is_trust { + self.tool_permissions.trust_tool(&tool_use.name); + } + tool_use.accepted = true; + + return Ok(ChatState::ExecuteTools(tool_uses)); + } + } else if !self.pending_prompts.is_empty() { + let prompts = self.pending_prompts.drain(0..).collect(); + user_input = self + .conversation_state + .append_prompts(prompts) + .ok_or(ChatError::Custom("Prompt append failed".into()))?; + } + + // Otherwise continue with normal chat on 'n' or other responses + self.tool_use_status = ToolUseStatus::Idle; + + if pending_tool_index.is_some() { + self.conversation_state.abandon_tool_use(tool_uses, user_input); + } else { + self.conversation_state.set_next_user_message(user_input).await; + } + + let conv_state = self.conversation_state.as_sendable_conversation_state(true).await; + + if self.interactive { + queue!(self.output, style::SetForegroundColor(Color::Magenta))?; + queue!(self.output, style::SetForegroundColor(Color::Reset))?; + queue!(self.output, cursor::Hide)?; + execute!(self.output, style::Print("\n"))?; + self.spinner = Some(Spinner::new(Spinners::Dots, "Thinking...".to_owned())); + } + + self.send_tool_use_telemetry().await; + + ChatState::HandleResponseStream(self.client.send_message(conv_state).await?) + }, + Command::Execute { command } => { + queue!(self.output, style::Print('\n'))?; + std::process::Command::new("bash").args(["-c", &command]).status().ok(); + queue!(self.output, style::Print('\n'))?; + ChatState::PromptUser { + tool_uses: None, + pending_tool_index: None, + skip_printing_tools: false, + } + }, + Command::Clear => { + execute!(self.output, cursor::Show)?; + execute!( + self.output, + style::SetForegroundColor(Color::DarkGrey), + style::Print( + "\nAre you sure? This will erase the conversation history and context from hooks for the current session. " + ), + style::Print("["), + style::SetForegroundColor(Color::Green), + style::Print("y"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("/"), + style::SetForegroundColor(Color::Green), + style::Print("n"), + style::SetForegroundColor(Color::DarkGrey), + style::Print("]:\n\n"), + style::SetForegroundColor(Color::Reset), + )?; + + // Setting `exit_on_single_ctrl_c` for better ux: exit the confirmation dialog rather than the CLI + let user_input = match self.read_user_input("> ".yellow().to_string().as_str(), true) { + Some(input) => input, + None => "".to_string(), + }; + + if ["y", "Y"].contains(&user_input.as_str()) { + self.conversation_state.clear(true); + if let Some(cm) = self.conversation_state.context_manager.as_mut() { + cm.hook_executor.global_cache.clear(); + cm.hook_executor.profile_cache.clear(); + } + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print("\nConversation history cleared.\n\n"), + style::SetForegroundColor(Color::Reset) + )?; + } + + ChatState::PromptUser { + tool_uses: None, + pending_tool_index: None, + skip_printing_tools: true, + } + }, + Command::Compact { + prompt, + show_summary, + help, + } => { + self.compact_history(Some(tool_uses), pending_tool_index, prompt, show_summary, help) + .await? + }, + Command::Help => { + execute!(self.output, style::Print(HELP_TEXT))?; + ChatState::PromptUser { + tool_uses: Some(tool_uses), + pending_tool_index, + skip_printing_tools: true, + } + }, + Command::Issue { prompt } => { + let input = "I would like to report an issue or make a feature request"; + ChatState::HandleInput { + input: if let Some(prompt) = prompt { + format!("{input}: {prompt}") + } else { + input.to_string() + }, + tool_uses: Some(tool_uses), + pending_tool_index, + } + }, + Command::PromptEditor { initial_text } => { + match Self::open_editor(initial_text) { + Ok(content) => { + if content.trim().is_empty() { + execute!( + self.output, + style::SetForegroundColor(Color::Yellow), + style::Print("\nEmpty content from editor, not submitting.\n\n"), + style::SetForegroundColor(Color::Reset) + )?; + + ChatState::PromptUser { + tool_uses: Some(tool_uses), + pending_tool_index, + skip_printing_tools: true, + } + } else { + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print("\nContent loaded from editor. Submitting prompt...\n\n"), + style::SetForegroundColor(Color::Reset) + )?; + + // Display the content as if the user typed it + execute!( + self.output, + style::SetAttribute(Attribute::Reset), + style::SetForegroundColor(Color::Magenta), + style::Print("> "), + style::SetAttribute(Attribute::Reset), + style::Print(&content), + style::Print("\n") + )?; + + // Process the content as user input + ChatState::HandleInput { + input: content, + tool_uses: Some(tool_uses), + pending_tool_index, + } + } + }, + Err(e) => { + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!("\nError opening editor: {}\n\n", e)), + style::SetForegroundColor(Color::Reset) + )?; + + ChatState::PromptUser { + tool_uses: Some(tool_uses), + pending_tool_index, + skip_printing_tools: true, + } + }, + } + }, + Command::Quit => ChatState::Exit, + Command::Profile { subcommand } => { + if let Some(context_manager) = &mut self.conversation_state.context_manager { + macro_rules! print_err { + ($err:expr) => { + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!("\nError: {}\n\n", $err)), + style::SetForegroundColor(Color::Reset) + )? + }; + } + + match subcommand { + command::ProfileSubcommand::List => { + let profiles = match context_manager.list_profiles().await { + Ok(profiles) => profiles, + Err(e) => { + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!("\nError listing profiles: {}\n\n", e)), + style::SetForegroundColor(Color::Reset) + )?; + vec![] + }, + }; + + execute!(self.output, style::Print("\n"))?; + for profile in profiles { + if profile == context_manager.current_profile { + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print("* "), + style::Print(&profile), + style::SetForegroundColor(Color::Reset), + style::Print("\n") + )?; + } else { + execute!( + self.output, + style::Print(" "), + style::Print(&profile), + style::Print("\n") + )?; + } + } + execute!(self.output, style::Print("\n"))?; + }, + command::ProfileSubcommand::Create { name } => { + match context_manager.create_profile(&name).await { + Ok(_) => { + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!("\nCreated profile: {}\n\n", name)), + style::SetForegroundColor(Color::Reset) + )?; + context_manager + .switch_profile(&name) + .await + .map_err(|e| warn!(?e, "failed to switch to newly created profile")) + .ok(); + }, + Err(e) => print_err!(e), + } + }, + command::ProfileSubcommand::Delete { name } => { + match context_manager.delete_profile(&name).await { + Ok(_) => { + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!("\nDeleted profile: {}\n\n", name)), + style::SetForegroundColor(Color::Reset) + )?; + }, + Err(e) => print_err!(e), + } + }, + command::ProfileSubcommand::Set { name } => match context_manager.switch_profile(&name).await { + Ok(_) => { + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!("\nSwitched to profile: {}\n\n", name)), + style::SetForegroundColor(Color::Reset) + )?; + }, + Err(e) => print_err!(e), + }, + command::ProfileSubcommand::Rename { old_name, new_name } => { + match context_manager.rename_profile(&old_name, &new_name).await { + Ok(_) => { + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!("\nRenamed profile: {} -> {}\n\n", old_name, new_name)), + style::SetForegroundColor(Color::Reset) + )?; + }, + Err(e) => print_err!(e), + } + }, + command::ProfileSubcommand::Help => { + execute!( + self.output, + style::Print("\n"), + style::Print(command::ProfileSubcommand::help_text()), + style::Print("\n") + )?; + }, + } + } + ChatState::PromptUser { + tool_uses: Some(tool_uses), + pending_tool_index, + skip_printing_tools: true, + } + }, + Command::Context { subcommand } => { + if let Some(context_manager) = &mut self.conversation_state.context_manager { + match subcommand { + command::ContextSubcommand::Show { expand } => { + // Display global context + execute!( + self.output, + style::SetAttribute(Attribute::Bold), + style::SetForegroundColor(Color::Magenta), + style::Print("\n🌍 global:\n"), + style::SetAttribute(Attribute::Reset), + )?; + let mut global_context_files = HashSet::new(); + let mut profile_context_files = HashSet::new(); + if context_manager.global_config.paths.is_empty() { + execute!( + self.output, + style::SetForegroundColor(Color::DarkGrey), + style::Print(" \n"), + style::SetForegroundColor(Color::Reset) + )?; + } else { + for path in &context_manager.global_config.paths { + execute!(self.output, style::Print(format!(" {} ", path)))?; + if let Ok(context_files) = + context_manager.get_context_files_by_path(false, path).await + { + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!( + "({} match{})", + context_files.len(), + if context_files.len() == 1 { "" } else { "es" } + )), + style::SetForegroundColor(Color::Reset) + )?; + global_context_files.extend(context_files); + } + execute!(self.output, style::Print("\n"))?; + } + } + + // Display profile context + execute!( + self.output, + style::SetAttribute(Attribute::Bold), + style::SetForegroundColor(Color::Magenta), + style::Print(format!("\n👤 profile ({}):\n", context_manager.current_profile)), + style::SetAttribute(Attribute::Reset), + )?; + + if context_manager.profile_config.paths.is_empty() { + execute!( + self.output, + style::SetForegroundColor(Color::DarkGrey), + style::Print(" \n\n"), + style::SetForegroundColor(Color::Reset) + )?; + } else { + for path in &context_manager.profile_config.paths { + execute!(self.output, style::Print(format!(" {} ", path)))?; + if let Ok(context_files) = + context_manager.get_context_files_by_path(false, path).await + { + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!( + "({} match{})", + context_files.len(), + if context_files.len() == 1 { "" } else { "es" } + )), + style::SetForegroundColor(Color::Reset) + )?; + profile_context_files.extend(context_files); + } + execute!(self.output, style::Print("\n"))?; + } + execute!(self.output, style::Print("\n"))?; + } + + if global_context_files.is_empty() && profile_context_files.is_empty() { + execute!( + self.output, + style::SetForegroundColor(Color::DarkGrey), + style::Print("No files in the current directory matched the rules above.\n\n"), + style::SetForegroundColor(Color::Reset) + )?; + } else { + let total = global_context_files.len() + profile_context_files.len(); + let total_tokens = global_context_files + .iter() + .map(|(_, content)| TokenCounter::count_tokens(content)) + .sum::() + + profile_context_files + .iter() + .map(|(_, content)| TokenCounter::count_tokens(content)) + .sum::(); + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::SetAttribute(Attribute::Bold), + style::Print(format!( + "{} matched file{} in use:\n", + total, + if total == 1 { "" } else { "s" } + )), + style::SetForegroundColor(Color::Reset), + style::SetAttribute(Attribute::Reset) + )?; + + for (filename, content) in global_context_files { + let est_tokens = TokenCounter::count_tokens(&content); + execute!( + self.output, + style::Print(format!("🌍 {} ", filename)), + style::SetForegroundColor(Color::DarkGrey), + style::Print(format!("(~{} tkns)\n", est_tokens)), + style::SetForegroundColor(Color::Reset), + )?; + if expand { + execute!( + self.output, + style::SetForegroundColor(Color::DarkGrey), + style::Print(format!("{}\n\n", content)), + style::SetForegroundColor(Color::Reset) + )?; + } + } + + for (filename, content) in profile_context_files { + let est_tokens = TokenCounter::count_tokens(&content); + execute!( + self.output, + style::Print(format!("👤 {} ", filename)), + style::SetForegroundColor(Color::DarkGrey), + style::Print(format!("(~{} tkns)\n", est_tokens)), + style::SetForegroundColor(Color::Reset), + )?; + if expand { + execute!( + self.output, + style::SetForegroundColor(Color::DarkGrey), + style::Print(format!("{}\n\n", content)), + style::SetForegroundColor(Color::Reset) + )?; + } + } + + if expand { + execute!(self.output, style::Print(format!("{}\n\n", "▔".repeat(3))),)?; + } + + execute!( + self.output, + style::Print(format!("\nTotal: ~{} tokens\n\n", total_tokens)), + )?; + + execute!(self.output, style::Print("\n"))?; + } + }, + command::ContextSubcommand::Add { global, force, paths } => { + match context_manager.add_paths(paths.clone(), global, force).await { + Ok(_) => { + let target = if global { "global" } else { "profile" }; + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!( + "\nAdded {} path(s) to {} context.\n\n", + paths.len(), + target + )), + style::SetForegroundColor(Color::Reset) + )?; + }, + Err(e) => { + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!("\nError: {}\n\n", e)), + style::SetForegroundColor(Color::Reset) + )?; + }, + } + }, + command::ContextSubcommand::Remove { global, paths } => { + match context_manager.remove_paths(paths.clone(), global).await { + Ok(_) => { + let target = if global { "global" } else { "profile" }; + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!( + "\nRemoved {} path(s) from {} context.\n\n", + paths.len(), + target + )), + style::SetForegroundColor(Color::Reset) + )?; + }, + Err(e) => { + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!("\nError: {}\n\n", e)), + style::SetForegroundColor(Color::Reset) + )?; + }, + } + }, + command::ContextSubcommand::Clear { global } => match context_manager.clear(global).await { + Ok(_) => { + let target = if global { + "global".to_string() + } else { + format!("profile '{}'", context_manager.current_profile) + }; + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!("\nCleared context for {}\n\n", target)), + style::SetForegroundColor(Color::Reset) + )?; + }, + Err(e) => { + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!("\nError: {}\n\n", e)), + style::SetForegroundColor(Color::Reset) + )?; + }, + }, + command::ContextSubcommand::Help => { + execute!( + self.output, + style::Print("\n"), + style::Print(command::ContextSubcommand::help_text()), + style::Print("\n") + )?; + }, + command::ContextSubcommand::Hooks { subcommand } => { + fn map_chat_error(e: ErrReport) -> ChatError { + ChatError::Custom(e.to_string().into()) + } + + let scope = |g: bool| if g { "global" } else { "profile" }; + if let Some(subcommand) = subcommand { + match subcommand { + command::HooksSubcommand::Add { + name, + trigger, + command, + global, + } => { + let trigger = if trigger == "conversation_start" { + HookTrigger::ConversationStart + } else { + HookTrigger::PerPrompt + }; + + let result = context_manager + .add_hook(name.clone(), Hook::new_inline_hook(trigger, command), global) + .await; + match result { + Ok(_) => { + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!( + "\nAdded {} hook '{name}'.\n\n", + scope(global) + )), + style::SetForegroundColor(Color::Reset) + )?; + }, + Err(e) => { + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!( + "\nCannot add {} hook '{name}': {}\n\n", + scope(global), + e + )), + style::SetForegroundColor(Color::Reset) + )?; + }, + } + }, + command::HooksSubcommand::Remove { name, global } => { + let result = context_manager.remove_hook(&name, global).await; + match result { + Ok(_) => { + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!( + "\nRemoved {} hook '{name}'.\n\n", + scope(global) + )), + style::SetForegroundColor(Color::Reset) + )?; + }, + Err(e) => { + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!( + "\nCannot remove {} hook '{name}': {}\n\n", + scope(global), + e + )), + style::SetForegroundColor(Color::Reset) + )?; + }, + } + }, + command::HooksSubcommand::Enable { name, global } => { + let result = context_manager.set_hook_disabled(&name, global, false).await; + match result { + Ok(_) => { + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!( + "\nEnabled {} hook '{name}'.\n\n", + scope(global) + )), + style::SetForegroundColor(Color::Reset) + )?; + }, + Err(e) => { + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!( + "\nCannot enable {} hook '{name}': {}\n\n", + scope(global), + e + )), + style::SetForegroundColor(Color::Reset) + )?; + }, + } + }, + command::HooksSubcommand::Disable { name, global } => { + let result = context_manager.set_hook_disabled(&name, global, true).await; + match result { + Ok(_) => { + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!( + "\nDisabled {} hook '{name}'.\n\n", + scope(global) + )), + style::SetForegroundColor(Color::Reset) + )?; + }, + Err(e) => { + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!( + "\nCannot disable {} hook '{name}': {}\n\n", + scope(global), + e + )), + style::SetForegroundColor(Color::Reset) + )?; + }, + } + }, + command::HooksSubcommand::EnableAll { global } => { + context_manager + .set_all_hooks_disabled(global, false) + .await + .map_err(map_chat_error)?; + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!("\nEnabled all {} hooks.\n\n", scope(global))), + style::SetForegroundColor(Color::Reset) + )?; + }, + command::HooksSubcommand::DisableAll { global } => { + context_manager + .set_all_hooks_disabled(global, true) + .await + .map_err(map_chat_error)?; + execute!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!("\nDisabled all {} hooks.\n\n", scope(global))), + style::SetForegroundColor(Color::Reset) + )?; + }, + command::HooksSubcommand::Help => { + execute!( + self.output, + style::Print("\n"), + style::Print(command::ContextSubcommand::hooks_help_text()), + style::Print("\n") + )?; + }, + } + } else { + fn print_hook_section( + output: &mut impl Write, + hooks: &HashMap, + trigger: HookTrigger, + ) -> Result<()> { + let section = match trigger { + HookTrigger::ConversationStart => "Conversation Start", + HookTrigger::PerPrompt => "Per Prompt", + }; + let hooks: Vec<(&String, &Hook)> = + hooks.iter().filter(|(_, h)| h.trigger == trigger).collect(); + + queue!( + output, + style::SetForegroundColor(Color::Cyan), + style::Print(format!(" {section}:\n")), + style::SetForegroundColor(Color::Reset), + )?; + + if hooks.is_empty() { + queue!( + output, + style::SetForegroundColor(Color::DarkGrey), + style::Print(" \n"), + style::SetForegroundColor(Color::Reset) + )?; + } else { + for (name, hook) in hooks { + if hook.disabled { + queue!( + output, + style::SetForegroundColor(Color::DarkGrey), + style::Print(format!(" {} (disabled)\n", name)), + style::SetForegroundColor(Color::Reset) + )?; + } else { + queue!(output, style::Print(format!(" {}\n", name)),)?; + } + } + } + Ok(()) + } + queue!( + self.output, + style::SetAttribute(Attribute::Bold), + style::SetForegroundColor(Color::Magenta), + style::Print("\n🌍 global:\n"), + style::SetAttribute(Attribute::Reset), + )?; + + print_hook_section( + &mut self.output, + &context_manager.global_config.hooks, + HookTrigger::ConversationStart, + ) + .map_err(map_chat_error)?; + print_hook_section( + &mut self.output, + &context_manager.global_config.hooks, + HookTrigger::PerPrompt, + ) + .map_err(map_chat_error)?; + + queue!( + self.output, + style::SetAttribute(Attribute::Bold), + style::SetForegroundColor(Color::Magenta), + style::Print(format!("\n👤 profile ({}):\n", &context_manager.current_profile)), + style::SetAttribute(Attribute::Reset), + )?; + + print_hook_section( + &mut self.output, + &context_manager.profile_config.hooks, + HookTrigger::ConversationStart, + ) + .map_err(map_chat_error)?; + print_hook_section( + &mut self.output, + &context_manager.profile_config.hooks, + HookTrigger::PerPrompt, + ) + .map_err(map_chat_error)?; + + execute!( + self.output, + style::Print(format!( + "\nUse {} to manage hooks.\n\n", + "/context hooks help".to_string().dark_green() + )), + )?; + } + }, + } + // crate::fig_telemetry::send_context_command_executed + } else { + execute!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print("\nContext management is not available.\n\n"), + style::SetForegroundColor(Color::Reset) + )?; + } + + ChatState::PromptUser { + tool_uses: Some(tool_uses), + pending_tool_index, + skip_printing_tools: true, + } + }, + Command::Tools { subcommand } => { + let existing_tools: HashSet<&String> = self + .conversation_state + .tools + .values() + .flatten() + .map(|FigTool::ToolSpecification(spec)| &spec.name) + .collect(); + + match subcommand { + Some(ToolsSubcommand::Schema) => { + let schema_json = serde_json::to_string_pretty(&self.tool_manager.schema).map_err(|e| { + ChatError::Custom(format!("Error converting tool schema to string: {e}").into()) + })?; + queue!(self.output, style::Print(schema_json), style::Print("\n"))?; + }, + Some(ToolsSubcommand::Trust { tool_names }) => { + let (valid_tools, invalid_tools): (Vec, Vec) = tool_names + .into_iter() + .partition(|tool_name| existing_tools.contains(tool_name)); + + if !invalid_tools.is_empty() { + queue!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!("\nCannot trust '{}', ", invalid_tools.join("', '"))), + if invalid_tools.len() > 1 { + style::Print("they do not exist.") + } else { + style::Print("it does not exist.") + }, + style::SetForegroundColor(Color::Reset), + )?; + } + if !valid_tools.is_empty() { + valid_tools.iter().for_each(|t| self.tool_permissions.trust_tool(t)); + queue!( + self.output, + style::SetForegroundColor(Color::Green), + if valid_tools.len() > 1 { + style::Print(format!("\nTools '{}' are ", valid_tools.join("', '"))) + } else { + style::Print(format!("\nTool '{}' is ", valid_tools[0])) + }, + style::Print("now trusted. I will "), + style::SetAttribute(Attribute::Bold), + style::Print("not"), + style::SetAttribute(Attribute::Reset), + style::SetForegroundColor(Color::Green), + style::Print(format!( + " ask for confirmation before running {}.", + if valid_tools.len() > 1 { + "these tools" + } else { + "this tool" + } + )), + style::SetForegroundColor(Color::Reset), + )?; + } + }, + Some(ToolsSubcommand::Untrust { tool_names }) => { + let (valid_tools, invalid_tools): (Vec, Vec) = tool_names + .into_iter() + .partition(|tool_name| existing_tools.contains(tool_name)); + + if !invalid_tools.is_empty() { + queue!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!("\nCannot untrust '{}', ", invalid_tools.join("', '"))), + if invalid_tools.len() > 1 { + style::Print("they do not exist.") + } else { + style::Print("it does not exist.") + }, + style::SetForegroundColor(Color::Reset), + )?; + } + if !valid_tools.is_empty() { + valid_tools.iter().for_each(|t| self.tool_permissions.untrust_tool(t)); + queue!( + self.output, + style::SetForegroundColor(Color::Green), + if valid_tools.len() > 1 { + style::Print(format!("\nTools '{}' are ", valid_tools.join("', '"))) + } else { + style::Print(format!("\nTool '{}' is ", valid_tools[0])) + }, + style::Print("set to per-request confirmation."), + style::SetForegroundColor(Color::Reset), + )?; + } + }, + Some(ToolsSubcommand::TrustAll) => { + self.conversation_state.tools.values().flatten().for_each( + |FigTool::ToolSpecification(spec)| { + self.tool_permissions.trust_tool(spec.name.as_str()); + }, + ); + queue!(self.output, style::Print(TRUST_ALL_TEXT),)?; + }, + Some(ToolsSubcommand::Reset) => { + self.tool_permissions.reset(); + queue!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print("\nReset all tools to the default permission levels."), + style::SetForegroundColor(Color::Reset), + )?; + }, + Some(ToolsSubcommand::ResetSingle { tool_name }) => { + if self.tool_permissions.has(&tool_name) { + self.tool_permissions.reset_tool(&tool_name); + queue!( + self.output, + style::SetForegroundColor(Color::Green), + style::Print(format!("\nReset tool '{}' to the default permission level.", tool_name)), + style::SetForegroundColor(Color::Reset), + )?; + } else { + queue!( + self.output, + style::SetForegroundColor(Color::Red), + style::Print(format!( + "\nTool '{}' does not exist or is already in default settings.", + tool_name + )), + style::SetForegroundColor(Color::Reset), + )?; + } + }, + Some(ToolsSubcommand::Help) => { + queue!( + self.output, + style::Print("\n"), + style::Print(command::ToolsSubcommand::help_text()), + )?; + }, + None => { + // No subcommand - print the current tools and their permissions. + // Determine how to format the output nicely. + let terminal_width = self.terminal_width(); + let longest = self + .conversation_state + .tools + .values() + .flatten() + .map(|FigTool::ToolSpecification(spec)| spec.name.len()) + .max() + .unwrap_or(0); + + queue!( + self.output, + style::Print("\n"), + style::SetAttribute(Attribute::Bold), + style::Print({ + // Adding 2 because of "- " preceding every tool name + let width = longest + 2 - "Tool".len() + 4; + format!("Tool{:>width$}Permission", "", width = width) + }), + style::SetAttribute(Attribute::Reset), + style::Print("\n"), + style::Print("▔".repeat(terminal_width)), + )?; + + self.conversation_state.tools.iter().for_each(|(origin, tools)| { + let to_display = + tools + .iter() + .fold(String::new(), |mut acc, FigTool::ToolSpecification(spec)| { + let width = longest - spec.name.len() + 4; + acc.push_str( + format!( + "- {}{:>width$}{}\n", + spec.name, + "", + self.tool_permissions.display_label(&spec.name), + width = width + ) + .as_str(), + ); + acc + }); + let _ = queue!( + self.output, + style::SetAttribute(Attribute::Bold), + style::Print(format!("{}:\n", origin)), + style::SetAttribute(Attribute::Reset), + style::Print(to_display), + style::Print("\n") + ); + }); + + queue!( + self.output, + style::Print("\nTrusted tools can be run without confirmation\n"), + style::SetForegroundColor(Color::DarkGrey), + style::Print(format!("\n{}\n", "* Default settings")), + style::Print("\n💡 Use "), + style::SetForegroundColor(Color::Green), + style::Print("/tools help"), + style::SetForegroundColor(Color::Reset), + style::SetForegroundColor(Color::DarkGrey), + style::Print(" to edit permissions."), + style::SetForegroundColor(Color::Reset), + )?; + }, + }; + + // Put spacing between previous output as to not be overwritten by + // during PromptUser. + self.output.flush()?; + + ChatState::PromptUser { + tool_uses: Some(tool_uses), + pending_tool_index, + skip_printing_tools: true, + } + }, + Command::Prompts { subcommand } => { + match subcommand { + Some(PromptsSubcommand::Help) => { + queue!(self.output, style::Print(command::PromptsSubcommand::help_text()))?; + }, + Some(PromptsSubcommand::Get { mut get_command }) => { + let orig_input = get_command.orig_input.take(); + let prompts = match self.tool_manager.get_prompt(get_command).await { + Ok(resp) => resp, + Err(e) => { + match e { + GetPromptError::AmbiguousPrompt(prompt_name, alt_msg) => { + queue!( + self.output, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(prompt_name), + style::SetForegroundColor(Color::Yellow), + style::Print(" is ambiguous. Use one of the following "), + style::SetForegroundColor(Color::Cyan), + style::Print(alt_msg), + style::SetForegroundColor(Color::Reset), + )?; + }, + GetPromptError::PromptNotFound(prompt_name) => { + queue!( + self.output, + style::Print("\n"), + style::SetForegroundColor(Color::Yellow), + style::Print("Prompt "), + style::SetForegroundColor(Color::Cyan), + style::Print(prompt_name), + style::SetForegroundColor(Color::Yellow), + style::Print(" not found. Use "), + style::SetForegroundColor(Color::Cyan), + style::Print("/prompts list"), + style::SetForegroundColor(Color::Yellow), + style::Print(" to see available prompts.\n"), + style::SetForegroundColor(Color::Reset), + )?; + }, + _ => return Err(ChatError::Custom(e.to_string().into())), + } + execute!(self.output, style::Print("\n"))?; + return Ok(ChatState::PromptUser { + tool_uses: Some(tool_uses), + pending_tool_index, + skip_printing_tools: true, + }); + }, + }; + if let Some(err) = prompts.error { + // If we are running into error we should just display the error + // and abort. + let to_display = serde_json::json!(err); + queue!( + self.output, + style::Print("\n"), + style::SetAttribute(Attribute::Bold), + style::Print("Error encountered while retrieving prompt:"), + style::SetAttribute(Attribute::Reset), + style::Print("\n"), + style::SetForegroundColor(Color::Red), + style::Print( + serde_json::to_string_pretty(&to_display) + .unwrap_or_else(|_| format!("{:?}", &to_display)) + ), + style::SetForegroundColor(Color::Reset), + style::Print("\n"), + )?; + } else { + let prompts = prompts + .result + .ok_or(ChatError::Custom("Result field missing from prompt/get request".into()))?; + let prompts = serde_json::from_value::(prompts).map_err(|e| { + ChatError::Custom(format!("Failed to deserialize prompt/get result: {:?}", e).into()) + })?; + self.pending_prompts.clear(); + self.pending_prompts.append(&mut VecDeque::from(prompts.messages)); + return Ok(ChatState::HandleInput { + input: orig_input.unwrap_or_default(), + tool_uses: Some(tool_uses), + pending_tool_index, + }); + } + }, + subcommand => { + let search_word = match subcommand { + Some(PromptsSubcommand::List { search_word }) => search_word, + _ => None, + }; + let terminal_width = self.terminal_width(); + let mut prompts_wl = self.tool_manager.prompts.write().map_err(|e| { + ChatError::Custom( + format!("Poison error encountered while retrieving prompts: {}", e).into(), + ) + })?; + self.tool_manager.refresh_prompts(&mut prompts_wl)?; + let mut longest_name = ""; + let arg_pos = { + let optimal_case = UnicodeWidthStr::width(longest_name) + terminal_width / 4; + if optimal_case > terminal_width { + terminal_width / 3 + } else { + optimal_case + } + }; + queue!( + self.output, + style::Print("\n"), + style::SetAttribute(Attribute::Bold), + style::Print("Prompt"), + style::SetAttribute(Attribute::Reset), + style::Print({ + let name_width = UnicodeWidthStr::width("Prompt"); + let padding = arg_pos.saturating_sub(name_width); + " ".repeat(padding) + }), + style::SetAttribute(Attribute::Bold), + style::Print("Arguments (* = required)"), + style::SetAttribute(Attribute::Reset), + style::Print("\n"), + style::Print(format!("{}\n", "▔".repeat(terminal_width))), + )?; + let prompts_by_server = prompts_wl.iter().fold( + HashMap::<&String, Vec<&PromptBundle>>::new(), + |mut acc, (prompt_name, bundles)| { + if prompt_name.contains(search_word.as_deref().unwrap_or("")) { + if prompt_name.len() > longest_name.len() { + longest_name = prompt_name.as_str(); + } + for bundle in bundles { + acc.entry(&bundle.server_name) + .and_modify(|b| b.push(bundle)) + .or_insert(vec![bundle]); + } + } + acc + }, + ); + for (i, (server_name, bundles)) in prompts_by_server.iter().enumerate() { + if i > 0 { + queue!(self.output, style::Print("\n"))?; + } + queue!( + self.output, + style::SetAttribute(Attribute::Bold), + style::Print(server_name), + style::Print(" (MCP):"), + style::SetAttribute(Attribute::Reset), + style::Print("\n"), + )?; + for bundle in bundles { + queue!( + self.output, + style::Print("- "), + style::Print(&bundle.prompt_get.name), + style::Print({ + if bundle + .prompt_get + .arguments + .as_ref() + .is_some_and(|args| !args.is_empty()) + { + let name_width = UnicodeWidthStr::width(bundle.prompt_get.name.as_str()); + let padding = + arg_pos.saturating_sub(name_width) - UnicodeWidthStr::width("- "); + " ".repeat(padding) + } else { + "\n".to_owned() + } + }) + )?; + if let Some(args) = bundle.prompt_get.arguments.as_ref() { + for (i, arg) in args.iter().enumerate() { + queue!( + self.output, + style::SetForegroundColor(Color::DarkGrey), + style::Print(match arg.required { + Some(true) => format!("{}*", arg.name), + _ => arg.name.clone(), + }), + style::SetForegroundColor(Color::Reset), + style::Print(if i < args.len() - 1 { ", " } else { "\n" }), + )?; + } + } + } + } + }, + } + execute!(self.output, style::Print("\n"))?; + ChatState::PromptUser { + tool_uses: Some(tool_uses), + pending_tool_index, + skip_printing_tools: true, + } + }, + Command::Usage => { + let state = self.conversation_state.backend_conversation_state(true, true).await; + let data = state.calculate_conversation_size(); + + let context_token_count: TokenCount = data.context_messages.into(); + let assistant_token_count: TokenCount = data.assistant_messages.into(); + let user_token_count: TokenCount = data.user_messages.into(); + let total_token_used: TokenCount = + (data.context_messages + data.user_messages + data.assistant_messages).into(); + + let window_width = self.terminal_width(); + // set a max width for the progress bar for better aesthetic + let progress_bar_width = std::cmp::min(window_width, 80); + + let context_width = ((context_token_count.value() as f64 / CONTEXT_WINDOW_SIZE as f64) + * progress_bar_width as f64) as usize; + let assistant_width = ((assistant_token_count.value() as f64 / CONTEXT_WINDOW_SIZE as f64) + * progress_bar_width as f64) as usize; + let user_width = ((user_token_count.value() as f64 / CONTEXT_WINDOW_SIZE as f64) + * progress_bar_width as f64) as usize; + + let left_over_width = progress_bar_width + - std::cmp::min(context_width + assistant_width + user_width, progress_bar_width); + + queue!( + self.output, + style::Print(format!( + "\nCurrent context window ({} of {}k tokens used)\n", + total_token_used, + CONTEXT_WINDOW_SIZE / 1000 + )), + style::SetForegroundColor(Color::DarkCyan), + // add a nice visual to mimic "tiny" progress, so the overral progress bar doesn't look too + // empty + style::Print("|".repeat(if context_width == 0 && *context_token_count > 0 { + 1 + } else { + 0 + })), + style::Print("█".repeat(context_width)), + style::SetForegroundColor(Color::Blue), + style::Print("|".repeat(if assistant_width == 0 && *assistant_token_count > 0 { + 1 + } else { + 0 + })), + style::Print("█".repeat(assistant_width)), + style::SetForegroundColor(Color::Magenta), + style::Print("|".repeat(if user_width == 0 && *user_token_count > 0 { 1 } else { 0 })), + style::Print("█".repeat(user_width)), + style::SetForegroundColor(Color::DarkGrey), + style::Print("█".repeat(left_over_width)), + style::Print(" "), + style::SetForegroundColor(Color::Reset), + style::Print(format!( + "{:.2}%", + (total_token_used.value() as f32 / CONTEXT_WINDOW_SIZE as f32) * 100.0 + )), + )?; + + queue!(self.output, style::Print("\n\n"))?; + self.output.flush()?; + + queue!( + self.output, + style::SetForegroundColor(Color::DarkCyan), + style::Print("█ Context files: "), + style::SetForegroundColor(Color::Reset), + style::Print(format!( + "~{} tokens ({:.2}%)\n", + context_token_count, + (context_token_count.value() as f32 / CONTEXT_WINDOW_SIZE as f32) * 100.0 + )), + style::SetForegroundColor(Color::Blue), + style::Print("█ Q responses: "), + style::SetForegroundColor(Color::Reset), + style::Print(format!( + " ~{} tokens ({:.2}%)\n", + assistant_token_count, + (assistant_token_count.value() as f32 / CONTEXT_WINDOW_SIZE as f32) * 100.0 + )), + style::SetForegroundColor(Color::Magenta), + style::Print("█ Your prompts: "), + style::SetForegroundColor(Color::Reset), + style::Print(format!( + " ~{} tokens ({:.2}%)\n\n", + user_token_count, + (user_token_count.value() as f32 / CONTEXT_WINDOW_SIZE as f32) * 100.0 + )), + )?; + + queue!( + self.output, + style::SetAttribute(Attribute::Bold), + style::Print("\n💡 Pro Tips:\n"), + style::SetAttribute(Attribute::Reset), + style::SetForegroundColor(Color::DarkGrey), + style::Print("Run "), + style::SetForegroundColor(Color::DarkGreen), + style::Print("/compact"), + style::SetForegroundColor(Color::DarkGrey), + style::Print(" to replace the conversation history with its summary\n"), + style::Print("Run "), + style::SetForegroundColor(Color::DarkGreen), + style::Print("/clear"), + style::SetForegroundColor(Color::DarkGrey), + style::Print(" to erase the entire chat history\n"), + style::Print("Run "), + style::SetForegroundColor(Color::DarkGreen), + style::Print("/context show"), + style::SetForegroundColor(Color::DarkGrey), + style::Print(" to see tokens per context file\n\n"), + style::SetForegroundColor(Color::Reset), + )?; + + ChatState::PromptUser { + tool_uses: Some(tool_uses), + pending_tool_index, + skip_printing_tools: true, + } + }, + }) + } + + async fn tool_use_execute(&mut self, mut tool_uses: Vec) -> Result { + // Verify tools have permissions. + for (index, tool) in tool_uses.iter_mut().enumerate() { + // Manually accepted by the user or otherwise verified already. + if tool.accepted { + continue; + } + + // If there is an override, we will use it. Otherwise fall back to Tool's default. + let allowed = if self.tool_permissions.has(&tool.name) { + self.tool_permissions.is_trusted(&tool.name) + } else { + !tool.tool.requires_acceptance(&self.ctx) + }; + + if self.settings.get_bool_or("chat.enableNotifications", false) { + play_notification_bell(!allowed); + } + + self.print_tool_descriptions(tool, allowed).await?; + + if allowed { + tool.accepted = true; + continue; + } + + let pending_tool_index = Some(index); + if !self.interactive { + // Cannot request in non-interactive, so fail. + return Err(ChatError::NonInteractiveToolApproval); + } + + return Ok(ChatState::PromptUser { + tool_uses: Some(tool_uses), + pending_tool_index, + skip_printing_tools: false, + }); + } + + // Execute the requested tools. + let mut tool_results = vec![]; + + for tool in tool_uses { + let mut tool_telemetry = self.tool_use_telemetry_events.entry(tool.id.clone()); + tool_telemetry = tool_telemetry.and_modify(|ev| ev.is_accepted = true); + + let tool_start = std::time::Instant::now(); + let invoke_result = tool.tool.invoke(&self.ctx, &mut self.output).await; + + if self.interactive && self.spinner.is_some() { + queue!( + self.output, + terminal::Clear(terminal::ClearType::CurrentLine), + cursor::MoveToColumn(0), + cursor::Show + )?; + } + execute!(self.output, style::Print("\n"))?; + + let tool_time = std::time::Instant::now().duration_since(tool_start); + if let Tool::Custom(ct) = &tool.tool { + tool_telemetry = tool_telemetry.and_modify(|ev| { + ev.custom_tool_call_latency = Some(tool_time.as_secs() as usize); + ev.input_token_size = Some(ct.get_input_token_size()); + ev.is_custom_tool = true; + }); + } + let tool_time = format!("{}.{}", tool_time.as_secs(), tool_time.subsec_millis()); + + match invoke_result { + Ok(result) => { + debug!("tool result output: {:#?}", result); + execute!( + self.output, + style::Print(CONTINUATION_LINE), + style::Print("\n"), + style::SetForegroundColor(Color::Green), + style::SetAttribute(Attribute::Bold), + style::Print(format!(" ● Completed in {}s", tool_time)), + style::SetForegroundColor(Color::Reset), + style::Print("\n"), + )?; + + tool_telemetry = tool_telemetry.and_modify(|ev| ev.is_success = Some(true)); + if let Tool::Custom(_) = &tool.tool { + tool_telemetry + .and_modify(|ev| ev.output_token_size = Some(TokenCounter::count_tokens(result.as_str()))); + } + tool_results.push(ToolUseResult { + tool_use_id: tool.id, + content: vec![result.into()], + status: ToolResultStatus::Success, + }); + }, + Err(err) => { + error!(?err, "An error occurred processing the tool"); + execute!( + self.output, + style::Print(CONTINUATION_LINE), + style::Print("\n"), + style::SetAttribute(Attribute::Bold), + style::SetForegroundColor(Color::Red), + style::Print(format!(" ● Execution failed after {}s:\n", tool_time)), + style::SetAttribute(Attribute::Reset), + style::SetForegroundColor(Color::Red), + style::Print(&err), + style::SetAttribute(Attribute::Reset), + style::Print("\n\n"), + )?; + + tool_telemetry.and_modify(|ev| ev.is_success = Some(false)); + tool_results.push(ToolUseResult { + tool_use_id: tool.id, + content: vec![ToolUseResultBlock::Text(format!( + "An error occurred processing the tool: \n{}", + &err + ))], + status: ToolResultStatus::Error, + }); + if let ToolUseStatus::Idle = self.tool_use_status { + self.tool_use_status = ToolUseStatus::RetryInProgress( + self.conversation_state + .message_id() + .map_or("No utterance id found".to_string(), |v| v.to_string()), + ); + } + }, + } + } + + self.conversation_state.add_tool_results(tool_results); + + self.send_tool_use_telemetry().await; + return Ok(ChatState::HandleResponseStream( + self.client + .send_message(self.conversation_state.as_sendable_conversation_state(false).await) + .await?, + )); + } + + async fn handle_response(&mut self, response: SendMessageOutput) -> Result { + let request_id = response.request_id().map(|s| s.to_string()); + let mut buf = String::new(); + let mut offset = 0; + let mut ended = false; + let mut parser = ResponseParser::new(response); + let mut state = ParseState::new(Some(self.terminal_width())); + + let mut tool_uses = Vec::new(); + let mut tool_name_being_recvd: Option = None; + + loop { + match parser.recv().await { + Ok(msg_event) => { + trace!("Consumed: {:?}", msg_event); + match msg_event { + parser::ResponseEvent::ToolUseStart { name } => { + // We need to flush the buffer here, otherwise text will not be + // printed while we are receiving tool use events. + buf.push('\n'); + tool_name_being_recvd = Some(name); + }, + parser::ResponseEvent::AssistantText(text) => { + buf.push_str(&text); + }, + parser::ResponseEvent::ToolUse(tool_use) => { + if self.interactive && self.spinner.is_some() { + drop(self.spinner.take()); + queue!( + self.output, + terminal::Clear(terminal::ClearType::CurrentLine), + cursor::MoveToColumn(0), + cursor::Show + )?; + } + tool_uses.push(tool_use); + tool_name_being_recvd = None; + }, + parser::ResponseEvent::EndStream { message } => { + // This log is attempting to help debug instances where users encounter + // the response timeout message. + if message.content() == RESPONSE_TIMEOUT_CONTENT { + error!(?request_id, ?message, "Encountered an unexpected model response"); + } + self.conversation_state.push_assistant_message(message); + ended = true; + }, + } + }, + Err(recv_error) => { + if let Some(request_id) = &recv_error.request_id { + self.failed_request_ids.push(request_id.clone()); + }; + + match recv_error.source { + RecvErrorKind::StreamTimeout { source, duration } => { + error!( + recv_error.request_id, + ?source, + "Encountered a stream timeout after waiting for {}s", + duration.as_secs() + ); + if self.interactive { + execute!(self.output, cursor::Hide)?; + self.spinner = + Some(Spinner::new(Spinners::Dots, "Dividing up the work...".to_string())); + } + // For stream timeouts, we'll tell the model to try and split its response into + // smaller chunks. + self.conversation_state + .push_assistant_message(AssistantMessage::new_response( + None, + RESPONSE_TIMEOUT_CONTENT.to_string(), + )); + self.conversation_state + .set_next_user_message( + "You took too long to respond - try to split up the work into smaller steps." + .to_string(), + ) + .await; + self.send_tool_use_telemetry().await; + return Ok(ChatState::HandleResponseStream( + self.client + .send_message(self.conversation_state.as_sendable_conversation_state(false).await) + .await?, + )); + }, + RecvErrorKind::UnexpectedToolUseEos { + tool_use_id, + name, + message, + time_elapsed, + } => { + error!( + recv_error.request_id, + tool_use_id, name, "The response stream ended before the entire tool use was received" + ); + if self.interactive { + drop(self.spinner.take()); + queue!( + self.output, + terminal::Clear(terminal::ClearType::CurrentLine), + cursor::MoveToColumn(0), + style::SetForegroundColor(Color::Yellow), + style::SetAttribute(Attribute::Bold), + style::Print(format!( + "Warning: received an unexpected error from the model after {:.2}s", + time_elapsed.as_secs_f64() + )), + )?; + if let Some(request_id) = recv_error.request_id { + queue!( + self.output, + style::Print(format!("\n request_id: {}", request_id)) + )?; + } + execute!(self.output, style::Print("\n\n"), style::SetAttribute(Attribute::Reset))?; + self.spinner = Some(Spinner::new( + Spinners::Dots, + "Trying to divide up the work...".to_string(), + )); + } + + self.conversation_state.push_assistant_message(*message); + let tool_results = vec![ToolUseResult { + tool_use_id, + content: vec![ToolUseResultBlock::Text( + "The generated tool was too large, try again but this time split up the work between multiple tool uses".to_string(), + )], + status: ToolResultStatus::Error, + }]; + self.conversation_state.add_tool_results(tool_results); + self.send_tool_use_telemetry().await; + return Ok(ChatState::HandleResponseStream( + self.client + .send_message(self.conversation_state.as_sendable_conversation_state(false).await) + .await?, + )); + }, + _ => return Err(recv_error.into()), + } + }, + } + + // Fix for the markdown parser copied over from q chat: + // this is a hack since otherwise the parser might report Incomplete with useful data + // still left in the buffer. I'm not sure how this is intended to be handled. + if ended { + buf.push('\n'); + } + + if tool_name_being_recvd.is_none() && !buf.is_empty() && self.interactive && self.spinner.is_some() { + drop(self.spinner.take()); + queue!( + self.output, + terminal::Clear(terminal::ClearType::CurrentLine), + cursor::MoveToColumn(0), + cursor::Show + )?; + } + + // Print the response for normal cases + loop { + let input = Partial::new(&buf[offset..]); + match interpret_markdown(input, &mut self.output, &mut state) { + Ok(parsed) => { + offset += parsed.offset_from(&input); + self.output.flush()?; + state.newline = state.set_newline; + state.set_newline = false; + }, + Err(err) => match err.into_inner() { + Some(err) => return Err(ChatError::Custom(err.to_string().into())), + None => break, // Data was incomplete + }, + } + + // TODO: We should buffer output based on how much we have to parse, not as a constant + // Do not remove unless you are nabochay :) + std::thread::sleep(Duration::from_millis(8)); + } + + // Set spinner after showing all of the assistant text content so far. + if let (Some(_name), true) = (&tool_name_being_recvd, self.interactive) { + queue!(self.output, cursor::Hide)?; + self.spinner = Some(Spinner::new(Spinners::Dots, "Thinking...".to_string())); + } + + if ended { + if let Some(message_id) = self.conversation_state.message_id() { + crate::fig_telemetry::send_chat_added_message( + self.conversation_state.conversation_id().to_owned(), + message_id.to_owned(), + self.conversation_state.context_message_length(), + ) + .await; + } + + if self.interactive && self.settings.get_bool_or("chat.enableNotifications", false) { + // For final responses (no tools suggested), always play the bell + play_notification_bell(tool_uses.is_empty()); + } + + if self.interactive { + queue!(self.output, style::ResetColor, style::SetAttribute(Attribute::Reset))?; + execute!(self.output, style::Print("\n"))?; + + for (i, citation) in &state.citations { + queue!( + self.output, + style::Print("\n"), + style::SetForegroundColor(Color::Blue), + style::Print(format!("[^{i}]: ")), + style::SetForegroundColor(Color::DarkGrey), + style::Print(format!("{citation}\n")), + style::SetForegroundColor(Color::Reset) + )?; + } + } + + break; + } + } + + if !tool_uses.is_empty() { + Ok(ChatState::ValidateTools(tool_uses)) + } else { + Ok(ChatState::PromptUser { + tool_uses: None, + pending_tool_index: None, + skip_printing_tools: false, + }) + } + } + + async fn validate_tools(&mut self, tool_uses: Vec) -> Result { + let conv_id = self.conversation_state.conversation_id().to_owned(); + debug!(?tool_uses, "Validating tool uses"); + let mut queued_tools: Vec = Vec::new(); + let mut tool_results: Vec = Vec::new(); + + for tool_use in tool_uses { + let tool_use_id = tool_use.id.clone(); + let tool_use_name = tool_use.name.clone(); + let mut tool_telemetry = ToolUseEventBuilder::new(conv_id.clone(), tool_use.id.clone()) + .set_tool_use_id(tool_use_id.clone()) + .set_tool_name(tool_use.name.clone()) + .utterance_id(self.conversation_state.message_id().map(|s| s.to_string())); + match self.tool_manager.get_tool_from_tool_use(tool_use) { + Ok(mut tool) => { + // Apply non-Q-generated context to tools + self.contextualize_tool(&mut tool); + + match tool.validate(&self.ctx).await { + Ok(()) => { + tool_telemetry.is_valid = Some(true); + queued_tools.push(QueuedTool { + id: tool_use_id.clone(), + name: tool_use_name, + tool, + accepted: false, + }); + }, + Err(err) => { + tool_telemetry.is_valid = Some(false); + tool_results.push(ToolUseResult { + tool_use_id: tool_use_id.clone(), + content: vec![ToolUseResultBlock::Text(format!( + "Failed to validate tool parameters: {err}" + ))], + status: ToolResultStatus::Error, + }); + }, + }; + }, + Err(err) => { + tool_telemetry.is_valid = Some(false); + tool_results.push(err.into()); + }, + } + self.tool_use_telemetry_events.insert(tool_use_id, tool_telemetry); + } + + // If we have any validation errors, then return them immediately to the model. + if !tool_results.is_empty() { + debug!(?tool_results, "Error found in the model tools"); + queue!( + self.output, + style::SetAttribute(Attribute::Bold), + style::Print("Tool validation failed: "), + style::SetAttribute(Attribute::Reset), + )?; + for tool_result in &tool_results { + for block in &tool_result.content { + let content: Option> = match block { + ToolUseResultBlock::Text(t) => Some(t.as_str().into()), + ToolUseResultBlock::Json(d) => serde_json::to_string(d) + .map_err(|err| error!(?err, "failed to serialize tool result content")) + .map(Into::into) + .ok(), + }; + if let Some(content) = content { + queue!( + self.output, + style::Print("\n"), + style::SetForegroundColor(Color::Red), + style::Print(format!("{}\n", content)), + style::SetForegroundColor(Color::Reset), + )?; + } + } + } + self.conversation_state.add_tool_results(tool_results); + self.send_tool_use_telemetry().await; + if let ToolUseStatus::Idle = self.tool_use_status { + self.tool_use_status = ToolUseStatus::RetryInProgress( + self.conversation_state + .message_id() + .map_or("No utterance id found".to_string(), |v| v.to_string()), + ); + } + + let response = self + .client + .send_message(self.conversation_state.as_sendable_conversation_state(false).await) + .await?; + return Ok(ChatState::HandleResponseStream(response)); + } + + Ok(ChatState::ExecuteTools(queued_tools)) + } + + /// Apply program context to tools that Q may not have. + // We cannot attach this any other way because Tools are constructed by deserializing + // output from Amazon Q. + // TODO: Is there a better way? + fn contextualize_tool(&self, tool: &mut Tool) { + #[allow(clippy::single_match)] + match tool { + Tool::GhIssue(gh_issue) => { + gh_issue.set_context(GhIssueContext { + // Ideally we avoid cloning, but this function is not called very often. + // Using references with lifetimes requires a large refactor, and Arc> + // seems like overkill and may incur some performance cost anyway. + context_manager: self.conversation_state.context_manager.clone(), + transcript: self.conversation_state.transcript.clone(), + failed_request_ids: self.failed_request_ids.clone(), + tool_permissions: self.tool_permissions.permissions.clone(), + interactive: self.interactive, + }); + }, + _ => (), + }; + } + + async fn print_tool_descriptions(&mut self, tool_use: &QueuedTool, trusted: bool) -> Result<(), ChatError> { + queue!( + self.output, + style::SetForegroundColor(Color::Magenta), + style::Print(format!( + "🛠️ Using tool: {}{}", + tool_use.tool.display_name(), + if trusted { " (trusted)".dark_green() } else { "".reset() } + )), + style::SetForegroundColor(Color::Reset) + )?; + if let Tool::Custom(ref tool) = tool_use.tool { + queue!( + self.output, + style::SetForegroundColor(Color::Reset), + style::Print(" from mcp server "), + style::SetForegroundColor(Color::Magenta), + style::Print(tool.client.get_server_name()), + style::SetForegroundColor(Color::Reset), + )?; + } + queue!(self.output, style::Print("\n"), style::Print(CONTINUATION_LINE))?; + queue!(self.output, style::Print("\n"))?; + queue!(self.output, style::Print(TOOL_BULLET))?; + + self.output.flush()?; + + tool_use + .tool + .queue_description(&self.ctx, &mut self.output) + .await + .map_err(|e| ChatError::Custom(format!("failed to print tool, `{}`: {}", tool_use.name, e).into()))?; + + Ok(()) + } + + /// Helper function to read user input with a prompt and Ctrl+C handling + fn read_user_input(&mut self, prompt: &str, exit_on_single_ctrl_c: bool) -> Option { + let mut ctrl_c = false; + loop { + match (self.input_source.read_line(Some(prompt)), ctrl_c) { + (Ok(Some(line)), _) => { + if line.trim().is_empty() { + continue; // Reprompt if the input is empty + } + return Some(line); + }, + (Ok(None), false) => { + if exit_on_single_ctrl_c { + return None; + } + execute!( + self.output, + style::Print(format!( + "\n(To exit the CLI, press Ctrl+C or Ctrl+D again or type {})\n\n", + "/quit".green() + )) + ) + .unwrap_or_default(); + ctrl_c = true; + }, + (Ok(None), true) => return None, // Exit if Ctrl+C was pressed twice + (Err(_), _) => return None, + } + } + } + + /// Helper function to generate a prompt based on the current context + fn generate_tool_trust_prompt(&self) -> String { + prompt::generate_prompt(self.conversation_state.current_profile(), self.all_tools_trusted()) + } + + async fn send_tool_use_telemetry(&mut self) { + for (_, mut event) in self.tool_use_telemetry_events.drain() { + event.user_input_id = match self.tool_use_status { + ToolUseStatus::Idle => self.conversation_state.message_id(), + ToolUseStatus::RetryInProgress(ref id) => Some(id.as_str()), + } + .map(|v| v.to_string()); + let event: crate::fig_telemetry::EventType = event.into(); + let app_event = crate::fig_telemetry::AppTelemetryEvent::new(event).await; + crate::fig_telemetry::send_event(app_event).await; + } + } + + fn terminal_width(&self) -> usize { + (self.terminal_width_provider)().unwrap_or(80) + } + + fn all_tools_trusted(&self) -> bool { + self.conversation_state.tools.values().flatten().all(|t| match t { + FigTool::ToolSpecification(t) => self.tool_permissions.is_trusted(&t.name), + }) + } + + /// Display character limit warnings based on current conversation size + async fn display_char_warnings(&mut self) -> Result<(), std::io::Error> { + let warning_level = self.conversation_state.get_token_warning_level().await; + + match warning_level { + TokenWarningLevel::Critical => { + // Memory constraint warning with gentler wording + execute!( + self.output, + style::SetForegroundColor(Color::Yellow), + style::SetAttribute(Attribute::Bold), + style::Print("\n⚠️ This conversation is getting lengthy.\n"), + style::SetAttribute(Attribute::Reset), + style::Print( + "To ensure continued smooth operation, please use /compact to summarize the conversation.\n\n" + ), + style::SetForegroundColor(Color::Reset) + )?; + }, + TokenWarningLevel::None => { + // No warning needed + }, + } + + Ok(()) + } +} + +#[derive(Debug)] +struct ToolUseEventBuilder { + pub conversation_id: String, + pub utterance_id: Option, + pub user_input_id: Option, + pub tool_use_id: Option, + pub tool_name: Option, + pub is_accepted: bool, + pub is_success: Option, + pub is_valid: Option, + pub is_custom_tool: bool, + pub input_token_size: Option, + pub output_token_size: Option, + pub custom_tool_call_latency: Option, +} + +impl ToolUseEventBuilder { + pub fn new(conv_id: String, tool_use_id: String) -> Self { + Self { + conversation_id: conv_id, + utterance_id: None, + user_input_id: None, + tool_use_id: Some(tool_use_id), + tool_name: None, + is_accepted: false, + is_success: None, + is_valid: None, + is_custom_tool: false, + input_token_size: None, + output_token_size: None, + custom_tool_call_latency: None, + } + } + + pub fn utterance_id(mut self, id: Option) -> Self { + self.utterance_id = id; + self + } + + pub fn set_tool_use_id(mut self, id: String) -> Self { + self.tool_use_id.replace(id); + self + } + + pub fn set_tool_name(mut self, name: String) -> Self { + self.tool_name.replace(name); + self + } +} + +impl From for crate::fig_telemetry::EventType { + fn from(val: ToolUseEventBuilder) -> Self { + crate::fig_telemetry::EventType::ToolUseSuggested { + conversation_id: val.conversation_id, + utterance_id: val.utterance_id, + user_input_id: val.user_input_id, + tool_use_id: val.tool_use_id, + tool_name: val.tool_name, + is_accepted: val.is_accepted, + is_success: val.is_success, + is_valid: val.is_valid, + is_custom_tool: val.is_custom_tool, + input_token_size: val.input_token_size, + output_token_size: val.output_token_size, + custom_tool_call_latency: val.custom_tool_call_latency, + } + } +} + +/// Testing helper +fn split_tool_use_event(value: &Map) -> Vec { + let tool_use_id = value.get("tool_use_id").unwrap().as_str().unwrap().to_string(); + let name = value.get("name").unwrap().as_str().unwrap().to_string(); + let args_str = value.get("args").unwrap().to_string(); + let split_point = args_str.len() / 2; + vec![ + ChatResponseStream::ToolUseEvent { + tool_use_id: tool_use_id.clone(), + name: name.clone(), + input: None, + stop: None, + }, + ChatResponseStream::ToolUseEvent { + tool_use_id: tool_use_id.clone(), + name: name.clone(), + input: Some(args_str.split_at(split_point).0.to_string()), + stop: None, + }, + ChatResponseStream::ToolUseEvent { + tool_use_id: tool_use_id.clone(), + name: name.clone(), + input: Some(args_str.split_at(split_point).1.to_string()), + stop: None, + }, + ChatResponseStream::ToolUseEvent { + tool_use_id: tool_use_id.clone(), + name: name.clone(), + input: None, + stop: Some(true), + }, + ] +} + +/// Testing helper +fn create_stream(model_responses: serde_json::Value) -> StreamingClient { + let mut mock = Vec::new(); + for response in model_responses.as_array().unwrap() { + let mut stream = Vec::new(); + for event in response.as_array().unwrap() { + match event { + serde_json::Value::String(assistant_text) => { + stream.push(ChatResponseStream::AssistantResponseEvent { + content: assistant_text.to_string(), + }); + }, + serde_json::Value::Object(tool_use) => { + stream.append(&mut split_tool_use_event(tool_use)); + }, + other => panic!("Unexpected value: {:?}", other), + } + } + mock.push(stream); + } + StreamingClient::mock(mock) +} + +#[cfg(test)] +mod tests { + use bstr::ByteSlice; + use shared_writer::TestWriterWithSink; + + use super::*; + + #[tokio::test] + async fn test_flow() { + let _ = tracing_subscriber::fmt::try_init(); + let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); + let test_client = create_stream(serde_json::json!([ + [ + "Sure, I'll create a file for you", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file.txt", + } + } + ], + [ + "Hope that looks good to you!", + ], + ])); + + let tool_manager = ToolManager::default(); + let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) + .expect("Tools failed to load"); + ChatContext::new( + Arc::clone(&ctx), + "fake_conv_id", + Settings::new_fake(), + State::new_fake(), + SharedWriter::stdout(), + None, + InputSource::new_mock(vec![ + "create a new file".to_string(), + "y".to_string(), + "exit".to_string(), + ]), + true, + test_client, + || Some(80), + tool_manager, + None, + tool_config, + ToolPermissions::new(0), + ) + .await + .unwrap() + .try_chat() + .await + .unwrap(); + + assert_eq!(ctx.fs().read_to_string("/file.txt").await.unwrap(), "Hello, world!\n"); + } + + #[tokio::test] + async fn test_flow_tool_permissions() { + let _ = tracing_subscriber::fmt::try_init(); + let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); + let test_client = create_stream(serde_json::json!([ + [ + "Ok", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file1.txt", + } + } + ], + [ + "Done", + ], + [ + "Ok", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file2.txt", + } + } + ], + [ + "Done", + ], + [ + "Ok", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file3.txt", + } + } + ], + [ + "Done", + ], + [ + "Ok", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file4.txt", + } + } + ], + [ + "Ok, I won't make it.", + ], + [ + "Ok", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file5.txt", + } + } + ], + [ + "Done", + ], + [ + "Ok", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file6.txt", + } + } + ], + [ + "Ok, I won't make it.", + ], + ])); + + let tool_manager = ToolManager::default(); + let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) + .expect("Tools failed to load"); + ChatContext::new( + Arc::clone(&ctx), + "fake_conv_id", + Settings::new_fake(), + State::new_fake(), + SharedWriter::stdout(), + None, + InputSource::new_mock(vec![ + "/tools".to_string(), + "/tools help".to_string(), + "create a new file".to_string(), + "y".to_string(), + "create a new file".to_string(), + "t".to_string(), + "create a new file".to_string(), // should make without prompting due to 't' + "/tools untrust fs_write".to_string(), + "create a file".to_string(), // prompt again due to untrust + "n".to_string(), // cancel + "/tools trust fs_write".to_string(), + "create a file".to_string(), // again without prompting due to '/tools trust' + "/tools reset".to_string(), + "create a file".to_string(), // prompt again due to reset + "n".to_string(), // cancel + "exit".to_string(), + ]), + true, + test_client, + || Some(80), + tool_manager, + None, + tool_config, + ToolPermissions::new(0), + ) + .await + .unwrap() + .try_chat() + .await + .unwrap(); + + assert_eq!(ctx.fs().read_to_string("/file2.txt").await.unwrap(), "Hello, world!\n"); + assert_eq!(ctx.fs().read_to_string("/file3.txt").await.unwrap(), "Hello, world!\n"); + assert!(!ctx.fs().exists("/file4.txt")); + assert_eq!(ctx.fs().read_to_string("/file5.txt").await.unwrap(), "Hello, world!\n"); + assert!(!ctx.fs().exists("/file6.txt")); + } + + #[tokio::test] + async fn test_flow_multiple_tools() { + let _ = tracing_subscriber::fmt::try_init(); + let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); + let test_client = create_stream(serde_json::json!([ + [ + "Sure, I'll create a file for you", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file1.txt", + } + }, + { + "tool_use_id": "2", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file2.txt", + } + } + ], + [ + "Done", + ], + [ + "Sure, I'll create a file for you", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file3.txt", + } + }, + { + "tool_use_id": "2", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file4.txt", + } + } + ], + [ + "Done", + ], + ])); + + let tool_manager = ToolManager::default(); + let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) + .expect("Tools failed to load"); + ChatContext::new( + Arc::clone(&ctx), + "fake_conv_id", + Settings::new_fake(), + State::new_fake(), + SharedWriter::stdout(), + None, + InputSource::new_mock(vec![ + "create 2 new files parallel".to_string(), + "t".to_string(), + "/tools reset".to_string(), + "create 2 new files parallel".to_string(), + "y".to_string(), + "y".to_string(), + "exit".to_string(), + ]), + true, + test_client, + || Some(80), + tool_manager, + None, + tool_config, + ToolPermissions::new(0), + ) + .await + .unwrap() + .try_chat() + .await + .unwrap(); + + assert_eq!(ctx.fs().read_to_string("/file1.txt").await.unwrap(), "Hello, world!\n"); + assert_eq!(ctx.fs().read_to_string("/file2.txt").await.unwrap(), "Hello, world!\n"); + assert_eq!(ctx.fs().read_to_string("/file3.txt").await.unwrap(), "Hello, world!\n"); + assert_eq!(ctx.fs().read_to_string("/file4.txt").await.unwrap(), "Hello, world!\n"); + } + + #[tokio::test] + async fn test_flow_tools_trust_all() { + let _ = tracing_subscriber::fmt::try_init(); + let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); + let test_client = create_stream(serde_json::json!([ + [ + "Sure, I'll create a file for you", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file1.txt", + } + } + ], + [ + "Done", + ], + [ + "Sure, I'll create a file for you", + { + "tool_use_id": "1", + "name": "fs_write", + "args": { + "command": "create", + "file_text": "Hello, world!", + "path": "/file3.txt", + } + } + ], + [ + "Ok I won't.", + ], + ])); + + let tool_manager = ToolManager::default(); + let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) + .expect("Tools failed to load"); + ChatContext::new( + Arc::clone(&ctx), + "fake_conv_id", + Settings::new_fake(), + State::new_fake(), + SharedWriter::stdout(), + None, + InputSource::new_mock(vec![ + "/tools trustall".to_string(), + "create a new file".to_string(), + "/tools reset".to_string(), + "create a new file".to_string(), + "exit".to_string(), + ]), + true, + test_client, + || Some(80), + tool_manager, + None, + tool_config, + ToolPermissions::new(0), + ) + .await + .unwrap() + .try_chat() + .await + .unwrap(); + + assert_eq!(ctx.fs().read_to_string("/file1.txt").await.unwrap(), "Hello, world!\n"); + assert!(!ctx.fs().exists("/file2.txt")); + } + + #[test] + fn test_editor_content_processing() { + // Since we no longer have template replacement, this test is simplified + let cases = vec![ + ("My content", "My content"), + ("My content with newline\n", "My content with newline"), + ("", ""), + ]; + + for (input, expected) in cases { + let processed = input.trim().to_string(); + assert_eq!(processed, expected.trim().to_string(), "Failed for input: {}", input); + } + } + + #[tokio::test] + async fn test_draw_tip_box() { + let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); + let buf = Arc::new(std::sync::Mutex::new(Vec::::new())); + let test_writer = TestWriterWithSink { sink: buf.clone() }; + let output = SharedWriter::new(test_writer.clone()); + let tool_manager = ToolManager::default(); + let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) + .expect("Tools failed to load"); + let test_client = create_stream(serde_json::json!([])); + + let mut chat_context = ChatContext::new( + Arc::clone(&ctx), + "fake_conv_id", + Settings::new_fake(), + State::new_fake(), + output, + None, + InputSource::new_mock(vec![]), + true, + test_client, + || Some(80), + tool_manager, + None, + tool_config, + ToolPermissions::new(0), + ) + .await + .unwrap(); + + // Test with a short tip + let short_tip = "This is a short tip"; + chat_context.draw_tip_box(short_tip).expect("Failed to draw tip box"); + + // Test with a longer tip that should wrap + let long_tip = "This is a much longer tip that should wrap to multiple lines because it exceeds the inner width of the tip box which is calculated based on the GREETING_BREAK_POINT constant"; + chat_context.draw_tip_box(long_tip).expect("Failed to draw tip box"); + + // Test with a long tip with two long words that should wrap + let long_tip_with_one_long_word = { + let mut s = "a".repeat(200); + s.push(' '); + s.push_str(&"a".repeat(200)); + s + }; + chat_context + .draw_tip_box(long_tip_with_one_long_word.as_str()) + .expect("Failed to draw tip box"); + + // Test with a long tip with two long words that should wrap + let long_tip_with_two_long_words = "a".repeat(200); + chat_context + .draw_tip_box(long_tip_with_two_long_words.as_str()) + .expect("Failed to draw tip box"); + + // Get the output and verify it contains expected formatting elements + let content = test_writer.get_content(); + let output_str = content.to_str_lossy(); + + // Check for box drawing characters + assert!(output_str.contains("╭"), "Output should contain top-left corner"); + assert!(output_str.contains("╮"), "Output should contain top-right corner"); + assert!(output_str.contains("│"), "Output should contain vertical lines"); + assert!(output_str.contains("╰"), "Output should contain bottom-left corner"); + assert!(output_str.contains("╯"), "Output should contain bottom-right corner"); + + // Check for the label + assert!( + output_str.contains("Did you know?"), + "Output should contain the 'Did you know?' label" + ); + + // Check that both tips are present + assert!(output_str.contains(short_tip), "Output should contain the short tip"); + + // For the long tip, we check for substrings since it will be wrapped + let long_tip_parts: Vec<&str> = long_tip.split_whitespace().collect(); + for part in long_tip_parts.iter().take(3) { + assert!(output_str.contains(part), "Output should contain parts of the long tip"); + } + } +} diff --git a/crates/kiro-cli/src/cli/chat/parse.rs b/crates/kiro-cli/src/cli/chat/parse.rs new file mode 100644 index 0000000000..db3f0cf382 --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/parse.rs @@ -0,0 +1,762 @@ +use std::io::Write; + +use crossterm::style::{ + Attribute, + Color, + Stylize, +}; +use crossterm::{ + Command, + style, +}; +use unicode_width::{ + UnicodeWidthChar, + UnicodeWidthStr, +}; +use winnow::Partial; +use winnow::ascii::{ + self, + digit1, + space0, + space1, + till_line_ending, +}; +use winnow::combinator::{ + alt, + delimited, + preceded, + repeat, + terminated, +}; +use winnow::error::{ + ErrMode, + ErrorKind, + ParserError, +}; +use winnow::prelude::*; +use winnow::stream::{ + AsChar, + Stream, +}; +use winnow::token::{ + any, + take_till, + take_until, + take_while, +}; + +const CODE_COLOR: Color = Color::Green; +const HEADING_COLOR: Color = Color::Magenta; +const BLOCKQUOTE_COLOR: Color = Color::DarkGrey; +const URL_TEXT_COLOR: Color = Color::Blue; +const URL_LINK_COLOR: Color = Color::DarkGrey; + +const DEFAULT_RULE_WIDTH: usize = 40; + +#[derive(Debug, thiserror::Error)] +pub enum Error<'a> { + #[error(transparent)] + Stdio(#[from] std::io::Error), + #[error("parse error {1}, input {0}")] + Winnow(Partial<&'a str>, ErrorKind), +} + +impl<'a> ParserError> for Error<'a> { + fn from_error_kind(input: &Partial<&'a str>, kind: ErrorKind) -> Self { + Self::Winnow(*input, kind) + } + + fn append( + self, + _input: &Partial<&'a str>, + _checkpoint: &winnow::stream::Checkpoint< + winnow::stream::Checkpoint<&'a str, &'a str>, + winnow::Partial<&'a str>, + >, + _kind: ErrorKind, + ) -> Self { + self + } +} + +#[derive(Debug)] +pub struct ParseState { + pub terminal_width: Option, + pub column: usize, + pub in_codeblock: bool, + pub bold: bool, + pub italic: bool, + pub strikethrough: bool, + pub set_newline: bool, + pub newline: bool, + pub citations: Vec<(String, String)>, +} + +impl ParseState { + pub fn new(terminal_width: Option) -> Self { + Self { + terminal_width, + column: 0, + in_codeblock: false, + bold: false, + italic: false, + strikethrough: false, + set_newline: false, + newline: true, + citations: vec![], + } + } +} + +pub fn interpret_markdown<'a, 'b>( + mut i: Partial<&'a str>, + mut o: impl Write + 'b, + state: &mut ParseState, +) -> PResult, Error<'a>> { + let mut error: Option> = None; + let start = i.checkpoint(); + + macro_rules! stateful_alt { + ($($fns:ident),*) => { + $({ + i.reset(&start); + match $fns(&mut o, state).parse_next(&mut i) { + Err(ErrMode::Backtrack(e)) => { + error = match error { + Some(error) => Some(error.or(e)), + None => Some(e), + }; + }, + res => { + return res.map(|_| i); + } + } + })* + }; + } + + match state.in_codeblock { + false => { + stateful_alt!( + // This pattern acts as a short circuit for alphanumeric plaintext + // More importantly, it's needed to support manual wordwrapping + text, + // multiline patterns + blockquote, + // linted_codeblock, + codeblock_begin, + // single line patterns + horizontal_rule, + heading, + bulleted_item, + numbered_item, + // inline patterns + code, + citation, + url, + bold, + italic, + strikethrough, + // symbols + less_than, + greater_than, + ampersand, + quot, + line_ending, + // fallback + fallback + ); + }, + true => { + stateful_alt!( + codeblock_less_than, + codeblock_greater_than, + codeblock_ampersand, + codeblock_quot, + codeblock_end, + codeblock_line_ending, + codeblock_fallback + ); + }, + } + + match error { + Some(e) => Err(ErrMode::Backtrack(e.append(&i, &start, ErrorKind::Alt))), + None => Err(ErrMode::assert(&i, "no parsers")), + } +} + +fn text<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + let content = take_while(1.., |t| AsChar::is_alphanum(t) || "+,.!?\"".contains(t)).parse_next(i)?; + queue_newline_or_advance(&mut o, state, content.width())?; + queue(&mut o, style::Print(content))?; + Ok(()) + } +} + +fn heading<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + if !state.newline { + return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); + } + + let level = terminated(take_while(1.., |c| c == '#'), space1).parse_next(i)?; + let print = format!("{level} "); + + queue_newline_or_advance(&mut o, state, print.width())?; + queue(&mut o, style::SetForegroundColor(HEADING_COLOR))?; + queue(&mut o, style::SetAttribute(Attribute::Bold))?; + queue(&mut o, style::Print(print)) + } +} + +fn bulleted_item<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + if !state.newline { + return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); + } + + let ws = (space0, alt(("-", "*")), space1).parse_next(i)?.0; + let print = format!("{ws}• "); + + queue_newline_or_advance(&mut o, state, print.width())?; + queue(&mut o, style::Print(print)) + } +} + +fn numbered_item<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + if !state.newline { + return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); + } + + let (ws, digits, _, _) = (space0, digit1, ".", space1).parse_next(i)?; + let print = format!("{ws}{digits}. "); + + queue_newline_or_advance(&mut o, state, print.width())?; + queue(&mut o, style::Print(print)) + } +} + +fn horizontal_rule<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + if !state.newline { + return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); + } + + ( + space0, + alt((take_while(3.., '-'), take_while(3.., '*'), take_while(3.., '_'))), + ) + .parse_next(i)?; + + state.column = 0; + state.set_newline = true; + + let rule_width = state.terminal_width.unwrap_or(DEFAULT_RULE_WIDTH); + queue(&mut o, style::Print(format!("{}\n", "━".repeat(rule_width)))) + } +} + +fn code<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + "`".parse_next(i)?; + let code = terminated(take_until(0.., "`"), "`").parse_next(i)?; + let out = code.replace("&", "&").replace(">", ">").replace("<", "<"); + + queue_newline_or_advance(&mut o, state, out.width())?; + queue(&mut o, style::SetForegroundColor(Color::Green))?; + queue(&mut o, style::Print(out))?; + queue(&mut o, style::ResetColor) + } +} + +fn blockquote<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + if !state.newline { + return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); + } + + let level = repeat::<_, _, Vec<&'_ str>, _, _>(1.., terminated(">", space0)) + .parse_next(i)? + .len(); + let print = "│ ".repeat(level); + + queue(&mut o, style::SetForegroundColor(BLOCKQUOTE_COLOR))?; + queue_newline_or_advance(&mut o, state, print.width())?; + queue(&mut o, style::Print(print)) + } +} + +fn bold<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + match state.newline { + true => { + alt(("**", "__")).parse_next(i)?; + queue(&mut o, style::SetAttribute(Attribute::Bold))?; + }, + false => match state.bold { + true => { + alt(("**", "__")).parse_next(i)?; + queue(&mut o, style::SetAttribute(Attribute::NormalIntensity))?; + }, + false => { + preceded(space1, alt(("**", "__"))).parse_next(i)?; + queue(&mut o, style::Print(' '))?; + queue(&mut o, style::SetAttribute(Attribute::Bold))?; + }, + }, + }; + + state.bold = !state.bold; + + Ok(()) + } +} + +fn italic<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + match state.newline { + true => { + alt(("*", "_")).parse_next(i)?; + queue(&mut o, style::SetAttribute(Attribute::Italic))?; + }, + false => match state.italic { + true => { + alt(("*", "_")).parse_next(i)?; + queue(&mut o, style::SetAttribute(Attribute::NoItalic))?; + }, + false => { + preceded(space1, alt(("*", "_"))).parse_next(i)?; + queue(&mut o, style::Print(' '))?; + queue(&mut o, style::SetAttribute(Attribute::Italic))?; + }, + }, + }; + + state.italic = !state.italic; + + Ok(()) + } +} + +fn strikethrough<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + "~~".parse_next(i)?; + state.strikethrough = !state.strikethrough; + match state.strikethrough { + true => queue(&mut o, style::SetAttribute(Attribute::CrossedOut)), + false => queue(&mut o, style::SetAttribute(Attribute::NotCrossedOut)), + } + } +} + +fn citation<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + let num = delimited("[[", digit1, "]]").parse_next(i)?; + let link = delimited("(", take_till(0.., ')'), ")").parse_next(i)?; + + state.citations.push((num.to_owned(), link.to_owned())); + + queue_newline_or_advance(&mut o, state, num.width() + 1)?; + queue(&mut o, style::SetForegroundColor(URL_TEXT_COLOR))?; + queue(&mut o, style::Print(format!("[^{num}]")))?; + queue(&mut o, style::ResetColor) + } +} + +fn url<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + // Save the current input position + let start = i.checkpoint(); + + // Try to match the first part of URL pattern "[text]" + let display = match delimited::<_, _, _, _, Error<'a>, _, _, _>("[", take_until(1.., "]("), "]").parse_next(i) { + Ok(display) => display, + Err(_) => { + // If it doesn't match, reset position and fail + i.reset(&start); + return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); + }, + }; + + // Try to match the second part of URL pattern "(url)" + let link = match delimited::<_, _, _, _, Error<'a>, _, _, _>("(", take_till(0.., ')'), ")").parse_next(i) { + Ok(link) => link, + Err(_) => { + // If it doesn't match, reset position and fail + i.reset(&start); + return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); + }, + }; + + // Only generate output if the complete URL pattern matches + queue_newline_or_advance(&mut o, state, display.width() + 1)?; + queue(&mut o, style::SetForegroundColor(URL_TEXT_COLOR))?; + queue(&mut o, style::Print(format!("{display} ")))?; + queue(&mut o, style::SetForegroundColor(URL_LINK_COLOR))?; + state.column += link.width(); + queue(&mut o, style::Print(link))?; + queue(&mut o, style::ResetColor) + } +} + +fn less_than<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + "<".parse_next(i)?; + queue_newline_or_advance(&mut o, state, 1)?; + queue(&mut o, style::Print('<')) + } +} + +fn greater_than<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + ">".parse_next(i)?; + queue_newline_or_advance(&mut o, state, 1)?; + queue(&mut o, style::Print('>')) + } +} + +fn ampersand<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + "&".parse_next(i)?; + queue_newline_or_advance(&mut o, state, 1)?; + queue(&mut o, style::Print('&')) + } +} + +fn quot<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + """.parse_next(i)?; + queue_newline_or_advance(&mut o, state, 1)?; + queue(&mut o, style::Print('"')) + } +} + +fn line_ending<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + ascii::line_ending.parse_next(i)?; + + state.column = 0; + state.set_newline = true; + + queue(&mut o, style::ResetColor)?; + queue(&mut o, style::SetAttribute(style::Attribute::Reset))?; + queue(&mut o, style::Print("\n")) + } +} + +fn fallback<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + let fallback = any.parse_next(i)?; + if let Some(width) = fallback.width() { + queue_newline_or_advance(&mut o, state, width)?; + if fallback != ' ' || state.column != 1 { + queue(&mut o, style::Print(fallback))?; + } + } + + Ok(()) + } +} + +fn queue_newline_or_advance<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, + width: usize, +) -> Result<(), ErrMode>> { + if let Some(terminal_width) = state.terminal_width { + if state.column > 0 && state.column + width > terminal_width { + state.column = width; + queue(&mut o, style::Print('\n'))?; + return Ok(()); + } + } + + // else + state.column += width; + + Ok(()) +} + +fn queue<'a>(o: &mut impl Write, command: impl Command) -> Result<(), ErrMode>> { + use crossterm::QueueableCommand; + o.queue(command).map_err(|err| ErrMode::Cut(Error::Stdio(err)))?; + Ok(()) +} + +fn codeblock_begin<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + if !state.newline { + return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); + } + + // We don't want to do anything special to text inside codeblocks so we wait for all of it + // The alternative is to switch between parse rules at the top level but that's slightly involved + let language = preceded("```", till_line_ending).parse_next(i)?; + ascii::line_ending.parse_next(i)?; + + state.in_codeblock = true; + + if !language.is_empty() { + queue(&mut o, style::Print(format!("{}\n", language).bold()))?; + } + + queue(&mut o, style::SetForegroundColor(CODE_COLOR))?; + + Ok(()) + } +} + +fn codeblock_end<'a, 'b>( + mut o: impl Write + 'b, + state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + "```".parse_next(i)?; + state.in_codeblock = false; + queue(&mut o, style::ResetColor) + } +} + +fn codeblock_less_than<'a, 'b>( + mut o: impl Write + 'b, + _state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + "<".parse_next(i)?; + queue(&mut o, style::Print('<')) + } +} + +fn codeblock_greater_than<'a, 'b>( + mut o: impl Write + 'b, + _state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + ">".parse_next(i)?; + queue(&mut o, style::Print('>')) + } +} + +fn codeblock_ampersand<'a, 'b>( + mut o: impl Write + 'b, + _state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + "&".parse_next(i)?; + queue(&mut o, style::Print('&')) + } +} + +fn codeblock_quot<'a, 'b>( + mut o: impl Write + 'b, + _state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + """.parse_next(i)?; + queue(&mut o, style::Print('"')) + } +} + +fn codeblock_line_ending<'a, 'b>( + mut o: impl Write + 'b, + _state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + ascii::line_ending.parse_next(i)?; + queue(&mut o, style::Print("\n")) + } +} + +fn codeblock_fallback<'a, 'b>( + mut o: impl Write + 'b, + _state: &'b mut ParseState, +) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { + move |i| { + let fallback = any.parse_next(i)?; + queue(&mut o, style::Print(fallback)) + } +} + +#[cfg(test)] +mod tests { + use std::io::Write; + + use winnow::stream::Offset; + + use super::*; + + macro_rules! validate { + ($test:ident, $input:literal, [$($commands:expr),+ $(,)?]) => { + #[test] + fn $test() -> eyre::Result<()> { + use crossterm::ExecutableCommand; + + let mut input = $input.trim().to_owned(); + input.push(' '); + input.push(' '); + + let mut state = ParseState::new(Some(80)); + let mut presult = vec![]; + let mut offset = 0; + + loop { + let input = Partial::new(&input[offset..]); + match interpret_markdown(input, &mut presult, &mut state) { + Ok(parsed) => { + offset += parsed.offset_from(&input); + state.newline = state.set_newline; + state.set_newline = false; + }, + Err(err) => match err.into_inner() { + Some(err) => panic!("{err}"), + None => break, // Data was incomplete + }, + } + } + + presult.flush()?; + let presult = String::from_utf8(presult)?; + + let mut wresult: Vec = vec![]; + $(wresult.execute($commands)?;)+ + let wresult = String::from_utf8(wresult)?; + + assert_eq!(presult.trim(), wresult); + + Ok(()) + } + }; + } + + validate!(text_1, "hello world!", [style::Print("hello world!")]); + validate!(linted_codeblock_1, "```java\nhello world!```", [ + style::SetAttribute(Attribute::Bold), + style::Print("java\n"), + style::SetAttribute(Attribute::Reset), + style::SetForegroundColor(CODE_COLOR), + style::Print("hello world!"), + style::ResetColor, + ]); + validate!(code_1, "`print`", [ + style::SetForegroundColor(CODE_COLOR), + style::Print("print"), + style::ResetColor, + ]); + validate!(url_1, "[google](google.com)", [ + style::SetForegroundColor(URL_TEXT_COLOR), + style::Print("google "), + style::SetForegroundColor(URL_LINK_COLOR), + style::Print("google.com"), + style::ResetColor, + ]); + validate!(citation_1, "[[1]](google.com)", [ + style::SetForegroundColor(URL_TEXT_COLOR), + style::Print("[^1]"), + style::ResetColor, + ]); + validate!(bold_1, "**hello**", [ + style::SetAttribute(Attribute::Bold), + style::Print("hello"), + style::SetAttribute(Attribute::NormalIntensity) + ]); + validate!(italic_1, "*hello*", [ + style::SetAttribute(Attribute::Italic), + style::Print("hello"), + style::SetAttribute(Attribute::NoItalic) + ]); + validate!(strikethrough_1, "~~hello~~", [ + style::SetAttribute(Attribute::CrossedOut), + style::Print("hello"), + style::SetAttribute(Attribute::NotCrossedOut) + ]); + validate!(less_than_1, "<", [style::Print('<')]); + validate!(greater_than_1, ".>.", [style::Print(".>.")]); + validate!(ampersand_1, "&", [style::Print('&')]); + validate!(quote_1, """, [style::Print('"')]); + validate!(fallback_1, "+ % @ . ? ", [style::Print("+ % @ . ?")]); + validate!(horizontal_rule_1, "---", [style::Print("━".repeat(80))]); + validate!(heading_1, "# Hello World", [ + style::SetForegroundColor(HEADING_COLOR), + style::SetAttribute(Attribute::Bold), + style::Print("# Hello World"), + ]); + validate!(bulleted_item_1, "- bullet", [style::Print("• bullet")]); + validate!(bulleted_item_2, "* bullet", [style::Print("• bullet")]); + validate!(numbered_item_1, "1. number", [style::Print("1. number")]); + validate!(blockquote_1, "> hello", [ + style::SetForegroundColor(BLOCKQUOTE_COLOR), + style::Print("│ hello"), + ]); + validate!(square_bracket_1, "[test]", [style::Print("[test]")]); + validate!(square_bracket_2, "Text with [brackets]", [style::Print( + "Text with [brackets]" + )]); + validate!(square_bracket_empty, "[]", [style::Print("[]")]); + validate!(square_bracket_array, "a[i]", [style::Print("a[i]")]); + validate!(square_bracket_url_like_1, "[text] without url part", [style::Print( + "[text] without url part" + )]); + validate!(square_bracket_url_like_2, "[text](without url part", [style::Print( + "[text](without url part" + )]); +} diff --git a/crates/kiro-cli/src/cli/chat/parser.rs b/crates/kiro-cli/src/cli/chat/parser.rs new file mode 100644 index 0000000000..c95c6f88f2 --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/parser.rs @@ -0,0 +1,375 @@ +use std::time::{ + Duration, + Instant, +}; + +use eyre::Result; +use rand::distr::{ + Alphanumeric, + SampleString, +}; +use thiserror::Error; +use tracing::{ + error, + info, + trace, +}; + +use super::message::{ + AssistantMessage, + AssistantToolUse, +}; +use crate::fig_api_client::clients::SendMessageOutput; +use crate::fig_api_client::model::ChatResponseStream; + +#[derive(Debug, Error)] +pub struct RecvError { + /// The request id associated with the [SendMessageOutput] stream. + pub request_id: Option, + #[source] + pub source: RecvErrorKind, +} + +impl std::fmt::Display for RecvError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Failed to receive the next message: ")?; + if let Some(request_id) = self.request_id.as_ref() { + write!(f, "request_id: {}, error: ", request_id)?; + } + write!(f, "{}", self.source)?; + Ok(()) + } +} + +#[derive(Debug, Error)] +pub enum RecvErrorKind { + #[error("{0}")] + Client(#[from] crate::fig_api_client::Error), + #[error("{0}")] + Json(#[from] serde_json::Error), + /// An error was encountered while waiting for the next event in the stream after a noticeably + /// long wait time. + /// + /// *Context*: the client can throw an error after ~100s of waiting with no response, likely due + /// to an exceptionally complex tool use taking too long to generate. + #[error("The stream ended after {}s: {source}", .duration.as_secs())] + StreamTimeout { + source: crate::fig_api_client::Error, + duration: std::time::Duration, + }, + /// Unexpected end of stream while receiving a tool use. + /// + /// *Context*: the stream can unexpectedly end with `Ok(None)` while waiting for an + /// exceptionally complex tool use. This is due to some proxy server dropping idle + /// connections after some timeout is reached. + /// + /// TODO: should this be removed? + #[error("Unexpected end of stream for tool: {} with id: {}", .name, .tool_use_id)] + UnexpectedToolUseEos { + tool_use_id: String, + name: String, + message: Box, + time_elapsed: Duration, + }, +} + +/// State associated with parsing a [ChatResponseStream] into a [Message]. +/// +/// # Usage +/// +/// You should repeatedly call [Self::recv] to receive [ResponseEvent]'s until a +/// [ResponseEvent::EndStream] value is returned. +#[derive(Debug)] +pub struct ResponseParser { + /// The response to consume and parse into a sequence of [Ev]. + response: SendMessageOutput, + /// Buffer to hold the next event in [SendMessageOutput]. + peek: Option, + /// Message identifier for the assistant's response. Randomly generated on creation. + message_id: String, + /// Buffer for holding the accumulated assistant response. + assistant_text: String, + /// Tool uses requested by the model. + tool_uses: Vec, + /// Whether or not we are currently receiving tool use delta events. Tuple of + /// `Some((tool_use_id, name))` if true, [None] otherwise. + parsing_tool_use: Option<(String, String)>, +} + +impl ResponseParser { + pub fn new(response: SendMessageOutput) -> Self { + let message_id = Alphanumeric.sample_string(&mut rand::rng(), 9); + info!(?message_id, "Generated new message id"); + Self { + response, + peek: None, + message_id, + assistant_text: String::new(), + tool_uses: Vec::new(), + parsing_tool_use: None, + } + } + + /// Consumes the associated [ConverseStreamResponse] until a valid [ResponseEvent] is parsed. + pub async fn recv(&mut self) -> Result { + if let Some((id, name)) = self.parsing_tool_use.take() { + let tool_use = self.parse_tool_use(id, name).await?; + self.tool_uses.push(tool_use.clone()); + return Ok(ResponseEvent::ToolUse(tool_use)); + } + + // First, handle discarding AssistantResponseEvent's that immediately precede a + // CodeReferenceEvent. + let peek = self.peek().await?; + if let Some(ChatResponseStream::AssistantResponseEvent { content }) = peek { + // Cloning to bypass borrowchecker stuff. + let content = content.clone(); + self.next().await?; + match self.peek().await? { + Some(ChatResponseStream::CodeReferenceEvent(_)) => (), + _ => { + self.assistant_text.push_str(&content); + return Ok(ResponseEvent::AssistantText(content)); + }, + } + } + + loop { + match self.next().await { + Ok(Some(output)) => match output { + ChatResponseStream::AssistantResponseEvent { content } => { + self.assistant_text.push_str(&content); + return Ok(ResponseEvent::AssistantText(content)); + }, + ChatResponseStream::InvalidStateEvent { reason, message } => { + error!(%reason, %message, "invalid state event"); + }, + ChatResponseStream::ToolUseEvent { + tool_use_id, + name, + input, + stop, + } => { + debug_assert!(input.is_none(), "Unexpected initial content in first tool use event"); + debug_assert!( + stop.is_none_or(|v| !v), + "Unexpected immediate stop in first tool use event" + ); + self.parsing_tool_use = Some((tool_use_id.clone(), name.clone())); + return Ok(ResponseEvent::ToolUseStart { name }); + }, + _ => {}, + }, + Ok(None) => { + let message_id = Some(self.message_id.clone()); + let content = std::mem::take(&mut self.assistant_text); + let message = if self.tool_uses.is_empty() { + AssistantMessage::new_response(message_id, content) + } else { + AssistantMessage::new_tool_use( + message_id, + content, + self.tool_uses.clone().into_iter().collect(), + ) + }; + return Ok(ResponseEvent::EndStream { message }); + }, + Err(err) => return Err(err), + } + } + } + + /// Consumes the response stream until a valid [ToolUse] is parsed. + /// + /// The arguments are the fields from the first [ChatResponseStream::ToolUseEvent] consumed. + async fn parse_tool_use(&mut self, id: String, name: String) -> Result { + let mut tool_string = String::new(); + let start = Instant::now(); + while let Some(ChatResponseStream::ToolUseEvent { .. }) = self.peek().await? { + if let Some(ChatResponseStream::ToolUseEvent { input, stop, .. }) = self.next().await? { + if let Some(i) = input { + tool_string.push_str(&i); + } + if let Some(true) = stop { + break; + } + } + } + + let args = match serde_json::from_str(&tool_string) { + Ok(args) => args, + Err(err) if !tool_string.is_empty() => { + // If we failed deserializing after waiting for a long time, then this is most + // likely bedrock responding with a stop event for some reason without actually + // including the tool contents. Essentially, the tool was too large. + // Timeouts have been seen as short as ~1 minute, so setting the time to 30. + let time_elapsed = start.elapsed(); + if self.peek().await?.is_none() && time_elapsed > Duration::from_secs(30) { + error!( + "Received an unexpected end of stream after spending ~{}s receiving tool events", + time_elapsed.as_secs_f64() + ); + self.tool_uses.push(AssistantToolUse { + id: id.clone(), + name: name.clone(), + args: serde_json::Value::Object( + [( + "key".to_string(), + serde_json::Value::String( + "WARNING: the actual tool use arguments were too complicated to be generated" + .to_string(), + ), + )] + .into_iter() + .collect(), + ), + }); + let message = Box::new(AssistantMessage::new_tool_use( + Some(self.message_id.clone()), + std::mem::take(&mut self.assistant_text), + self.tool_uses.clone().into_iter().collect(), + )); + return Err(self.error(RecvErrorKind::UnexpectedToolUseEos { + tool_use_id: id, + name, + message, + time_elapsed, + })); + } else { + return Err(self.error(err)); + } + }, + // if the tool just does not need any input + _ => serde_json::json!({}), + }; + Ok(AssistantToolUse { id, name, args }) + } + + /// Returns the next event in the [SendMessageOutput] without consuming it. + async fn peek(&mut self) -> Result, RecvError> { + if self.peek.is_some() { + return Ok(self.peek.as_ref()); + } + match self.next().await? { + Some(v) => { + self.peek = Some(v); + Ok(self.peek.as_ref()) + }, + None => Ok(None), + } + } + + /// Consumes the next [SendMessageOutput] event. + async fn next(&mut self) -> Result, RecvError> { + if let Some(ev) = self.peek.take() { + return Ok(Some(ev)); + } + trace!("Attempting to recv next event"); + let start = std::time::Instant::now(); + let result = self.response.recv().await; + let duration = std::time::Instant::now().duration_since(start); + match result { + Ok(r) => { + trace!(?r, "Received new event"); + Ok(r) + }, + Err(err) => { + if duration.as_secs() >= 59 { + Err(self.error(RecvErrorKind::StreamTimeout { source: err, duration })) + } else { + Err(self.error(err)) + } + }, + } + } + + fn request_id(&self) -> Option<&str> { + self.response.request_id() + } + + /// Helper to create a new [RecvError] populated with the associated request id for the stream. + fn error(&self, source: impl Into) -> RecvError { + RecvError { + request_id: self.request_id().map(str::to_string), + source: source.into(), + } + } +} + +#[derive(Debug)] +pub enum ResponseEvent { + /// Text returned by the assistant. This should be displayed to the user as it is received. + AssistantText(String), + /// Notification that a tool use is being received. + ToolUseStart { name: String }, + /// A tool use requested by the assistant. This should be displayed to the user as it is + /// received. + ToolUse(AssistantToolUse), + /// Represents the end of the response. No more events will be returned. + EndStream { + /// The completed message containing all of the assistant text and tool use events + /// previously emitted. This should be stored in the conversation history and sent in + /// subsequent requests. + message: AssistantMessage, + }, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_parse() { + let _ = tracing_subscriber::fmt::try_init(); + let tool_use_id = "TEST_ID".to_string(); + let tool_name = "execute_bash".to_string(); + let tool_args = serde_json::json!({ + "command": "echo hello" + }) + .to_string(); + let tool_use_split_at = 5; + let mut events = vec![ + ChatResponseStream::AssistantResponseEvent { + content: "hi".to_string(), + }, + ChatResponseStream::AssistantResponseEvent { + content: " there".to_string(), + }, + ChatResponseStream::AssistantResponseEvent { + content: "IGNORE ME PLEASE".to_string(), + }, + ChatResponseStream::CodeReferenceEvent(()), + ChatResponseStream::ToolUseEvent { + tool_use_id: tool_use_id.clone(), + name: tool_name.clone(), + input: None, + stop: None, + }, + ChatResponseStream::ToolUseEvent { + tool_use_id: tool_use_id.clone(), + name: tool_name.clone(), + input: Some(tool_args.as_str().split_at(tool_use_split_at).0.to_string()), + stop: None, + }, + ChatResponseStream::ToolUseEvent { + tool_use_id: tool_use_id.clone(), + name: tool_name.clone(), + input: Some(tool_args.as_str().split_at(tool_use_split_at).1.to_string()), + stop: None, + }, + ChatResponseStream::ToolUseEvent { + tool_use_id: tool_use_id.clone(), + name: tool_name.clone(), + input: None, + stop: Some(true), + }, + ]; + events.reverse(); + let mock = SendMessageOutput::Mock(events); + let mut parser = ResponseParser::new(mock); + + for _ in 0..5 { + println!("{:?}", parser.recv().await.unwrap()); + } + } +} diff --git a/crates/kiro-cli/src/cli/chat/prompt.rs b/crates/kiro-cli/src/cli/chat/prompt.rs new file mode 100644 index 0000000000..0fa7cf6f62 --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/prompt.rs @@ -0,0 +1,364 @@ +use std::borrow::Cow; + +use crossterm::style::Stylize; +use eyre::Result; +use rustyline::completion::{ + Completer, + FilenameCompleter, + extract_word, +}; +use rustyline::error::ReadlineError; +use rustyline::highlight::{ + CmdKind, + Highlighter, +}; +use rustyline::history::DefaultHistory; +use rustyline::validate::{ + ValidationContext, + ValidationResult, + Validator, +}; +use rustyline::{ + Cmd, + Completer, + CompletionType, + Config, + Context, + EditMode, + Editor, + EventHandler, + Helper, + Hinter, + KeyCode, + KeyEvent, + Modifiers, +}; +use winnow::stream::AsChar; + +pub const COMMANDS: &[&str] = &[ + "/clear", + "/help", + "/editor", + "/issue", + // "/acceptall", /// Functional, but deprecated in favor of /tools trustall + "/quit", + "/tools", + "/tools trust", + "/tools untrust", + "/tools trustall", + "/tools reset", + "/profile", + "/profile help", + "/profile list", + "/profile create", + "/profile delete", + "/profile rename", + "/profile set", + "/context help", + "/context show", + "/context show --expand", + "/context add", + "/context add --global", + "/context rm", + "/context rm --global", + "/context clear", + "/context clear --global", + "/context hooks help", + "/context hooks add", + "/context hooks rm", + "/context hooks enable", + "/context hooks disable", + "/context hooks enable-all", + "/context hooks disable-all", + "/compact", + "/compact help", + "/usage", +]; + +pub fn generate_prompt(current_profile: Option<&str>, warning: bool) -> String { + let warning_symbol = if warning { "!".red().to_string() } else { "".to_string() }; + let profile_part = current_profile + .filter(|&p| p != "default") + .map(|p| format!("[{p}] ").cyan().to_string()) + .unwrap_or_default(); + + format!("{profile_part}{warning_symbol}{}", "> ".magenta()) +} + +/// Complete commands that start with a slash +fn complete_command(word: &str, start: usize) -> (usize, Vec) { + ( + start, + COMMANDS + .iter() + .filter(|p| p.starts_with(word)) + .map(|s| (*s).to_owned()) + .collect(), + ) +} + +/// A wrapper around FilenameCompleter that provides enhanced path detection +/// and completion capabilities for the chat interface. +pub struct PathCompleter { + /// The underlying filename completer from rustyline + filename_completer: FilenameCompleter, +} + +impl PathCompleter { + /// Creates a new PathCompleter instance + pub fn new() -> Self { + Self { + filename_completer: FilenameCompleter::new(), + } + } + + /// Attempts to complete a file path at the given position in the line + pub fn complete_path( + &self, + line: &str, + pos: usize, + ctx: &Context<'_>, + ) -> Result<(usize, Vec), ReadlineError> { + // Use the filename completer to get path completions + match self.filename_completer.complete(line, pos, ctx) { + Ok((pos, completions)) => { + // Convert the filename completer's pairs to strings + let file_completions: Vec = completions.iter().map(|pair| pair.replacement.clone()).collect(); + + // Return the completions if we have any + Ok((pos, file_completions)) + }, + Err(err) => Err(err), + } + } +} + +pub struct PromptCompleter { + sender: std::sync::mpsc::Sender>, + receiver: std::sync::mpsc::Receiver>, +} + +impl PromptCompleter { + fn new(sender: std::sync::mpsc::Sender>, receiver: std::sync::mpsc::Receiver>) -> Self { + PromptCompleter { sender, receiver } + } + + fn complete_prompt(&self, word: &str) -> Result, ReadlineError> { + let sender = &self.sender; + let receiver = &self.receiver; + sender + .send(if !word.is_empty() { Some(word.to_string()) } else { None }) + .map_err(|e| ReadlineError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))?; + let prompt_info = receiver + .recv() + .map_err(|e| ReadlineError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))? + .iter() + .map(|n| format!("@{n}")) + .collect::>(); + + Ok(prompt_info) + } +} + +pub struct ChatCompleter { + path_completer: PathCompleter, + prompt_completer: PromptCompleter, +} + +impl ChatCompleter { + fn new(sender: std::sync::mpsc::Sender>, receiver: std::sync::mpsc::Receiver>) -> Self { + Self { + path_completer: PathCompleter::new(), + prompt_completer: PromptCompleter::new(sender, receiver), + } + } +} + +impl Completer for ChatCompleter { + type Candidate = String; + + fn complete( + &self, + line: &str, + pos: usize, + _ctx: &Context<'_>, + ) -> Result<(usize, Vec), ReadlineError> { + let (start, word) = extract_word(line, pos, None, |c| c.is_space()); + + // Handle command completion + if word.starts_with('/') { + return Ok(complete_command(word, start)); + } + + if line.starts_with('@') { + let search_word = line.strip_prefix('@').unwrap_or(""); + if let Ok(completions) = self.prompt_completer.complete_prompt(search_word) { + if !completions.is_empty() { + return Ok((0, completions)); + } + } + } + + // Handle file path completion as fallback + if let Ok((pos, completions)) = self.path_completer.complete_path(line, pos, _ctx) { + if !completions.is_empty() { + return Ok((pos, completions)); + } + } + + // Default: no completions + Ok((start, Vec::new())) + } +} + +/// Custom validator for multi-line input +pub struct MultiLineValidator; + +impl Validator for MultiLineValidator { + fn validate(&self, ctx: &mut ValidationContext<'_>) -> rustyline::Result { + let input = ctx.input(); + + // Check for explicit multi-line markers + if input.starts_with("```") && !input.ends_with("```") { + return Ok(ValidationResult::Incomplete); + } + + // Check for backslash continuation + if input.ends_with('\\') { + return Ok(ValidationResult::Incomplete); + } + + Ok(ValidationResult::Valid(None)) + } +} + +#[derive(Helper, Completer, Hinter)] +pub struct ChatHelper { + #[rustyline(Completer)] + completer: ChatCompleter, + #[rustyline(Hinter)] + hinter: (), + validator: MultiLineValidator, +} + +impl Validator for ChatHelper { + fn validate(&self, ctx: &mut ValidationContext<'_>) -> rustyline::Result { + self.validator.validate(ctx) + } +} + +impl Highlighter for ChatHelper { + fn highlight_hint<'h>(&self, hint: &'h str) -> Cow<'h, str> { + Cow::Owned(format!("\x1b[1m{hint}\x1b[m")) + } + + fn highlight<'l>(&self, line: &'l str, _pos: usize) -> Cow<'l, str> { + Cow::Borrowed(line) + } + + fn highlight_char(&self, _line: &str, _pos: usize, _kind: CmdKind) -> bool { + false + } +} + +pub fn rl( + sender: std::sync::mpsc::Sender>, + receiver: std::sync::mpsc::Receiver>, +) -> Result> { + let edit_mode = match crate::fig_settings::settings::get_string_opt("chat.editMode").as_deref() { + Some("vi" | "vim") => EditMode::Vi, + _ => EditMode::Emacs, + }; + let config = Config::builder() + .history_ignore_space(true) + .completion_type(CompletionType::List) + .edit_mode(edit_mode) + .build(); + let h = ChatHelper { + completer: ChatCompleter::new(sender, receiver), + hinter: (), + validator: MultiLineValidator, + }; + let mut rl = Editor::with_config(config)?; + rl.set_helper(Some(h)); + + // Add custom keybinding for Alt+Enter to insert a newline + rl.bind_sequence( + KeyEvent(KeyCode::Enter, Modifiers::ALT), + EventHandler::Simple(Cmd::Insert(1, "\n".to_string())), + ); + + // Add custom keybinding for Ctrl+J to insert a newline + rl.bind_sequence( + KeyEvent(KeyCode::Char('j'), Modifiers::CTRL), + EventHandler::Simple(Cmd::Insert(1, "\n".to_string())), + ); + + Ok(rl) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generate_prompt() { + // Test default prompt (no profile) + assert_eq!(generate_prompt(None, false), "> ".magenta().to_string()); + // Test default prompt with warning + assert_eq!(generate_prompt(None, true), format!("{}{}", "!".red(), "> ".magenta())); + // Test default profile (should be same as no profile) + assert_eq!(generate_prompt(Some("default"), false), "> ".magenta().to_string()); + // Test custom profile + assert_eq!( + generate_prompt(Some("test-profile"), false), + format!("{}{}", "[test-profile] ".cyan(), "> ".magenta()) + ); + // Test another custom profile with warning + assert_eq!( + generate_prompt(Some("dev"), true), + format!("{}{}{}", "[dev] ".cyan(), "!".red(), "> ".magenta()) + ); + } + + #[test] + fn test_chat_completer_command_completion() { + let (prompt_request_sender, _) = std::sync::mpsc::channel::>(); + let (_, prompt_response_receiver) = std::sync::mpsc::channel::>(); + let completer = ChatCompleter::new(prompt_request_sender, prompt_response_receiver); + let line = "/h"; + let pos = 2; // Position at the end of "/h" + + // Create a mock context with empty history + let empty_history = DefaultHistory::new(); + let ctx = Context::new(&empty_history); + + // Get completions + let (start, completions) = completer.complete(line, pos, &ctx).unwrap(); + + // Verify start position + assert_eq!(start, 0); + + // Verify completions contain expected commands + assert!(completions.contains(&"/help".to_string())); + } + + #[test] + fn test_chat_completer_no_completion() { + let (prompt_request_sender, _) = std::sync::mpsc::channel::>(); + let (_, prompt_response_receiver) = std::sync::mpsc::channel::>(); + let completer = ChatCompleter::new(prompt_request_sender, prompt_response_receiver); + let line = "Hello, how are you?"; + let pos = line.len(); + + // Create a mock context with empty history + let empty_history = DefaultHistory::new(); + let ctx = Context::new(&empty_history); + + // Get completions + let (_, completions) = completer.complete(line, pos, &ctx).unwrap(); + + // Verify no completions are returned for regular text + assert!(completions.is_empty()); + } +} diff --git a/crates/kiro-cli/src/cli/chat/shared_writer.rs b/crates/kiro-cli/src/cli/chat/shared_writer.rs new file mode 100644 index 0000000000..c5a2f55c41 --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/shared_writer.rs @@ -0,0 +1,89 @@ +use std::io::{ + self, + Write, +}; +use std::sync::{ + Arc, + Mutex, +}; + +/// A thread-safe wrapper for any Write implementation. +#[derive(Clone)] +pub struct SharedWriter { + inner: Arc>>, +} + +impl SharedWriter { + pub fn new(writer: W) -> Self + where + W: Write + Send + 'static, + { + Self { + inner: Arc::new(Mutex::new(Box::new(writer))), + } + } + + pub fn stdout() -> Self { + Self::new(io::stdout()) + } + + pub fn stderr() -> Self { + Self::new(io::stderr()) + } + + pub fn null() -> Self { + Self::new(NullWriter {}) + } +} + +impl std::fmt::Debug for SharedWriter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SharedWriter").finish() + } +} + +impl Write for SharedWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.inner.lock().expect("Mutex poisoned").write(buf) + } + + fn flush(&mut self) -> io::Result<()> { + self.inner.lock().expect("Mutex poisoned").flush() + } +} + +#[derive(Debug, Clone)] +pub struct NullWriter {} + +impl Write for NullWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub struct TestWriterWithSink { + pub sink: Arc>>, +} + +impl TestWriterWithSink { + #[allow(dead_code)] + pub fn get_content(&self) -> Vec { + self.sink.lock().unwrap().clone() + } +} + +impl Write for TestWriterWithSink { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.sink.lock().unwrap().append(&mut buf.to_vec()); + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } +} diff --git a/crates/kiro-cli/src/cli/chat/skim_integration.rs b/crates/kiro-cli/src/cli/chat/skim_integration.rs new file mode 100644 index 0000000000..026576be14 --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/skim_integration.rs @@ -0,0 +1,378 @@ +use std::io::{ + BufReader, + Cursor, + Write, + stdout, +}; + +use crossterm::execute; +use crossterm::terminal::{ + EnterAlternateScreen, + LeaveAlternateScreen, +}; +use eyre::{ + Result, + eyre, +}; +use rustyline::{ + Cmd, + ConditionalEventHandler, + EventContext, + RepeatCount, +}; +use skim::prelude::*; +use tempfile::NamedTempFile; + +use super::context::ContextManager; + +pub fn select_profile_with_skim(context_manager: &ContextManager) -> Result> { + let profiles = context_manager.list_profiles_blocking()?; + + launch_skim_selector(&profiles, "Select profile: ", false) + .map(|selected| selected.and_then(|s| s.into_iter().next())) +} + +pub struct SkimCommandSelector { + context_manager: Arc, + tool_names: Vec, +} + +impl SkimCommandSelector { + /// This allows the ConditionalEventHandler handle function to be bound to a KeyEvent. + pub fn new(context_manager: Arc, tool_names: Vec) -> Self { + Self { + context_manager, + tool_names, + } + } +} + +impl ConditionalEventHandler for SkimCommandSelector { + fn handle( + &self, + _evt: &rustyline::Event, + _n: RepeatCount, + _positive: bool, + _ctx: &EventContext<'_>, + ) -> Option { + // Launch skim command selector with the context manager if available + match select_command(self.context_manager.as_ref(), &self.tool_names) { + Ok(Some(command)) => Some(Cmd::Insert(1, command)), + _ => { + // If cancelled or error, do nothing + Some(Cmd::Noop) + }, + } + } +} + +pub fn get_available_commands() -> Vec { + // Import the COMMANDS array directly from prompt.rs + // This is the single source of truth for available commands + let commands_array = super::prompt::COMMANDS; + + let mut commands = Vec::new(); + for &cmd in commands_array { + commands.push(cmd.to_string()); + } + + commands +} + +/// Format commands for skim display +/// Create a standard set of skim options with consistent styling +fn create_skim_options(prompt: &str, multi: bool) -> Result { + SkimOptionsBuilder::default() + .height("100%".to_string()) + .prompt(prompt.to_string()) + .reverse(true) + .multi(multi) + .build() + .map_err(|e| eyre!("Failed to build skim options: {}", e)) +} + +/// Run skim with the given options and items in an alternate screen +/// This helper function handles entering/exiting the alternate screen and running skim +fn run_skim_with_options(options: &SkimOptions, items: SkimItemReceiver) -> Result>>> { + // Enter alternate screen to prevent skim output from persisting in terminal history + execute!(stdout(), EnterAlternateScreen).map_err(|e| eyre!("Failed to enter alternate screen: {}", e))?; + + let selected_items = + Skim::run_with(options, Some(items)).and_then(|out| if out.is_abort { None } else { Some(out.selected_items) }); + + execute!(stdout(), LeaveAlternateScreen).map_err(|e| eyre!("Failed to leave alternate screen: {}", e))?; + + Ok(selected_items) +} + +/// Extract string selections from skim items +fn extract_selections(items: Vec>) -> Vec { + items.iter().map(|item| item.output().to_string()).collect() +} + +/// Launch skim with the given items and return the selected item +pub fn launch_skim_selector(items: &[String], prompt: &str, multi: bool) -> Result>> { + let mut temp_file_for_skim_input = NamedTempFile::new()?; + temp_file_for_skim_input.write_all(items.join("\n").as_bytes())?; + + let options = create_skim_options(prompt, multi)?; + let item_reader = SkimItemReader::default(); + let items = item_reader.of_bufread(BufReader::new(std::fs::File::open(temp_file_for_skim_input.path())?)); + + // Run skim and get selected items + match run_skim_with_options(&options, items)? { + Some(items) if !items.is_empty() => { + let selections = extract_selections(items); + Ok(Some(selections)) + }, + _ => Ok(None), // User cancelled or no selection + } +} + +/// Select files using skim +pub fn select_files_with_skim() -> Result>> { + // Create skim options with appropriate settings + let options = create_skim_options("Select files: ", true)?; + + // Create a command that will be executed by skim + // This avoids loading all files into memory at once + let find_cmd = "find . -type f -not -path '*/\\.*'"; + + // Create a command collector that will execute the find command + let item_reader = SkimItemReader::default(); + let items = item_reader.of_bufread(BufReader::new( + std::process::Command::new("sh") + .args(["-c", find_cmd]) + .stdout(std::process::Stdio::piped()) + .spawn()? + .stdout + .ok_or_else(|| eyre!("Failed to get stdout from command"))?, + )); + + // Run skim with the command output as a stream + match run_skim_with_options(&options, items)? { + Some(items) if !items.is_empty() => { + let selections = extract_selections(items); + Ok(Some(selections)) + }, + _ => Ok(None), // User cancelled or no selection + } +} + +/// Select context paths using skim +pub fn select_context_paths_with_skim(context_manager: &ContextManager) -> Result, bool)>> { + let mut global_paths = Vec::new(); + let mut profile_paths = Vec::new(); + + // Get global paths + for path in &context_manager.global_config.paths { + global_paths.push(format!("(global) {}", path)); + } + + // Get profile-specific paths + for path in &context_manager.profile_config.paths { + profile_paths.push(format!("(profile: {}) {}", context_manager.current_profile, path)); + } + + // Combine paths, but keep track of which are global + let mut all_paths = Vec::new(); + all_paths.extend(global_paths); + all_paths.extend(profile_paths); + + if all_paths.is_empty() { + return Ok(None); // No paths to select + } + + // Create skim options + let options = create_skim_options("Select paths to remove: ", true)?; + + // Create item reader + let item_reader = SkimItemReader::default(); + let items = item_reader.of_bufread(Cursor::new(all_paths.join("\n"))); + + // Run skim and get selected paths + match run_skim_with_options(&options, items)? { + Some(items) if !items.is_empty() => { + let selected_paths = extract_selections(items); + + // Check if any global paths were selected + let has_global = selected_paths.iter().any(|p| p.starts_with("(global)")); + + // Extract the actual paths from the formatted strings + let paths: Vec = selected_paths + .iter() + .map(|p| { + // Extract the path part after the prefix + let parts: Vec<&str> = p.splitn(2, ") ").collect(); + if parts.len() > 1 { + parts[1].to_string() + } else { + p.clone() + } + }) + .collect(); + + Ok(Some((paths, has_global))) + }, + _ => Ok(None), // User cancelled selection + } +} + +/// Launch the command selector and handle the selected command +pub fn select_command(context_manager: &ContextManager, tools: &[String]) -> Result> { + let commands = get_available_commands(); + + match launch_skim_selector(&commands, "Select command: ", false)? { + Some(selections) if !selections.is_empty() => { + let selected_command = &selections[0]; + + match CommandType::from_str(selected_command) { + Some(CommandType::ContextAdd(cmd)) => { + // For context add commands, we need to select files + match select_files_with_skim()? { + Some(files) if !files.is_empty() => { + // Construct the full command with selected files + let mut cmd = cmd.clone(); + for file in files { + cmd.push_str(&format!(" {}", file)); + } + Ok(Some(cmd)) + }, + _ => Ok(Some(selected_command.clone())), /* User cancelled file selection, return just the + * command */ + } + }, + Some(CommandType::ContextRemove(cmd)) => { + // For context rm commands, we need to select from existing context paths + match select_context_paths_with_skim(context_manager)? { + Some((paths, has_global)) if !paths.is_empty() => { + // Construct the full command with selected paths + let mut full_cmd = cmd.clone(); + if has_global { + full_cmd.push_str(" --global"); + } + for path in paths { + full_cmd.push_str(&format!(" {}", path)); + } + Ok(Some(full_cmd)) + }, + Some((_, _)) => Ok(Some(format!("{} (No paths selected)", cmd))), + None => Ok(Some(selected_command.clone())), // User cancelled path selection + } + }, + Some(CommandType::Tools(_)) => { + let options = create_skim_options("Select tool: ", false)?; + let item_reader = SkimItemReader::default(); + let items = item_reader.of_bufread(Cursor::new(tools.join("\n"))); + let selected_tool = match run_skim_with_options(&options, items)? { + Some(items) if !items.is_empty() => Some(items[0].output().to_string()), + _ => None, + }; + + match selected_tool { + Some(tool) => Ok(Some(format!("{} {}", selected_command, tool))), + None => Ok(Some(selected_command.clone())), /* User cancelled tool selection, return just the + * command */ + } + }, + Some(cmd @ CommandType::Profile(_)) if cmd.needs_profile_selection() => { + // For profile operations that need a profile name, show profile selector + match select_profile_with_skim(context_manager)? { + Some(profile) => { + let full_cmd = format!("{} {}", selected_command, profile); + Ok(Some(full_cmd)) + }, + None => Ok(Some(selected_command.clone())), // User cancelled profile selection + } + }, + Some(CommandType::Profile(_)) => { + // For other profile operations (like create), just return the command + Ok(Some(selected_command.clone())) + }, + None => { + // Command doesn't need additional parameters + Ok(Some(selected_command.clone())) + }, + } + }, + _ => Ok(None), // User cancelled command selection + } +} + +#[derive(PartialEq)] +enum CommandType { + ContextAdd(String), + ContextRemove(String), + Tools(&'static str), + Profile(&'static str), +} + +impl CommandType { + fn needs_profile_selection(&self) -> bool { + matches!(self, CommandType::Profile("set" | "delete" | "rename")) + } + + fn from_str(cmd: &str) -> Option { + if cmd.starts_with("/context add") { + Some(CommandType::ContextAdd(cmd.to_string())) + } else if cmd.starts_with("/context rm") { + Some(CommandType::ContextRemove(cmd.to_string())) + } else { + match cmd { + "/tools trust" => Some(CommandType::Tools("trust")), + "/tools untrust" => Some(CommandType::Tools("untrust")), + "/profile set" => Some(CommandType::Profile("set")), + "/profile delete" => Some(CommandType::Profile("delete")), + "/profile rename" => Some(CommandType::Profile("rename")), + "/profile create" => Some(CommandType::Profile("create")), + _ => None, + } + } + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use super::*; + + /// Test to verify that all hardcoded command strings in select_command + /// are present in the COMMANDS array from prompt.rs + #[test] + fn test_hardcoded_commands_in_commands_array() { + // Get the set of available commands from prompt.rs + let available_commands: HashSet = get_available_commands().iter().cloned().collect(); + + // List of hardcoded commands used in select_command + let hardcoded_commands = vec![ + "/context add", + "/context add --global", + "/context rm", + "/context rm --global", + "/tools trust", + "/tools untrust", + "/profile set", + "/profile delete", + "/profile rename", + "/profile create", + ]; + + // Check that each hardcoded command is in the COMMANDS array + for cmd in hardcoded_commands { + assert!( + available_commands.contains(cmd), + "Command '{}' is used in select_command but not defined in COMMANDS array", + cmd + ); + + // This should assert that all the commands we assert are present in the match statement of + // select_command() + assert!( + CommandType::from_str(cmd).is_some(), + "Command '{}' cannot be parsed into a CommandType", + cmd + ); + } + } +} diff --git a/crates/kiro-cli/src/cli/chat/token_counter.rs b/crates/kiro-cli/src/cli/chat/token_counter.rs new file mode 100644 index 0000000000..1e651b96b4 --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/token_counter.rs @@ -0,0 +1,251 @@ +use std::ops::Deref; + +use super::conversation_state::{ + BackendConversationState, + ConversationSize, +}; +use super::message::{ + AssistantMessage, + ToolUseResult, + ToolUseResultBlock, + UserMessage, + UserMessageContent, +}; + +#[derive(Debug, Clone, Copy)] +pub struct CharCount(usize); + +impl CharCount { + pub fn value(&self) -> usize { + self.0 + } +} + +impl Deref for CharCount { + type Target = usize; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From for CharCount { + fn from(value: usize) -> Self { + Self(value) + } +} + +impl std::ops::Add for CharCount { + type Output = CharCount; + + fn add(self, rhs: Self) -> Self::Output { + Self(self.value() + rhs.value()) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub struct TokenCount(usize); + +impl TokenCount { + pub fn value(&self) -> usize { + self.0 + } +} + +impl Deref for TokenCount { + type Target = usize; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl From for TokenCount { + fn from(value: CharCount) -> Self { + Self(TokenCounter::count_tokens_char_count(value.value())) + } +} + +impl std::fmt::Display for TokenCount { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +pub struct TokenCounter; + +impl TokenCounter { + pub const TOKEN_TO_CHAR_RATIO: usize = 3; + + /// Estimates the number of tokens in the input content. + /// Currently uses a simple heuristic: content length / TOKEN_TO_CHAR_RATIO + /// + /// Rounds up to the nearest multiple of 10 to avoid giving users a false sense of precision. + pub fn count_tokens(content: &str) -> usize { + Self::count_tokens_char_count(content.len()) + } + + fn count_tokens_char_count(count: usize) -> usize { + (count / Self::TOKEN_TO_CHAR_RATIO + 5) / 10 * 10 + } + + pub const fn token_to_chars(token: usize) -> usize { + token * Self::TOKEN_TO_CHAR_RATIO + } +} + +/// A trait for types that represent some number of characters (aka bytes). For use in calculating +/// context window size utilization. +pub trait CharCounter { + /// Returns the number of characters contained within this type. + /// + /// One "character" is essentially the same as one "byte" + fn char_count(&self) -> CharCount; +} + +impl CharCounter for BackendConversationState<'_> { + fn char_count(&self) -> CharCount { + self.calculate_conversation_size().char_count() + } +} + +impl CharCounter for ConversationSize { + fn char_count(&self) -> CharCount { + self.user_messages + self.assistant_messages + self.context_messages + } +} + +impl CharCounter for UserMessage { + fn char_count(&self) -> CharCount { + let mut total_chars = 0; + total_chars += self.additional_context().len(); + match self.content() { + UserMessageContent::Prompt { prompt } => { + total_chars += prompt.len(); + }, + UserMessageContent::CancelledToolUses { + prompt, + tool_use_results, + } => { + total_chars += prompt.as_ref().map_or(0, String::len); + total_chars += tool_use_results.as_slice().char_count().0; + }, + UserMessageContent::ToolUseResults { tool_use_results } => { + total_chars += tool_use_results.as_slice().char_count().0; + }, + } + total_chars.into() + } +} + +impl CharCounter for AssistantMessage { + fn char_count(&self) -> CharCount { + let mut total_chars = 0; + total_chars += self.content().len(); + if let Some(tool_uses) = self.tool_uses() { + total_chars += tool_uses + .iter() + .map(|v| calculate_value_char_count(&v.args)) + .reduce(|acc, e| acc + e) + .unwrap_or_default(); + } + total_chars.into() + } +} + +impl CharCounter for &[ToolUseResult] { + fn char_count(&self) -> CharCount { + self.iter() + .flat_map(|v| &v.content) + .fold(0, |acc, v| { + acc + match v { + ToolUseResultBlock::Json(v) => calculate_value_char_count(v), + ToolUseResultBlock::Text(s) => s.len(), + } + }) + .into() + } +} + +fn calculate_value_char_count(document: &serde_json::Value) -> usize { + match document { + serde_json::Value::Null => 1, + serde_json::Value::Bool(_) => 1, + serde_json::Value::Number(_) => 1, + serde_json::Value::String(s) => s.len(), + serde_json::Value::Array(vec) => vec.iter().fold(0, |acc, v| acc + calculate_value_char_count(v)), + serde_json::Value::Object(map) => map.values().fold(0, |acc, v| acc + calculate_value_char_count(v)), + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_token_count() { + let text = "This is a test sentence."; + let count = TokenCounter::count_tokens(text); + assert_eq!(count, (text.len() / 3 + 5) / 10 * 10); + } + + #[test] + fn test_calculate_value_char_count() { + // Test simple types + assert_eq!( + calculate_value_char_count(&serde_json::Value::String("hello".to_string())), + 5 + ); + assert_eq!( + calculate_value_char_count(&serde_json::Value::Number(serde_json::Number::from(123))), + 1 + ); + assert_eq!(calculate_value_char_count(&serde_json::Value::Bool(true)), 1); + assert_eq!(calculate_value_char_count(&serde_json::Value::Null), 1); + + // Test array + let array = serde_json::Value::Array(vec![ + serde_json::Value::String("test".to_string()), + serde_json::Value::Number(serde_json::Number::from(42)), + serde_json::Value::Bool(false), + ]); + assert_eq!(calculate_value_char_count(&array), 6); // "test" (4) + Number (1) + Bool (1) + + // Test object + let mut obj = serde_json::Map::new(); + obj.insert("key1".to_string(), serde_json::Value::String("value1".to_string())); + obj.insert( + "key2".to_string(), + serde_json::Value::Number(serde_json::Number::from(99)), + ); + let object = serde_json::Value::Object(obj); + assert_eq!(calculate_value_char_count(&object), 7); // "value1" (6) + Number (1) + + // Test nested structure + let mut nested_obj = serde_json::Map::new(); + let mut inner_obj = serde_json::Map::new(); + inner_obj.insert( + "inner_key".to_string(), + serde_json::Value::String("inner_value".to_string()), + ); + nested_obj.insert("outer_key".to_string(), serde_json::Value::Object(inner_obj)); + nested_obj.insert( + "array_key".to_string(), + serde_json::Value::Array(vec![ + serde_json::Value::String("item1".to_string()), + serde_json::Value::String("item2".to_string()), + ]), + ); + + let complex = serde_json::Value::Object(nested_obj); + assert_eq!(calculate_value_char_count(&complex), 21); // "inner_value" (11) + "item1" (5) + "item2" (5) + + // Test empty structures + assert_eq!(calculate_value_char_count(&serde_json::Value::Array(vec![])), 0); + assert_eq!( + calculate_value_char_count(&serde_json::Value::Object(serde_json::Map::new())), + 0 + ); + } +} diff --git a/crates/kiro-cli/src/cli/chat/tool_manager.rs b/crates/kiro-cli/src/cli/chat/tool_manager.rs new file mode 100644 index 0000000000..108c156327 --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/tool_manager.rs @@ -0,0 +1,1019 @@ +use std::collections::HashMap; +use std::hash::{ + DefaultHasher, + Hasher, +}; +use std::io::Write; +use std::path::PathBuf; +use std::sync::mpsc::RecvTimeoutError; +use std::sync::{ + Arc, + RwLock as SyncRwLock, +}; + +use convert_case::Casing; +use crossterm::{ + cursor, + execute, + queue, + style, + terminal, +}; +use futures::{ + StreamExt, + stream, +}; +use serde::{ + Deserialize, + Serialize, +}; +use thiserror::Error; +use tokio::sync::Mutex; +use tracing::error; + +use super::command::PromptsGetCommand; +use super::message::AssistantToolUse; +use super::tools::custom_tool::{ + CustomTool, + CustomToolClient, + CustomToolConfig, +}; +use super::tools::execute_bash::ExecuteBash; +use super::tools::fs_read::FsRead; +use super::tools::fs_write::FsWrite; +use super::tools::gh_issue::GhIssue; +use super::tools::use_aws::UseAws; +use super::tools::{ + Tool, + ToolOrigin, + ToolSpec, +}; +use crate::fig_api_client::model::{ + ToolResult, + ToolResultContentBlock, + ToolResultStatus, +}; +use crate::mcp_client::{ + JsonRpcResponse, + PromptGet, +}; + +const NAMESPACE_DELIMITER: &str = "___"; +// This applies for both mcp server and tool name since in the end the tool name as seen by the +// model is just {server_name}{NAMESPACE_DELIMITER}{tool_name} +const VALID_TOOL_NAME: &str = "^[a-zA-Z][a-zA-Z0-9_]*$"; +const SPINNER_CHARS: [char; 10] = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']; + +#[derive(Debug, Error)] +pub enum GetPromptError { + #[error("Prompt with name {0} does not exist")] + PromptNotFound(String), + #[error("Prompt {0} is offered by more than one server. Use one of the following {1}")] + AmbiguousPrompt(String, String), + #[error("Missing client")] + MissingClient, + #[error("Missing prompt name")] + MissingPromptName, + #[error("Synchronization error: {0}")] + Synchronization(String), + #[error("Missing prompt bundle")] + MissingPromptInfo, + #[error(transparent)] + General(#[from] eyre::Report), +} + +/// Messages used for communication between the tool initialization thread and the loading +/// display thread. These messages control the visual loading indicators shown to +/// the user during tool initialization. +enum LoadingMsg { + /// Indicates a new tool is being initialized and should be added to the loading + /// display. The String parameter is the name of the tool being initialized. + Add(String), + /// Indicates a tool has finished initializing successfully and should be removed from + /// the loading display. The String parameter is the name of the tool that + /// completed initialization. + Done(String), + /// Represents an error that occurred during tool initialization. + /// Contains the name of the server that failed to initialize and the error message. + Error { name: String, msg: eyre::Report }, + /// Represents a warning that occurred during tool initialization. + /// Contains the name of the server that generated the warning and the warning message. + Warn { name: String, msg: eyre::Report }, +} + +/// Represents the state of a loading indicator for a tool being initialized. +/// +/// This struct tracks timing information for each tool's loading status display in the terminal. +/// +/// # Fields +/// * `init_time` - When initialization for this tool began, used to calculate load time +struct StatusLine { + init_time: std::time::Instant, +} + +// This is to mirror claude's config set up +#[derive(Clone, Serialize, Deserialize, Debug, Default)] +#[serde(rename_all = "camelCase")] +pub struct McpServerConfig { + mcp_servers: HashMap, +} + +impl McpServerConfig { + pub async fn load_config(output: &mut impl Write) -> eyre::Result { + let mut cwd = std::env::current_dir()?; + cwd.push(".amazonq/mcp.json"); + let expanded_path = shellexpand::tilde("~/.aws/amazonq/mcp.json"); + let global_path = PathBuf::from(expanded_path.as_ref()); + let global_buf = tokio::fs::read(global_path).await.ok(); + let local_buf = tokio::fs::read(cwd).await.ok(); + let conf = match (global_buf, local_buf) { + (Some(global_buf), Some(local_buf)) => { + let mut global_conf = Self::from_slice(&global_buf, output, "global")?; + let local_conf = Self::from_slice(&local_buf, output, "local")?; + for (server_name, config) in local_conf.mcp_servers { + if global_conf.mcp_servers.insert(server_name.clone(), config).is_some() { + queue!( + output, + style::SetForegroundColor(style::Color::Yellow), + style::Print("WARNING: "), + style::ResetColor, + style::Print("MCP config conflict for "), + style::SetForegroundColor(style::Color::Green), + style::Print(server_name), + style::ResetColor, + style::Print(". Using workspace version.\n") + )?; + } + } + global_conf + }, + (None, Some(local_buf)) => Self::from_slice(&local_buf, output, "local")?, + (Some(global_buf), None) => Self::from_slice(&global_buf, output, "global")?, + _ => Default::default(), + }; + output.flush()?; + Ok(conf) + } + + fn from_slice(slice: &[u8], output: &mut impl Write, location: &str) -> eyre::Result { + match serde_json::from_slice::(slice) { + Ok(config) => Ok(config), + Err(e) => { + queue!( + output, + style::SetForegroundColor(style::Color::Yellow), + style::Print("WARNING: "), + style::ResetColor, + style::Print(format!("Error reading {location} mcp config: {e}\n")), + style::Print("Please check to make sure config is correct. Discarding.\n"), + )?; + Ok(McpServerConfig::default()) + }, + } + } +} + +#[derive(Default)] +pub struct ToolManagerBuilder { + mcp_server_config: Option, + prompt_list_sender: Option>>, + prompt_list_receiver: Option>>, + conversation_id: Option, +} + +impl ToolManagerBuilder { + pub fn mcp_server_config(mut self, config: McpServerConfig) -> Self { + self.mcp_server_config.replace(config); + self + } + + pub fn prompt_list_sender(mut self, sender: std::sync::mpsc::Sender>) -> Self { + self.prompt_list_sender.replace(sender); + self + } + + pub fn prompt_list_receiver(mut self, receiver: std::sync::mpsc::Receiver>) -> Self { + self.prompt_list_receiver.replace(receiver); + self + } + + pub fn conversation_id(mut self, conversation_id: &str) -> Self { + self.conversation_id.replace(conversation_id.to_string()); + self + } + + pub fn build(mut self) -> eyre::Result { + let McpServerConfig { mcp_servers } = self.mcp_server_config.ok_or(eyre::eyre!("Missing mcp server config"))?; + debug_assert!(self.conversation_id.is_some()); + let conversation_id = self.conversation_id.ok_or(eyre::eyre!("Missing conversation id"))?; + let regex = regex::Regex::new(VALID_TOOL_NAME)?; + let mut hasher = DefaultHasher::new(); + let pre_initialized = mcp_servers + .into_iter() + .map(|(server_name, server_config)| { + let snaked_cased_name = server_name.to_case(convert_case::Case::Snake); + let sanitized_server_name = sanitize_name(snaked_cased_name, ®ex, &mut hasher); + let custom_tool_client = CustomToolClient::from_config(sanitized_server_name.clone(), server_config); + (sanitized_server_name, custom_tool_client) + }) + .collect::>(); + + // Send up task to update user on server loading status + let (tx, rx) = std::sync::mpsc::channel::(); + // Using a hand rolled thread because it's just easier to do this than do deal with the Send + // requirements that comes with holding onto the stdout lock. + let loading_display_task = std::thread::spawn(move || { + let stdout = std::io::stdout(); + let mut stdout_lock = stdout.lock(); + let mut loading_servers = HashMap::::new(); + let mut spinner_logo_idx: usize = 0; + let mut complete: usize = 0; + let mut failed: usize = 0; + loop { + match rx.recv_timeout(std::time::Duration::from_millis(50)) { + Ok(recv_result) => match recv_result { + LoadingMsg::Add(name) => { + let init_time = std::time::Instant::now(); + let status_line = StatusLine { init_time }; + execute!(stdout_lock, cursor::MoveToColumn(0))?; + if !loading_servers.is_empty() { + // TODO: account for terminal width + execute!(stdout_lock, cursor::MoveUp(1))?; + } + loading_servers.insert(name.clone(), status_line); + let total = loading_servers.len(); + execute!(stdout_lock, terminal::Clear(terminal::ClearType::CurrentLine))?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; + stdout_lock.flush()?; + }, + LoadingMsg::Done(name) => { + if let Some(status_line) = loading_servers.get(&name) { + complete += 1; + let time_taken = + (std::time::Instant::now() - status_line.init_time).as_secs_f64().abs(); + let time_taken = format!("{:.2}", time_taken); + execute!( + stdout_lock, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + queue_success_message(&name, &time_taken, &mut stdout_lock)?; + let total = loading_servers.len(); + queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; + stdout_lock.flush()?; + } + }, + LoadingMsg::Error { name, msg } => { + failed += 1; + execute!( + stdout_lock, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + queue_failure_message(&name, &msg, &mut stdout_lock)?; + let total = loading_servers.len(); + queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; + }, + LoadingMsg::Warn { name, msg } => { + complete += 1; + execute!( + stdout_lock, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + let msg = eyre::eyre!(msg.to_string()); + queue_warn_message(&name, &msg, &mut stdout_lock)?; + let total = loading_servers.len(); + queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; + stdout_lock.flush()?; + }, + }, + Err(RecvTimeoutError::Timeout) => { + spinner_logo_idx = (spinner_logo_idx + 1) % SPINNER_CHARS.len(); + execute!( + stdout_lock, + cursor::SavePosition, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + style::Print(SPINNER_CHARS[spinner_logo_idx]), + cursor::RestorePosition + )?; + }, + _ => break, + } + } + Ok::<_, eyre::Report>(()) + }); + let mut clients = HashMap::>::new(); + for (mut name, init_res) in pre_initialized { + let _ = tx.send(LoadingMsg::Add(name.clone())); + match init_res { + Ok(client) => { + let mut client = Arc::new(client); + while let Some(collided_client) = clients.insert(name.clone(), client) { + // to avoid server name collision we are going to circumvent this by + // appending the name with 1 + name.push('1'); + client = collided_client; + } + }, + Err(e) => { + error!("Error initializing mcp client for server {}: {:?}", name, &e); + let event = crate::fig_telemetry::EventType::McpServerInit { + conversation_id: conversation_id.clone(), + init_failure_reason: Some(e.to_string()), + number_of_tools: 0, + }; + tokio::spawn(async move { + let app_event = crate::fig_telemetry::AppTelemetryEvent::new(event).await; + crate::fig_telemetry::send_event(app_event).await; + }); + let _ = tx.send(LoadingMsg::Error { + name: name.clone(), + msg: e, + }); + }, + } + } + let loading_display_task = Some(loading_display_task); + let loading_status_sender = Some(tx); + + // Set up task to handle prompt requests + let sender = self.prompt_list_sender.take(); + let receiver = self.prompt_list_receiver.take(); + let prompts = Arc::new(SyncRwLock::new(HashMap::default())); + // TODO: accommodate hot reload of mcp servers + if let (Some(sender), Some(receiver)) = (sender, receiver) { + let clients = clients.iter().fold(HashMap::new(), |mut acc, (n, c)| { + acc.insert(n.to_string(), Arc::downgrade(c)); + acc + }); + let prompts_clone = prompts.clone(); + tokio::task::spawn_blocking(move || { + let receiver = Arc::new(std::sync::Mutex::new(receiver)); + loop { + let search_word = receiver.lock().map_err(|e| eyre::eyre!("{:?}", e))?.recv()?; + if clients + .values() + .any(|client| client.upgrade().is_some_and(|c| c.is_prompts_out_of_date())) + { + let mut prompts_wl = prompts_clone.write().map_err(|e| { + eyre::eyre!( + "Error retrieving write lock on prompts for tab complete {}", + e.to_string() + ) + })?; + *prompts_wl = clients.iter().fold( + HashMap::>::new(), + |mut acc, (server_name, client)| { + let Some(client) = client.upgrade() else { + return acc; + }; + let prompt_gets = client.list_prompt_gets(); + let Ok(prompt_gets) = prompt_gets.read() else { + tracing::error!("Error retrieving read lock for prompt gets for tab complete"); + return acc; + }; + for (prompt_name, prompt_get) in prompt_gets.iter() { + acc.entry(prompt_name.to_string()) + .and_modify(|bundles| { + bundles.push(PromptBundle { + server_name: server_name.to_owned(), + prompt_get: prompt_get.clone(), + }); + }) + .or_insert(vec![PromptBundle { + server_name: server_name.to_owned(), + prompt_get: prompt_get.clone(), + }]); + } + client.prompts_updated(); + acc + }, + ); + } + let prompts_rl = prompts_clone.read().map_err(|e| { + eyre::eyre!( + "Error retrieving read lock on prompts for tab complete {}", + e.to_string() + ) + })?; + let filtered_prompts = prompts_rl + .iter() + .flat_map(|(prompt_name, bundles)| { + if bundles.len() > 1 { + bundles + .iter() + .map(|b| format!("{}/{}", b.server_name, prompt_name)) + .collect() + } else { + vec![prompt_name.to_owned()] + } + }) + .filter(|n| { + if let Some(p) = &search_word { + n.contains(p) + } else { + true + } + }) + .collect::>(); + if let Err(e) = sender.send(filtered_prompts) { + error!("Error sending prompts to chat helper: {:?}", e); + } + } + #[allow(unreachable_code)] + Ok::<(), eyre::Report>(()) + }); + } + + Ok(ToolManager { + conversation_id, + clients, + prompts, + loading_display_task, + loading_status_sender, + ..Default::default() + }) + } +} + +#[derive(Clone, Debug)] +/// A collection of information that is used for the following purposes: +/// - Checking if prompt info cached is out of date +/// - Retrieve new prompt info +pub struct PromptBundle { + /// The server name from which the prompt is offered / exposed + pub server_name: String, + /// The prompt get (info with which a prompt is retrieved) cached + pub prompt_get: PromptGet, +} + +/// Categorizes different types of tool name validation failures: +/// - `TooLong`: The tool name exceeds the maximum allowed length +/// - `IllegalChar`: The tool name contains characters that are not allowed +/// - `EmptyDescription`: The tool description is empty or missing +#[allow(dead_code)] +enum OutOfSpecName { + TooLong(String), + IllegalChar(String), + EmptyDescription(String), +} + +/// Manages the lifecycle and interactions with tools from various sources, including MCP servers. +/// This struct is responsible for initializing tools, handling tool requests, and maintaining +/// a cache of available prompts from connected servers. +#[derive(Default)] +pub struct ToolManager { + /// Unique identifier for the current conversation. + /// This ID is used to track and associate tools with a specific chat session. + pub conversation_id: String, + + /// Map of server names to their corresponding client instances. + /// These clients are used to communicate with MCP servers. + pub clients: HashMap>, + + /// Cache for prompts collected from different servers. + /// Key: prompt name + /// Value: a list of PromptBundle that has a prompt of this name. + /// This cache helps resolve prompt requests efficiently and handles + /// cases where multiple servers offer prompts with the same name. + pub prompts: Arc>>>, + + /// Handle to the thread that displays loading status for tool initialization. + /// This thread provides visual feedback to users during the tool loading process. + loading_display_task: Option>>, + + /// Channel sender for communicating with the loading display thread. + /// Used to send status updates about tool initialization progress. + loading_status_sender: Option>, + + /// Mapping from sanitized tool names to original tool names. + /// This is used to handle tool name transformations that may occur during initialization + /// to ensure tool names comply with naming requirements. + pub tn_map: HashMap, + + /// A cache of tool's input schema for all of the available tools. + /// This is mainly used to show the user what the tools look like from the perspective of the + /// model. + pub schema: HashMap, +} + +impl ToolManager { + pub async fn load_tools(&mut self) -> eyre::Result> { + let tx = self.loading_status_sender.take(); + let display_task = self.loading_display_task.take(); + let tool_specs = { + let tool_specs = serde_json::from_str::>(include_str!("tools/tool_index.json"))?; + Arc::new(Mutex::new(tool_specs)) + }; + let conversation_id = self.conversation_id.clone(); + let regex = Arc::new(regex::Regex::new(VALID_TOOL_NAME)?); + let load_tool = self + .clients + .iter() + .map(|(server_name, client)| { + let client_clone = client.clone(); + let server_name_clone = server_name.clone(); + let tx_clone = tx.clone(); + let regex_clone = regex.clone(); + let tool_specs_clone = tool_specs.clone(); + let conversation_id = conversation_id.clone(); + async move { + let tool_spec = client_clone.init().await; + let mut sanitized_mapping = HashMap::::new(); + match tool_spec { + Ok((server_name, specs)) => { + // Each mcp server might have multiple tools. + // To avoid naming conflicts we are going to namespace it. + // This would also help us locate which mcp server to call the tool from. + let mut out_of_spec_tool_names = Vec::::new(); + let mut hasher = DefaultHasher::new(); + let number_of_tools = specs.len(); + // Sanitize tool names to ensure they comply with the naming requirements: + // 1. If the name already matches the regex pattern and doesn't contain the namespace delimiter, use it as is + // 2. Otherwise, remove invalid characters and handle special cases: + // - Remove namespace delimiters + // - Ensure the name starts with an alphabetic character + // - Generate a hash-based name if the sanitized result is empty + // This ensures all tool names are valid identifiers that can be safely used in the system + // If after all of the aforementioned modification the combined tool + // name we have exceeds a length of 64, we surface it as an error + for mut spec in specs { + let sn = if !regex_clone.is_match(&spec.name) { + let mut sn = sanitize_name(spec.name.clone(), ®ex_clone, &mut hasher); + while sanitized_mapping.contains_key(&sn) { + sn.push('1'); + } + sn + } else { + spec.name.clone() + }; + let full_name = format!("{}{}{}", server_name, NAMESPACE_DELIMITER, sn); + if full_name.len() > 64 { + out_of_spec_tool_names.push(OutOfSpecName::TooLong(spec.name)); + continue; + } else if spec.description.is_empty() { + out_of_spec_tool_names.push(OutOfSpecName::EmptyDescription(spec.name)); + continue; + } + if sn != spec.name { + sanitized_mapping.insert(full_name.clone(), format!("{}{}{}", server_name, NAMESPACE_DELIMITER, spec.name)); + } + spec.name = full_name; + spec.tool_origin = ToolOrigin::McpServer(server_name.clone()); + tool_specs_clone.lock().await.insert(spec.name.clone(), spec); + } + // Send server load success metric datum + tokio::spawn(async move { + let event = crate::fig_telemetry::EventType::McpServerInit { conversation_id, init_failure_reason: None, number_of_tools }; + let app_event = crate::fig_telemetry::AppTelemetryEvent::new(event).await; + crate::fig_telemetry::send_event(app_event).await; + }); + // Tool name translation. This is beyond of the scope of what is + // considered a "server load". Reasoning being: + // - Failures here are not related to server load + // - There is not a whole lot we can do with this data + if let Some(tx_clone) = &tx_clone { + let send_result = if !out_of_spec_tool_names.is_empty() { + let msg = out_of_spec_tool_names.iter().fold( + String::from("The following tools are out of spec. They will be excluded from the list of available tools:\n"), + |mut acc, name| { + let (tool_name, msg) = match name { + OutOfSpecName::TooLong(tool_name) => (tool_name.as_str(), "tool name exceeds max length of 64 when combined with server name"), + OutOfSpecName::IllegalChar(tool_name) => (tool_name.as_str(), "tool name must be compliant with ^[a-zA-Z][a-zA-Z0-9_]*$"), + OutOfSpecName::EmptyDescription(tool_name) => (tool_name.as_str(), "tool schema contains empty description"), + }; + acc.push_str(format!(" - {} ({})\n", tool_name, msg).as_str()); + acc + } + ); + tx_clone.send(LoadingMsg::Error { + name: server_name.clone(), + msg: eyre::eyre!(msg), + }) + // TODO: if no tools are valid, we need to offload the server + // from the fleet (i.e. kill the server) + } else if !sanitized_mapping.is_empty() { + let warn = sanitized_mapping.iter().fold(String::from("The following tool names are changed:\n"), |mut acc, (k, v)| { + acc.push_str(format!(" - {} -> {}\n", v, k).as_str()); + acc + }); + tx_clone.send(LoadingMsg::Warn { + name: server_name.clone(), + msg: eyre::eyre!(warn), + }) + } else { + tx_clone.send(LoadingMsg::Done(server_name.clone())) + }; + if let Err(e) = send_result { + error!("Error while sending status update to display task: {:?}", e); + } + } + }, + Err(e) => { + error!("Error obtaining tool spec for {}: {:?}", server_name_clone, e); + let init_failure_reason = Some(e.to_string()); + tokio::spawn(async move { + let event = crate::fig_telemetry::EventType::McpServerInit { conversation_id, init_failure_reason, number_of_tools: 0 }; + let app_event = crate::fig_telemetry::AppTelemetryEvent::new(event).await; + crate::fig_telemetry::send_event(app_event).await; + }); + if let Some(tx_clone) = &tx_clone { + if let Err(e) = tx_clone.send(LoadingMsg::Error { + name: server_name_clone, + msg: e, + }) { + error!("Error while sending status update to display task: {:?}", e); + } + } + }, + } + Ok::<_, eyre::Report>(Some(sanitized_mapping)) + } + }) + .collect::>(); + // TODO: do we want to introduce a timeout here? + self.tn_map = stream::iter(load_tool) + .map(|async_closure| tokio::task::spawn(async_closure)) + .buffer_unordered(20) + .collect::>() + .await + .into_iter() + .filter_map(|r| r.ok()) + .filter_map(|r| r.ok()) + .flatten() + .flatten() + .collect::>(); + drop(tx); + if let Some(display_task) = display_task { + if let Err(e) = display_task.join() { + error!("Error while joining status display task: {:?}", e); + } + } + let tool_specs = { + let mutex = + Arc::try_unwrap(tool_specs).map_err(|e| eyre::eyre!("Error unwrapping arc for tool specs {:?}", e))?; + mutex.into_inner() + }; + // caching the tool names for skim operations + for tool_name in tool_specs.keys() { + if !self.tn_map.contains_key(tool_name) { + self.tn_map.insert(tool_name.clone(), tool_name.clone()); + } + } + self.schema = tool_specs.clone(); + Ok(tool_specs) + } + + pub fn get_tool_from_tool_use(&self, value: AssistantToolUse) -> Result { + let map_err = |parse_error| ToolResult { + tool_use_id: value.id.clone(), + content: vec![ToolResultContentBlock::Text(format!( + "Failed to validate tool parameters: {parse_error}. The model has either suggested tool parameters which are incompatible with the existing tools, or has suggested one or more tool that does not exist in the list of known tools." + ))], + status: ToolResultStatus::Error, + }; + + Ok(match value.name.as_str() { + "fs_read" => Tool::FsRead(serde_json::from_value::(value.args).map_err(map_err)?), + "fs_write" => Tool::FsWrite(serde_json::from_value::(value.args).map_err(map_err)?), + "execute_bash" => Tool::ExecuteBash(serde_json::from_value::(value.args).map_err(map_err)?), + "use_aws" => Tool::UseAws(serde_json::from_value::(value.args).map_err(map_err)?), + "report_issue" => Tool::GhIssue(serde_json::from_value::(value.args).map_err(map_err)?), + // Note that this name is namespaced with server_name{DELIMITER}tool_name + name => { + let name = self.tn_map.get(name).map_or(name, String::as_str); + let (server_name, tool_name) = name.split_once(NAMESPACE_DELIMITER).ok_or(ToolResult { + tool_use_id: value.id.clone(), + content: vec![ToolResultContentBlock::Text(format!( + "The tool, \"{name}\" is supplied with incorrect name" + ))], + status: ToolResultStatus::Error, + })?; + let Some(client) = self.clients.get(server_name) else { + return Err(ToolResult { + tool_use_id: value.id, + content: vec![ToolResultContentBlock::Text(format!( + "The tool, \"{server_name}\" is not supported by the client" + ))], + status: ToolResultStatus::Error, + }); + }; + // The tool input schema has the shape of { type, properties }. + // The field "params" expected by MCP is { name, arguments }, where name is the + // name of the tool being invoked, + // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#calling-tools. + // The field "arguments" is where ToolUse::args belong. + let mut params = serde_json::Map::::new(); + params.insert("name".to_owned(), serde_json::Value::String(tool_name.to_owned())); + params.insert("arguments".to_owned(), value.args); + let params = serde_json::Value::Object(params); + let custom_tool = CustomTool { + name: tool_name.to_owned(), + client: client.clone(), + method: "tools/call".to_owned(), + params: Some(params), + }; + Tool::Custom(custom_tool) + }, + }) + } + + #[allow(clippy::await_holding_lock)] + pub async fn get_prompt(&self, get_command: PromptsGetCommand) -> Result { + let (server_name, prompt_name) = match get_command.params.name.split_once('/') { + None => (None::, Some(get_command.params.name.clone())), + Some((server_name, prompt_name)) => (Some(server_name.to_string()), Some(prompt_name.to_string())), + }; + let prompt_name = prompt_name.ok_or(GetPromptError::MissingPromptName)?; + // We need to use a sync lock here because this lock is also used in a blocking thread, + // necessitated by the fact that said thread is also responsible for using a sync channel, + // which is itself necessitated by the fact that consumer of said channel is calling from a + // sync function + let mut prompts_wl = self + .prompts + .write() + .map_err(|e| GetPromptError::Synchronization(e.to_string()))?; + let mut maybe_bundles = prompts_wl.get(&prompt_name); + let mut has_retried = false; + 'blk: loop { + match (maybe_bundles, server_name.as_ref(), has_retried) { + // If we have more than one eligible clients but no server name specified + (Some(bundles), None, _) if bundles.len() > 1 => { + break 'blk Err(GetPromptError::AmbiguousPrompt(prompt_name.clone(), { + bundles.iter().fold("\n".to_string(), |mut acc, b| { + acc.push_str(&format!("- @{}/{}\n", b.server_name, prompt_name)); + acc + }) + })); + }, + // Normal case where we have enough info to proceed + // Note that if bundle exists, it should never be empty + (Some(bundles), sn, _) => { + let bundle = if bundles.len() > 1 { + let Some(server_name) = sn else { + maybe_bundles = None; + continue 'blk; + }; + let bundle = bundles.iter().find(|b| b.server_name == *server_name); + match bundle { + Some(bundle) => bundle, + None => { + maybe_bundles = None; + continue 'blk; + }, + } + } else { + bundles.first().ok_or(GetPromptError::MissingPromptInfo)? + }; + let server_name = bundle.server_name.clone(); + let client = self.clients.get(&server_name).ok_or(GetPromptError::MissingClient)?; + // Here we lazily update the out of date cache + if client.is_prompts_out_of_date() { + let prompt_gets = client.list_prompt_gets(); + let prompt_gets = prompt_gets + .read() + .map_err(|e| GetPromptError::Synchronization(e.to_string()))?; + for (prompt_name, prompt_get) in prompt_gets.iter() { + prompts_wl + .entry(prompt_name.to_string()) + .and_modify(|bundles| { + let mut is_modified = false; + for bundle in &mut *bundles { + let mut updated_bundle = PromptBundle { + server_name: server_name.clone(), + prompt_get: prompt_get.clone(), + }; + if bundle.server_name == *server_name { + std::mem::swap(bundle, &mut updated_bundle); + is_modified = true; + break; + } + } + if !is_modified { + bundles.push(PromptBundle { + server_name: server_name.clone(), + prompt_get: prompt_get.clone(), + }); + } + }) + .or_insert(vec![PromptBundle { + server_name: server_name.clone(), + prompt_get: prompt_get.clone(), + }]); + } + client.prompts_updated(); + } + let PromptsGetCommand { params, .. } = get_command; + let PromptBundle { prompt_get, .. } = prompts_wl + .get(&prompt_name) + .and_then(|bundles| bundles.iter().find(|b| b.server_name == server_name)) + .ok_or(GetPromptError::MissingPromptInfo)?; + // Here we need to convert the positional arguments into key value pair + // The assignment order is assumed to be the order of args as they are + // presented in PromptGet::arguments + let args = if let (Some(schema), Some(value)) = (&prompt_get.arguments, ¶ms.arguments) { + let params = schema.iter().zip(value.iter()).fold( + HashMap::::new(), + |mut acc, (prompt_get_arg, value)| { + acc.insert(prompt_get_arg.name.clone(), value.clone()); + acc + }, + ); + Some(serde_json::json!(params)) + } else { + None + }; + let params = { + let mut params = serde_json::Map::new(); + params.insert("name".to_string(), serde_json::Value::String(prompt_name)); + if let Some(args) = args { + params.insert("arguments".to_string(), args); + } + Some(serde_json::Value::Object(params)) + }; + let resp = client.request("prompts/get", params).await?; + break 'blk Ok(resp); + }, + // If we have no eligible clients this would mean one of the following: + // - The prompt does not exist, OR + // - This is the first time we have a query / our cache is out of date + // Both of which means we would have to requery + (None, _, false) => { + has_retried = true; + self.refresh_prompts(&mut prompts_wl)?; + maybe_bundles = prompts_wl.get(&prompt_name); + continue 'blk; + }, + (_, _, true) => { + break 'blk Err(GetPromptError::PromptNotFound(prompt_name)); + }, + } + } + } + + pub fn refresh_prompts(&self, prompts_wl: &mut HashMap>) -> Result<(), GetPromptError> { + *prompts_wl = self.clients.iter().fold( + HashMap::>::new(), + |mut acc, (server_name, client)| { + let prompt_gets = client.list_prompt_gets(); + let Ok(prompt_gets) = prompt_gets.read() else { + tracing::error!("Error encountered while retrieving read lock"); + return acc; + }; + for (prompt_name, prompt_get) in prompt_gets.iter() { + acc.entry(prompt_name.to_string()) + .and_modify(|bundles| { + bundles.push(PromptBundle { + server_name: server_name.to_owned(), + prompt_get: prompt_get.clone(), + }); + }) + .or_insert(vec![PromptBundle { + server_name: server_name.to_owned(), + prompt_get: prompt_get.clone(), + }]); + } + acc + }, + ); + Ok(()) + } +} + +fn sanitize_name(orig: String, regex: ®ex::Regex, hasher: &mut impl Hasher) -> String { + if regex.is_match(&orig) && !orig.contains(NAMESPACE_DELIMITER) { + return orig; + } + let sanitized: String = orig + .chars() + .filter(|c| c.is_ascii_alphabetic() || c.is_ascii_digit() || *c == '_') + .collect::() + .replace(NAMESPACE_DELIMITER, ""); + if sanitized.is_empty() { + hasher.write(orig.as_bytes()); + let hash = format!("{:03}", hasher.finish() % 1000); + return format!("a{}", hash); + } + match sanitized.chars().next() { + Some(c) if c.is_ascii_alphabetic() => sanitized, + Some(_) => { + format!("a{}", sanitized) + }, + None => { + hasher.write(orig.as_bytes()); + format!("a{}", hasher.finish()) + }, + } +} + +fn queue_success_message(name: &str, time_taken: &str, output: &mut impl Write) -> eyre::Result<()> { + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Green), + style::Print("✓ "), + style::SetForegroundColor(style::Color::Blue), + style::Print(name), + style::ResetColor, + style::Print(" loaded in "), + style::SetForegroundColor(style::Color::Yellow), + style::Print(format!("{time_taken} s\n")), + )?) +} + +fn queue_init_message( + spinner_logo_idx: usize, + complete: usize, + failed: usize, + total: usize, + output: &mut impl Write, +) -> eyre::Result<()> { + if total == complete { + queue!( + output, + style::SetForegroundColor(style::Color::Green), + style::Print("✓"), + style::ResetColor, + )?; + } else if total == complete + failed { + queue!( + output, + style::SetForegroundColor(style::Color::Red), + style::Print("✗"), + style::ResetColor, + )?; + } else { + queue!(output, style::Print(SPINNER_CHARS[spinner_logo_idx]))?; + } + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Blue), + style::Print(format!(" {}", complete)), + style::ResetColor, + style::Print(" of "), + style::SetForegroundColor(style::Color::Blue), + style::Print(format!("{} ", total)), + style::ResetColor, + style::Print("mcp servers initialized\n"), + )?) +} + +fn queue_failure_message(name: &str, fail_load_msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Red), + style::Print("✗ "), + style::SetForegroundColor(style::Color::Blue), + style::Print(name), + style::ResetColor, + style::Print(" has failed to load:\n- "), + style::Print(fail_load_msg), + style::Print("\n"), + style::Print("- run with Q_LOG_LEVEL=trace and see $TMPDIR/qlog for detail\n"), + style::ResetColor, + )?) +} + +fn queue_warn_message(name: &str, msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { + Ok(queue!( + output, + style::SetForegroundColor(style::Color::Yellow), + style::Print("⚠ "), + style::SetForegroundColor(style::Color::Blue), + style::Print(name), + style::ResetColor, + style::Print(" has the following warning:\n"), + style::Print(msg), + style::ResetColor, + )?) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sanitize_server_name() { + let regex = regex::Regex::new(VALID_TOOL_NAME).unwrap(); + let mut hasher = DefaultHasher::new(); + let orig_name = "@awslabs.cdk-mcp-server"; + let sanitized_server_name = sanitize_name(orig_name.to_string(), ®ex, &mut hasher); + assert_eq!(sanitized_server_name, "awslabscdkmcpserver"); + + let orig_name = "good_name"; + let sanitized_good_name = sanitize_name(orig_name.to_string(), ®ex, &mut hasher); + assert_eq!(sanitized_good_name, orig_name); + + let all_bad_name = "@@@@@"; + let sanitized_all_bad_name = sanitize_name(all_bad_name.to_string(), ®ex, &mut hasher); + assert!(regex.is_match(&sanitized_all_bad_name)); + + let with_delim = format!("a{}b{}c", NAMESPACE_DELIMITER, NAMESPACE_DELIMITER); + let sanitized = sanitize_name(with_delim, ®ex, &mut hasher); + assert_eq!(sanitized, "abc"); + } +} diff --git a/crates/kiro-cli/src/cli/chat/tools/custom_tool.rs b/crates/kiro-cli/src/cli/chat/tools/custom_tool.rs new file mode 100644 index 0000000000..ee0e5c5875 --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/tools/custom_tool.rs @@ -0,0 +1,241 @@ +use std::collections::HashMap; +use std::io::Write; +use std::sync::Arc; +use std::sync::atomic::Ordering; + +use crossterm::{ + queue, + style, +}; +use eyre::Result; +use serde::{ + Deserialize, + Serialize, +}; +use tokio::sync::RwLock; +use tracing::warn; + +use super::{ + InvokeOutput, + ToolSpec, +}; +use crate::cli::chat::CONTINUATION_LINE; +use crate::cli::chat::token_counter::TokenCounter; +use crate::fig_os_shim::Context; +use crate::mcp_client::{ + Client as McpClient, + ClientConfig as McpClientConfig, + JsonRpcResponse, + JsonRpcStdioTransport, + MessageContent, + PromptGet, + ServerCapabilities, + StdioTransport, + ToolCallResult, +}; + +// TODO: support http transport type +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct CustomToolConfig { + pub command: String, + #[serde(default)] + pub args: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub env: Option>, + #[serde(default = "default_timeout")] + pub timeout: u64, +} + +fn default_timeout() -> u64 { + 120 * 1000 +} + +#[derive(Debug)] +pub enum CustomToolClient { + Stdio { + server_name: String, + client: McpClient, + server_capabilities: RwLock>, + }, +} + +impl CustomToolClient { + // TODO: add support for http transport + pub fn from_config(server_name: String, config: CustomToolConfig) -> Result { + let CustomToolConfig { + command, + args, + env, + timeout, + } = config; + let mcp_client_config = McpClientConfig { + server_name: server_name.clone(), + bin_path: command.clone(), + args, + timeout, + client_info: serde_json::json!({ + "name": "Q CLI Chat", + "version": "1.0.0" + }), + env, + }; + let client = McpClient::::from_config(mcp_client_config)?; + Ok(CustomToolClient::Stdio { + server_name, + client, + server_capabilities: RwLock::new(None), + }) + } + + pub async fn init(&self) -> Result<(String, Vec)> { + match self { + CustomToolClient::Stdio { + client, + server_name, + server_capabilities, + } => { + // We'll need to first initialize. This is the handshake every client and server + // needs to do before proceeding to anything else + let init_resp = client.init().await?; + // We'll be scrapping this for background server load: https://github.com/aws/amazon-q-developer-cli/issues/1466 + // So don't worry about the tidiness for now + let is_tool_supported = init_resp + .get("result") + .is_some_and(|r| r.get("capabilities").is_some_and(|cap| cap.get("tools").is_some())); + server_capabilities.write().await.replace(init_resp); + // Assuming a shape of return as per https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#listing-tools + let tools = if is_tool_supported { + // And now we make the server tell us what tools they have + let resp = client.request("tools/list", None).await?; + match resp.result.and_then(|r| r.get("tools").cloned()) { + Some(value) => serde_json::from_value::>(value)?, + None => Default::default(), + } + } else { + Default::default() + }; + Ok((server_name.clone(), tools)) + }, + } + } + + pub fn get_server_name(&self) -> &str { + match self { + CustomToolClient::Stdio { server_name, .. } => server_name.as_str(), + } + } + + pub async fn request(&self, method: &str, params: Option) -> Result { + match self { + CustomToolClient::Stdio { client, .. } => Ok(client.request(method, params).await?), + } + } + + pub fn list_prompt_gets(&self) -> Arc>> { + match self { + CustomToolClient::Stdio { client, .. } => client.prompt_gets.clone(), + } + } + + #[allow(dead_code)] + pub async fn notify(&self, method: &str, params: Option) -> Result<()> { + match self { + CustomToolClient::Stdio { client, .. } => Ok(client.notify(method, params).await?), + } + } + + pub fn is_prompts_out_of_date(&self) -> bool { + match self { + CustomToolClient::Stdio { client, .. } => client.is_prompts_out_of_date.load(Ordering::Relaxed), + } + } + + pub fn prompts_updated(&self) { + match self { + CustomToolClient::Stdio { client, .. } => client.is_prompts_out_of_date.store(false, Ordering::Relaxed), + } + } +} + +/// Represents a custom tool that can be invoked through the Model Context Protocol (MCP). +#[derive(Clone, Debug)] +pub struct CustomTool { + /// Actual tool name as recognized by its MCP server. This differs from the tool names as they + /// are seen by the model since they are not prefixed by its MCP server name. + pub name: String, + /// Reference to the client that manages communication with the tool's server process. + pub client: Arc, + /// The method name to call on the tool's server, following the JSON-RPC convention. + /// This corresponds to a specific functionality provided by the tool. + pub method: String, + /// Optional parameters to pass to the tool when invoking the method. + /// Structured as a JSON value to accommodate various parameter types and structures. + pub params: Option, +} + +impl CustomTool { + pub async fn invoke(&self, _ctx: &Context, _updates: &mut impl Write) -> Result { + // Assuming a response shape as per https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#calling-tools + let resp = self.client.request(self.method.as_str(), self.params.clone()).await?; + let result = resp + .result + .ok_or(eyre::eyre!("{} invocation failed to produce a result", self.name))?; + + match serde_json::from_value::(result.clone()) { + Ok(mut de_result) => { + for content in &mut de_result.content { + if let MessageContent::Image { data, .. } = content { + *data = format!("Redacted base64 encoded string of an image of size {}", data.len()); + } + } + Ok(InvokeOutput { + output: super::OutputKind::Json(serde_json::json!(de_result)), + }) + }, + Err(e) => { + warn!("Tool call result deserialization failed: {:?}", e); + Ok(InvokeOutput { + output: super::OutputKind::Json(result.clone()), + }) + }, + } + } + + pub fn queue_description(&self, updates: &mut impl Write) -> Result<()> { + queue!( + updates, + style::Print("Running "), + style::SetForegroundColor(style::Color::Green), + style::Print(&self.name), + style::ResetColor, + )?; + if let Some(params) = &self.params { + let params = match serde_json::to_string_pretty(params) { + Ok(params) => params + .split("\n") + .map(|p| format!("{CONTINUATION_LINE} {p}")) + .collect::>() + .join("\n"), + _ => format!("{:?}", params), + }; + queue!( + updates, + style::Print(" with the param:\n"), + style::Print(params), + style::ResetColor, + )?; + } else { + queue!(updates, style::Print("\n"))?; + } + Ok(()) + } + + pub async fn validate(&mut self, _ctx: &Context) -> Result<()> { + Ok(()) + } + + pub fn get_input_token_size(&self) -> usize { + TokenCounter::count_tokens(self.method.as_str()) + + TokenCounter::count_tokens(self.params.as_ref().map_or("", |p| p.as_str().unwrap_or_default())) + } +} diff --git a/crates/kiro-cli/src/cli/chat/tools/execute_bash.rs b/crates/kiro-cli/src/cli/chat/tools/execute_bash.rs new file mode 100644 index 0000000000..427435bb37 --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/tools/execute_bash.rs @@ -0,0 +1,373 @@ +use std::collections::VecDeque; +use std::io::Write; +use std::process::{ + ExitStatus, + Stdio, +}; +use std::str::from_utf8; + +use crossterm::queue; +use crossterm::style::{ + self, + Color, +}; +use eyre::{ + Context as EyreContext, + Result, +}; +use serde::Deserialize; +use tokio::io::AsyncBufReadExt; +use tokio::select; +use tracing::error; + +use super::super::util::truncate_safe; +use super::{ + InvokeOutput, + MAX_TOOL_RESPONSE_SIZE, + OutputKind, +}; +use crate::fig_os_shim::Context; + +const READONLY_COMMANDS: &[&str] = &["ls", "cat", "echo", "pwd", "which", "head", "tail", "find", "grep"]; + +#[derive(Debug, Clone, Deserialize)] +pub struct ExecuteBash { + pub command: String, +} + +impl ExecuteBash { + pub fn requires_acceptance(&self) -> bool { + let Some(args) = shlex::split(&self.command) else { + return true; + }; + + const DANGEROUS_PATTERNS: &[&str] = &["<(", "$(", "`", ">", "&&", "||", "&", ";"]; + if args + .iter() + .any(|arg| DANGEROUS_PATTERNS.iter().any(|p| arg.contains(p))) + { + return true; + } + + // Split commands by pipe and check each one + let mut current_cmd = Vec::new(); + let mut all_commands = Vec::new(); + + for arg in args { + if arg == "|" { + if !current_cmd.is_empty() { + all_commands.push(current_cmd); + } + current_cmd = Vec::new(); + } else if arg.contains("|") { + // if pipe appears without spacing e.g. `echo myimportantfile|args rm` it won't get + // parsed out, in this case - we want to verify before running + return true; + } else { + current_cmd.push(arg); + } + } + if !current_cmd.is_empty() { + all_commands.push(current_cmd); + } + + // Check if each command in the pipe chain starts with a safe command + for cmd_args in all_commands { + match cmd_args.first() { + // Special casing for `find` so that we support most cases while safeguarding + // against unwanted mutations + Some(cmd) + if cmd == "find" + && cmd_args + .iter() + .any(|arg| arg.contains("-exec") || arg.contains("-delete")) => + { + return true; + }, + Some(cmd) if !READONLY_COMMANDS.contains(&cmd.as_str()) => return true, + None => return true, + _ => (), + } + } + + false + } + + pub async fn invoke(&self, updates: impl Write) -> Result { + let output = run_command(&self.command, MAX_TOOL_RESPONSE_SIZE / 3, Some(updates)).await?; + let result = serde_json::json!({ + "exit_status": output.exit_status.unwrap_or(0).to_string(), + "stdout": output.stdout, + "stderr": output.stderr, + }); + + Ok(InvokeOutput { + output: OutputKind::Json(result), + }) + } + + pub fn queue_description(&self, updates: &mut impl Write) -> Result<()> { + queue!(updates, style::Print("I will run the following shell command: "),)?; + + // TODO: Could use graphemes for a better heuristic + if self.command.len() > 20 { + queue!(updates, style::Print("\n"),)?; + } + + Ok(queue!( + updates, + style::SetForegroundColor(Color::Green), + style::Print(&self.command), + style::Print("\n\n"), + style::ResetColor + )?) + } + + pub async fn validate(&mut self, _ctx: &Context) -> Result<()> { + // TODO: probably some small amount of PATH checking + Ok(()) + } +} + +pub struct CommandResult { + pub exit_status: Option, + /// Truncated stdout + pub stdout: String, + /// Truncated stderr + pub stderr: String, +} + +/// Run a bash command. +/// # Arguments +/// * `max_result_size` - max size of output streams, truncating if required +/// * `updates` - output stream to push informational messages about the progress +/// # Returns +/// A [`CommandResult`] +pub async fn run_command( + command: &str, + max_result_size: usize, + mut updates: Option, +) -> Result { + // We need to maintain a handle on stderr and stdout, but pipe it to the terminal as well + let mut child = tokio::process::Command::new("bash") + .arg("-c") + .arg(command) + .stdin(Stdio::inherit()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .wrap_err_with(|| format!("Unable to spawn command '{}'", command))?; + + let stdout_final: String; + let stderr_final: String; + let exit_status: ExitStatus; + + // Buffered output vs all-at-once + if let Some(u) = updates.as_mut() { + let stdout = child.stdout.take().unwrap(); + let stdout = tokio::io::BufReader::new(stdout); + let mut stdout = stdout.lines(); + + let stderr = child.stderr.take().unwrap(); + let stderr = tokio::io::BufReader::new(stderr); + let mut stderr = stderr.lines(); + + const LINE_COUNT: usize = 1024; + let mut stdout_buf = VecDeque::with_capacity(LINE_COUNT); + let mut stderr_buf = VecDeque::with_capacity(LINE_COUNT); + + let mut stdout_done = false; + let mut stderr_done = false; + exit_status = loop { + select! { + biased; + line = stdout.next_line(), if !stdout_done => match line { + Ok(Some(line)) => { + writeln!(u, "{line}")?; + if stdout_buf.len() >= LINE_COUNT { + stdout_buf.pop_front(); + } + stdout_buf.push_back(line); + }, + Ok(None) => stdout_done = true, + Err(err) => error!(%err, "Failed to read stdout of child process"), + }, + line = stderr.next_line(), if !stderr_done => match line { + Ok(Some(line)) => { + writeln!(u, "{line}")?; + if stderr_buf.len() >= LINE_COUNT { + stderr_buf.pop_front(); + } + stderr_buf.push_back(line); + }, + Ok(None) => stderr_done = true, + Err(err) => error!(%err, "Failed to read stderr of child process"), + }, + exit_status = child.wait() => { + break exit_status; + }, + }; + } + .wrap_err_with(|| format!("No exit status for '{}'", command))?; + + u.flush()?; + + stdout_final = stdout_buf.into_iter().collect::>().join("\n"); + stderr_final = stderr_buf.into_iter().collect::>().join("\n"); + } else { + // Take output all at once since we are not reporting anything in real time + // + // NOTE: If we don't split this logic, then any writes to stdout while calling + // this function concurrently may cause the piped child output to be ignored + + let output = child + .wait_with_output() + .await + .wrap_err_with(|| format!("No exit status for '{}'", command))?; + + exit_status = output.status; + stdout_final = from_utf8(&output.stdout).unwrap_or_default().to_string(); + stderr_final = from_utf8(&output.stderr).unwrap_or_default().to_string(); + } + + Ok(CommandResult { + exit_status: exit_status.code(), + stdout: format!( + "{}{}", + truncate_safe(&stdout_final, max_result_size), + if stdout_final.len() > max_result_size { + " ... truncated" + } else { + "" + } + ), + stderr: format!( + "{}{}", + truncate_safe(&stderr_final, max_result_size), + if stderr_final.len() > max_result_size { + " ... truncated" + } else { + "" + } + ), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[ignore = "todo: fix failing on musl for some reason"] + #[tokio::test] + async fn test_execute_bash_tool() { + let mut stdout = std::io::stdout(); + + // Verifying stdout + let v = serde_json::json!({ + "command": "echo Hello, world!", + }); + let out = serde_json::from_value::(v) + .unwrap() + .invoke(&mut stdout) + .await + .unwrap(); + + if let OutputKind::Json(json) = out.output { + assert_eq!(json.get("exit_status").unwrap(), &0.to_string()); + assert_eq!(json.get("stdout").unwrap(), "Hello, world!"); + assert_eq!(json.get("stderr").unwrap(), ""); + } else { + panic!("Expected JSON output"); + } + + // Verifying stderr + let v = serde_json::json!({ + "command": "echo Hello, world! 1>&2", + }); + let out = serde_json::from_value::(v) + .unwrap() + .invoke(&mut stdout) + .await + .unwrap(); + + if let OutputKind::Json(json) = out.output { + assert_eq!(json.get("exit_status").unwrap(), &0.to_string()); + assert_eq!(json.get("stdout").unwrap(), ""); + assert_eq!(json.get("stderr").unwrap(), "Hello, world!"); + } else { + panic!("Expected JSON output"); + } + + // Verifying exit code + let v = serde_json::json!({ + "command": "exit 1", + "interactive": false + }); + let out = serde_json::from_value::(v) + .unwrap() + .invoke(&mut stdout) + .await + .unwrap(); + if let OutputKind::Json(json) = out.output { + assert_eq!(json.get("exit_status").unwrap(), &1.to_string()); + assert_eq!(json.get("stdout").unwrap(), ""); + assert_eq!(json.get("stderr").unwrap(), ""); + } else { + panic!("Expected JSON output"); + } + } + + #[test] + fn test_requires_acceptance_for_readonly_commands() { + let cmds = &[ + // Safe commands + ("ls ~", false), + ("ls -al ~", false), + ("pwd", false), + ("echo 'Hello, world!'", false), + ("which aws", false), + // Potentially dangerous readonly commands + ("echo hi > myimportantfile", true), + ("ls -al >myimportantfile", true), + ("echo hi 2> myimportantfile", true), + ("echo hi >> myimportantfile", true), + ("echo $(rm myimportantfile)", true), + ("echo `rm myimportantfile`", true), + ("echo hello && rm myimportantfile", true), + ("echo hello&&rm myimportantfile", true), + ("ls nonexistantpath || rm myimportantfile", true), + ("echo myimportantfile | xargs rm", true), + ("echo myimportantfile|args rm", true), + ("echo <(rm myimportantfile)", true), + ("cat <<< 'some string here' > myimportantfile", true), + ("echo '\n#!/usr/bin/env bash\necho hello\n' > myscript.sh", true), + ("cat < myimportantfile\nhello world\nEOF", true), + // Safe piped commands + ("find . -name '*.rs' | grep main", false), + ("ls -la | grep .git", false), + ("cat file.txt | grep pattern | head -n 5", false), + // Unsafe piped commands + ("find . -name '*.rs' | rm", true), + ("ls -la | grep .git | rm -rf", true), + ("echo hello | sudo rm -rf /", true), + // `find` command arguments + ("find important-dir/ -exec rm {} \\;", true), + ("find . -name '*.c' -execdir gcc -o '{}.out' '{}' \\;", true), + ("find important-dir/ -delete", true), + ("find important-dir/ -name '*.txt'", false), + ]; + for (cmd, expected) in cmds { + let tool = serde_json::from_value::(serde_json::json!({ + "command": cmd, + })) + .unwrap(); + assert_eq!( + tool.requires_acceptance(), + *expected, + "expected command: `{}` to have requires_acceptance: `{}`", + cmd, + expected + ); + } + } +} diff --git a/crates/kiro-cli/src/cli/chat/tools/fs_read.rs b/crates/kiro-cli/src/cli/chat/tools/fs_read.rs new file mode 100644 index 0000000000..6cb2bc96ed --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/tools/fs_read.rs @@ -0,0 +1,669 @@ +use std::collections::VecDeque; +use std::fs::Metadata; +use std::io::Write; +use std::os::unix::fs::PermissionsExt; + +use crossterm::queue; +use crossterm::style::{ + self, + Color, +}; +use eyre::{ + Result, + bail, +}; +use serde::{ + Deserialize, + Serialize, +}; +use syntect::util::LinesWithEndings; +use tracing::{ + debug, + warn, +}; + +use super::{ + InvokeOutput, + MAX_TOOL_RESPONSE_SIZE, + OutputKind, + format_path, + sanitize_path_tool_arg, +}; +use crate::fig_os_shim::Context; + +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "mode")] +pub enum FsRead { + Line(FsLine), + Directory(FsDirectory), + Search(FsSearch), +} + +impl FsRead { + pub async fn validate(&mut self, ctx: &Context) -> Result<()> { + match self { + FsRead::Line(fs_line) => fs_line.validate(ctx).await, + FsRead::Directory(fs_directory) => fs_directory.validate(ctx).await, + FsRead::Search(fs_search) => fs_search.validate(ctx).await, + } + } + + pub async fn queue_description(&self, ctx: &Context, updates: &mut impl Write) -> Result<()> { + match self { + FsRead::Line(fs_line) => fs_line.queue_description(ctx, updates).await, + FsRead::Directory(fs_directory) => fs_directory.queue_description(updates), + FsRead::Search(fs_search) => fs_search.queue_description(updates), + } + } + + pub async fn invoke(&self, ctx: &Context, updates: &mut impl Write) -> Result { + match self { + FsRead::Line(fs_line) => fs_line.invoke(ctx, updates).await, + FsRead::Directory(fs_directory) => fs_directory.invoke(ctx, updates).await, + FsRead::Search(fs_search) => fs_search.invoke(ctx, updates).await, + } + } +} + +/// Read lines from a file. +#[derive(Debug, Clone, Deserialize)] +pub struct FsLine { + pub path: String, + pub start_line: Option, + pub end_line: Option, +} + +impl FsLine { + const DEFAULT_END_LINE: i32 = -1; + const DEFAULT_START_LINE: i32 = 1; + + pub async fn validate(&mut self, ctx: &Context) -> Result<()> { + let path = sanitize_path_tool_arg(ctx, &self.path); + if !path.exists() { + bail!("'{}' does not exist", self.path); + } + let is_file = ctx.fs().symlink_metadata(&path).await?.is_file(); + if !is_file { + bail!("'{}' is not a file", self.path); + } + Ok(()) + } + + pub async fn queue_description(&self, ctx: &Context, updates: &mut impl Write) -> Result<()> { + let path = sanitize_path_tool_arg(ctx, &self.path); + let line_count = ctx.fs().read_to_string(&path).await?.lines().count(); + queue!( + updates, + style::Print("Reading file: "), + style::SetForegroundColor(Color::Green), + style::Print(&self.path), + style::ResetColor, + style::Print(", "), + )?; + + let start = convert_negative_index(line_count, self.start_line()) + 1; + let end = convert_negative_index(line_count, self.end_line()) + 1; + match (start, end) { + _ if start == 1 && end == line_count => Ok(queue!(updates, style::Print("all lines".to_string()))?), + _ if end == line_count => Ok(queue!( + updates, + style::Print("from line "), + style::SetForegroundColor(Color::Green), + style::Print(start), + style::ResetColor, + style::Print(" to end of file"), + )?), + _ => Ok(queue!( + updates, + style::Print("from line "), + style::SetForegroundColor(Color::Green), + style::Print(start), + style::ResetColor, + style::Print(" to "), + style::SetForegroundColor(Color::Green), + style::Print(end), + style::ResetColor, + )?), + } + } + + pub async fn invoke(&self, ctx: &Context, _updates: &mut impl Write) -> Result { + let path = sanitize_path_tool_arg(ctx, &self.path); + debug!(?path, "Reading"); + let file = ctx.fs().read_to_string(&path).await?; + let line_count = file.lines().count(); + let (start, end) = ( + convert_negative_index(line_count, self.start_line()), + convert_negative_index(line_count, self.end_line()), + ); + + // safety check to ensure end is always greater than start + let end = end.max(start); + + if start >= line_count { + bail!( + "starting index: {} is outside of the allowed range: ({}, {})", + self.start_line(), + -(line_count as i64), + line_count + ); + } + + // The range should be inclusive on both ends. + let file_contents = file + .lines() + .skip(start) + .take(end - start + 1) + .collect::>() + .join("\n"); + + let byte_count = file_contents.len(); + if byte_count > MAX_TOOL_RESPONSE_SIZE { + bail!( + "This tool only supports reading {MAX_TOOL_RESPONSE_SIZE} bytes at a +time. You tried to read {byte_count} bytes. Try executing with fewer lines specified." + ); + } + + Ok(InvokeOutput { + output: OutputKind::Text(file_contents), + }) + } + + fn start_line(&self) -> i32 { + self.start_line.unwrap_or(Self::DEFAULT_START_LINE) + } + + fn end_line(&self) -> i32 { + self.end_line.unwrap_or(Self::DEFAULT_END_LINE) + } +} + +/// Search in a file. +#[derive(Debug, Clone, Deserialize)] +pub struct FsSearch { + pub path: String, + pub pattern: String, + pub context_lines: Option, +} + +impl FsSearch { + const CONTEXT_LINE_PREFIX: &str = " "; + const DEFAULT_CONTEXT_LINES: usize = 2; + const MATCHING_LINE_PREFIX: &str = "→ "; + + pub async fn validate(&mut self, ctx: &Context) -> Result<()> { + let path = sanitize_path_tool_arg(ctx, &self.path); + let relative_path = format_path(ctx.env().current_dir()?, &path); + if !path.exists() { + bail!("File not found: {}", relative_path); + } + if !ctx.fs().symlink_metadata(path).await?.is_file() { + bail!("Path is not a file: {}", relative_path); + } + if self.pattern.is_empty() { + bail!("Search pattern cannot be empty"); + } + Ok(()) + } + + pub fn queue_description(&self, updates: &mut impl Write) -> Result<()> { + queue!( + updates, + style::Print("Searching: "), + style::SetForegroundColor(Color::Green), + style::Print(&self.path), + style::ResetColor, + style::Print(" for pattern: "), + style::SetForegroundColor(Color::Green), + style::Print(&self.pattern.to_lowercase()), + style::ResetColor, + )?; + Ok(()) + } + + pub async fn invoke(&self, ctx: &Context, updates: &mut impl Write) -> Result { + let file_path = sanitize_path_tool_arg(ctx, &self.path); + let pattern = &self.pattern; + let relative_path = format_path(ctx.env().current_dir()?, &file_path); + + let file_content = ctx.fs().read_to_string(&file_path).await?; + let lines: Vec<&str> = LinesWithEndings::from(&file_content).collect(); + + let mut results = Vec::new(); + let mut total_matches = 0; + + // Case insensitive search + let pattern_lower = pattern.to_lowercase(); + for (line_num, line) in lines.iter().enumerate() { + if line.to_lowercase().contains(&pattern_lower) { + total_matches += 1; + let start = line_num.saturating_sub(self.context_lines()); + let end = lines.len().min(line_num + self.context_lines() + 1); + let mut context_text = Vec::new(); + (start..end).for_each(|i| { + let prefix = if i == line_num { + Self::MATCHING_LINE_PREFIX + } else { + Self::CONTEXT_LINE_PREFIX + }; + let line_text = lines[i].to_string(); + context_text.push(format!("{}{}: {}", prefix, i + 1, line_text)); + }); + let match_text = context_text.join(""); + results.push(SearchMatch { + line_number: line_num + 1, + context: match_text, + }); + } + } + + queue!( + updates, + style::SetForegroundColor(Color::Yellow), + style::ResetColor, + style::Print(format!( + "Found {} matches for pattern '{}' in {}\n", + total_matches, pattern, relative_path + )), + style::Print("\n"), + style::ResetColor, + )?; + + Ok(InvokeOutput { + output: OutputKind::Text(serde_json::to_string(&results)?), + }) + } + + fn context_lines(&self) -> usize { + self.context_lines.unwrap_or(Self::DEFAULT_CONTEXT_LINES) + } +} + +/// List directory contents. +#[derive(Debug, Clone, Deserialize)] +pub struct FsDirectory { + pub path: String, + pub depth: Option, +} + +impl FsDirectory { + const DEFAULT_DEPTH: usize = 0; + + pub async fn validate(&mut self, ctx: &Context) -> Result<()> { + let path = sanitize_path_tool_arg(ctx, &self.path); + let relative_path = format_path(ctx.env().current_dir()?, &path); + if !path.exists() { + bail!("Directory not found: {}", relative_path); + } + if !ctx.fs().symlink_metadata(path).await?.is_dir() { + bail!("Path is not a directory: {}", relative_path); + } + Ok(()) + } + + pub fn queue_description(&self, updates: &mut impl Write) -> Result<()> { + queue!( + updates, + style::Print("Reading directory: "), + style::SetForegroundColor(Color::Green), + style::Print(&self.path), + style::ResetColor, + style::Print(" "), + )?; + let depth = self.depth.unwrap_or_default(); + Ok(queue!( + updates, + style::Print(format!("with maximum depth of {}", depth)) + )?) + } + + pub async fn invoke(&self, ctx: &Context, updates: &mut impl Write) -> Result { + let path = sanitize_path_tool_arg(ctx, &self.path); + let cwd = ctx.env().current_dir()?; + let max_depth = self.depth(); + debug!(?path, max_depth, "Reading directory at path with depth"); + let mut result = Vec::new(); + let mut dir_queue = VecDeque::new(); + dir_queue.push_back((path, 0)); + while let Some((path, depth)) = dir_queue.pop_front() { + if depth > max_depth { + break; + } + let relative_path = format_path(&cwd, &path); + if !relative_path.is_empty() { + queue!( + updates, + style::Print("Reading: "), + style::SetForegroundColor(Color::Green), + style::Print(&relative_path), + style::ResetColor, + style::Print("\n"), + )?; + } + let mut read_dir = ctx.fs().read_dir(path).await?; + while let Some(ent) = read_dir.next_entry().await? { + use std::os::unix::fs::MetadataExt; + let md = ent.metadata().await?; + let formatted_mode = format_mode(md.permissions().mode()).into_iter().collect::(); + + let modified_timestamp = md.modified()?.duration_since(std::time::UNIX_EPOCH)?.as_secs(); + let datetime = time::OffsetDateTime::from_unix_timestamp(modified_timestamp as i64).unwrap(); + let formatted_date = datetime + .format(time::macros::format_description!( + "[month repr:short] [day] [hour]:[minute]" + )) + .unwrap(); + + // Mostly copying "The Long Format" from `man ls`. + // TODO: query user/group database to convert uid/gid to names? + result.push(format!( + "{}{} {} {} {} {} {} {}", + format_ftype(&md), + formatted_mode, + md.nlink(), + md.uid(), + md.gid(), + md.size(), + formatted_date, + ent.path().to_string_lossy() + )); + if md.is_dir() { + dir_queue.push_back((ent.path(), depth + 1)); + } + } + } + + let file_count = result.len(); + let result = result.join("\n"); + let byte_count = result.len(); + if byte_count > MAX_TOOL_RESPONSE_SIZE { + bail!( + "This tool only supports reading up to {MAX_TOOL_RESPONSE_SIZE} bytes at a time. You tried to read {byte_count} bytes ({file_count} files). Try executing with fewer lines specified." + ); + } + + Ok(InvokeOutput { + output: OutputKind::Text(result), + }) + } + + fn depth(&self) -> usize { + self.depth.unwrap_or(Self::DEFAULT_DEPTH) + } +} + +/// Converts negative 1-based indices to positive 0-based indices. +fn convert_negative_index(line_count: usize, i: i32) -> usize { + if i <= 0 { + (line_count as i32 + i).max(0) as usize + } else { + i as usize - 1 + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SearchMatch { + line_number: usize, + context: String, +} + +fn format_ftype(md: &Metadata) -> char { + if md.is_symlink() { + 'l' + } else if md.is_file() { + '-' + } else if md.is_dir() { + 'd' + } else { + warn!("unknown file metadata: {:?}", md); + '-' + } +} + +/// Formats a permissions mode into the form used by `ls`, e.g. `0o644` to `rw-r--r--` +fn format_mode(mode: u32) -> [char; 9] { + let mut mode = mode & 0o777; + let mut res = ['-'; 9]; + fn octal_to_chars(val: u32) -> [char; 3] { + match val { + 1 => ['-', '-', 'x'], + 2 => ['-', 'w', '-'], + 3 => ['-', 'w', 'x'], + 4 => ['r', '-', '-'], + 5 => ['r', '-', 'x'], + 6 => ['r', 'w', '-'], + 7 => ['r', 'w', 'x'], + _ => ['-', '-', '-'], + } + } + for c in res.rchunks_exact_mut(3) { + c.copy_from_slice(&octal_to_chars(mode & 0o7)); + mode /= 0o10; + } + res +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + + const TEST_FILE_CONTENTS: &str = "\ +1: Hello world! +2: This is line 2 +3: asdf +4: Hello world! +"; + + const TEST_FILE_PATH: &str = "/test_file.txt"; + const TEST_HIDDEN_FILE_PATH: &str = "/aaaa2/.hidden"; + + /// Sets up the following filesystem structure: + /// ```text + /// test_file.txt + /// /home/testuser/ + /// /aaaa1/ + /// /bbbb1/ + /// /cccc1/ + /// /aaaa2/ + /// .hidden + /// ``` + async fn setup_test_directory() -> Arc { + let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); + let fs = ctx.fs(); + fs.write(TEST_FILE_PATH, TEST_FILE_CONTENTS).await.unwrap(); + fs.create_dir_all("/aaaa1/bbbb1/cccc1").await.unwrap(); + fs.create_dir_all("/aaaa2").await.unwrap(); + fs.write(TEST_HIDDEN_FILE_PATH, "this is a hidden file").await.unwrap(); + ctx + } + + #[test] + fn test_negative_index_conversion() { + assert_eq!(convert_negative_index(5, -100), 0); + assert_eq!(convert_negative_index(5, -1), 4); + } + + #[test] + fn test_fs_read_deser() { + serde_json::from_value::(serde_json::json!({ "path": "/test_file.txt", "mode": "Line" })).unwrap(); + serde_json::from_value::( + serde_json::json!({ "path": "/test_file.txt", "mode": "Line", "end_line": 5 }), + ) + .unwrap(); + serde_json::from_value::( + serde_json::json!({ "path": "/test_file.txt", "mode": "Line", "start_line": -1 }), + ) + .unwrap(); + serde_json::from_value::( + serde_json::json!({ "path": "/test_file.txt", "mode": "Line", "start_line": None:: }), + ) + .unwrap(); + serde_json::from_value::(serde_json::json!({ "path": "/", "mode": "Directory" })).unwrap(); + serde_json::from_value::( + serde_json::json!({ "path": "/test_file.txt", "mode": "Directory", "depth": 2 }), + ) + .unwrap(); + serde_json::from_value::( + serde_json::json!({ "path": "/test_file.txt", "mode": "Search", "pattern": "hello" }), + ) + .unwrap(); + } + + #[tokio::test] + async fn test_fs_read_line_invoke() { + let ctx = setup_test_directory().await; + let lines = TEST_FILE_CONTENTS.lines().collect::>(); + let mut stdout = std::io::stdout(); + + macro_rules! assert_lines { + ($start_line:expr, $end_line:expr, $expected:expr) => { + let v = serde_json::json!({ + "path": TEST_FILE_PATH, + "mode": "Line", + "start_line": $start_line, + "end_line": $end_line, + }); + let output = serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut stdout) + .await + .unwrap(); + + if let OutputKind::Text(text) = output.output { + assert_eq!(text, $expected.join("\n"), "actual(left) does not equal + expected(right) for (start_line, end_line): ({:?}, {:?})", $start_line, $end_line); + } else { + panic!("expected text output"); + } + } + } + assert_lines!(None::, None::, lines[..]); + assert_lines!(1, 2, lines[..=1]); + assert_lines!(1, -1, lines[..]); + assert_lines!(2, 1, lines[1..=1]); + assert_lines!(-2, -1, lines[2..]); + assert_lines!(-2, None::, lines[2..]); + assert_lines!(2, None::, lines[1..]); + } + + #[tokio::test] + async fn test_fs_read_line_past_eof() { + let ctx = setup_test_directory().await; + let mut stdout = std::io::stdout(); + let v = serde_json::json!({ + "path": TEST_FILE_PATH, + "mode": "Line", + "start_line": 100, + "end_line": None::, + }); + assert!( + serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut stdout) + .await + .is_err() + ); + } + + #[test] + fn test_format_mode() { + macro_rules! assert_mode { + ($actual:expr, $expected:expr) => { + assert_eq!(format_mode($actual).iter().collect::(), $expected); + }; + } + assert_mode!(0o000, "---------"); + assert_mode!(0o700, "rwx------"); + assert_mode!(0o744, "rwxr--r--"); + assert_mode!(0o641, "rw-r----x"); + } + + #[tokio::test] + async fn test_fs_read_directory_invoke() { + let ctx = setup_test_directory().await; + let mut stdout = std::io::stdout(); + + // Testing without depth + let v = serde_json::json!({ + "mode": "Directory", + "path": "/", + }); + let output = serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut stdout) + .await + .unwrap(); + + if let OutputKind::Text(text) = output.output { + assert_eq!(text.lines().collect::>().len(), 4); + } else { + panic!("expected text output"); + } + + // Testing with depth level 1 + let v = serde_json::json!({ + "mode": "Directory", + "path": "/", + "depth": 1, + }); + let output = serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut stdout) + .await + .unwrap(); + + if let OutputKind::Text(text) = output.output { + let lines = text.lines().collect::>(); + assert_eq!(lines.len(), 7); + assert!( + !lines.iter().any(|l| l.contains("cccc1")), + "directory at depth level 2 should not be included in output" + ); + } else { + panic!("expected text output"); + } + } + + #[tokio::test] + async fn test_fs_read_search_invoke() { + let ctx = setup_test_directory().await; + let mut stdout = std::io::stdout(); + + macro_rules! invoke_search { + ($value:tt) => {{ + let v = serde_json::json!($value); + let output = serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut stdout) + .await + .unwrap(); + + if let OutputKind::Text(value) = output.output { + serde_json::from_str::>(&value).unwrap() + } else { + panic!("expected Text output") + } + }}; + } + + let matches = invoke_search!({ + "mode": "Search", + "path": TEST_FILE_PATH, + "pattern": "hello", + }); + assert_eq!(matches.len(), 2); + assert_eq!(matches[0].line_number, 1); + assert_eq!( + matches[0].context, + format!( + "{}1: 1: Hello world!\n{}2: 2: This is line 2\n{}3: 3: asdf\n", + FsSearch::MATCHING_LINE_PREFIX, + FsSearch::CONTEXT_LINE_PREFIX, + FsSearch::CONTEXT_LINE_PREFIX + ) + ); + } +} diff --git a/crates/kiro-cli/src/cli/chat/tools/fs_write.rs b/crates/kiro-cli/src/cli/chat/tools/fs_write.rs new file mode 100644 index 0000000000..a7eb02487e --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/tools/fs_write.rs @@ -0,0 +1,953 @@ +use std::io::Write; +use std::path::Path; +use std::sync::LazyLock; + +use crossterm::queue; +use crossterm::style::{ + self, + Color, +}; +use eyre::{ + ContextCompat as _, + Result, + bail, + eyre, +}; +use serde::Deserialize; +use similar::DiffableStr; +use syntect::easy::HighlightLines; +use syntect::highlighting::ThemeSet; +use syntect::parsing::SyntaxSet; +use syntect::util::{ + LinesWithEndings, + as_24_bit_terminal_escaped, +}; +use tracing::{ + error, + warn, +}; + +use super::{ + InvokeOutput, + format_path, + sanitize_path_tool_arg, + supports_truecolor, +}; +use crate::fig_os_shim::Context; + +static SYNTAX_SET: LazyLock = LazyLock::new(SyntaxSet::load_defaults_newlines); +static THEME_SET: LazyLock = LazyLock::new(ThemeSet::load_defaults); + +#[derive(Debug, Clone, Deserialize)] +#[serde(tag = "command")] +pub enum FsWrite { + /// The tool spec should only require `file_text`, but the model sometimes doesn't want to + /// provide it. Thus, including `new_str` as a fallback check, if it's available. + #[serde(rename = "create")] + Create { + path: String, + file_text: Option, + new_str: Option, + }, + #[serde(rename = "str_replace")] + StrReplace { + path: String, + old_str: String, + new_str: String, + }, + #[serde(rename = "insert")] + Insert { + path: String, + insert_line: usize, + new_str: String, + }, + #[serde(rename = "append")] + Append { path: String, new_str: String }, +} + +impl FsWrite { + pub async fn invoke(&self, ctx: &Context, updates: &mut impl Write) -> Result { + let fs = ctx.fs(); + let cwd = ctx.env().current_dir()?; + match self { + FsWrite::Create { path, .. } => { + let file_text = self.canonical_create_command_text(); + let path = sanitize_path_tool_arg(ctx, path); + if let Some(parent) = path.parent() { + fs.create_dir_all(parent).await?; + } + + let invoke_description = if fs.exists(&path) { "Replacing: " } else { "Creating: " }; + queue!( + updates, + style::Print(invoke_description), + style::SetForegroundColor(Color::Green), + style::Print(format_path(cwd, &path)), + style::ResetColor, + style::Print("\n"), + )?; + + write_to_file(ctx, path, file_text).await?; + Ok(Default::default()) + }, + FsWrite::StrReplace { path, old_str, new_str } => { + let path = sanitize_path_tool_arg(ctx, path); + let file = fs.read_to_string(&path).await?; + let matches = file.match_indices(old_str).collect::>(); + queue!( + updates, + style::Print("Updating: "), + style::SetForegroundColor(Color::Green), + style::Print(format_path(cwd, &path)), + style::ResetColor, + style::Print("\n"), + )?; + match matches.len() { + 0 => Err(eyre!("no occurrences of \"{old_str}\" were found")), + 1 => { + let file = file.replacen(old_str, new_str, 1); + fs.write(path, file).await?; + Ok(Default::default()) + }, + x => Err(eyre!("{x} occurrences of old_str were found when only 1 is expected")), + } + }, + FsWrite::Insert { + path, + insert_line, + new_str, + } => { + let path = sanitize_path_tool_arg(ctx, path); + let mut file = fs.read_to_string(&path).await?; + queue!( + updates, + style::Print("Updating: "), + style::SetForegroundColor(Color::Green), + style::Print(format_path(cwd, &path)), + style::ResetColor, + style::Print("\n"), + )?; + + // Get the index of the start of the line to insert at. + let num_lines = file.lines().enumerate().map(|(i, _)| i + 1).last().unwrap_or(1); + let insert_line = insert_line.clamp(&0, &num_lines); + let mut i = 0; + for _ in 0..*insert_line { + let line_len = &file[i..].find("\n").map_or(file[i..].len(), |i| i + 1); + i += line_len; + } + file.insert_str(i, new_str); + write_to_file(ctx, &path, file).await?; + Ok(Default::default()) + }, + FsWrite::Append { path, new_str } => { + let path = sanitize_path_tool_arg(ctx, path); + + queue!( + updates, + style::Print("Appending to: "), + style::SetForegroundColor(Color::Green), + style::Print(format_path(cwd, &path)), + style::ResetColor, + style::Print("\n"), + )?; + + let mut file = fs.read_to_string(&path).await?; + if !file.ends_with_newline() { + file.push('\n'); + } + file.push_str(new_str); + write_to_file(ctx, path, file).await?; + Ok(Default::default()) + }, + } + } + + pub fn queue_description(&self, ctx: &Context, updates: &mut impl Write) -> Result<()> { + let cwd = ctx.env().current_dir()?; + self.print_relative_path(ctx, updates)?; + match self { + FsWrite::Create { path, .. } => { + let file_text = self.canonical_create_command_text(); + let relative_path = format_path(cwd, path); + let prev = if ctx.fs().exists(path) { + let file = ctx.fs().read_to_string_sync(path)?; + stylize_output_if_able(ctx, path, &file) + } else { + Default::default() + }; + let new = stylize_output_if_able(ctx, &relative_path, &file_text); + print_diff(updates, &prev, &new, 1)?; + Ok(()) + }, + FsWrite::Insert { + path, + insert_line, + new_str, + } => { + let relative_path = format_path(cwd, path); + let file = ctx.fs().read_to_string_sync(&relative_path)?; + + // Diff the old with the new by adding extra context around the line being inserted + // at. + let (prefix, start_line, suffix, _) = get_lines_with_context(&file, *insert_line, *insert_line, 3); + let insert_line_content = LinesWithEndings::from(&file) + // don't include any content if insert_line is 0 + .nth(insert_line.checked_sub(1).unwrap_or(usize::MAX)) + .unwrap_or_default(); + let old = [prefix, insert_line_content, suffix].join(""); + let new = [prefix, insert_line_content, new_str, suffix].join(""); + + let old = stylize_output_if_able(ctx, &relative_path, &old); + let new = stylize_output_if_able(ctx, &relative_path, &new); + print_diff(updates, &old, &new, start_line)?; + Ok(()) + }, + FsWrite::StrReplace { path, old_str, new_str } => { + let relative_path = format_path(cwd, path); + let file = ctx.fs().read_to_string_sync(&relative_path)?; + let (start_line, _) = match line_number_at(&file, old_str) { + Some((start_line, end_line)) => (start_line, end_line), + _ => (0, 0), + }; + let old_str = stylize_output_if_able(ctx, &relative_path, old_str); + let new_str = stylize_output_if_able(ctx, &relative_path, new_str); + print_diff(updates, &old_str, &new_str, start_line)?; + + Ok(()) + }, + FsWrite::Append { path, new_str } => { + let relative_path = format_path(cwd, path); + let start_line = ctx.fs().read_to_string_sync(&relative_path)?.lines().count() + 1; + let file = stylize_output_if_able(ctx, &relative_path, new_str); + print_diff(updates, &Default::default(), &file, start_line)?; + Ok(()) + }, + } + } + + pub async fn validate(&mut self, ctx: &Context) -> Result<()> { + match self { + FsWrite::Create { path, .. } => { + if path.is_empty() { + bail!("Path must not be empty") + }; + }, + FsWrite::StrReplace { path, .. } | FsWrite::Insert { path, .. } => { + let path = sanitize_path_tool_arg(ctx, path); + if !path.exists() { + bail!("The provided path must exist in order to replace or insert contents into it") + } + }, + FsWrite::Append { path, new_str } => { + if path.is_empty() { + bail!("Path must not be empty") + }; + if new_str.is_empty() { + bail!("Content to append must not be empty") + }; + }, + } + + Ok(()) + } + + fn print_relative_path(&self, ctx: &Context, updates: &mut impl Write) -> Result<()> { + let cwd = ctx.env().current_dir()?; + let path = match self { + FsWrite::Create { path, .. } => path, + FsWrite::StrReplace { path, .. } => path, + FsWrite::Insert { path, .. } => path, + FsWrite::Append { path, .. } => path, + }; + let relative_path = format_path(cwd, path); + queue!( + updates, + style::Print("Path: "), + style::SetForegroundColor(Color::Green), + style::Print(&relative_path), + style::ResetColor, + style::Print("\n\n"), + )?; + Ok(()) + } + + /// Returns the text to use for the [FsWrite::Create] command. This is required since we can't + /// rely on the model always providing `file_text`. + fn canonical_create_command_text(&self) -> String { + match self { + FsWrite::Create { file_text, new_str, .. } => match (file_text, new_str) { + (Some(file_text), _) => file_text.clone(), + (None, Some(new_str)) => { + warn!("required field `file_text` is missing, using the provided `new_str` instead"); + new_str.clone() + }, + _ => { + warn!("no content provided for the create command"); + String::new() + }, + }, + _ => String::new(), + } + } +} + +/// Writes `content` to `path`, adding a newline if necessary. +async fn write_to_file(ctx: &Context, path: impl AsRef, mut content: String) -> Result<()> { + if !content.ends_with_newline() { + content.push('\n'); + } + ctx.fs().write(path.as_ref(), content).await?; + Ok(()) +} + +/// Returns a prefix/suffix pair before and after the content dictated by `[start_line, end_line]` +/// within `content`. The updated start and end lines containing the original context along with +/// the suffix and prefix are returned. +/// +/// Params: +/// - `start_line` - 1-indexed starting line of the content. +/// - `end_line` - 1-indexed ending line of the content. +/// - `context_lines` - number of lines to include before the start and end. +/// +/// Returns `(prefix, new_start_line, suffix, new_end_line)` +fn get_lines_with_context( + content: &str, + start_line: usize, + end_line: usize, + context_lines: usize, +) -> (&str, usize, &str, usize) { + let line_count = content.lines().count(); + // We want to support end_line being 0, in which case we should be able to set the first line + // as the suffix. + let zero_check_inc = if end_line == 0 { 0 } else { 1 }; + + // Convert to 0-indexing. + let (start_line, end_line) = ( + start_line.saturating_sub(1).clamp(0, line_count - 1), + end_line.saturating_sub(1).clamp(0, line_count - 1), + ); + let new_start_line = 0.max(start_line.saturating_sub(context_lines)); + let new_end_line = (line_count - 1).min(end_line + context_lines); + + // Build prefix + let mut prefix_start = 0; + for line in LinesWithEndings::from(content).take(new_start_line) { + prefix_start += line.len(); + } + let mut prefix_end = prefix_start; + for line in LinesWithEndings::from(&content[prefix_start..]).take(start_line - new_start_line) { + prefix_end += line.len(); + } + + // Build suffix + let mut suffix_start = 0; + for line in LinesWithEndings::from(content).take(end_line + zero_check_inc) { + suffix_start += line.len(); + } + let mut suffix_end = suffix_start; + for line in LinesWithEndings::from(&content[suffix_start..]).take(new_end_line - end_line) { + suffix_end += line.len(); + } + + ( + &content[prefix_start..prefix_end], + new_start_line + 1, + &content[suffix_start..suffix_end], + new_end_line + zero_check_inc, + ) +} + +/// Prints a git-diff style comparison between `old_str` and `new_str`. +/// - `start_line` - 1-indexed line number that `old_str` and `new_str` start at. +fn print_diff( + updates: &mut impl Write, + old_str: &StylizedFile, + new_str: &StylizedFile, + start_line: usize, +) -> Result<()> { + let diff = similar::TextDiff::from_lines(&old_str.content, &new_str.content); + + // First, get the gutter width required for both the old and new lines. + let (mut max_old_i, mut max_new_i) = (1, 1); + for change in diff.iter_all_changes() { + if let Some(i) = change.old_index() { + max_old_i = i + start_line; + } + if let Some(i) = change.new_index() { + max_new_i = i + start_line; + } + } + let old_line_num_width = terminal_width_required_for_line_count(max_old_i); + let new_line_num_width = terminal_width_required_for_line_count(max_new_i); + + // Now, print + fn fmt_index(i: Option, start_line: usize) -> String { + match i { + Some(i) => (i + start_line).to_string(), + _ => " ".to_string(), + } + } + for change in diff.iter_all_changes() { + // Define the colors per line. + let (text_color, gutter_bg_color, line_bg_color) = match (change.tag(), new_str.truecolor) { + (similar::ChangeTag::Equal, true) => (style::Color::Reset, new_str.gutter_bg, new_str.line_bg), + (similar::ChangeTag::Delete, true) => ( + style::Color::Reset, + style::Color::Rgb { r: 79, g: 40, b: 40 }, + style::Color::Rgb { r: 36, g: 25, b: 28 }, + ), + (similar::ChangeTag::Insert, true) => ( + style::Color::Reset, + style::Color::Rgb { r: 40, g: 67, b: 43 }, + style::Color::Rgb { r: 24, g: 38, b: 30 }, + ), + (similar::ChangeTag::Equal, false) => (style::Color::Reset, new_str.gutter_bg, new_str.line_bg), + (similar::ChangeTag::Delete, false) => (style::Color::Red, new_str.gutter_bg, new_str.line_bg), + (similar::ChangeTag::Insert, false) => (style::Color::Green, new_str.gutter_bg, new_str.line_bg), + }; + // Define the change tag character to print, if any. + let sign = match change.tag() { + similar::ChangeTag::Equal => " ", + similar::ChangeTag::Delete => "-", + similar::ChangeTag::Insert => "+", + }; + + let old_i_str = fmt_index(change.old_index(), start_line); + let new_i_str = fmt_index(change.new_index(), start_line); + + // Print the gutter and line numbers. + queue!(updates, style::SetBackgroundColor(gutter_bg_color))?; + queue!( + updates, + style::SetForegroundColor(text_color), + style::Print(sign), + style::Print(" ") + )?; + queue!( + updates, + style::Print(format!( + "{:>old_line_num_width$}", + old_i_str, + old_line_num_width = old_line_num_width + )) + )?; + if sign == " " { + queue!(updates, style::Print(", "))?; + } else { + queue!(updates, style::Print(" "))?; + } + queue!( + updates, + style::Print(format!( + "{:>new_line_num_width$}", + new_i_str, + new_line_num_width = new_line_num_width + )) + )?; + // Print the line. + queue!( + updates, + style::SetForegroundColor(style::Color::Reset), + style::Print(":"), + style::SetForegroundColor(text_color), + style::SetBackgroundColor(line_bg_color), + style::Print(" "), + style::Print(change), + style::ResetColor, + )?; + } + queue!( + updates, + crossterm::terminal::Clear(crossterm::terminal::ClearType::UntilNewLine), + style::Print("\n"), + )?; + + Ok(()) +} + +/// Returns a 1-indexed line number range of the start and end of `needle` inside `file`. +fn line_number_at(file: impl AsRef, needle: impl AsRef) -> Option<(usize, usize)> { + let file = file.as_ref(); + let needle = needle.as_ref(); + if let Some((i, _)) = file.match_indices(needle).next() { + let start = file[..i].matches("\n").count(); + let end = needle.matches("\n").count(); + Some((start + 1, start + end + 1)) + } else { + None + } +} + +/// Returns the number of terminal cells required for displaying line numbers. This is used to +/// determine how many characters the gutter should allocate when displaying line numbers for a +/// text file. +/// +/// For example, `10` and `99` both take 2 cells, whereas `100` and `999` take 3. +fn terminal_width_required_for_line_count(line_count: usize) -> usize { + line_count.to_string().chars().count() +} + +fn stylize_output_if_able(ctx: &Context, path: impl AsRef, file_text: &str) -> StylizedFile { + if supports_truecolor(ctx) { + match stylized_file(path, file_text) { + Ok(s) => return s, + Err(err) => { + error!(?err, "unable to syntax highlight the output"); + }, + } + } + StylizedFile { + truecolor: false, + content: file_text.to_string(), + gutter_bg: style::Color::Reset, + line_bg: style::Color::Reset, + } +} + +/// Represents a [String] that is potentially stylized with truecolor escape codes. +#[derive(Debug)] +struct StylizedFile { + /// Whether or not the file is stylized with 24bit color. + truecolor: bool, + /// File content. If [Self::truecolor] is true, then it has escape codes for styling with 24bit + /// color. + content: String, + /// Background color for the gutter. + gutter_bg: style::Color, + /// Background color for the line content. + line_bg: style::Color, +} + +impl Default for StylizedFile { + fn default() -> Self { + Self { + truecolor: false, + content: Default::default(), + gutter_bg: style::Color::Reset, + line_bg: style::Color::Reset, + } + } +} + +/// Returns a 24bit terminal escaped syntax-highlighted [String] of the file pointed to by `path`, +/// if able. +fn stylized_file(path: impl AsRef, file_text: impl AsRef) -> Result { + let ps = &*SYNTAX_SET; + let ts = &*THEME_SET; + + let extension = path + .as_ref() + .extension() + .wrap_err("missing extension")? + .to_str() + .wrap_err("not utf8")?; + + let syntax = ps + .find_syntax_by_extension(extension) + .wrap_err_with(|| format!("missing extension: {}", extension))?; + + let theme = &ts.themes["base16-ocean.dark"]; + let mut highlighter = HighlightLines::new(syntax, theme); + let file_text = file_text.as_ref().lines(); + let mut file = String::new(); + for line in file_text { + let mut ranges = Vec::new(); + ranges.append(&mut highlighter.highlight_line(line, ps)?); + let mut escaped_line = as_24_bit_terminal_escaped(&ranges[..], false); + escaped_line.push_str(&format!( + "{}\n", + crossterm::terminal::Clear(crossterm::terminal::ClearType::UntilNewLine), + )); + file.push_str(&escaped_line); + } + + let (line_bg, gutter_bg) = match (theme.settings.background, theme.settings.gutter) { + (Some(line_bg), Some(gutter_bg)) => (line_bg, gutter_bg), + (Some(line_bg), None) => (line_bg, line_bg), + _ => bail!("missing theme"), + }; + Ok(StylizedFile { + truecolor: true, + content: file, + gutter_bg: syntect_to_crossterm_color(gutter_bg), + line_bg: syntect_to_crossterm_color(line_bg), + }) +} + +fn syntect_to_crossterm_color(syntect: syntect::highlighting::Color) -> style::Color { + style::Color::Rgb { + r: syntect.r, + g: syntect.g, + b: syntect.b, + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + + const TEST_FILE_CONTENTS: &str = "\ +1: Hello world! +2: This is line 2 +3: asdf +4: Hello world! +"; + + const TEST_FILE_PATH: &str = "/test_file.txt"; + const TEST_HIDDEN_FILE_PATH: &str = "/aaaa2/.hidden"; + + /// Sets up the following filesystem structure: + /// ```text + /// test_file.txt + /// /home/testuser/ + /// /aaaa1/ + /// /bbbb1/ + /// /cccc1/ + /// /aaaa2/ + /// .hidden + /// ``` + async fn setup_test_directory() -> Arc { + let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); + let fs = ctx.fs(); + fs.write(TEST_FILE_PATH, TEST_FILE_CONTENTS).await.unwrap(); + fs.create_dir_all("/aaaa1/bbbb1/cccc1").await.unwrap(); + fs.create_dir_all("/aaaa2").await.unwrap(); + fs.write(TEST_HIDDEN_FILE_PATH, "this is a hidden file").await.unwrap(); + ctx + } + + #[test] + fn test_fs_write_deserialize() { + let path = "/my-file"; + let file_text = "hello world"; + + // create + let v = serde_json::json!({ + "path": path, + "command": "create", + "file_text": file_text + }); + let fw = serde_json::from_value::(v).unwrap(); + assert!(matches!(fw, FsWrite::Create { .. })); + + // str_replace + let v = serde_json::json!({ + "path": path, + "command": "str_replace", + "old_str": "prev string", + "new_str": "new string", + }); + let fw = serde_json::from_value::(v).unwrap(); + assert!(matches!(fw, FsWrite::StrReplace { .. })); + + // insert + let v = serde_json::json!({ + "path": path, + "command": "insert", + "insert_line": 3, + "new_str": "new string", + }); + let fw = serde_json::from_value::(v).unwrap(); + assert!(matches!(fw, FsWrite::Insert { .. })); + + // append + let v = serde_json::json!({ + "path": path, + "command": "append", + "new_str": "appended content", + }); + let fw = serde_json::from_value::(v).unwrap(); + assert!(matches!(fw, FsWrite::Append { .. })); + } + + #[tokio::test] + async fn test_fs_write_tool_create() { + let ctx = setup_test_directory().await; + let mut stdout = std::io::stdout(); + + let file_text = "Hello, world!"; + let v = serde_json::json!({ + "path": "/my-file", + "command": "create", + "file_text": file_text + }); + serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut stdout) + .await + .unwrap(); + + assert_eq!( + ctx.fs().read_to_string("/my-file").await.unwrap(), + format!("{}\n", file_text) + ); + + let file_text = "Goodbye, world!\nSee you later"; + let v = serde_json::json!({ + "path": "/my-file", + "command": "create", + "file_text": file_text + }); + serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut stdout) + .await + .unwrap(); + + // File should end with a newline + assert_eq!( + ctx.fs().read_to_string("/my-file").await.unwrap(), + format!("{}\n", file_text) + ); + + let file_text = "This is a new string"; + let v = serde_json::json!({ + "path": "/my-file", + "command": "create", + "new_str": file_text + }); + serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut stdout) + .await + .unwrap(); + + assert_eq!( + ctx.fs().read_to_string("/my-file").await.unwrap(), + format!("{}\n", file_text) + ); + } + + #[tokio::test] + async fn test_fs_write_tool_str_replace() { + let ctx = setup_test_directory().await; + let mut stdout = std::io::stdout(); + + // No instances found + let v = serde_json::json!({ + "path": TEST_FILE_PATH, + "command": "str_replace", + "old_str": "asjidfopjaieopr", + "new_str": "1623749", + }); + assert!( + serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut stdout) + .await + .is_err() + ); + + // Multiple instances found + let v = serde_json::json!({ + "path": TEST_FILE_PATH, + "command": "str_replace", + "old_str": "Hello world!", + "new_str": "Goodbye world!", + }); + assert!( + serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut stdout) + .await + .is_err() + ); + + // Single instance found and replaced + let v = serde_json::json!({ + "path": TEST_FILE_PATH, + "command": "str_replace", + "old_str": "1: Hello world!", + "new_str": "1: Goodbye world!", + }); + serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut stdout) + .await + .unwrap(); + assert_eq!( + ctx.fs() + .read_to_string(TEST_FILE_PATH) + .await + .unwrap() + .lines() + .next() + .unwrap(), + "1: Goodbye world!", + "expected the only occurrence to be replaced" + ); + } + + #[tokio::test] + async fn test_fs_write_tool_insert_at_beginning() { + let ctx = setup_test_directory().await; + let mut stdout = std::io::stdout(); + + let new_str = "1: New first line!\n"; + let v = serde_json::json!({ + "path": TEST_FILE_PATH, + "command": "insert", + "insert_line": 0, + "new_str": new_str, + }); + serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut stdout) + .await + .unwrap(); + let actual = ctx.fs().read_to_string(TEST_FILE_PATH).await.unwrap(); + assert_eq!( + format!("{}\n", actual.lines().next().unwrap()), + new_str, + "expected the first line to be updated to '{}'", + new_str + ); + assert_eq!( + actual.lines().skip(1).collect::>(), + TEST_FILE_CONTENTS.lines().collect::>(), + "the rest of the file should not have been updated" + ); + } + + #[tokio::test] + async fn test_fs_write_tool_insert_after_first_line() { + let ctx = setup_test_directory().await; + let mut stdout = std::io::stdout(); + + let new_str = "2: New second line!\n"; + let v = serde_json::json!({ + "path": TEST_FILE_PATH, + "command": "insert", + "insert_line": 1, + "new_str": new_str, + }); + + serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut stdout) + .await + .unwrap(); + let actual = ctx.fs().read_to_string(TEST_FILE_PATH).await.unwrap(); + assert_eq!( + format!("{}\n", actual.lines().nth(1).unwrap()), + new_str, + "expected the second line to be updated to '{}'", + new_str + ); + assert_eq!( + actual.lines().skip(2).collect::>(), + TEST_FILE_CONTENTS.lines().skip(1).collect::>(), + "the rest of the file should not have been updated" + ); + } + + #[tokio::test] + async fn test_fs_write_tool_insert_when_no_newlines_in_file() { + let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); + let mut stdout = std::io::stdout(); + + let test_file_path = "/file.txt"; + let test_file_contents = "hello there"; + ctx.fs().write(test_file_path, test_file_contents).await.unwrap(); + + let new_str = "test"; + + // First, test appending + let v = serde_json::json!({ + "path": test_file_path, + "command": "insert", + "insert_line": 1, + "new_str": new_str, + }); + serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut stdout) + .await + .unwrap(); + let actual = ctx.fs().read_to_string(test_file_path).await.unwrap(); + assert_eq!(actual, format!("{}{}\n", test_file_contents, new_str)); + + // Then, test prepending + let v = serde_json::json!({ + "path": test_file_path, + "command": "insert", + "insert_line": 0, + "new_str": new_str, + }); + serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut stdout) + .await + .unwrap(); + let actual = ctx.fs().read_to_string(test_file_path).await.unwrap(); + assert_eq!(actual, format!("{}{}{}\n", new_str, test_file_contents, new_str)); + } + + #[tokio::test] + async fn test_fs_write_tool_append() { + let ctx = setup_test_directory().await; + let mut stdout = std::io::stdout(); + + // Test appending to existing file + let content_to_append = "5: Appended line"; + let v = serde_json::json!({ + "path": TEST_FILE_PATH, + "command": "append", + "new_str": content_to_append, + }); + + serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut stdout) + .await + .unwrap(); + + let actual = ctx.fs().read_to_string(TEST_FILE_PATH).await.unwrap(); + assert_eq!( + actual, + format!("{}{}\n", TEST_FILE_CONTENTS, content_to_append), + "Content should be appended to the end of the file with a newline added" + ); + + // Test appending to non-existent file (should fail) + let new_file_path = "/new_append_file.txt"; + let content = "This is a new file created by append"; + let v = serde_json::json!({ + "path": new_file_path, + "command": "append", + "new_str": content, + }); + + let result = serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut stdout) + .await; + + assert!(result.is_err(), "Appending to non-existent file should fail"); + } + + #[test] + fn test_lines_with_context() { + let content = "Hello\nWorld!\nhow\nare\nyou\ntoday?"; + assert_eq!(get_lines_with_context(content, 1, 1, 1), ("", 1, "World!\n", 2)); + assert_eq!(get_lines_with_context(content, 0, 0, 2), ("", 1, "Hello\nWorld!\n", 2)); + assert_eq!( + get_lines_with_context(content, 2, 4, 50), + ("Hello\n", 1, "you\ntoday?", 6) + ); + assert_eq!(get_lines_with_context(content, 4, 100, 2), ("World!\nhow\n", 2, "", 6)); + } + + #[test] + fn test_gutter_width() { + assert_eq!(terminal_width_required_for_line_count(1), 1); + assert_eq!(terminal_width_required_for_line_count(9), 1); + assert_eq!(terminal_width_required_for_line_count(10), 2); + assert_eq!(terminal_width_required_for_line_count(99), 2); + assert_eq!(terminal_width_required_for_line_count(100), 3); + assert_eq!(terminal_width_required_for_line_count(999), 3); + } +} diff --git a/crates/kiro-cli/src/cli/chat/tools/gh_issue.rs b/crates/kiro-cli/src/cli/chat/tools/gh_issue.rs new file mode 100644 index 0000000000..58702170c5 --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/tools/gh_issue.rs @@ -0,0 +1,222 @@ +use std::collections::{ + HashMap, + VecDeque, +}; +use std::io::Write; + +use crossterm::style::Color; +use crossterm::{ + queue, + style, +}; +use eyre::{ + Result, + WrapErr, + eyre, +}; +use serde::Deserialize; + +use super::super::context::ContextManager; +use super::super::util::issue::IssueCreator; +use super::{ + InvokeOutput, + ToolPermission, +}; +use crate::cli::chat::token_counter::TokenCounter; +use crate::fig_os_shim::Context; + +#[derive(Debug, Clone, Deserialize)] +pub struct GhIssue { + pub title: String, + pub expected_behavior: Option, + pub actual_behavior: Option, + pub steps_to_reproduce: Option, + + #[serde(skip_deserializing)] + pub context: Option, +} + +#[derive(Debug, Clone)] +pub struct GhIssueContext { + pub context_manager: Option, + pub transcript: VecDeque, + pub failed_request_ids: Vec, + pub tool_permissions: HashMap, + pub interactive: bool, +} + +/// Max amount of characters to include in the transcript. +const MAX_TRANSCRIPT_CHAR_LEN: usize = 3_000; + +impl GhIssue { + pub async fn invoke(&self, _updates: impl Write) -> Result { + let Some(context) = self.context.as_ref() else { + return Err(eyre!( + "report_issue: Required tool context (GhIssueContext) not set by the program." + )); + }; + + // Prepare additional details from the chat session + let additional_environment = [ + Self::get_chat_settings(context), + Self::get_request_ids(context), + Self::get_context(context).await, + ] + .join("\n\n"); + + // Add chat history to the actual behavior text. + let actual_behavior = self.actual_behavior.as_ref().map_or_else( + || Self::get_transcript(context), + |behavior| format!("{behavior}\n\n{}\n", Self::get_transcript(context)), + ); + + let _ = IssueCreator { + title: Some(self.title.clone()), + expected_behavior: self.expected_behavior.clone(), + actual_behavior: Some(actual_behavior), + steps_to_reproduce: self.steps_to_reproduce.clone(), + additional_environment: Some(additional_environment), + } + .create_url() + .await + .wrap_err("failed to invoke gh issue tool"); + + Ok(Default::default()) + } + + pub fn set_context(&mut self, context: GhIssueContext) { + self.context = Some(context); + } + + fn get_transcript(context: &GhIssueContext) -> String { + let mut transcript_str = String::from("```\n[chat-transcript]\n"); + let mut is_truncated = false; + let transcript: Vec = context.transcript + .iter() + .rev() // To take last N items + .scan(0, |user_msg_char_count, line| { + if *user_msg_char_count >= MAX_TRANSCRIPT_CHAR_LEN { + is_truncated = true; + return None; + } + let remaining_chars = MAX_TRANSCRIPT_CHAR_LEN - *user_msg_char_count; + let trimmed_line = if line.len() > remaining_chars { + &line[..remaining_chars] + } else { + line + }; + *user_msg_char_count += trimmed_line.len(); + + // backticks will mess up the markdown + let text = trimmed_line.replace("```", r"\```"); + Some(text) + }) + .collect::>() + .into_iter() + .rev() // Now return items to the proper order + .collect(); + + if !transcript.is_empty() { + transcript_str.push_str(&transcript.join("\n\n")); + } else { + transcript_str.push_str("No chat history found."); + } + + if is_truncated { + transcript_str.push_str("\n\n(...truncated)"); + } + transcript_str.push_str("\n```"); + transcript_str + } + + fn get_request_ids(context: &GhIssueContext) -> String { + format!( + "[chat-failed_request_ids]\n{}", + if context.failed_request_ids.is_empty() { + "none".to_string() + } else { + context.failed_request_ids.join("\n") + } + ) + } + + async fn get_context(context: &GhIssueContext) -> String { + let mut ctx_str = "[chat-context]\n".to_string(); + let Some(ctx_manager) = &context.context_manager else { + ctx_str.push_str("No context available."); + return ctx_str; + }; + + ctx_str.push_str(&format!("current_profile={}\n", ctx_manager.current_profile)); + match ctx_manager.list_profiles().await { + Ok(profiles) if !profiles.is_empty() => { + ctx_str.push_str(&format!("profiles=\n{}\n\n", profiles.join("\n"))); + }, + _ => ctx_str.push_str("profiles=none\n\n"), + } + + // Context file categories + if ctx_manager.global_config.paths.is_empty() { + ctx_str.push_str("global_context=none\n\n"); + } else { + ctx_str.push_str(&format!( + "global_context=\n{}\n\n", + &ctx_manager.global_config.paths.join("\n") + )); + } + + if ctx_manager.profile_config.paths.is_empty() { + ctx_str.push_str("profile_context=none\n\n"); + } else { + ctx_str.push_str(&format!( + "profile_context=\n{}\n\n", + &ctx_manager.profile_config.paths.join("\n") + )); + } + + // Handle context files + match ctx_manager.get_context_files(false).await { + Ok(context_files) if !context_files.is_empty() => { + ctx_str.push_str("files=\n"); + let total_size: usize = context_files + .iter() + .map(|(file, content)| { + let size = TokenCounter::count_tokens(content); + ctx_str.push_str(&format!("{}, {} tkns\n", file, size)); + size + }) + .sum(); + ctx_str.push_str(&format!("total context size={total_size} tkns")); + }, + _ => ctx_str.push_str("files=none"), + } + + ctx_str + } + + fn get_chat_settings(context: &GhIssueContext) -> String { + let mut result_str = "[chat-settings]\n".to_string(); + result_str.push_str(&format!("interactive={}", context.interactive)); + + result_str.push_str("\n\n[chat-trusted_tools]"); + for (tool, permission) in context.tool_permissions.iter() { + result_str.push_str(&format!("\n{tool}={}", permission.trusted)); + } + + result_str + } + + pub fn queue_description(&self, updates: &mut impl Write) -> Result<()> { + Ok(queue!( + updates, + style::Print("I will prepare a github issue with our conversation history.\n\n"), + style::SetForegroundColor(Color::Green), + style::Print(format!("Title: {}\n", &self.title)), + style::ResetColor + )?) + } + + pub async fn validate(&mut self, _ctx: &Context) -> Result<()> { + Ok(()) + } +} diff --git a/crates/kiro-cli/src/cli/chat/tools/mod.rs b/crates/kiro-cli/src/cli/chat/tools/mod.rs new file mode 100644 index 0000000000..316363df5d --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/tools/mod.rs @@ -0,0 +1,432 @@ +pub mod custom_tool; +pub mod execute_bash; +pub mod fs_read; +pub mod fs_write; +pub mod gh_issue; +pub mod use_aws; + +use std::collections::HashMap; +use std::io::Write; +use std::path::{ + Path, + PathBuf, +}; + +use aws_smithy_types::{ + Document, + Number as SmithyNumber, +}; +use crossterm::style::Stylize; +use custom_tool::CustomTool; +use execute_bash::ExecuteBash; +use eyre::Result; +use fs_read::FsRead; +use fs_write::FsWrite; +use gh_issue::GhIssue; +use serde::{ + Deserialize, + Serialize, +}; +use use_aws::UseAws; + +use super::consts::MAX_TOOL_RESPONSE_SIZE; +use crate::fig_os_shim::Context; + +/// Represents an executable tool use. +#[derive(Debug, Clone)] +pub enum Tool { + FsRead(FsRead), + FsWrite(FsWrite), + ExecuteBash(ExecuteBash), + UseAws(UseAws), + Custom(CustomTool), + GhIssue(GhIssue), +} + +impl Tool { + /// The display name of a tool + pub fn display_name(&self) -> String { + match self { + Tool::FsRead(_) => "fs_read", + Tool::FsWrite(_) => "fs_write", + Tool::ExecuteBash(_) => "execute_bash", + Tool::UseAws(_) => "use_aws", + Tool::Custom(custom_tool) => &custom_tool.name, + Tool::GhIssue(_) => "gh_issue", + } + .to_owned() + } + + /// Whether or not the tool should prompt the user to accept before [Self::invoke] is called. + pub fn requires_acceptance(&self, _ctx: &Context) -> bool { + match self { + Tool::FsRead(_) => false, + Tool::FsWrite(_) => true, + Tool::ExecuteBash(execute_bash) => execute_bash.requires_acceptance(), + Tool::UseAws(use_aws) => use_aws.requires_acceptance(), + Tool::Custom(_) => true, + Tool::GhIssue(_) => false, + } + } + + /// Invokes the tool asynchronously + pub async fn invoke(&self, context: &Context, updates: &mut impl Write) -> Result { + match self { + Tool::FsRead(fs_read) => fs_read.invoke(context, updates).await, + Tool::FsWrite(fs_write) => fs_write.invoke(context, updates).await, + Tool::ExecuteBash(execute_bash) => execute_bash.invoke(updates).await, + Tool::UseAws(use_aws) => use_aws.invoke(context, updates).await, + Tool::Custom(custom_tool) => custom_tool.invoke(context, updates).await, + Tool::GhIssue(gh_issue) => gh_issue.invoke(updates).await, + } + } + + /// Queues up a tool's intention in a human readable format + pub async fn queue_description(&self, ctx: &Context, updates: &mut impl Write) -> Result<()> { + match self { + Tool::FsRead(fs_read) => fs_read.queue_description(ctx, updates).await, + Tool::FsWrite(fs_write) => fs_write.queue_description(ctx, updates), + Tool::ExecuteBash(execute_bash) => execute_bash.queue_description(updates), + Tool::UseAws(use_aws) => use_aws.queue_description(updates), + Tool::Custom(custom_tool) => custom_tool.queue_description(updates), + Tool::GhIssue(gh_issue) => gh_issue.queue_description(updates), + } + } + + /// Validates the tool with the arguments supplied + pub async fn validate(&mut self, ctx: &Context) -> Result<()> { + match self { + Tool::FsRead(fs_read) => fs_read.validate(ctx).await, + Tool::FsWrite(fs_write) => fs_write.validate(ctx).await, + Tool::ExecuteBash(execute_bash) => execute_bash.validate(ctx).await, + Tool::UseAws(use_aws) => use_aws.validate(ctx).await, + Tool::Custom(custom_tool) => custom_tool.validate(ctx).await, + Tool::GhIssue(gh_issue) => gh_issue.validate(ctx).await, + } + } +} + +#[derive(Debug, Clone)] +pub struct ToolPermission { + pub trusted: bool, +} + +#[derive(Debug, Clone)] +/// Holds overrides for tool permissions. +/// Tools that do not have an associated ToolPermission should use +/// their default logic to determine to permission. +pub struct ToolPermissions { + pub permissions: HashMap, +} + +impl ToolPermissions { + pub fn new(capacity: usize) -> Self { + Self { + permissions: HashMap::with_capacity(capacity), + } + } + + pub fn is_trusted(&self, tool_name: &str) -> bool { + self.permissions.get(tool_name).is_some_and(|perm| perm.trusted) + } + + /// Returns a label to describe the permission status for a given tool. + pub fn display_label(&self, tool_name: &str) -> String { + if self.has(tool_name) { + if self.is_trusted(tool_name) { + format!(" {}", "trusted".dark_green().bold()) + } else { + format!(" {}", "not trusted".dark_grey()) + } + } else { + Self::default_permission_label(tool_name) + } + } + + pub fn trust_tool(&mut self, tool_name: &str) { + self.permissions + .insert(tool_name.to_string(), ToolPermission { trusted: true }); + } + + pub fn untrust_tool(&mut self, tool_name: &str) { + self.permissions + .insert(tool_name.to_string(), ToolPermission { trusted: false }); + } + + pub fn reset(&mut self) { + self.permissions.clear(); + } + + pub fn reset_tool(&mut self, tool_name: &str) { + self.permissions.remove(tool_name); + } + + pub fn has(&self, tool_name: &str) -> bool { + self.permissions.contains_key(tool_name) + } + + /// Provide default permission labels for the built-in set of tools. + /// Unknown tools are assumed to be "Per-request" + // This "static" way avoids needing to construct a tool instance. + fn default_permission_label(tool_name: &str) -> String { + let label = match tool_name { + "fs_read" => "trusted".dark_green().bold(), + "fs_write" => "not trusted".dark_grey(), + "execute_bash" => "trust read-only commands".dark_grey(), + "use_aws" => "trust read-only commands".dark_grey(), + "report_issue" => "trusted".dark_green().bold(), + _ => "not trusted".dark_grey(), + }; + + format!("{} {label}", "*".reset()) + } +} + +/// A tool specification to be sent to the model as part of a conversation. Maps to +/// [BedrockToolSpecification]. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolSpec { + pub name: String, + pub description: String, + #[serde(alias = "inputSchema")] + pub input_schema: InputSchema, + #[serde(skip_serializing, default = "tool_origin")] + pub tool_origin: ToolOrigin, +} + +#[derive(Debug, Clone, Deserialize, Eq, PartialEq, Hash)] +pub enum ToolOrigin { + Native, + McpServer(String), +} + +impl std::fmt::Display for ToolOrigin { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ToolOrigin::Native => write!(f, "Built-in"), + ToolOrigin::McpServer(server) => write!(f, "{} (MCP)", server), + } + } +} + +fn tool_origin() -> ToolOrigin { + ToolOrigin::Native +} + +#[derive(Debug, Clone)] +pub struct QueuedTool { + pub id: String, + pub name: String, + pub accepted: bool, + pub tool: Tool, +} + +/// The schema specification describing a tool's fields. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct InputSchema(pub serde_json::Value); + +/// The output received from invoking a [Tool]. +#[derive(Debug, Default)] +pub struct InvokeOutput { + pub output: OutputKind, +} + +impl InvokeOutput { + pub fn as_str(&self) -> &str { + match &self.output { + OutputKind::Text(s) => s.as_str(), + OutputKind::Json(j) => j.as_str().unwrap_or_default(), + } + } +} + +#[non_exhaustive] +#[derive(Debug)] +pub enum OutputKind { + Text(String), + Json(serde_json::Value), +} + +impl Default for OutputKind { + fn default() -> Self { + Self::Text(String::new()) + } +} + +pub fn serde_value_to_document(value: serde_json::Value) -> Document { + match value { + serde_json::Value::Null => Document::Null, + serde_json::Value::Bool(bool) => Document::Bool(bool), + serde_json::Value::Number(number) => { + if let Some(num) = number.as_u64() { + Document::Number(SmithyNumber::PosInt(num)) + } else if number.as_i64().is_some_and(|n| n < 0) { + Document::Number(SmithyNumber::NegInt(number.as_i64().unwrap())) + } else { + Document::Number(SmithyNumber::Float(number.as_f64().unwrap_or_default())) + } + }, + serde_json::Value::String(string) => Document::String(string), + serde_json::Value::Array(vec) => { + Document::Array(vec.clone().into_iter().map(serde_value_to_document).collect::<_>()) + }, + serde_json::Value::Object(map) => Document::Object( + map.into_iter() + .map(|(k, v)| (k, serde_value_to_document(v))) + .collect::<_>(), + ), + } +} + +pub fn document_to_serde_value(value: Document) -> serde_json::Value { + use serde_json::Value; + match value { + Document::Object(map) => Value::Object( + map.into_iter() + .map(|(k, v)| (k, document_to_serde_value(v))) + .collect::<_>(), + ), + Document::Array(vec) => Value::Array(vec.clone().into_iter().map(document_to_serde_value).collect::<_>()), + Document::Number(number) => { + if let Ok(v) = TryInto::::try_into(number) { + Value::Number(v.into()) + } else if let Ok(v) = TryInto::::try_into(number) { + Value::Number(v.into()) + } else { + Value::Number( + serde_json::Number::from_f64(number.to_f64_lossy()) + .unwrap_or(serde_json::Number::from_f64(0.0).expect("converting from 0.0 will not fail")), + ) + } + }, + Document::String(s) => serde_json::Value::String(s), + Document::Bool(b) => serde_json::Value::Bool(b), + Document::Null => serde_json::Value::Null, + } +} + +/// Performs tilde expansion and other required sanitization modifications for handling tool use +/// path arguments. +/// +/// Required since path arguments are defined by the model. +#[allow(dead_code)] +fn sanitize_path_tool_arg(ctx: &Context, path: impl AsRef) -> PathBuf { + let mut res = PathBuf::new(); + // Expand `~` only if it is the first part. + let mut path = path.as_ref().components(); + match path.next() { + Some(p) if p.as_os_str() == "~" => { + res.push(ctx.env().home().unwrap_or_default()); + }, + Some(p) => res.push(p), + None => return res, + } + for p in path { + res.push(p); + } + // For testing scenarios, we need to make sure paths are appropriately handled in chroot test + // file systems since they are passed directly from the model. + ctx.fs().chroot_path(res) +} + +/// Converts `path` to a relative path according to the current working directory `cwd`. +fn absolute_to_relative(cwd: impl AsRef, path: impl AsRef) -> Result { + let cwd = cwd.as_ref().canonicalize()?; + let path = path.as_ref().canonicalize()?; + let mut cwd_parts = cwd.components().peekable(); + let mut path_parts = path.components().peekable(); + + // Skip common prefix + while let (Some(a), Some(b)) = (cwd_parts.peek(), path_parts.peek()) { + if a == b { + cwd_parts.next(); + path_parts.next(); + } else { + break; + } + } + + // ".." for any uncommon parts, then just append the rest of the path. + let mut relative = PathBuf::new(); + for _ in cwd_parts { + relative.push(".."); + } + for part in path_parts { + relative.push(part); + } + + Ok(relative) +} + +/// Small helper for formatting the path as a relative path, if able. +fn format_path(cwd: impl AsRef, path: impl AsRef) -> String { + absolute_to_relative(cwd, path.as_ref()) + .map(|p| p.to_string_lossy().to_string()) + // If we have three consecutive ".." then it should probably just stay as an absolute path. + .map(|p| { + if p.starts_with("../../..") { + path.as_ref().to_string_lossy().to_string() + } else { + p + } + }) + .unwrap_or(path.as_ref().to_string_lossy().to_string()) +} + +fn supports_truecolor(ctx: &Context) -> bool { + // Simple override to disable truecolor since shell_color doesn't use Context. + !ctx.env().get("Q_DISABLE_TRUECOLOR").is_ok_and(|s| !s.is_empty()) + && shell_color::get_color_support().contains(shell_color::ColorSupport::TERM24BIT) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::fig_os_shim::EnvProvider; + + #[tokio::test] + async fn test_tilde_path_expansion() { + let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); + + let actual = sanitize_path_tool_arg(&ctx, "~"); + assert_eq!( + actual, + ctx.fs().chroot_path(ctx.env().home().unwrap()), + "tilde should expand" + ); + let actual = sanitize_path_tool_arg(&ctx, "~/hello"); + assert_eq!( + actual, + ctx.fs().chroot_path(ctx.env().home().unwrap().join("hello")), + "tilde should expand" + ); + let actual = sanitize_path_tool_arg(&ctx, "/~"); + assert_eq!( + actual, + ctx.fs().chroot_path("/~"), + "tilde should not expand when not the first component" + ); + } + + #[tokio::test] + async fn test_format_path() { + async fn assert_paths(cwd: &str, path: &str, expected: &str) { + let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); + let fs = ctx.fs(); + let cwd = sanitize_path_tool_arg(&ctx, cwd); + let path = sanitize_path_tool_arg(&ctx, path); + fs.create_dir_all(&cwd).await.unwrap(); + fs.create_dir_all(&path).await.unwrap(); + // Using `contains` since the chroot test directory will prefix the formatted path with a tmpdir + // path. + assert!(format_path(cwd, path).contains(expected)); + } + assert_paths("/Users/testuser/src", "/Users/testuser/Downloads", "../Downloads").await; + assert_paths( + "/Users/testuser/projects/MyProject/src", + "/Volumes/projects/MyProject/src", + "/Volumes/projects/MyProject/src", + ) + .await; + } +} diff --git a/crates/kiro-cli/src/cli/chat/tools/tool_index.json b/crates/kiro-cli/src/cli/chat/tools/tool_index.json new file mode 100644 index 0000000000..397d856cfa --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/tools/tool_index.json @@ -0,0 +1,176 @@ +{ + "execute_bash": { + "name": "execute_bash", + "description": "Execute the specified bash command.", + "input_schema": { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Bash command to execute" + } + }, + "required": [ + "command" + ] + } + }, + "fs_read": { + "name": "fs_read", + "description": "Tool for reading files (for example, `cat -n`) and directories (for example, `ls -la`). The behavior of this tool is determined by the `mode` parameter. The available modes are:\n- line: Show lines in a file, given by an optional `start_line` and optional `end_line`.\n- directory: List directory contents. Content is returned in the \"long format\" of ls (that is, `ls -la`).\n- search: Search for a pattern in a file. The pattern is a string. The matching is case insensitive.\n\nExample Usage:\n1. Read all lines from a file: command=\"line\", path=\"/path/to/file.txt\"\n2. Read the last 5 lines from a file: command=\"line\", path=\"/path/to/file.txt\", start_line=-5\n3. List the files in the home directory: command=\"line\", path=\"~\"\n4. Recursively list files in a directory to a max depth of 2: command=\"line\", path=\"/path/to/directory\", depth=2\n5. Search for all instances of \"test\" in a file: command=\"search\", path=\"/path/to/file.txt\", pattern=\"test\"\n", + "input_schema": { + "type": "object", + "properties": { + "path": { + "description": "Path to the file or directory. The path should be absolute, or otherwise start with ~ for the user's home.", + "type": "string" + }, + "mode": { + "type": "string", + "enum": [ + "Line", + "Directory", + "Search" + ], + "description": "The mode to run in: `Line`, `Directory`, `Search`. `Line` and `Search` are only for text files, and `Directory` is only for directories." + }, + "start_line": { + "type": "integer", + "description": "Starting line number (optional, for Line mode). A negative index represents a line number starting from the end of the file.", + "default": 1 + }, + "end_line": { + "type": "integer", + "description": "Ending line number (optional, for Line mode). A negative index represents a line number starting from the end of the file.", + "default": -1 + }, + "pattern": { + "type": "string", + "description": "Pattern to search for (required, for Search mode). Case insensitive. The pattern matching is performed per line." + }, + "context_lines": { + "type": "integer", + "description": "Number of context lines around search results (optional, for Search mode)", + "default": 2 + }, + "depth": { + "type": "integer", + "description": "Depth of a recursive directory listing (optional, for Directory mode)", + "default": 0 + } + }, + "required": [ + "path", + "mode" + ] + } + }, + "fs_write": { + "name": "fs_write", + "description": "A tool for creating and editing files\n * The `create` command will override the file at `path` if it already exists as a file, and otherwise create a new file\n * The `append` command will add content to the end of an existing file, automatically adding a newline if the file doesn't end with one. The file must exist.\n Notes for using the `str_replace` command:\n * The `old_str` parameter should match EXACTLY one or more consecutive lines from the original file. Be mindful of whitespaces!\n * If the `old_str` parameter is not unique in the file, the replacement will not be performed. Make sure to include enough context in `old_str` to make it unique\n * The `new_str` parameter should contain the edited lines that should replace the `old_str`.", + "input_schema": { + "type": "object", + "properties": { + "command": { + "type": "string", + "enum": [ + "create", + "str_replace", + "insert", + "append" + ], + "description": "The commands to run. Allowed options are: `create`, `str_replace`, `insert`, `append`." + }, + "file_text": { + "description": "Required parameter of `create` command, with the content of the file to be created.", + "type": "string" + }, + "insert_line": { + "description": "Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.", + "type": "integer" + }, + "new_str": { + "description": "Required parameter of `str_replace` command containing the new string. Required parameter of `insert` command containing the string to insert. Required parameter of `append` command containing the content to append to the file.", + "type": "string" + }, + "old_str": { + "description": "Required parameter of `str_replace` command containing the string in `path` to replace.", + "type": "string" + }, + "path": { + "description": "Absolute path to file or directory, e.g. `/repo/file.py` or `/repo`.", + "type": "string" + } + }, + "required": [ + "command", + "path" + ] + } + }, + "use_aws": { + "name": "use_aws", + "description": "Make an AWS CLI api call with the specified service, operation, and parameters. All arguments MUST conform to the AWS CLI specification. Should the output of the invocation indicate a malformed command, invoke help to obtain the the correct command.", + "input_schema": { + "type": "object", + "properties": { + "service_name": { + "type": "string", + "description": "The name of the AWS service. If you want to query s3, you should use s3api if possible." + }, + "operation_name": { + "type": "string", + "description": "The name of the operation to perform." + }, + "parameters": { + "type": "object", + "description": "The parameters for the operation. The parameter keys MUST conform to the AWS CLI specification. You should prefer to use JSON Syntax over shorthand syntax wherever possible. For parameters that are booleans, prioritize using flags with no value. Denote these flags with flag names as key and an empty string as their value. You should also prefer kebab case." + }, + "region": { + "type": "string", + "description": "Region name for calling the operation on AWS." + }, + "profile_name": { + "type": "string", + "description": "Optional: AWS profile name to use from ~/.aws/credentials. Defaults to default profile if not specified." + }, + "label": { + "type": "string", + "description": "Human readable description of the api that is being called." + } + }, + "required": [ + "region", + "service_name", + "operation_name", + "label" + ] + } + }, + "gh_issue": { + "name": "report_issue", + "description": "Opens the browser to a pre-filled gh (GitHub) issue template to report chat issues, bugs, or feature requests. Pre-filled information includes the conversation transcript, chat context, and chat request IDs from the service.", + "input_schema": { + "type": "object", + "properties": { + "title": { + "type": "string", + "description": "The title of the GitHub issue." + }, + "expected_behavior": { + "type": "string", + "description": "Optional: The expected chat behavior or action that did not happen." + }, + "actual_behavior": { + "type": "string", + "description": "Optional: The actual chat behavior that happened and demonstrates the issue or lack of a feature." + }, + "steps_to_reproduce": { + "type": "string", + "description": "Optional: Previous user chat requests or steps that were taken that may have resulted in the issue or error response." + } + }, + "required": ["title"] + } + } +} diff --git a/crates/kiro-cli/src/cli/chat/tools/use_aws.rs b/crates/kiro-cli/src/cli/chat/tools/use_aws.rs new file mode 100644 index 0000000000..68dfd0b2af --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/tools/use_aws.rs @@ -0,0 +1,315 @@ +use std::collections::HashMap; +use std::io::Write; +use std::process::Stdio; + +use bstr::ByteSlice; +use convert_case::{ + Case, + Casing, +}; +use crossterm::{ + queue, + style, +}; +use eyre::{ + Result, + WrapErr, +}; +use serde::Deserialize; + +use super::{ + InvokeOutput, + MAX_TOOL_RESPONSE_SIZE, + OutputKind, +}; +use crate::fig_os_shim::Context; + +const READONLY_OPS: [&str; 6] = ["get", "describe", "list", "ls", "search", "batch_get"]; + +/// The environment variable name where we set additional metadata for the AWS CLI user agent. +const USER_AGENT_ENV_VAR: &str = "AWS_EXECUTION_ENV"; +const USER_AGENT_APP_NAME: &str = "AmazonQ-For-CLI"; +const USER_AGENT_VERSION_KEY: &str = "Version"; +const USER_AGENT_VERSION_VALUE: &str = env!("CARGO_PKG_VERSION"); + +// TODO: we should perhaps composite this struct with an interface that we can use to mock the +// actual cli with. That will allow us to more thoroughly test it. +#[derive(Debug, Clone, Deserialize)] +pub struct UseAws { + pub service_name: String, + pub operation_name: String, + pub parameters: Option>, + pub region: String, + pub profile_name: Option, + pub label: Option, +} + +impl UseAws { + pub fn requires_acceptance(&self) -> bool { + !READONLY_OPS.iter().any(|op| self.operation_name.starts_with(op)) + } + + pub async fn invoke(&self, _ctx: &Context, _updates: impl Write) -> Result { + let mut command = tokio::process::Command::new("aws"); + + // Set up environment variables + let mut env_vars: std::collections::HashMap = std::env::vars().collect(); + + // Set up additional metadata for the AWS CLI user agent + let user_agent_metadata_value = format!( + "{} {}/{}", + USER_AGENT_APP_NAME, USER_AGENT_VERSION_KEY, USER_AGENT_VERSION_VALUE + ); + + // If the user agent metadata env var already exists, append to it, otherwise set it + if let Some(existing_value) = env_vars.get(USER_AGENT_ENV_VAR) { + if !existing_value.is_empty() { + env_vars.insert( + USER_AGENT_ENV_VAR.to_string(), + format!("{} {}", existing_value, user_agent_metadata_value), + ); + } else { + env_vars.insert(USER_AGENT_ENV_VAR.to_string(), user_agent_metadata_value); + } + } else { + env_vars.insert(USER_AGENT_ENV_VAR.to_string(), user_agent_metadata_value); + } + + command.envs(env_vars).arg("--region").arg(&self.region); + if let Some(profile_name) = self.profile_name.as_deref() { + command.arg("--profile").arg(profile_name); + } + command.arg(&self.service_name).arg(&self.operation_name); + if let Some(parameters) = self.cli_parameters() { + for (name, val) in parameters { + command.arg(name); + if !val.is_empty() { + command.arg(val); + } + } + } + let output = command + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .wrap_err_with(|| format!("Unable to spawn command '{:?}'", self))? + .wait_with_output() + .await + .wrap_err_with(|| format!("Unable to spawn command '{:?}'", self))?; + let status = output.status.code().unwrap_or(0).to_string(); + let stdout = output.stdout.to_str_lossy(); + let stderr = output.stderr.to_str_lossy(); + + let stdout = format!( + "{}{}", + &stdout[0..stdout.len().min(MAX_TOOL_RESPONSE_SIZE / 3)], + if stdout.len() > MAX_TOOL_RESPONSE_SIZE / 3 { + " ... truncated" + } else { + "" + } + ); + + let stderr = format!( + "{}{}", + &stderr[0..stderr.len().min(MAX_TOOL_RESPONSE_SIZE / 3)], + if stderr.len() > MAX_TOOL_RESPONSE_SIZE / 3 { + " ... truncated" + } else { + "" + } + ); + + if status.eq("0") { + Ok(InvokeOutput { + output: OutputKind::Json(serde_json::json!({ + "exit_status": status, + "stdout": stdout, + "stderr": stderr.clone() + })), + }) + } else { + Err(eyre::eyre!(stderr)) + } + } + + pub fn queue_description(&self, updates: &mut impl Write) -> Result<()> { + queue!( + updates, + style::Print("Running aws cli command:\n\n"), + style::Print(format!("Service name: {}\n", self.service_name)), + style::Print(format!("Operation name: {}\n", self.operation_name)), + )?; + if let Some(parameters) = &self.parameters { + queue!(updates, style::Print("Parameters: \n".to_string()))?; + for (name, value) in parameters { + match value { + serde_json::Value::String(s) if s.is_empty() => { + queue!(updates, style::Print(format!("- {}\n", name)))?; + }, + _ => { + queue!(updates, style::Print(format!("- {}: {}\n", name, value)))?; + }, + } + } + } + + if let Some(ref profile_name) = self.profile_name { + queue!(updates, style::Print(format!("Profile name: {}\n", profile_name)))?; + } else { + queue!(updates, style::Print("Profile name: default\n".to_string()))?; + } + + queue!(updates, style::Print(format!("Region: {}", self.region)))?; + + if let Some(ref label) = self.label { + queue!(updates, style::Print(format!("\nLabel: {}", label)))?; + } + Ok(()) + } + + pub async fn validate(&mut self, _ctx: &Context) -> Result<()> { + Ok(()) + } + + /// Returns the CLI arguments properly formatted as kebab case if parameters is + /// [Option::Some], otherwise None + fn cli_parameters(&self) -> Option> { + if let Some(parameters) = &self.parameters { + let mut params = vec![]; + for (param_name, val) in parameters { + let param_name = format!("--{}", param_name.trim_start_matches("--").to_case(Case::Kebab)); + let param_val = val.as_str().map(|s| s.to_string()).unwrap_or(val.to_string()); + params.push((param_name, param_val)); + } + Some(params) + } else { + None + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! use_aws { + ($value:tt) => { + serde_json::from_value::(serde_json::json!($value)).unwrap() + }; + } + + #[test] + fn test_requires_acceptance() { + let cmd = use_aws! {{ + "service_name": "ecs", + "operation_name": "list-task-definitions", + "region": "us-west-2", + "profile_name": "default", + "label": "" + }}; + assert!(!cmd.requires_acceptance()); + let cmd = use_aws! {{ + "service_name": "lambda", + "operation_name": "list-functions", + "region": "us-west-2", + "profile_name": "default", + "label": "" + }}; + assert!(!cmd.requires_acceptance()); + let cmd = use_aws! {{ + "service_name": "s3", + "operation_name": "put-object", + "region": "us-west-2", + "profile_name": "default", + "label": "" + }}; + assert!(cmd.requires_acceptance()); + } + + #[test] + fn test_use_aws_deser() { + let cmd = use_aws! {{ + "service_name": "s3", + "operation_name": "put-object", + "parameters": { + "TableName": "table-name", + "KeyConditionExpression": "PartitionKey = :pkValue" + }, + "region": "us-west-2", + "profile_name": "default", + "label": "" + }}; + let params = cmd.cli_parameters().unwrap(); + assert!( + params.iter().any(|p| p.0 == "--table-name" && p.1 == "table-name"), + "not found in {:?}", + params + ); + assert!( + params + .iter() + .any(|p| p.0 == "--key-condition-expression" && p.1 == "PartitionKey = :pkValue"), + "not found in {:?}", + params + ); + } + + #[tokio::test] + #[ignore = "not in ci"] + async fn test_aws_read_only() { + let ctx = Context::new_fake(); + + let v = serde_json::json!({ + "service_name": "s3", + "operation_name": "put-object", + // technically this wouldn't be a valid request with an empty parameter set but it's + // okay for this test + "parameters": {}, + "region": "us-west-2", + "profile_name": "default", + "label": "" + }); + + assert!( + serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut std::io::stdout()) + .await + .is_err() + ); + } + + #[tokio::test] + #[ignore = "not in ci"] + async fn test_aws_output() { + let ctx = Context::new_fake(); + + let v = serde_json::json!({ + "service_name": "s3", + "operation_name": "ls", + "parameters": {}, + "region": "us-west-2", + "profile_name": "default", + "label": "" + }); + let out = serde_json::from_value::(v) + .unwrap() + .invoke(&ctx, &mut std::io::stdout()) + .await + .unwrap(); + + if let OutputKind::Json(json) = out.output { + // depending on where the test is ran we might get different outcome here but it does + // not mean the tool is not working + let exit_status = json.get("exit_status").unwrap(); + if exit_status == 0 { + assert_eq!(json.get("stderr").unwrap(), ""); + } else { + assert_ne!(json.get("stderr").unwrap(), ""); + } + } else { + panic!("Expected JSON output"); + } + } +} diff --git a/crates/kiro-cli/src/cli/chat/util/issue.rs b/crates/kiro-cli/src/cli/chat/util/issue.rs new file mode 100644 index 0000000000..d1d98e4bed --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/util/issue.rs @@ -0,0 +1,83 @@ +use anstream::{ + eprintln, + println, +}; +use crossterm::style::Stylize; +use eyre::Result; + +use crate::diagnostics::Diagnostics; +use crate::fig_util::GITHUB_REPO_NAME; +use crate::fig_util::system_info::is_remote; + +const TEMPLATE_NAME: &str = "1_bug_report_template.yml"; + +pub struct IssueCreator { + /// Issue title + pub title: Option, + /// Issue description + pub expected_behavior: Option, + /// Issue description + pub actual_behavior: Option, + /// Issue description + pub steps_to_reproduce: Option, + /// Issue description + pub additional_environment: Option, +} + +impl IssueCreator { + pub async fn create_url(&self) -> Result { + println!("Heading over to GitHub..."); + + let warning = |text: &String| { + format!("\n\n{text}") + }; + let diagnostics = Diagnostics::new().await; + + let os = match &diagnostics.system_info.os { + Some(os) => os.to_string(), + None => "None".to_owned(), + }; + + let diagnostic_info = match diagnostics.user_readable() { + Ok(diagnostics) => diagnostics, + Err(err) => { + eprintln!("Error getting diagnostics: {err}"); + "Error occurred while generating diagnostics".to_owned() + }, + }; + + let environment = match &self.additional_environment { + Some(ctx) => format!("{diagnostic_info}\n{ctx}"), + None => diagnostic_info, + }; + + let mut params = Vec::new(); + params.push(("template", TEMPLATE_NAME.to_string())); + params.push(("os", os)); + params.push(("environment", warning(&environment))); + + if let Some(t) = self.title.clone() { + params.push(("title", t)); + } + if let Some(t) = self.expected_behavior.as_ref() { + params.push(("expected", warning(t))); + } + if let Some(t) = self.actual_behavior.as_ref() { + params.push(("actual", warning(t))); + } + if let Some(t) = self.steps_to_reproduce.as_ref() { + params.push(("reproduce", warning(t))); + } + + let url = url::Url::parse_with_params( + &format!("https://github.com/{GITHUB_REPO_NAME}/issues/new"), + params.iter(), + )?; + + if is_remote() || crate::fig_util::open::open_url_async(url.as_str()).await.is_err() { + println!("Issue Url: {}", url.as_str().underlined()); + } + + Ok(url) + } +} diff --git a/crates/kiro-cli/src/cli/chat/util/mod.rs b/crates/kiro-cli/src/cli/chat/util/mod.rs new file mode 100644 index 0000000000..7a575db83b --- /dev/null +++ b/crates/kiro-cli/src/cli/chat/util/mod.rs @@ -0,0 +1,111 @@ +pub mod issue; + +use std::io::Write; +use std::time::Duration; + +use super::ChatError; +use crate::fig_util::system_info::in_cloudshell; + +const GOV_REGIONS: &[&str] = &["us-gov-east-1", "us-gov-west-1"]; + +pub fn region_check(capability: &'static str) -> eyre::Result<()> { + let Ok(region) = std::env::var("AWS_REGION") else { + return Ok(()); + }; + + if in_cloudshell() && GOV_REGIONS.contains(®ion.as_str()) { + eyre::bail!("AWS GovCloud ({region}) is not supported for {capability}."); + } + + Ok(()) +} + +pub fn truncate_safe(s: &str, max_bytes: usize) -> &str { + if s.len() <= max_bytes { + return s; + } + + let mut byte_count = 0; + let mut char_indices = s.char_indices(); + + for (byte_idx, _) in &mut char_indices { + if byte_count + (byte_idx - byte_count) > max_bytes { + break; + } + byte_count = byte_idx; + } + + &s[..byte_count] +} + +pub fn animate_output(output: &mut impl Write, bytes: &[u8]) -> Result<(), ChatError> { + for b in bytes.chunks(12) { + output.write_all(b)?; + std::thread::sleep(Duration::from_millis(16)); + } + Ok(()) +} + +/// Play the terminal bell notification sound +pub fn play_notification_bell(requires_confirmation: bool) { + // Don't play bell for tools that don't require confirmation + if !requires_confirmation { + return; + } + + // Check if we should play the bell based on terminal type + if should_play_bell() { + print!("\x07"); // ASCII bell character + std::io::stdout().flush().unwrap(); + } +} + +/// Determine if we should play the bell based on terminal type +fn should_play_bell() -> bool { + // Get the TERM environment variable + if let Ok(term) = std::env::var("TERM") { + // List of terminals known to handle bell character well + let bell_compatible_terms = [ + "xterm", + "xterm-256color", + "screen", + "screen-256color", + "tmux", + "tmux-256color", + "rxvt", + "rxvt-unicode", + "linux", + "konsole", + "gnome", + "gnome-256color", + "alacritty", + "iterm2", + ]; + + // Check if the current terminal is in the compatible list + for compatible_term in bell_compatible_terms.iter() { + if term.starts_with(compatible_term) { + return true; + } + } + + // For other terminals, don't play the bell + return false; + } + + // If TERM is not set, default to not playing the bell + false +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_truncate_safe() { + assert_eq!(truncate_safe("Hello World", 5), "Hello"); + assert_eq!(truncate_safe("Hello ", 5), "Hello"); + assert_eq!(truncate_safe("Hello World", 11), "Hello World"); + assert_eq!(truncate_safe("Hello World", 15), "Hello World"); + } +} diff --git a/crates/kiro-cli/src/cli/debug.rs b/crates/kiro-cli/src/cli/debug.rs new file mode 100644 index 0000000000..ea34c1cdc3 --- /dev/null +++ b/crates/kiro-cli/src/cli/debug.rs @@ -0,0 +1,109 @@ +use std::process::ExitCode; + +use anstream::eprintln; +use clap::{ + Subcommand, + ValueEnum, +}; +use eyre::Result; + +#[derive(Debug, ValueEnum, Clone, PartialEq, Eq)] +pub enum Build { + Production, + #[value(alias = "staging")] + Beta, + #[value(hide = true, alias = "dev")] + Develop, +} + +impl std::fmt::Display for Build { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Build::Production => f.write_str("production"), + Build::Beta => f.write_str("beta"), + Build::Develop => f.write_str("develop"), + } + } +} + +#[derive(Debug, ValueEnum, Clone, PartialEq, Eq)] +pub enum App { + Dashboard, + Autocomplete, +} + +impl std::fmt::Display for App { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + App::Dashboard => f.write_str("dashboard"), + App::Autocomplete => f.write_str("autocomplete"), + } + } +} + +#[derive(Debug, ValueEnum, Clone, PartialEq, Eq)] +pub enum AutocompleteWindowDebug { + On, + Off, +} + +#[derive(Debug, ValueEnum, Clone, PartialEq, Eq)] +pub enum AccessibilityAction { + Refresh, + Reset, + Prompt, + Open, + Status, +} + +#[cfg(target_os = "macos")] +#[derive(Debug, Clone, PartialEq, Eq, ValueEnum)] +pub enum TISAction { + Enable, + Disable, + Select, + Deselect, +} + +#[cfg(target_os = "macos")] +use std::path::PathBuf; + +#[cfg(target_os = "macos")] +#[derive(Debug, Subcommand, Clone, PartialEq, Eq)] +pub enum InputMethodDebugAction { + Install { + bundle_path: Option, + }, + Uninstall { + bundle_path: Option, + }, + List, + Status { + bundle_path: Option, + }, + Source { + bundle_identifier: String, + #[arg(value_enum)] + action: TISAction, + }, +} + +#[derive(Debug, PartialEq, Subcommand)] +pub enum DebugSubcommand { + RefreshAuthToken, +} + +impl DebugSubcommand { + pub async fn execute(&self) -> Result { + match self { + DebugSubcommand::RefreshAuthToken => match crate::fig_auth::refresh_token().await? { + Some(_) => eprintln!("Refreshed token"), + None => { + eprintln!("No token to refresh"); + return Ok(ExitCode::FAILURE); + }, + }, + } + Ok(ExitCode::SUCCESS) + } +} diff --git a/crates/kiro-cli/src/cli/diagnostics.rs b/crates/kiro-cli/src/cli/diagnostics.rs new file mode 100644 index 0000000000..83c94c2d2b --- /dev/null +++ b/crates/kiro-cli/src/cli/diagnostics.rs @@ -0,0 +1,68 @@ +use std::io::{ + IsTerminal, + stdout, +}; +use std::process::ExitCode; + +use anstream::println; +use clap::Args; +use color_eyre::Result; +use crossterm::terminal::{ + Clear, + ClearType, +}; +use crossterm::{ + cursor, + execute, +}; +use spinners::{ + Spinner, + Spinners, +}; + +use super::OutputFormat; +use crate::diagnostics::Diagnostics; + +#[derive(Debug, Args, PartialEq, Eq)] +pub struct DiagnosticArgs { + /// The format of the output + #[arg(long, short, value_enum, default_value_t)] + format: OutputFormat, + /// Force limited diagnostic output + #[arg(long)] + force: bool, +} + +impl DiagnosticArgs { + pub async fn execute(&self) -> Result { + let spinner = if stdout().is_terminal() { + Some(Spinner::new(Spinners::Dots, "Generating...".into())) + } else { + None + }; + + if spinner.is_some() { + execute!(std::io::stdout(), cursor::Hide)?; + + ctrlc::set_handler(move || { + execute!(std::io::stdout(), cursor::Show).ok(); + std::process::exit(1); + })?; + } + + let diagnostics = Diagnostics::new().await; + + if let Some(mut sp) = spinner { + sp.stop(); + execute!(std::io::stdout(), Clear(ClearType::CurrentLine), cursor::Show)?; + println!(); + } + + self.format.print( + || diagnostics.user_readable().expect("Failed to run user_readable()"), + || &diagnostics, + ); + + Ok(ExitCode::SUCCESS) + } +} diff --git a/crates/kiro-cli/src/cli/feed.rs b/crates/kiro-cli/src/cli/feed.rs new file mode 100644 index 0000000000..7df058c946 --- /dev/null +++ b/crates/kiro-cli/src/cli/feed.rs @@ -0,0 +1,49 @@ +use serde::{ + Deserialize, + Serialize, +}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Feed { + pub entries: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Entry { + #[serde(rename = "type")] + pub entry_type: String, + pub date: String, + pub version: String, + #[serde(default)] + pub hidden: bool, + #[serde(default)] + pub changes: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Change { + #[serde(rename = "type")] + pub change_type: String, + pub description: String, +} + +impl Feed { + pub fn load() -> Self { + serde_json::from_str(include_str!("../../../../feed.json")).expect("feed.json is valid json") + } + + pub fn get_version_changelog(&self, version: &str) -> Option { + self.entries + .iter() + .find(|entry| entry.entry_type == "release" && entry.version == version && !entry.hidden) + .cloned() + } + + pub fn get_all_changelogs(&self) -> Vec { + self.entries + .iter() + .filter(|entry| entry.entry_type == "release" && !entry.hidden) + .cloned() + .collect() + } +} diff --git a/crates/kiro-cli/src/cli/issue.rs b/crates/kiro-cli/src/cli/issue.rs new file mode 100644 index 0000000000..87ae041cbb --- /dev/null +++ b/crates/kiro-cli/src/cli/issue.rs @@ -0,0 +1,39 @@ +use std::process::ExitCode; + +use clap::Args; +use eyre::Result; + +#[derive(Debug, Args, PartialEq, Eq)] +pub struct IssueArgs { + /// Force issue creation + #[arg(long, short = 'f')] + force: bool, + /// Issue description + description: Vec, +} + +impl IssueArgs { + #[allow(unreachable_code)] + pub async fn execute(&self) -> Result { + let joined_description = self.description.join(" ").trim().to_owned(); + + let issue_title = match joined_description.len() { + 0 => dialoguer::Input::with_theme(&crate::fig_util::dialoguer_theme()) + .with_prompt("Issue Title") + .interact_text()?, + _ => joined_description, + }; + + let _ = crate::cli::chat::util::issue::IssueCreator { + title: Some(issue_title), + expected_behavior: None, + actual_behavior: None, + steps_to_reproduce: None, + additional_environment: None, + } + .create_url() + .await; + + Ok(ExitCode::SUCCESS) + } +} diff --git a/crates/kiro-cli/src/cli/mod.rs b/crates/kiro-cli/src/cli/mod.rs new file mode 100644 index 0000000000..b6cfa570a0 --- /dev/null +++ b/crates/kiro-cli/src/cli/mod.rs @@ -0,0 +1,522 @@ +//! CLI functionality + +mod chat; +mod debug; +mod diagnostics; +mod feed; +mod issue; +mod settings; +mod telemetry; +mod uninstall; +mod update; +mod user; + +use std::io::{ + Write as _, + stdout, +}; +use std::process::ExitCode; + +use anstream::{ + eprintln, + println, +}; +use chat::cli::Chat; +use clap::{ + ArgAction, + CommandFactory, + Parser, + Subcommand, + ValueEnum, +}; +use crossterm::style::Stylize; +use eyre::Result; +use feed::Feed; +use serde::Serialize; +use tracing::{ + Level, + debug, +}; + +use self::user::RootUserSubcommand; +use crate::fig_log::{ + LogArgs, + initialize_logging, +}; +use crate::fig_telemetry::send_cli_subcommand_executed; +use crate::fig_util::directories::logs_dir; +use crate::fig_util::{ + CLI_BINARY_NAME, + CliContext, +}; + +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, ValueEnum)] +pub enum OutputFormat { + /// Outputs the results as markdown + #[default] + Plain, + /// Outputs the results as JSON + Json, + /// Outputs the results as pretty print JSON + JsonPretty, +} + +impl OutputFormat { + pub fn print(&self, text_fn: TFn, json_fn: JFn) + where + T: std::fmt::Display, + TFn: FnOnce() -> T, + J: Serialize, + JFn: FnOnce() -> J, + { + match self { + OutputFormat::Plain => println!("{}", text_fn()), + OutputFormat::Json => println!("{}", serde_json::to_string(&json_fn()).unwrap()), + OutputFormat::JsonPretty => println!("{}", serde_json::to_string_pretty(&json_fn()).unwrap()), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)] +pub enum Processes { + /// Desktop Process + App, +} + +/// The Amazon Q CLI +#[deny(missing_docs)] +#[derive(Debug, PartialEq, Subcommand)] +pub enum CliRootCommands { + /// Debug the app + #[command(subcommand)] + Debug(debug::DebugSubcommand), + /// Customize appearance & behavior + #[command(alias("setting"))] + Settings(settings::SettingsArgs), + /// Uninstall Amazon Q + #[command(hide = true)] + Uninstall { + /// Force uninstall + #[arg(long, short = 'y')] + no_confirm: bool, + }, + /// Update the Amazon Q application + #[command(alias("upgrade"))] + Update(update::UpdateArgs), + /// Run diagnostic tests + #[command(alias("diagnostics"))] + Diagnostic(diagnostics::DiagnosticArgs), + /// Create a new Github issue + Issue(issue::IssueArgs), + /// Root level user subcommands + #[command(flatten)] + RootUser(user::RootUserSubcommand), + /// Manage your account + #[command(subcommand)] + User(user::UserSubcommand), + /// Enable/disable telemetry + #[command(subcommand, hide = true)] + Telemetry(telemetry::TelemetrySubcommand), + /// Version + #[command(hide = true)] + Version { + /// Show the changelog (use --changelog=all for all versions, or --changelog=x.x.x for a + /// specific version) + #[arg(long, num_args = 0..=1, default_missing_value = "")] + changelog: Option, + }, + /// AI assistant in your terminal + #[command(alias("q"))] + Chat(Chat), +} + +impl CliRootCommands { + fn name(&self) -> &'static str { + match self { + CliRootCommands::Debug(_) => "debug", + CliRootCommands::Settings(_) => "settings", + CliRootCommands::Uninstall { .. } => "uninstall", + CliRootCommands::Update(_) => "update", + CliRootCommands::Diagnostic(_) => "diagnostics", + CliRootCommands::Issue(_) => "issue", + CliRootCommands::RootUser(RootUserSubcommand::Login(_)) => "login", + CliRootCommands::RootUser(RootUserSubcommand::Logout) => "logout", + CliRootCommands::RootUser(RootUserSubcommand::Whoami { .. }) => "whoami", + CliRootCommands::RootUser(RootUserSubcommand::Profile) => "profile", + CliRootCommands::User(_) => "user", + CliRootCommands::Telemetry(_) => "telemetry", + CliRootCommands::Version { .. } => "version", + CliRootCommands::Chat { .. } => "chat", + } + } +} + +const HELP_TEXT: &str = color_print::cstr! {" + +q (Amazon Q CLI) + +Popular Subcommands Usage: q [subcommand] +╭────────────────────────────────────────────────────╮ +│ chat Chat with Amazon Q │ +│ settings Customize appearance & behavior │ +╰────────────────────────────────────────────────────╯ + +To see all subcommands, use: + q --help-all +ㅤ +"}; + +#[derive(Debug, Parser, PartialEq, Default)] +#[command(version, about, name = crate::CLI_BINARY_NAME, help_template = HELP_TEXT)] +pub struct Cli { + #[command(subcommand)] + pub subcommand: Option, + /// Increase logging verbosity + #[arg(long, short = 'v', action = ArgAction::Count, global = true)] + pub verbose: u8, + /// Print help for all subcommands + #[arg(long)] + help_all: bool, +} + +impl Cli { + pub async fn execute(self) -> Result { + // Initialize our logger and keep around the guard so logging can perform as expected. + let _log_guard = initialize_logging(LogArgs { + log_level: match self.verbose > 0 { + true => Some( + match self.verbose { + 1 => Level::WARN, + 2 => Level::INFO, + 3 => Level::DEBUG, + _ => Level::TRACE, + } + .to_string(), + ), + false => None, + }, + log_to_stdout: std::env::var_os("Q_LOG_STDOUT").is_some() || self.verbose > 0, + log_file_path: match self.subcommand { + Some(CliRootCommands::Chat { .. }) => Some("chat.log".to_owned()), + _ => match crate::fig_log::get_log_level_max() >= Level::DEBUG { + true => Some("cli.log".to_owned()), + false => None, + }, + } + .map(|name| logs_dir().expect("home dir must be set").join(name)), + delete_old_log_file: false, + }); + + debug!(command =? std::env::args().collect::>(), "Command ran"); + + self.send_telemetry().await; + + if self.help_all { + return self.print_help_all(); + } + + let cli_context = CliContext::new(); + + match self.subcommand { + Some(subcommand) => match subcommand { + CliRootCommands::Uninstall { no_confirm } => uninstall::uninstall_command(no_confirm).await, + CliRootCommands::Update(args) => args.execute().await, + CliRootCommands::Diagnostic(args) => args.execute().await, + CliRootCommands::User(user) => user.execute().await, + CliRootCommands::RootUser(root_user) => root_user.execute().await, + CliRootCommands::Settings(settings_args) => settings_args.execute(&cli_context).await, + CliRootCommands::Debug(debug_subcommand) => debug_subcommand.execute().await, + CliRootCommands::Issue(args) => args.execute().await, + CliRootCommands::Telemetry(subcommand) => subcommand.execute().await, + CliRootCommands::Version { changelog } => Self::print_version(changelog), + CliRootCommands::Chat(args) => chat::launch_chat(args).await, + }, + // Root command + None => chat::launch_chat(chat::cli::Chat::default()).await, + } + } + + async fn send_telemetry(&self) { + match &self.subcommand { + None => {}, + Some(subcommand) => { + send_cli_subcommand_executed(subcommand.name()).await; + }, + } + } + + #[allow(clippy::unused_self)] + fn print_help_all(&self) -> Result { + let mut cmd = Self::command().help_template("{all-args}"); + eprintln!(); + eprintln!( + "{}\n {CLI_BINARY_NAME} [OPTIONS] [SUBCOMMAND]\n", + "USAGE:".bold().underlined(), + ); + cmd.print_long_help()?; + Ok(ExitCode::SUCCESS) + } + + fn print_changelog_entry(entry: &feed::Entry) -> Result<()> { + println!("Version {} ({})", entry.version, entry.date); + + if entry.changes.is_empty() { + println!(" No changes recorded for this version."); + } else { + for change in &entry.changes { + let type_label = match change.change_type.as_str() { + "added" => "Added", + "fixed" => "Fixed", + "changed" => "Changed", + other => other, + }; + + println!(" - {}: {}", type_label, change.description); + } + } + + println!(); + Ok(()) + } + + #[allow(clippy::unused_self)] + fn print_version(changelog: Option) -> Result { + // If no changelog is requested, display normal version information + if changelog.is_none() { + let _ = writeln!(stdout(), "{}", Self::command().render_version()); + return Ok(ExitCode::SUCCESS); + } + + let changelog_value = changelog.unwrap_or_default(); + let feed = Feed::load(); + + // Display changelog for all versions + if changelog_value == "all" { + let entries = feed.get_all_changelogs(); + if entries.is_empty() { + println!("No changelog information available."); + } else { + println!("Changelog for all versions:"); + for entry in entries { + Self::print_changelog_entry(&entry)?; + } + } + return Ok(ExitCode::SUCCESS); + } + + // Display changelog for a specific version (--changelog=x.x.x) + if !changelog_value.is_empty() { + match feed.get_version_changelog(&changelog_value) { + Some(entry) => { + println!("Changelog for version {}:", changelog_value); + Self::print_changelog_entry(&entry)?; + return Ok(ExitCode::SUCCESS); + }, + None => { + println!("No changelog information available for version {}.", changelog_value); + return Ok(ExitCode::SUCCESS); + }, + } + } + + // Display changelog for the current version (--changelog only) + let current_version = env!("CARGO_PKG_VERSION"); + match feed.get_version_changelog(current_version) { + Some(entry) => { + println!("Changelog for version {}:", current_version); + Self::print_changelog_entry(&entry)?; + }, + None => { + println!("No changelog information available for version {}.", current_version); + }, + } + + Ok(ExitCode::SUCCESS) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn debug_assert() { + Cli::command().debug_assert(); + } + + macro_rules! assert_parse { + ( + [ $($args:expr),+ ], + $subcommand:expr + ) => { + assert_eq!( + Cli::parse_from([CLI_BINARY_NAME, $($args),*]), + Cli { + subcommand: Some($subcommand), + ..Default::default() + } + ); + }; + } + + /// Test flag parsing for the top level [Cli] + #[test] + fn test_flags() { + assert_eq!(Cli::parse_from([CLI_BINARY_NAME, "-v"]), Cli { + subcommand: None, + verbose: 1, + help_all: false, + }); + + assert_eq!(Cli::parse_from([CLI_BINARY_NAME, "-vvv"]), Cli { + subcommand: None, + verbose: 3, + help_all: false, + }); + + assert_eq!(Cli::parse_from([CLI_BINARY_NAME, "--help-all"]), Cli { + subcommand: None, + verbose: 0, + help_all: true, + }); + + assert_eq!(Cli::parse_from([CLI_BINARY_NAME, "chat", "-vv"]), Cli { + subcommand: Some(CliRootCommands::Chat(Chat { + accept_all: false, + no_interactive: false, + input: None, + profile: None, + trust_all_tools: false, + trust_tools: None, + })), + verbose: 2, + help_all: false, + }); + } + + #[test] + fn test_version_changelog() { + assert_parse!(["version", "--changelog"], CliRootCommands::Version { + changelog: Some("".to_string()), + }); + } + + #[test] + fn test_version_changelog_all() { + assert_parse!(["version", "--changelog=all"], CliRootCommands::Version { + changelog: Some("all".to_string()), + }); + } + + #[test] + fn test_version_changelog_specific() { + assert_parse!(["version", "--changelog=1.8.0"], CliRootCommands::Version { + changelog: Some("1.8.0".to_string()), + }); + } + + #[test] + fn test_chat_with_context_profile() { + assert_parse!( + ["chat", "--profile", "my-profile"], + CliRootCommands::Chat(Chat { + accept_all: false, + no_interactive: false, + input: None, + profile: Some("my-profile".to_string()), + trust_all_tools: false, + trust_tools: None, + }) + ); + } + + #[test] + fn test_chat_with_context_profile_and_input() { + assert_parse!( + ["chat", "--profile", "my-profile", "Hello"], + CliRootCommands::Chat(Chat { + accept_all: false, + no_interactive: false, + input: Some("Hello".to_string()), + profile: Some("my-profile".to_string()), + trust_all_tools: false, + trust_tools: None, + }) + ); + } + + #[test] + fn test_chat_with_context_profile_and_accept_all() { + assert_parse!( + ["chat", "--profile", "my-profile", "--accept-all"], + CliRootCommands::Chat(Chat { + accept_all: true, + no_interactive: false, + input: None, + profile: Some("my-profile".to_string()), + trust_all_tools: false, + trust_tools: None, + }) + ); + } + + #[test] + fn test_chat_with_no_interactive() { + assert_parse!( + ["chat", "--no-interactive"], + CliRootCommands::Chat(Chat { + accept_all: false, + no_interactive: true, + input: None, + profile: None, + trust_all_tools: false, + trust_tools: None, + }) + ); + } + + #[test] + fn test_chat_with_tool_trust_all() { + assert_parse!( + ["chat", "--trust-all-tools"], + CliRootCommands::Chat(Chat { + accept_all: false, + no_interactive: false, + input: None, + profile: None, + trust_all_tools: true, + trust_tools: None, + }) + ); + } + + #[test] + fn test_chat_with_tool_trust_none() { + assert_parse!( + ["chat", "--trust-tools="], + CliRootCommands::Chat(Chat { + accept_all: false, + no_interactive: false, + input: None, + profile: None, + trust_all_tools: false, + trust_tools: Some(vec!["".to_string()]), + }) + ); + } + + #[test] + fn test_chat_with_tool_trust_some() { + assert_parse!( + ["chat", "--trust-tools=fs_read,fs_write"], + CliRootCommands::Chat(Chat { + accept_all: false, + no_interactive: false, + input: None, + profile: None, + trust_all_tools: false, + trust_tools: Some(vec!["fs_read".to_string(), "fs_write".to_string()]), + }) + ); + } +} diff --git a/crates/kiro-cli/src/cli/settings.rs b/crates/kiro-cli/src/cli/settings.rs new file mode 100644 index 0000000000..09d21f8370 --- /dev/null +++ b/crates/kiro-cli/src/cli/settings.rs @@ -0,0 +1,152 @@ +use std::process::ExitCode; + +use anstream::println; +use clap::{ + ArgGroup, + Args, + Subcommand, +}; +use eyre::{ + Result, + WrapErr, + bail, +}; +use globset::Glob; +use serde_json::json; + +use super::OutputFormat; +use crate::fig_os_shim::Os; +use crate::fig_settings::JsonStore; +use crate::fig_util::{ + CliContext, + directories, +}; + +#[derive(Debug, Subcommand, PartialEq, Eq)] +pub enum SettingsSubcommands { + /// Open the settings file + Open, + /// List all the settings + All { + /// Format of the output + #[arg(long, short, value_enum, default_value_t)] + format: OutputFormat, + }, +} + +#[derive(Debug, Args, PartialEq, Eq)] +#[command(subcommand_negates_reqs = true)] +#[command(args_conflicts_with_subcommands = true)] +#[command(group(ArgGroup::new("vals").requires("key").args(&["value", "delete", "format"])))] +pub struct SettingsArgs { + #[command(subcommand)] + cmd: Option, + /// key + key: Option, + /// value + value: Option, + /// Delete a value + #[arg(long, short)] + delete: bool, + /// Format of the output + #[arg(long, short, value_enum, default_value_t)] + format: OutputFormat, +} + +impl SettingsArgs { + pub async fn execute(&self, cli_context: &CliContext) -> Result { + match self.cmd { + Some(SettingsSubcommands::Open) => { + let file = directories::settings_path().context("Could not get settings path")?; + if cli_context.context().platform().os() == Os::Mac { + tokio::process::Command::new("open").arg(file).output().await?; + Ok(ExitCode::SUCCESS) + } else if let Ok(editor) = cli_context.context().env().get("EDITOR") { + tokio::process::Command::new(editor).arg(file).spawn()?.wait().await?; + Ok(ExitCode::SUCCESS) + } else { + bail!("The EDITOR environment variable is not set") + } + }, + Some(SettingsSubcommands::All { format }) => { + let settings = crate::fig_settings::OldSettings::load()?.map().clone(); + + match format { + OutputFormat::Plain => { + for (key, value) in settings { + println!("{key} = {value}"); + } + }, + OutputFormat::Json => println!("{}", serde_json::to_string(&settings)?), + OutputFormat::JsonPretty => { + println!("{}", serde_json::to_string_pretty(&settings)?); + }, + } + + Ok(ExitCode::SUCCESS) + }, + None => { + let Some(key) = &self.key else { + return Ok(ExitCode::SUCCESS); + }; + + match (&self.value, self.delete) { + (None, false) => match crate::fig_settings::settings::get_value(key)? { + Some(value) => { + match self.format { + OutputFormat::Plain => match value.as_str() { + Some(value) => println!("{value}"), + None => println!("{value:#}"), + }, + OutputFormat::Json => println!("{value}"), + OutputFormat::JsonPretty => println!("{value:#}"), + } + Ok(ExitCode::SUCCESS) + }, + None => match self.format { + OutputFormat::Plain => Err(eyre::eyre!("No value associated with {key}")), + OutputFormat::Json | OutputFormat::JsonPretty => { + println!("null"); + Ok(ExitCode::SUCCESS) + }, + }, + }, + (Some(value_str), false) => { + let value = serde_json::from_str(value_str).unwrap_or_else(|_| json!(value_str)); + crate::fig_settings::settings::set_value(key, value)?; + Ok(ExitCode::SUCCESS) + }, + (None, true) => { + let glob = Glob::new(key).context("Could not create glob")?.compile_matcher(); + let settings = crate::fig_settings::OldSettings::load()?; + let map = settings.map(); + let keys_to_remove = map.keys().filter(|key| glob.is_match(key)).collect::>(); + + match keys_to_remove.len() { + 0 => { + return Err(eyre::eyre!("No settings found matching {key}")); + }, + 1 => { + println!("Removing {:?}", keys_to_remove[0]); + crate::fig_settings::settings::remove_value(keys_to_remove[0])?; + }, + _ => { + println!("Removing:"); + for key in &keys_to_remove { + println!(" - {key}"); + } + + for key in &keys_to_remove { + crate::fig_settings::settings::remove_value(key)?; + } + }, + } + + Ok(ExitCode::SUCCESS) + }, + _ => Ok(ExitCode::SUCCESS), + } + }, + } + } +} diff --git a/crates/kiro-cli/src/cli/telemetry.rs b/crates/kiro-cli/src/cli/telemetry.rs new file mode 100644 index 0000000000..86ce2d5aff --- /dev/null +++ b/crates/kiro-cli/src/cli/telemetry.rs @@ -0,0 +1,53 @@ +use std::process::ExitCode; + +use clap::Subcommand; +use crossterm::style::Stylize; +use eyre::Result; +use serde_json::json; + +use super::OutputFormat; + +const TELEMETRY_ENABLED_KEY: &str = "telemetry.enabled"; + +#[derive(Debug, PartialEq, Eq, Subcommand)] +pub enum TelemetrySubcommand { + Enable, + Disable, + Status { + /// Format of the output + #[arg(long, short, value_enum, default_value_t)] + format: OutputFormat, + }, +} + +impl TelemetrySubcommand { + pub async fn execute(&self) -> Result { + match self { + TelemetrySubcommand::Enable => { + crate::fig_settings::settings::set_value(TELEMETRY_ENABLED_KEY, true)?; + Ok(ExitCode::SUCCESS) + }, + TelemetrySubcommand::Disable => { + crate::fig_settings::settings::set_value(TELEMETRY_ENABLED_KEY, false)?; + Ok(ExitCode::SUCCESS) + }, + TelemetrySubcommand::Status { format } => { + let status = crate::fig_settings::settings::get_bool_or(TELEMETRY_ENABLED_KEY, true); + format.print( + || { + format!( + "Telemetry status: {}", + if status { "enabled" } else { "disabled" }.bold() + ) + }, + || { + json!({ + TELEMETRY_ENABLED_KEY: status, + }) + }, + ); + Ok(ExitCode::SUCCESS) + }, + } + } +} diff --git a/crates/kiro-cli/src/cli/uninstall.rs b/crates/kiro-cli/src/cli/uninstall.rs new file mode 100644 index 0000000000..fc08b9c0c7 --- /dev/null +++ b/crates/kiro-cli/src/cli/uninstall.rs @@ -0,0 +1,174 @@ +use std::process::ExitCode; + +use anstream::println; +use crossterm::style::Stylize; +use eyre::Result; + +use crate::fig_util::{ + CLI_BINARY_NAME, + PRODUCT_NAME, + dialoguer_theme, +}; + +pub async fn uninstall_command(no_confirm: bool) -> Result { + if !no_confirm { + println!( + "\nIs {PRODUCT_NAME} not working? Try running {}\n", + format!("{CLI_BINARY_NAME} doctor").bold().magenta() + ); + let should_continue = dialoguer::Select::with_theme(&dialoguer_theme()) + .with_prompt(format!("Are you sure want to continue uninstalling {PRODUCT_NAME}?")) + .items(&["Yes", "No"]) + .default(0) + .interact_opt()?; + + if should_continue == Some(0) { + println!("Uninstalling {PRODUCT_NAME}"); + } else { + println!("Cancelled"); + return Ok(ExitCode::FAILURE); + } + }; + + cfg_if::cfg_if! { + if #[cfg(target_os = "macos")] { + uninstall().await?; + } else if #[cfg(target_os = "linux")] { + use crate::fig_util::manifest::is_minimal; + let ctx = crate::fig_os_shim::Context::new(); + if is_minimal() { + uninstall_linux_minimal(ctx).await?; + } else { + uninstall_linux_full(ctx).await?; + } + + } + } + + Ok(ExitCode::SUCCESS) +} + +#[cfg(target_os = "macos")] +async fn uninstall() -> Result<()> { + crate::fig_auth::logout().await.ok(); + crate::fig_install::uninstall().await?; + Ok(()) +} + +#[cfg(target_os = "linux")] +async fn uninstall_linux_minimal(ctx: std::sync::Arc) -> Result<()> { + use eyre::bail; + use tracing::error; + + let exe_path = ctx.fs().canonicalize(ctx.env().current_exe()?.canonicalize()?).await?; + let Some(exe_name) = exe_path.file_name().and_then(|s| s.to_str()) else { + bail!("Failed to get name of current executable: {exe_path:?}") + }; + let Some(exe_parent) = exe_path.parent() else { + bail!("Failed to get parent of current executable: {exe_path:?}") + }; + // canonicalize to handle if the home dir is a symlink (like on Dev Desktops) + let local_bin = crate::fig_util::directories::home_local_bin_ctx(&ctx)?.canonicalize()?; + + if exe_parent != local_bin { + bail!( + "Uninstall is only supported for binaries installed in {local_bin:?}, the current executable is in {exe_parent:?}" + ); + } + + if exe_name != CLI_BINARY_NAME { + bail!("Uninstall is only supported for {CLI_BINARY_NAME:?}, the current executable is {exe_name:?}"); + } + + if let Err(err) = crate::fig_auth::logout().await { + error!(%err, "Failed to logout"); + } + crate::fig_install::uninstall(crate::fig_install::InstallComponents::all_linux_minimal(), ctx).await?; + Ok(()) +} + +#[cfg(target_os = "linux")] +async fn uninstall_linux_full(ctx: std::sync::Arc) -> Result<()> { + use eyre::bail; + use tracing::error; + + use crate::fig_install::{ + InstallComponents, + UNINSTALL_URL, + uninstall, + }; + + // TODO: Add a better way to distinguish binaries distributed between AppImage and package + // managers. + // We want to support q uninstall for AppImage, but not for package managers. + match ctx.process_info().current_pid().exe() { + Some(exe) => { + let Some(exe_parent) = exe.parent() else { + bail!("Failed to get parent of current executable: {exe:?}") + }; + let local_bin = crate::fig_util::directories::home_local_bin_ctx(&ctx)?.canonicalize()?; + if exe_parent != local_bin { + bail!( + "Managed uninstalls are not supported. Please use your package manager to uninstall {}", + PRODUCT_NAME + ); + } + }, + None => bail!("Unable to determine the current process executable."), + } + + if let Err(err) = crate::fig_util::open_url_async(UNINSTALL_URL).await { + error!(%err, %UNINSTALL_URL, "Failed to open uninstall url"); + } + + if let Err(err) = crate::fig_auth::logout().await { + error!(%err, "Failed to logout"); + } + uninstall(InstallComponents::all(), ctx).await?; + Ok(()) +} + +#[cfg(all(unix, not(any(target_os = "macos", target_os = "linux"))))] +async fn uninstall() -> Result<()> { + eyre::bail!("Guided uninstallation is not supported on this platform. Please uninstall manually."); +} + +// #[cfg(target_os = "linux")] +// mod linux { +// use eyre::Result; +// +// pub async fn uninstall_apt(pkg: String) -> Result<()> { +// tokio::process::Command::new("apt") +// .arg("remove") +// .arg("-y") +// .arg(pkg) +// .status() +// .await?; +// std::fs::remove_file("/etc/apt/sources.list.d/fig.list")?; +// std::fs::remove_file("/etc/apt/keyrings/fig.gpg")?; +// +// Ok(()) +// } +// +// pub async fn uninstall_dnf(pkg: String) -> Result<()> { +// tokio::process::Command::new("dnf") +// .arg("remove") +// .arg("-y") +// .arg(pkg) +// .status() +// .await?; +// std::fs::remove_file("/etc/yum.repos.d/fig.repo")?; +// +// Ok(()) +// } +// +// pub async fn uninstall_pacman(pkg: String) -> Result<()> { +// tokio::process::Command::new("pacman") +// .arg("-Rs") +// .arg(pkg) +// .status() +// .await?; +// +// Ok(()) +// } +// } diff --git a/crates/kiro-cli/src/cli/update.rs b/crates/kiro-cli/src/cli/update.rs new file mode 100644 index 0000000000..a1739856c2 --- /dev/null +++ b/crates/kiro-cli/src/cli/update.rs @@ -0,0 +1,58 @@ +use std::process::ExitCode; + +use anstream::println; +use clap::Args; +use crossterm::style::Stylize; +use eyre::Result; +use self_update::{ + Status, + cargo_crate_version, +}; + +use crate::fig_util::CLI_BINARY_NAME; + +#[derive(Debug, PartialEq, Args)] +pub struct UpdateArgs { + /// Don't prompt for confirmation + #[arg(long, short = 'y')] + non_interactive: bool, + /// Relaunch into dashboard after update (false will launch in background) + #[arg(long, default_value = "true")] + relaunch_dashboard: bool, + /// Uses rollout + #[arg(long)] + rollout: bool, +} + +impl UpdateArgs { + pub async fn execute(&self) -> Result { + todo!(); + + let res = self_update::backends::s3::Update::configure() + .bucket_name("self_update_releases") + .asset_prefix("something/self_update") + .region("eu-west-2") + .bin_name("self_update_example") + .show_download_progress(true) + .current_version(cargo_crate_version!()) + .build()? + .update(); + + match res { + Ok(Status::UpToDate(_)) => { + println!( + "No updates available, \n{} is the latest version.", + env!("CARGO_PKG_VERSION").bold() + ); + Ok(ExitCode::SUCCESS) + }, + Ok(Status::Updated(_)) => Ok(ExitCode::SUCCESS), + Err(err) => { + eyre::bail!( + "{err}\n\nIf this is unexpected, try running {} and then try again.\n", + format!("{CLI_BINARY_NAME} doctor").bold() + ) + }, + } + } +} diff --git a/crates/kiro-cli/src/cli/user.rs b/crates/kiro-cli/src/cli/user.rs new file mode 100644 index 0000000000..f878a4057e --- /dev/null +++ b/crates/kiro-cli/src/cli/user.rs @@ -0,0 +1,471 @@ +use std::fmt; +use std::fmt::Display; +use std::process::{ + ExitCode, + exit, +}; +use std::time::Duration; + +use anstream::{ + eprintln, + println, +}; +use clap::{ + Args, + Subcommand, +}; +use crossterm::style::Stylize; +use dialoguer::Select; +use eyre::{ + Result, + bail, +}; +use serde_json::json; +use tokio::signal::unix::{ + SignalKind, + signal, +}; +use tracing::{ + error, + info, +}; + +use super::OutputFormat; +use crate::fig_api_client::list_available_profiles; +use crate::fig_api_client::profile::Profile; +use crate::fig_auth::builder_id::{ + PollCreateToken, + TokenType, + poll_create_token, + start_device_authorization, +}; +use crate::fig_auth::pkce::start_pkce_authorization; +use crate::fig_auth::secret_store::SecretStore; +use crate::fig_telemetry::{ + QProfileSwitchIntent, + TelemetryResult, +}; +use crate::fig_util::spinner::{ + Spinner, + SpinnerComponent, +}; +use crate::fig_util::system_info::is_remote; +use crate::fig_util::{ + CLI_BINARY_NAME, + PRODUCT_NAME, + choose, + input, +}; + +#[derive(Subcommand, Debug, PartialEq, Eq)] +pub enum RootUserSubcommand { + /// Login + Login(LoginArgs), + /// Logout + Logout, + /// Prints details about the current user + Whoami { + /// Output format to use + #[arg(long, short, value_enum, default_value_t)] + format: OutputFormat, + }, + /// Show the profile associated with this idc user + Profile, +} + +#[derive(Args, Debug, PartialEq, Eq, Clone, Default)] +pub struct LoginArgs { + /// License type (pro for Identity Center, free for Builder ID) + #[arg(long, value_enum)] + pub license: Option, + + /// Identity provider URL (for Identity Center) + #[arg(long)] + pub identity_provider: Option, + + /// Region (for Identity Center) + #[arg(long)] + pub region: Option, + + /// Always use the OAuth device flow for authentication. Useful for instances where browser + /// redirects cannot be handled. + #[arg(long)] + pub use_device_flow: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)] +pub enum LicenseType { + /// Free license with Builder ID + Free, + /// Pro license with Identity Center + Pro, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum AuthMethod { + /// Builder ID (free) + BuilderId, + /// IdC (enterprise) + IdentityCenter, +} + +impl Display for AuthMethod { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + AuthMethod::BuilderId => write!(f, "Use for Free with Builder ID"), + AuthMethod::IdentityCenter => write!(f, "Use with Pro license"), + } + } +} + +impl RootUserSubcommand { + pub async fn execute(self) -> Result { + match self { + Self::Login(args) => { + if crate::fig_auth::is_logged_in().await { + eyre::bail!( + "Already logged in, please logout with {} first", + format!("{CLI_BINARY_NAME} logout").magenta() + ); + } + + login_interactive(args).await?; + + Ok(ExitCode::SUCCESS) + }, + Self::Logout => { + let _ = crate::fig_auth::logout().await; + + println!("You are now logged out"); + println!( + "Run {} to log back in to {PRODUCT_NAME}", + format!("{CLI_BINARY_NAME} login").magenta() + ); + Ok(ExitCode::SUCCESS) + }, + Self::Whoami { format } => { + let builder_id = crate::fig_auth::builder_id_token().await; + + match builder_id { + Ok(Some(token)) => { + format.print( + || match token.token_type() { + TokenType::BuilderId => "Logged in with Builder ID".into(), + TokenType::IamIdentityCenter => { + format!( + "Logged in with IAM Identity Center ({})", + token.start_url.as_ref().unwrap() + ) + }, + }, + || { + json!({ + "accountType": match token.token_type() { + TokenType::BuilderId => "BuilderId", + TokenType::IamIdentityCenter => "IamIdentityCenter", + }, + "startUrl": token.start_url, + "region": token.region, + }) + }, + ); + + if matches!(token.token_type(), TokenType::IamIdentityCenter) { + if let Ok(Some(profile)) = crate::fig_settings::state::get::< + crate::fig_api_client::profile::Profile, + >("api.codewhisperer.profile") + { + color_print::cprintln!( + "\nProfile:\n{}\n{}\n", + profile.profile_name, + profile.arn + ); + } + } + Ok(ExitCode::SUCCESS) + }, + _ => { + format.print(|| "Not logged in", || json!({ "account": null })); + Ok(ExitCode::FAILURE) + }, + } + }, + Self::Profile => { + if !crate::fig_util::system_info::in_cloudshell() && !crate::fig_auth::is_logged_in().await { + bail!( + "You are not logged in, please log in with {}", + format!("{CLI_BINARY_NAME} login",).bold() + ); + } + + if let Ok(Some(token)) = crate::fig_auth::builder_id_token().await { + if matches!(token.token_type(), TokenType::BuilderId) { + bail!("This command is only available for Pro users"); + } + } + + select_profile_interactive(false).await?; + + Ok(ExitCode::SUCCESS) + }, + } + } +} + +#[derive(Subcommand, Debug, PartialEq, Eq)] +pub enum UserSubcommand { + #[command(flatten)] + Root(RootUserSubcommand), +} + +impl UserSubcommand { + pub async fn execute(self) -> Result { + match self { + Self::Root(cmd) => cmd.execute().await, + } + } +} + +pub async fn login_interactive(args: LoginArgs) -> Result<()> { + let login_method = match args.license { + Some(LicenseType::Free) => AuthMethod::BuilderId, + Some(LicenseType::Pro) => AuthMethod::IdentityCenter, + None => { + // No license specified, prompt the user to choose + let options = [AuthMethod::BuilderId, AuthMethod::IdentityCenter]; + let i = match choose("Select login method", &options)? { + Some(i) => i, + None => bail!("No login method selected"), + }; + options[i] + }, + }; + + match login_method { + AuthMethod::BuilderId | AuthMethod::IdentityCenter => { + let (start_url, region) = match login_method { + AuthMethod::BuilderId => (None, None), + AuthMethod::IdentityCenter => { + let default_start_url = args.identity_provider.or_else(|| { + crate::fig_settings::state::get_string("auth.idc.start-url") + .ok() + .flatten() + }); + let default_region = args + .region + .or_else(|| crate::fig_settings::state::get_string("auth.idc.region").ok().flatten()); + + let start_url = input("Enter Start URL", default_start_url.as_deref())?; + let region = input("Enter Region", default_region.as_deref())?; + + let _ = crate::fig_settings::state::set_value("auth.idc.start-url", start_url.clone()); + let _ = crate::fig_settings::state::set_value("auth.idc.region", region.clone()); + + (Some(start_url), Some(region)) + }, + }; + let secret_store = SecretStore::new().await?; + + // Remote machine won't be able to handle browser opening and redirects, + // hence always use device code flow. + if is_remote() || args.use_device_flow { + try_device_authorization(&secret_store, start_url.clone(), region.clone()).await?; + } else { + let (client, registration) = start_pkce_authorization(start_url.clone(), region.clone()).await?; + + match crate::fig_util::open::open_url_async(®istration.url).await { + // If it succeeded, finish PKCE. + Ok(()) => { + let mut spinner = Spinner::new(vec![ + SpinnerComponent::Spinner, + SpinnerComponent::Text(" Logging in...".into()), + ]); + let mut ctrl_c_stream = signal(SignalKind::interrupt())?; + tokio::select! { + res = registration.finish(&client, Some(&secret_store)) => res?, + Some(_) = ctrl_c_stream.recv() => { + #[allow(clippy::exit)] + exit(1); + }, + } + crate::fig_telemetry::send_user_logged_in().await; + spinner.stop_with_message("Device authorized".into()); + }, + // If we are unable to open the link with the browser, then fallback to + // the device code flow. + Err(err) => { + error!(%err, "Failed to open URL with browser, falling back to device code flow"); + + // Try device code flow. + try_device_authorization(&secret_store, start_url.clone(), region.clone()).await?; + }, + } + } + }, + }; + + if login_method == AuthMethod::IdentityCenter { + select_profile_interactive(true).await?; + } + + eprintln!("Logged in successfully"); + + Ok(()) +} + +async fn try_device_authorization( + secret_store: &SecretStore, + start_url: Option, + region: Option, +) -> Result<()> { + let device_auth = start_device_authorization(secret_store, start_url.clone(), region.clone()).await?; + + println!(); + println!("Confirm the following code in the browser"); + println!("Code: {}", device_auth.user_code.bold()); + println!(); + + let print_open_url = || println!("Open this URL: {}", device_auth.verification_uri_complete); + + if is_remote() { + print_open_url(); + } else if let Err(err) = crate::fig_util::open::open_url_async(&device_auth.verification_uri_complete).await { + error!(%err, "Failed to open URL with browser"); + print_open_url(); + } + + let mut spinner = Spinner::new(vec![ + SpinnerComponent::Spinner, + SpinnerComponent::Text(" Logging in...".into()), + ]); + + let mut ctrl_c_stream = signal(SignalKind::interrupt())?; + loop { + tokio::select! { + _ = tokio::time::sleep(Duration::from_secs(device_auth.interval.try_into().unwrap_or(1))) => (), + Some(_) = ctrl_c_stream.recv() => { + #[allow(clippy::exit)] + exit(1); + } + } + match poll_create_token( + secret_store, + device_auth.device_code.clone(), + start_url.clone(), + region.clone(), + ) + .await + { + PollCreateToken::Pending => {}, + PollCreateToken::Complete(_) => { + crate::fig_telemetry::send_user_logged_in().await; + spinner.stop_with_message("Device authorized".into()); + break; + }, + PollCreateToken::Error(err) => { + spinner.stop(); + return Err(err.into()); + }, + }; + } + Ok(()) +} + +async fn select_profile_interactive(whoami: bool) -> Result<()> { + let mut spinner = Spinner::new(vec![ + SpinnerComponent::Spinner, + SpinnerComponent::Text(" Fetching profiles...".into()), + ]); + let profiles = list_available_profiles().await; + if profiles.is_empty() { + info!("Available profiles was empty"); + return Ok(()); + } + + let sso_region: Option = crate::fig_settings::state::get_string("auth.idc.region").ok().flatten(); + let total_profiles = profiles.len() as i64; + + if whoami && profiles.len() == 1 { + if let Some(profile_region) = profiles[0].arn.split(':').nth(3) { + crate::fig_telemetry::send_profile_state( + QProfileSwitchIntent::Update, + profile_region.to_string(), + TelemetryResult::Succeeded, + sso_region, + ) + .await; + } + spinner.stop_with_message(String::new()); + return Ok(crate::fig_settings::state::set_value( + "api.codewhisperer.profile", + serde_json::to_value(&profiles[0])?, + )?); + } + + let mut items: Vec = profiles + .iter() + .map(|p| format!("{} (arn: {})", p.profile_name, p.arn)) + .collect(); + let active_profile: Option = crate::fig_settings::state::get("api.codewhisperer.profile")?; + + if let Some(default_idx) = active_profile + .as_ref() + .and_then(|active| profiles.iter().position(|p| p.arn == active.arn)) + { + items[default_idx] = format!("{} (active)", items[default_idx].as_str()); + } + + spinner.stop_with_message(String::new()); + let selected = Select::with_theme(&crate::fig_util::dialoguer_theme()) + .with_prompt("Select an IAM Identity Center profile") + .items(&items) + .default(0) + .interact_opt()?; + + match selected { + Some(i) => { + let chosen = &profiles[i]; + let profile = serde_json::to_value(chosen)?; + eprintln!("Set profile: {}\n", chosen.profile_name.as_str().green()); + crate::fig_settings::state::set_value("api.codewhisperer.profile", profile)?; + crate::fig_settings::state::remove_value("api.selectedCustomization")?; + + if let Some(profile_region) = chosen.arn.split(':').nth(3) { + let intent = if whoami { + QProfileSwitchIntent::Auth + } else { + QProfileSwitchIntent::User + }; + crate::fig_telemetry::send_did_select_profile( + intent, + profile_region.to_string(), + TelemetryResult::Succeeded, + sso_region, + Some(total_profiles), + ) + .await; + } + }, + None => { + crate::fig_telemetry::send_did_select_profile( + QProfileSwitchIntent::User, + "not-set".to_string(), + TelemetryResult::Cancelled, + sso_region, + Some(total_profiles), + ) + .await; + bail!("No profile selected.\n"); + }, + } + + Ok(()) +} + +mod tests { + #[test] + #[ignore] + fn unset_profile() { + crate::fig_settings::state::remove_value("api.codewhisperer.profile").unwrap(); + } +} diff --git a/crates/kiro-cli/src/diagnostics.rs b/crates/kiro-cli/src/diagnostics.rs new file mode 100644 index 0000000000..6de3f937e0 --- /dev/null +++ b/crates/kiro-cli/src/diagnostics.rs @@ -0,0 +1,253 @@ +#![allow(clippy::ref_option_ref)] +use std::collections::BTreeMap; + +use serde::Serialize; +use sysinfo::{ + CpuRefreshKind, + MemoryRefreshKind, + RefreshKind, +}; +use time::OffsetDateTime; +use time::format_description::well_known::Rfc3339; + +use crate::fig_os_shim::{ + Context, + Os, + PlatformProvider, +}; +use crate::fig_telemetry::InstallMethod; +use crate::fig_util::consts::build::HASH; +use crate::fig_util::manifest::manifest; +use crate::fig_util::system_info::{ + OSVersion, + os_version, +}; + +fn serialize_display(display: D, serializer: S) -> Result +where + D: std::fmt::Display, + S: serde::Serializer, +{ + serializer.serialize_str(&display.to_string()) +} + +fn is_false(value: &bool) -> bool { + !value +} + +#[derive(Debug, Clone, Serialize, Default)] +#[serde(rename_all = "kebab-case")] +pub struct BuildDetails { + pub version: String, + pub hash: Option<&'static str>, + pub date: Option, + pub variant: String, +} + +impl BuildDetails { + pub fn new() -> BuildDetails { + let date = crate::fig_util::consts::build::DATETIME + .and_then(|input| OffsetDateTime::parse(input, &Rfc3339).ok()) + .and_then(|time| { + let rfc3339 = time.format(&Rfc3339).ok()?; + let duration = OffsetDateTime::now_utc() - time; + Some(format!("{rfc3339} ({duration:.0} ago)")) + }); + + BuildDetails { + version: env!("CARGO_PKG_VERSION").to_owned(), + hash: HASH, + date, + variant: manifest().variant.to_string(), + } + } +} + +fn serialize_os_version(version: &Option<&OSVersion>, serializer: S) -> Result +where + S: serde::Serializer, +{ + match version { + Some(version) => match version { + OSVersion::Linux { .. } => version.serialize(serializer), + other => serializer.serialize_str(&other.to_string()), + }, + None => serializer.serialize_none(), + } +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "kebab-case")] +pub struct SystemInfo { + #[serde(serialize_with = "serialize_os_version")] + pub os: Option<&'static OSVersion>, + pub chip: Option, + pub total_cores: Option, + pub memory: Option, +} + +impl SystemInfo { + fn new() -> SystemInfo { + let system = sysinfo::System::new_with_specifics( + RefreshKind::nothing() + .with_cpu(CpuRefreshKind::everything()) + .with_memory(MemoryRefreshKind::everything()), + ); + + let mut hardware_info = SystemInfo { + os: os_version(), + chip: None, + total_cores: system.physical_core_count(), + memory: Some(format!("{:0.2} GB", system.total_memory() as f32 / 2.0_f32.powi(30))), + }; + + if let Some(processor) = system.cpus().first() { + hardware_info.chip = Some(processor.brand().into()); + } + + hardware_info + } +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "kebab-case")] +pub struct EnvVarDiagnostic { + pub env_vars: BTreeMap, +} + +impl EnvVarDiagnostic { + fn new() -> EnvVarDiagnostic { + let env_vars = std::env::vars() + .filter(|(key, _)| { + let fig_var = crate::fig_util::env_var::ALL.contains(&key.as_str()); + let other_var = [ + // General env vars + "SHELL", + "DISPLAY", + "PATH", + "TERM", + "ZDOTDIR", + // Linux vars + "XDG_CURRENT_DESKTOP", + "XDG_SESSION_DESKTOP", + "XDG_SESSION_TYPE", + "GLFW_IM_MODULE", + "GTK_IM_MODULE", + "QT_IM_MODULE", + "XMODIFIERS", + // Macos vars + "__CFBundleIdentifier", + ] + .contains(&key.as_str()); + + fig_var || other_var + }) + .map(|(key, value)| { + // sanitize username from values + let username = format!("/{}", whoami::username()); + (key, value.replace(&username, "/USER")) + }) + .collect(); + + EnvVarDiagnostic { env_vars } + } +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "kebab-case")] +pub struct CurrentEnvironment { + pub cwd: Option, + pub cli_path: Option, + pub os: Os, + #[serde(serialize_with = "serialize_display")] + pub install_method: InstallMethod, + #[serde(skip_serializing_if = "is_false")] + pub in_cloudshell: bool, + #[serde(skip_serializing_if = "is_false")] + pub in_ssh: bool, + #[serde(skip_serializing_if = "is_false")] + pub in_ci: bool, + #[serde(skip_serializing_if = "is_false")] + pub in_wsl: bool, + #[serde(skip_serializing_if = "is_false")] + pub in_codespaces: bool, +} + +impl CurrentEnvironment { + async fn new() -> CurrentEnvironment { + let ctx = Context::new(); + + let username = format!("/{}", whoami::username()); + + let cwd = ctx + .env() + .current_dir() + .ok() + .map(|path| path.to_string_lossy().replace(&username, "/USER")); + + let cli_path = ctx + .env() + .current_dir() + .ok() + .map(|path| path.to_string_lossy().replace(&username, "/USER")); + + let os = ctx.platform().os(); + let install_method = crate::fig_telemetry::get_install_method(); + + let in_cloudshell = crate::fig_util::system_info::in_cloudshell(); + let in_ssh = crate::fig_util::system_info::in_ssh(); + let in_ci = crate::fig_util::system_info::in_ci(); + let in_wsl = crate::fig_util::system_info::in_wsl(); + let in_codespaces = crate::fig_util::system_info::in_codespaces(); + + CurrentEnvironment { + cwd, + cli_path, + os, + install_method, + in_cloudshell, + in_ssh, + in_ci, + in_wsl, + in_codespaces, + } + } +} + +#[derive(Clone, Debug, Serialize)] +#[serde(rename_all = "kebab-case")] +pub struct Diagnostics { + #[serde(rename = "q-details")] + pub build_details: BuildDetails, + pub system_info: SystemInfo, + pub environment: CurrentEnvironment, + #[serde(flatten)] + pub environment_variables: EnvVarDiagnostic, +} + +impl Diagnostics { + pub async fn new() -> Diagnostics { + Diagnostics { + build_details: BuildDetails::new(), + system_info: SystemInfo::new(), + environment: CurrentEnvironment::new().await, + environment_variables: EnvVarDiagnostic::new(), + } + } + + pub fn user_readable(&self) -> Result { + toml::to_string(&self) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_diagnostics_user_readable() { + let diagnostics = Diagnostics::new().await; + let toml = diagnostics.user_readable().unwrap(); + assert!(!toml.is_empty()); + } +} diff --git a/crates/kiro-cli/src/fig_api_client/clients/client.rs b/crates/kiro-cli/src/fig_api_client/clients/client.rs new file mode 100644 index 0000000000..170bca43da --- /dev/null +++ b/crates/kiro-cli/src/fig_api_client/clients/client.rs @@ -0,0 +1,208 @@ +use amzn_codewhisperer_client::Client as CodewhispererClient; +use amzn_codewhisperer_client::types::{ + OptOutPreference, + TelemetryEvent, + UserContext, +}; +use amzn_consolas_client::Client as ConsolasClient; +use tracing::error; + +use super::shared::{ + bearer_sdk_config, + sigv4_sdk_config, +}; +use crate::fig_api_client::interceptor::opt_out::OptOutInterceptor; +use crate::fig_api_client::profile::Profile; +use crate::fig_api_client::{ + Endpoint, + Error, +}; +use crate::fig_auth::builder_id::BearerResolver; +use crate::fig_aws_common::{ + UserAgentOverrideInterceptor, + app_name, +}; + +mod inner { + use amzn_codewhisperer_client::Client as CodewhispererClient; + use amzn_consolas_client::Client as ConsolasClient; + + #[derive(Clone, Debug)] + pub enum Inner { + Codewhisperer(CodewhispererClient), + Consolas(ConsolasClient), + Mock, + } +} + +#[derive(Clone, Debug)] +pub struct Client { + inner: inner::Inner, + profile_arn: Option, +} + +impl Client { + pub async fn new() -> Result { + let endpoint = Endpoint::load_codewhisperer(); + let client = if crate::fig_util::system_info::in_cloudshell() { + Self::new_consolas_client(&endpoint).await? + } else { + Self::new_codewhisperer_client(&endpoint).await + }; + Ok(client) + } + + pub fn mock() -> Self { + Self { + inner: inner::Inner::Mock, + profile_arn: None, + } + } + + pub async fn new_codewhisperer_client(endpoint: &Endpoint) -> Self { + let conf_builder: amzn_codewhisperer_client::config::Builder = (&bearer_sdk_config(endpoint).await).into(); + let conf = conf_builder + .http_client(crate::fig_aws_common::http_client::client()) + .interceptor(OptOutInterceptor::new()) + .interceptor(UserAgentOverrideInterceptor::new()) + .bearer_token_resolver(BearerResolver) + .app_name(app_name()) + .endpoint_url(endpoint.url()) + .build(); + + let inner = inner::Inner::Codewhisperer(CodewhispererClient::from_conf(conf)); + + let profile_arn = match crate::fig_settings::state::get_value("api.codewhisperer.profile") { + Ok(Some(profile)) => match profile.get("arn") { + Some(arn) => match arn.as_str() { + Some(arn) => Some(arn.to_string()), + None => { + error!("Stored arn is not a string. Instead it was: {arn}"); + None + }, + }, + None => { + error!("Stored profile does not contain an arn. Instead it was: {profile}"); + None + }, + }, + Ok(None) => None, + Err(err) => { + error!("Failed to retrieve profile: {}", err); + None + }, + }; + + Self { inner, profile_arn } + } + + pub async fn new_consolas_client(endpoint: &Endpoint) -> Result { + let conf_builder: amzn_consolas_client::config::Builder = (&sigv4_sdk_config(endpoint).await?).into(); + let conf = conf_builder + .http_client(crate::fig_aws_common::http_client::client()) + .interceptor(OptOutInterceptor::new()) + .interceptor(UserAgentOverrideInterceptor::new()) + .app_name(app_name()) + .endpoint_url(endpoint.url()) + .build(); + Ok(Self { + inner: inner::Inner::Consolas(ConsolasClient::from_conf(conf)), + profile_arn: None, + }) + } + + // .telemetry_event(TelemetryEvent::UserTriggerDecisionEvent(user_trigger_decision_event)) + // .user_context(user_context) + // .opt_out_preference(opt_out_preference) + pub async fn send_telemetry_event( + &self, + telemetry_event: TelemetryEvent, + user_context: UserContext, + opt_out: OptOutPreference, + ) -> Result<(), Error> { + match &self.inner { + inner::Inner::Codewhisperer(client) => { + let _ = client + .send_telemetry_event() + .telemetry_event(telemetry_event) + .user_context(user_context) + .opt_out_preference(opt_out) + .set_profile_arn(self.profile_arn.clone()) + .send() + .await; + Ok(()) + }, + inner::Inner::Consolas(_) => Err(Error::UnsupportedConsolas("send_telemetry_event")), + inner::Inner::Mock => Ok(()), + } + } + + pub async fn list_available_profiles(&self) -> Result, Error> { + match &self.inner { + inner::Inner::Codewhisperer(client) => { + let mut profiles = vec![]; + let mut client = client.list_available_profiles().into_paginator().send(); + while let Some(profiles_output) = client.next().await { + profiles.extend(profiles_output?.profiles().iter().cloned().map(Profile::from)); + } + + Ok(profiles) + }, + inner::Inner::Consolas(_) => Err(Error::UnsupportedConsolas("list_available_profiles")), + inner::Inner::Mock => Ok(vec![ + Profile { + arn: "my:arn:1".to_owned(), + profile_name: "MyProfile".to_owned(), + }, + Profile { + arn: "my:arn:2".to_owned(), + profile_name: "MyOtherProfile".to_owned(), + }, + ]), + } + } +} + +#[cfg(test)] +mod tests { + use amzn_codewhisperer_client::types::{ + ChatAddMessageEvent, + IdeCategory, + OperatingSystem, + }; + + use super::*; + + #[tokio::test] + async fn create_clients() { + let endpoint = Endpoint::load_codewhisperer(); + + let _ = Client::new().await; + let _ = Client::new_codewhisperer_client(&endpoint).await; + let _ = Client::new_consolas_client(&endpoint).await; + } + + #[tokio::test] + async fn test_mock() { + let client = Client::mock(); + client + .send_telemetry_event( + TelemetryEvent::ChatAddMessageEvent( + ChatAddMessageEvent::builder() + .conversation_id("") + .message_id("") + .build() + .unwrap(), + ), + UserContext::builder() + .ide_category(IdeCategory::Cli) + .operating_system(OperatingSystem::Linux) + .product("") + .build() + .unwrap(), + OptOutPreference::OptIn, + ) + .await + .unwrap(); + } +} diff --git a/crates/kiro-cli/src/fig_api_client/clients/mod.rs b/crates/kiro-cli/src/fig_api_client/clients/mod.rs new file mode 100644 index 0000000000..75b7e4ab87 --- /dev/null +++ b/crates/kiro-cli/src/fig_api_client/clients/mod.rs @@ -0,0 +1,9 @@ +mod client; +pub(crate) mod shared; +mod streaming_client; + +pub use client::Client; +pub use streaming_client::{ + SendMessageOutput, + StreamingClient, +}; diff --git a/crates/kiro-cli/src/fig_api_client/clients/shared.rs b/crates/kiro-cli/src/fig_api_client/clients/shared.rs new file mode 100644 index 0000000000..a9a458cb06 --- /dev/null +++ b/crates/kiro-cli/src/fig_api_client/clients/shared.rs @@ -0,0 +1,65 @@ +use std::time::Duration; + +use aws_config::Region; +use aws_config::retry::RetryConfig; +use aws_config::timeout::TimeoutConfig; +use aws_credential_types::Credentials; +use aws_credential_types::provider::ProvideCredentials; +use aws_types::SdkConfig; +use aws_types::sdk_config::StalledStreamProtectionConfig; + +use crate::fig_api_client::credentials::CredentialsChain; +use crate::fig_api_client::{ + Endpoint, + Error, +}; +use crate::fig_aws_common::behavior_version; + +// TODO(bskiser): confirm timeout is updated to an appropriate value? +const DEFAULT_TIMEOUT_DURATION: Duration = Duration::from_secs(60 * 5); + +pub(crate) fn timeout_config() -> TimeoutConfig { + let timeout = crate::fig_settings::settings::get_int("api.timeout") + .ok() + .flatten() + .and_then(|i| i.try_into().ok()) + .map_or(DEFAULT_TIMEOUT_DURATION, Duration::from_millis); + + TimeoutConfig::builder() + .read_timeout(timeout) + .operation_timeout(timeout) + .operation_attempt_timeout(timeout) + .connect_timeout(timeout) + .build() +} + +pub(crate) fn stalled_stream_protection_config() -> StalledStreamProtectionConfig { + StalledStreamProtectionConfig::enabled() + .grace_period(Duration::from_secs(60 * 5)) + .build() +} + +async fn base_sdk_config(region: Region, credentials_provider: impl ProvideCredentials + 'static) -> SdkConfig { + aws_config::defaults(behavior_version()) + .region(region) + .credentials_provider(credentials_provider) + .timeout_config(timeout_config()) + .retry_config(RetryConfig::adaptive()) + .load() + .await +} + +pub(crate) async fn bearer_sdk_config(endpoint: &Endpoint) -> SdkConfig { + let credentials = Credentials::new("xxx", "xxx", None, None, "xxx"); + base_sdk_config(endpoint.region().clone(), credentials).await +} + +pub(crate) async fn sigv4_sdk_config(endpoint: &Endpoint) -> Result { + let credentials_chain = CredentialsChain::new().await; + + if let Err(err) = credentials_chain.provide_credentials().await { + return Err(Error::Credentials(err)); + }; + + Ok(base_sdk_config(endpoint.region().clone(), credentials_chain).await) +} diff --git a/crates/kiro-cli/src/fig_api_client/clients/streaming_client.rs b/crates/kiro-cli/src/fig_api_client/clients/streaming_client.rs new file mode 100644 index 0000000000..1c0a9ec5db --- /dev/null +++ b/crates/kiro-cli/src/fig_api_client/clients/streaming_client.rs @@ -0,0 +1,339 @@ +use std::sync::{ + Arc, + Mutex, +}; + +use amzn_codewhisperer_streaming_client::Client as CodewhispererStreamingClient; +use amzn_qdeveloper_streaming_client::Client as QDeveloperStreamingClient; +use aws_types::request_id::RequestId; +use tracing::{ + debug, + error, +}; + +use super::shared::{ + bearer_sdk_config, + sigv4_sdk_config, + stalled_stream_protection_config, +}; +use crate::fig_api_client::interceptor::opt_out::OptOutInterceptor; +use crate::fig_api_client::model::{ + ChatResponseStream, + ConversationState, +}; +use crate::fig_api_client::{ + Endpoint, + Error, +}; +use crate::fig_auth::builder_id::BearerResolver; +use crate::fig_aws_common::{ + UserAgentOverrideInterceptor, + app_name, +}; + +mod inner { + use std::sync::{ + Arc, + Mutex, + }; + + use amzn_codewhisperer_streaming_client::Client as CodewhispererStreamingClient; + use amzn_qdeveloper_streaming_client::Client as QDeveloperStreamingClient; + + use crate::fig_api_client::model::ChatResponseStream; + + #[derive(Clone, Debug)] + pub enum Inner { + Codewhisperer(CodewhispererStreamingClient), + QDeveloper(QDeveloperStreamingClient), + Mock(Arc>>>), + } +} + +#[derive(Clone, Debug)] +pub struct StreamingClient { + inner: inner::Inner, + profile_arn: Option, +} + +impl StreamingClient { + pub async fn new() -> Result { + let client = if crate::fig_util::system_info::in_cloudshell() + || std::env::var("Q_USE_SENDMESSAGE").is_ok_and(|v| !v.is_empty()) + { + Self::new_qdeveloper_client(&Endpoint::load_q()).await? + } else { + Self::new_codewhisperer_client(&Endpoint::load_codewhisperer()).await + }; + Ok(client) + } + + pub fn mock(events: Vec>) -> Self { + Self { + inner: inner::Inner::Mock(Arc::new(Mutex::new(events.into_iter()))), + profile_arn: None, + } + } + + pub async fn new_codewhisperer_client(endpoint: &Endpoint) -> Self { + let conf_builder: amzn_codewhisperer_streaming_client::config::Builder = + (&bearer_sdk_config(endpoint).await).into(); + let conf = conf_builder + .http_client(crate::fig_aws_common::http_client::client()) + .interceptor(OptOutInterceptor::new()) + .interceptor(UserAgentOverrideInterceptor::new()) + .bearer_token_resolver(BearerResolver) + .app_name(app_name()) + .endpoint_url(endpoint.url()) + .stalled_stream_protection(stalled_stream_protection_config()) + .build(); + let inner = inner::Inner::Codewhisperer(CodewhispererStreamingClient::from_conf(conf)); + + let profile_arn = match crate::fig_settings::state::get_value("api.codewhisperer.profile") { + Ok(Some(profile)) => match profile.get("arn") { + Some(arn) => match arn.as_str() { + Some(arn) => Some(arn.to_string()), + None => { + error!("Stored arn is not a string. Instead it was: {arn}"); + None + }, + }, + None => { + error!("Stored profile does not contain an arn. Instead it was: {profile}"); + None + }, + }, + Ok(None) => None, + Err(err) => { + error!("Failed to retrieve profile: {}", err); + None + }, + }; + + Self { inner, profile_arn } + } + + pub async fn new_qdeveloper_client(endpoint: &Endpoint) -> Result { + let conf_builder: amzn_qdeveloper_streaming_client::config::Builder = + (&sigv4_sdk_config(endpoint).await?).into(); + let conf = conf_builder + .http_client(crate::fig_aws_common::http_client::client()) + .interceptor(OptOutInterceptor::new()) + .interceptor(UserAgentOverrideInterceptor::new()) + .app_name(app_name()) + .endpoint_url(endpoint.url()) + .stalled_stream_protection(stalled_stream_protection_config()) + .build(); + let client = QDeveloperStreamingClient::from_conf(conf); + Ok(Self { + inner: inner::Inner::QDeveloper(client), + profile_arn: None, + }) + } + + pub async fn send_message(&self, conversation_state: ConversationState) -> Result { + debug!("Sending conversation: {:#?}", conversation_state); + let ConversationState { + conversation_id, + user_input_message, + history, + } = conversation_state; + + match &self.inner { + inner::Inner::Codewhisperer(client) => { + let conversation_state = amzn_codewhisperer_streaming_client::types::ConversationState::builder() + .set_conversation_id(conversation_id) + .current_message( + amzn_codewhisperer_streaming_client::types::ChatMessage::UserInputMessage( + user_input_message.into(), + ), + ) + .chat_trigger_type(amzn_codewhisperer_streaming_client::types::ChatTriggerType::Manual) + .set_history( + history + .map(|v| v.into_iter().map(|i| i.try_into()).collect::, _>>()) + .transpose()?, + ) + .build() + .expect("building conversation_state should not fail"); + let response = client + .generate_assistant_response() + .conversation_state(conversation_state) + .set_profile_arn(self.profile_arn.clone()) + .send() + .await; + + match response { + Ok(resp) => Ok(SendMessageOutput::Codewhisperer(resp)), + Err(e) => { + let is_quota_breach = e.raw_response().is_some_and(|resp| resp.status().as_u16() == 429); + let is_context_window_overflow = e.as_service_error().is_some_and(|err| { + matches!(err, err if err.meta().code() == Some("ValidationException") + && err.meta().message() == Some("Input is too long.")) + }); + + if is_quota_breach { + Err(Error::QuotaBreach("quota has reached its limit")) + } else if is_context_window_overflow { + Err(Error::ContextWindowOverflow) + } else { + Err(e.into()) + } + }, + } + }, + inner::Inner::QDeveloper(client) => { + let conversation_state_builder = amzn_qdeveloper_streaming_client::types::ConversationState::builder() + .set_conversation_id(conversation_id) + .current_message(amzn_qdeveloper_streaming_client::types::ChatMessage::UserInputMessage( + user_input_message.into(), + )) + .chat_trigger_type(amzn_qdeveloper_streaming_client::types::ChatTriggerType::Manual) + .set_history( + history + .map(|v| v.into_iter().map(|i| i.try_into()).collect::, _>>()) + .transpose()?, + ); + + Ok(SendMessageOutput::QDeveloper( + client + .send_message() + .conversation_state(conversation_state_builder.build().expect("fix me")) + .send() + .await?, + )) + }, + inner::Inner::Mock(events) => { + let mut new_events = events.lock().unwrap().next().unwrap_or_default().clone(); + new_events.reverse(); + Ok(SendMessageOutput::Mock(new_events)) + }, + } + } +} + +#[derive(Debug)] +pub enum SendMessageOutput { + Codewhisperer( + amzn_codewhisperer_streaming_client::operation::generate_assistant_response::GenerateAssistantResponseOutput, + ), + QDeveloper(amzn_qdeveloper_streaming_client::operation::send_message::SendMessageOutput), + Mock(Vec), +} + +impl SendMessageOutput { + pub fn request_id(&self) -> Option<&str> { + match self { + SendMessageOutput::Codewhisperer(output) => output.request_id(), + SendMessageOutput::QDeveloper(output) => output.request_id(), + SendMessageOutput::Mock(_) => None, + } + } + + pub async fn recv(&mut self) -> Result, Error> { + match self { + SendMessageOutput::Codewhisperer(output) => Ok(output + .generate_assistant_response_response + .recv() + .await? + .map(|s| s.into())), + SendMessageOutput::QDeveloper(output) => Ok(output.send_message_response.recv().await?.map(|s| s.into())), + SendMessageOutput::Mock(vec) => Ok(vec.pop()), + } + } +} + +impl RequestId for SendMessageOutput { + fn request_id(&self) -> Option<&str> { + match self { + SendMessageOutput::Codewhisperer(output) => output.request_id(), + SendMessageOutput::QDeveloper(output) => output.request_id(), + SendMessageOutput::Mock(_) => Some(""), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::fig_api_client::model::{ + AssistantResponseMessage, + ChatMessage, + UserInputMessage, + }; + + #[tokio::test] + async fn create_clients() { + let endpoint = Endpoint::load_codewhisperer(); + + let _ = StreamingClient::new().await; + let _ = StreamingClient::new_codewhisperer_client(&endpoint).await; + let _ = StreamingClient::new_qdeveloper_client(&endpoint).await; + } + + #[tokio::test] + async fn test_mock() { + let client = StreamingClient::mock(vec![vec![ + ChatResponseStream::AssistantResponseEvent { + content: "Hello!".to_owned(), + }, + ChatResponseStream::AssistantResponseEvent { + content: " How can I".to_owned(), + }, + ChatResponseStream::AssistantResponseEvent { + content: " assist you today?".to_owned(), + }, + ]]); + let mut output = client + .send_message(ConversationState { + conversation_id: None, + user_input_message: UserInputMessage { + content: "Hello".into(), + user_input_message_context: None, + user_intent: None, + }, + history: None, + }) + .await + .unwrap(); + + let mut output_content = String::new(); + while let Some(ChatResponseStream::AssistantResponseEvent { content }) = output.recv().await.unwrap() { + output_content.push_str(&content); + } + assert_eq!(output_content, "Hello! How can I assist you today?"); + } + + #[ignore] + #[tokio::test] + async fn assistant_response() { + let client = StreamingClient::new().await.unwrap(); + let mut response = client + .send_message(ConversationState { + conversation_id: None, + user_input_message: UserInputMessage { + content: "How about rustc?".into(), + user_input_message_context: None, + user_intent: None, + }, + history: Some(vec![ + ChatMessage::UserInputMessage(UserInputMessage { + content: "What language is the linux kernel written in, and who wrote it?".into(), + user_input_message_context: None, + user_intent: None, + }), + ChatMessage::AssistantResponseMessage(AssistantResponseMessage { + content: "It is written in C by Linus Torvalds.".into(), + message_id: None, + tool_uses: None, + }), + ]), + }) + .await + .unwrap(); + + while let Some(event) = response.recv().await.unwrap() { + println!("{:?}", event); + } + } +} diff --git a/crates/kiro-cli/src/fig_api_client/consts.rs b/crates/kiro-cli/src/fig_api_client/consts.rs new file mode 100644 index 0000000000..8a954e3a54 --- /dev/null +++ b/crates/kiro-cli/src/fig_api_client/consts.rs @@ -0,0 +1,19 @@ +use aws_config::Region; + +// Endpoint constants +pub const PROD_CODEWHISPERER_ENDPOINT_URL: &str = "https://codewhisperer.us-east-1.amazonaws.com"; +pub const PROD_CODEWHISPERER_ENDPOINT_REGION: Region = Region::from_static("us-east-1"); + +pub const PROD_Q_ENDPOINT_URL: &str = "https://q.us-east-1.amazonaws.com"; +pub const PROD_Q_ENDPOINT_REGION: Region = Region::from_static("us-east-1"); + +// FRA endpoint constants +pub const PROD_CODEWHISPERER_FRA_ENDPOINT_URL: &str = "https://q.eu-central-1.amazonaws.com/"; +pub const PROD_CODEWHISPERER_FRA_ENDPOINT_REGION: Region = Region::from_static("eu-central-1"); + +// Opt out constants +pub const SHARE_CODEWHISPERER_CONTENT_SETTINGS_KEY: &str = "codeWhisperer.shareCodeWhispererContentWithAWS"; +pub const X_AMZN_CODEWHISPERER_OPT_OUT_HEADER: &str = "x-amzn-codewhisperer-optout"; + +// Session ID constants +pub const X_AMZN_SESSIONID_HEADER: &str = "x-amzn-sessionid"; diff --git a/crates/kiro-cli/src/fig_api_client/credentials/mod.rs b/crates/kiro-cli/src/fig_api_client/credentials/mod.rs new file mode 100644 index 0000000000..508052a554 --- /dev/null +++ b/crates/kiro-cli/src/fig_api_client/credentials/mod.rs @@ -0,0 +1,80 @@ +use aws_config::default_provider::region::DefaultRegionChain; +use aws_config::ecs::EcsCredentialsProvider; +use aws_config::environment::credentials::EnvironmentVariableCredentialsProvider; +use aws_config::imds::credentials::ImdsCredentialsProvider; +use aws_config::meta::credentials::CredentialsProviderChain; +use aws_config::profile::ProfileFileCredentialsProvider; +use aws_config::provider_config::ProviderConfig; +use aws_config::web_identity_token::WebIdentityTokenCredentialsProvider; +use aws_credential_types::Credentials; +use aws_credential_types::provider::{ + self, + ProvideCredentials, + future, +}; +use tracing::Instrument; + +#[derive(Debug)] +pub struct CredentialsChain { + provider_chain: CredentialsProviderChain, +} + +impl CredentialsChain { + /// Based on code the code for + /// [aws_config::default_provider::credentials::DefaultCredentialsChain]: + pub async fn new() -> Self { + let region = DefaultRegionChain::builder().build().region().await; + let config = ProviderConfig::default().with_region(region.clone()); + + let env_provider = EnvironmentVariableCredentialsProvider::new(); + let profile_provider = ProfileFileCredentialsProvider::builder().configure(&config).build(); + let web_identity_token_provider = WebIdentityTokenCredentialsProvider::builder() + .configure(&config) + .build(); + let imds_provider = ImdsCredentialsProvider::builder().configure(&config).build(); + let ecs_provider = EcsCredentialsProvider::builder().configure(&config).build(); + + let mut provider_chain = CredentialsProviderChain::first_try("Environment", env_provider); + + provider_chain = provider_chain + .or_else("Profile", profile_provider) + .or_else("WebIdentityToken", web_identity_token_provider) + .or_else("EcsContainer", ecs_provider) + .or_else("Ec2InstanceMetadata", imds_provider); + + CredentialsChain { provider_chain } + } + + async fn credentials(&self) -> provider::Result { + self.provider_chain + .provide_credentials() + .instrument(tracing::debug_span!("provide_credentials", provider = %"default_chain")) + .await + } +} + +impl ProvideCredentials for CredentialsChain { + fn provide_credentials<'a>(&'a self) -> future::ProvideCredentials<'a> + where + Self: 'a, + { + future::ProvideCredentials::new(self.credentials()) + } + + fn fallback_on_interrupt(&self) -> Option { + self.provider_chain.fallback_on_interrupt() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_credentials_chain() { + let credentials_chain = CredentialsChain::new().await; + let credentials_res = credentials_chain.provide_credentials().await; + let fallback_on_interrupt_res = credentials_chain.fallback_on_interrupt(); + println!("credentials_res: {credentials_res:?}, fallback_on_interrupt_res: {fallback_on_interrupt_res:?}"); + } +} diff --git a/crates/kiro-cli/src/fig_api_client/customization.rs b/crates/kiro-cli/src/fig_api_client/customization.rs new file mode 100644 index 0000000000..98acbc718d --- /dev/null +++ b/crates/kiro-cli/src/fig_api_client/customization.rs @@ -0,0 +1,161 @@ +use amzn_codewhisperer_client::types::Customization as CodewhispererCustomization; +use amzn_consolas_client::types::CustomizationSummary as ConsolasCustomization; +use serde::{ + Deserialize, + Serialize, +}; + +use crate::fig_settings::State; + +const CUSTOMIZATION_STATE_KEY: &str = "api.selectedCustomization"; + +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct Customization { + pub arn: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, +} + +impl Customization { + /// Load the currently selected customization from state + pub fn load_selected(state: &State) -> Result, crate::fig_settings::Error> { + state.get(CUSTOMIZATION_STATE_KEY) + } + + /// Save the currently selected customization to state + pub fn save_selected(&self, state: &State) -> Result<(), crate::fig_settings::Error> { + state.set_value(CUSTOMIZATION_STATE_KEY, serde_json::to_value(self)?) + } + + /// Delete the currently selected customization from state + pub fn delete_selected(state: &State) -> Result<(), crate::fig_settings::Error> { + state.remove_value(CUSTOMIZATION_STATE_KEY) + } +} + +impl From for CodewhispererCustomization { + fn from(Customization { arn, name, description }: Customization) -> Self { + CodewhispererCustomization::builder() + .arn(arn) + .set_name(name) + .set_description(description) + .build() + .expect("Failed to build CW Customization") + } +} + +impl From for Customization { + fn from(cw_customization: CodewhispererCustomization) -> Self { + Customization { + arn: cw_customization.arn, + name: cw_customization.name, + description: cw_customization.description, + } + } +} + +impl From for Customization { + fn from(consolas_customization: ConsolasCustomization) -> Self { + Customization { + arn: consolas_customization.arn, + name: Some(consolas_customization.customization_name), + description: consolas_customization.description, + } + } +} + +#[cfg(test)] +mod tests { + use amzn_consolas_client::types::CustomizationStatus; + use aws_smithy_types::DateTime; + + use super::*; + + #[test] + fn test_customization_from_impls() { + let cw_customization = CodewhispererCustomization::builder() + .arn("arn") + .name("name") + .description("description") + .build() + .unwrap(); + + let custom_from_cw: Customization = cw_customization.into(); + let cw_from_custom: CodewhispererCustomization = custom_from_cw.into(); + + assert_eq!(cw_from_custom.arn, "arn"); + assert_eq!(cw_from_custom.name, Some("name".into())); + assert_eq!(cw_from_custom.description, Some("description".into())); + + let cw_customization = CodewhispererCustomization::builder().arn("arn").build().unwrap(); + + let custom_from_cw: Customization = cw_customization.into(); + let cw_from_custom: CodewhispererCustomization = custom_from_cw.into(); + + assert_eq!(cw_from_custom.arn, "arn"); + assert_eq!(cw_from_custom.name, None); + assert_eq!(cw_from_custom.description, None); + + let consolas_customization = ConsolasCustomization::builder() + .arn("arn") + .customization_name("name") + .description("description") + .status(CustomizationStatus::Activated) + .updated_at(DateTime::from_secs(0)) + .build() + .unwrap(); + + let custom_from_consolas: Customization = consolas_customization.into(); + + assert_eq!(custom_from_consolas.arn, "arn"); + assert_eq!(custom_from_consolas.name, Some("name".into())); + assert_eq!(custom_from_consolas.description, Some("description".into())); + } + + #[test] + fn test_customization_save_load() { + let state = State::new_fake(); + + let value = Customization { + arn: "arn".into(), + name: Some("name".into()), + description: Some("description".into()), + }; + + value.save_selected(&state).unwrap(); + let loaded_value = Customization::load_selected(&state).unwrap(); + assert_eq!(loaded_value, Some(value)); + + Customization::delete_selected(&state).unwrap(); + } + + #[test] + fn test_customization_serde() { + let customization = Customization { + arn: "arn".into(), + name: Some("name".into()), + description: Some("description".into()), + }; + + let serialized = serde_json::to_string(&customization).unwrap(); + assert_eq!(serialized, r#"{"arn":"arn","name":"name","description":"description"}"#); + + let deserialized: Customization = serde_json::from_str(&serialized).unwrap(); + assert_eq!(deserialized, customization); + + let customization = Customization { + arn: "arn".into(), + name: None, + description: None, + }; + + let serialized = serde_json::to_string(&customization).unwrap(); + assert_eq!(serialized, r#"{"arn":"arn"}"#); + + let deserialized: Customization = serde_json::from_str(&serialized).unwrap(); + assert_eq!(deserialized, customization); + } +} diff --git a/crates/kiro-cli/src/fig_api_client/endpoints.rs b/crates/kiro-cli/src/fig_api_client/endpoints.rs new file mode 100644 index 0000000000..c560e6ecf4 --- /dev/null +++ b/crates/kiro-cli/src/fig_api_client/endpoints.rs @@ -0,0 +1,125 @@ +use std::borrow::Cow; + +use aws_config::Region; +use serde_json::Value; +use tracing::error; + +use crate::fig_api_client::consts::{ + PROD_CODEWHISPERER_ENDPOINT_REGION, + PROD_CODEWHISPERER_ENDPOINT_URL, + PROD_CODEWHISPERER_FRA_ENDPOINT_REGION, + PROD_CODEWHISPERER_FRA_ENDPOINT_URL, + PROD_Q_ENDPOINT_REGION, + PROD_Q_ENDPOINT_URL, +}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Endpoint { + pub url: Cow<'static, str>, + pub region: Region, +} + +impl Endpoint { + pub const CODEWHISPERER_ENDPOINTS: [Self; 2] = [Self::DEFAULT_ENDPOINT, Self { + url: Cow::Borrowed(PROD_CODEWHISPERER_FRA_ENDPOINT_URL), + region: PROD_CODEWHISPERER_FRA_ENDPOINT_REGION, + }]; + pub const DEFAULT_ENDPOINT: Self = Self { + url: Cow::Borrowed(PROD_CODEWHISPERER_ENDPOINT_URL), + region: PROD_CODEWHISPERER_ENDPOINT_REGION, + }; + pub const PROD_Q: Self = Self { + url: Cow::Borrowed(PROD_Q_ENDPOINT_URL), + region: PROD_Q_ENDPOINT_REGION, + }; + + pub fn load_codewhisperer() -> Self { + let (endpoint, region) = if let Ok(Some(Value::Object(o))) = + crate::fig_settings::settings::get_value("api.codewhisperer.service") + { + // The following branch is evaluated in case the user has set their own endpoint. + ( + o.get("endpoint").and_then(|v| v.as_str()).map(|v| v.to_owned()), + o.get("region").and_then(|v| v.as_str()).map(|v| v.to_owned()), + ) + } else if let Ok(Some(Value::Object(o))) = crate::fig_settings::state::get_value("api.codewhisperer.profile") { + // The following branch is evaluated in the case of user profile being set. + match o.get("arn").and_then(|v| v.as_str()).map(|v| v.to_owned()) { + Some(arn) => { + let region = arn.split(':').nth(3).unwrap_or_default().to_owned(); + match Self::CODEWHISPERER_ENDPOINTS + .iter() + .find(|e| e.region().as_ref() == region) + { + Some(endpoint) => (Some(endpoint.url().to_owned()), Some(region)), + None => { + error!("Failed to find endpoint for region: {region}"); + (None, None) + }, + } + }, + None => (None, None), + } + } else { + (None, None) + }; + + match (endpoint, region) { + (Some(endpoint), Some(region)) => Self { + url: endpoint.clone().into(), + region: Region::new(region.clone()), + }, + _ => Endpoint::DEFAULT_ENDPOINT, + } + } + + pub fn load_q() -> Self { + match crate::fig_settings::settings::get_value("api.q.service") { + Ok(Some(Value::Object(o))) => { + let endpoint = o.get("endpoint").and_then(|v| v.as_str()); + let region = o.get("region").and_then(|v| v.as_str()); + + match (endpoint, region) { + (Some(endpoint), Some(region)) => Self { + url: endpoint.to_owned().into(), + region: Region::new(region.to_owned()), + }, + _ => Endpoint::PROD_Q, + } + }, + _ => Endpoint::PROD_Q, + } + } + + pub(crate) fn url(&self) -> &str { + &self.url + } + + pub(crate) fn region(&self) -> &Region { + &self.region + } +} + +#[cfg(test)] +mod tests { + use url::Url; + + use super::*; + + #[test] + fn test_endpoints() { + let _ = Endpoint::load_codewhisperer(); + let _ = Endpoint::load_q(); + + let prod = &Endpoint::DEFAULT_ENDPOINT; + Url::parse(prod.url()).unwrap(); + assert_eq!(prod.region(), &PROD_CODEWHISPERER_ENDPOINT_REGION); + + let custom = Endpoint { + region: Region::new("us-west-2"), + url: "https://example.com".into(), + }; + Url::parse(custom.url()).unwrap(); + assert_eq!(custom.region(), &Region::new("us-west-2")); + } +} diff --git a/crates/kiro-cli/src/fig_api_client/error.rs b/crates/kiro-cli/src/fig_api_client/error.rs new file mode 100644 index 0000000000..52dca7db11 --- /dev/null +++ b/crates/kiro-cli/src/fig_api_client/error.rs @@ -0,0 +1,178 @@ +use amzn_codewhisperer_client::operation::generate_completions::GenerateCompletionsError; +use amzn_codewhisperer_client::operation::list_available_customizations::ListAvailableCustomizationsError; +use amzn_codewhisperer_client::operation::list_available_profiles::ListAvailableProfilesError; +pub use amzn_codewhisperer_streaming_client::operation::generate_assistant_response::GenerateAssistantResponseError; +use amzn_codewhisperer_streaming_client::types::error::ChatResponseStreamError as CodewhispererChatResponseStreamError; +use amzn_consolas_client::operation::generate_recommendations::GenerateRecommendationsError; +use amzn_consolas_client::operation::list_customizations::ListCustomizationsError; +use amzn_qdeveloper_streaming_client::operation::send_message::SendMessageError as QDeveloperSendMessageError; +use amzn_qdeveloper_streaming_client::types::error::ChatResponseStreamError as QDeveloperChatResponseStreamError; +use aws_credential_types::provider::error::CredentialsError; +use aws_smithy_runtime_api::client::orchestrator::HttpResponse; +pub use aws_smithy_runtime_api::client::result::SdkError; +use aws_smithy_types::event_stream::RawMessage; +use thiserror::Error; + +use crate::fig_aws_common::SdkErrorDisplay; + +#[derive(Debug, Error)] +pub enum Error { + #[error("failed to load credentials: {}", .0)] + Credentials(CredentialsError), + + // Generate completions errors + #[error("{}", SdkErrorDisplay(.0))] + GenerateCompletions(#[from] SdkError), + #[error("{}", SdkErrorDisplay(.0))] + GenerateRecommendations(#[from] SdkError), + + // List customizations error + #[error("{}", SdkErrorDisplay(.0))] + ListAvailableCustomizations(#[from] SdkError), + #[error("{}", SdkErrorDisplay(.0))] + ListAvailableServices(#[from] SdkError), + + // Send message errors + #[error("{}", SdkErrorDisplay(.0))] + CodewhispererGenerateAssistantResponse(#[from] SdkError), + #[error("{}", SdkErrorDisplay(.0))] + QDeveloperSendMessage(#[from] SdkError), + + // chat stream errors + #[error("{}", SdkErrorDisplay(.0))] + CodewhispererChatResponseStream(#[from] SdkError), + #[error("{}", SdkErrorDisplay(.0))] + QDeveloperChatResponseStream(#[from] SdkError), + + // quota breach + #[error("quota has reached its limit")] + QuotaBreach(&'static str), + + /// Returned from the backend when the user input is too large to fit within the model context + /// window. + /// + /// Note that we currently do not receive token usage information regarding how large the + /// context window is. + #[error("the context window has overflowed")] + ContextWindowOverflow, + + #[error(transparent)] + SmithyBuild(#[from] aws_smithy_types::error::operation::BuildError), + + #[error("unsupported action by consolas: {0}")] + UnsupportedConsolas(&'static str), + + #[error(transparent)] + ListAvailableProfilesError(#[from] SdkError), +} + +impl Error { + pub fn is_throttling_error(&self) -> bool { + match self { + Error::Credentials(_) => false, + Error::GenerateCompletions(e) => e.as_service_error().is_some_and(|e| e.is_throttling_error()), + Error::GenerateRecommendations(e) => e.as_service_error().is_some_and(|e| e.is_throttling_error()), + Error::ListAvailableCustomizations(e) => e.as_service_error().is_some_and(|e| e.is_throttling_error()), + Error::ListAvailableServices(e) => e.as_service_error().is_some_and(|e| e.is_throttling_error()), + Error::CodewhispererGenerateAssistantResponse(e) => { + e.as_service_error().is_some_and(|e| e.is_throttling_error()) + }, + Error::QDeveloperSendMessage(e) => e.as_service_error().is_some_and(|e| e.is_throttling_error()), + Error::ListAvailableProfilesError(e) => e.as_service_error().is_some_and(|e| e.is_throttling_error()), + Error::CodewhispererChatResponseStream(_) + | Error::QDeveloperChatResponseStream(_) + | Error::SmithyBuild(_) + | Error::UnsupportedConsolas(_) + | Error::ContextWindowOverflow + | Error::QuotaBreach(_) => false, + } + } + + pub fn is_service_error(&self) -> bool { + match self { + Error::Credentials(_) => false, + Error::GenerateCompletions(e) => e.as_service_error().is_some(), + Error::GenerateRecommendations(e) => e.as_service_error().is_some(), + Error::ListAvailableCustomizations(e) => e.as_service_error().is_some(), + Error::ListAvailableServices(e) => e.as_service_error().is_some(), + Error::CodewhispererGenerateAssistantResponse(e) => e.as_service_error().is_some(), + Error::QDeveloperSendMessage(e) => e.as_service_error().is_some(), + Error::ContextWindowOverflow => true, + Error::ListAvailableProfilesError(e) => e.as_service_error().is_some(), + Error::CodewhispererChatResponseStream(_) + | Error::QDeveloperChatResponseStream(_) + | Error::SmithyBuild(_) + | Error::UnsupportedConsolas(_) + | Error::QuotaBreach(_) => false, + } + } +} + +#[cfg(test)] +mod tests { + use std::error::Error as _; + + use aws_smithy_runtime_api::http::Response; + use aws_smithy_types::body::SdkBody; + use aws_smithy_types::event_stream::Message; + + use super::*; + + fn response() -> Response { + Response::new(500.try_into().unwrap(), SdkBody::empty()) + } + + fn raw_message() -> RawMessage { + RawMessage::Decoded(Message::new(b"".to_vec())) + } + + fn all_errors() -> Vec { + vec![ + Error::Credentials(CredentialsError::unhandled("")), + Error::GenerateCompletions(SdkError::service_error( + GenerateCompletionsError::unhandled(""), + response(), + )), + Error::GenerateRecommendations(SdkError::service_error( + GenerateRecommendationsError::unhandled(""), + response(), + )), + Error::ListAvailableCustomizations(SdkError::service_error( + ListAvailableCustomizationsError::unhandled(""), + response(), + )), + Error::ListAvailableServices(SdkError::service_error( + ListCustomizationsError::unhandled(""), + response(), + )), + Error::CodewhispererGenerateAssistantResponse(SdkError::service_error( + GenerateAssistantResponseError::unhandled(""), + response(), + )), + Error::QDeveloperSendMessage(SdkError::service_error( + QDeveloperSendMessageError::unhandled(""), + response(), + )), + Error::CodewhispererChatResponseStream(SdkError::service_error( + CodewhispererChatResponseStreamError::unhandled(""), + raw_message(), + )), + Error::QDeveloperChatResponseStream(SdkError::service_error( + QDeveloperChatResponseStreamError::unhandled(""), + raw_message(), + )), + Error::SmithyBuild(aws_smithy_types::error::operation::BuildError::other("")), + Error::UnsupportedConsolas("test"), + ] + } + + #[test] + fn test_errors() { + for error in all_errors() { + let _ = error.is_throttling_error(); + let _ = error.is_service_error(); + let _ = error.source(); + println!("{error} {error:?}"); + } + } +} diff --git a/crates/kiro-cli/src/fig_api_client/interceptor/mod.rs b/crates/kiro-cli/src/fig_api_client/interceptor/mod.rs new file mode 100644 index 0000000000..5722738da9 --- /dev/null +++ b/crates/kiro-cli/src/fig_api_client/interceptor/mod.rs @@ -0,0 +1,2 @@ +pub mod opt_out; +pub mod session_id; diff --git a/crates/kiro-cli/src/fig_api_client/interceptor/opt_out.rs b/crates/kiro-cli/src/fig_api_client/interceptor/opt_out.rs new file mode 100644 index 0000000000..6ebc40817a --- /dev/null +++ b/crates/kiro-cli/src/fig_api_client/interceptor/opt_out.rs @@ -0,0 +1,89 @@ +use aws_smithy_runtime_api::box_error::BoxError; +use aws_smithy_runtime_api::client::interceptors::Intercept; +use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut; +use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents; +use aws_smithy_types::config_bag::ConfigBag; + +use crate::fig_api_client::consts::{ + SHARE_CODEWHISPERER_CONTENT_SETTINGS_KEY, + X_AMZN_CODEWHISPERER_OPT_OUT_HEADER, +}; + +fn is_codewhisperer_content_optout() -> bool { + !crate::fig_settings::settings::get_bool_or(SHARE_CODEWHISPERER_CONTENT_SETTINGS_KEY, true) +} + +#[derive(Debug, Clone)] +pub struct OptOutInterceptor { + override_value: Option, + _inner: (), +} + +impl OptOutInterceptor { + pub const fn new() -> Self { + Self { + override_value: None, + _inner: (), + } + } +} + +impl Intercept for OptOutInterceptor { + fn name(&self) -> &'static str { + "OptOutInterceptor" + } + + fn modify_before_signing( + &self, + context: &mut BeforeTransmitInterceptorContextMut<'_>, + _runtime_components: &RuntimeComponents, + _cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + let opt_out = self.override_value.unwrap_or_else(is_codewhisperer_content_optout); + context + .request_mut() + .headers_mut() + .insert(X_AMZN_CODEWHISPERER_OPT_OUT_HEADER, opt_out.to_string()); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use amzn_consolas_client::config::RuntimeComponentsBuilder; + use amzn_consolas_client::config::interceptors::InterceptorContext; + use aws_smithy_runtime_api::client::interceptors::context::Input; + + use super::*; + + #[test] + fn test_opt_out_interceptor() { + let rc = RuntimeComponentsBuilder::for_tests().build().unwrap(); + let mut cfg = ConfigBag::base(); + + let mut context = InterceptorContext::new(Input::erase(())); + context.set_request(aws_smithy_runtime_api::http::Request::empty()); + let mut context = BeforeTransmitInterceptorContextMut::from(&mut context); + + let mut interceptor = OptOutInterceptor::new(); + println!("Interceptor: {}", interceptor.name()); + + interceptor + .modify_before_signing(&mut context, &rc, &mut cfg) + .expect("success"); + + interceptor.override_value = Some(false); + interceptor + .modify_before_signing(&mut context, &rc, &mut cfg) + .expect("success"); + let val = context.request().headers().get(X_AMZN_CODEWHISPERER_OPT_OUT_HEADER); + assert_eq!(val, Some("false")); + + interceptor.override_value = Some(true); + interceptor + .modify_before_signing(&mut context, &rc, &mut cfg) + .expect("success"); + let val = context.request().headers().get(X_AMZN_CODEWHISPERER_OPT_OUT_HEADER); + assert_eq!(val, Some("true")); + } +} diff --git a/crates/kiro-cli/src/fig_api_client/interceptor/session_id.rs b/crates/kiro-cli/src/fig_api_client/interceptor/session_id.rs new file mode 100644 index 0000000000..0ab6b2bf95 --- /dev/null +++ b/crates/kiro-cli/src/fig_api_client/interceptor/session_id.rs @@ -0,0 +1,82 @@ +use std::sync::{ + Arc, + Mutex, +}; + +use aws_smithy_runtime_api::client::interceptors::Intercept; +use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents; +use aws_smithy_types::config_bag::ConfigBag; + +use crate::fig_api_client::consts::X_AMZN_SESSIONID_HEADER; + +#[derive(Debug, Clone)] +pub struct SessionIdInterceptor { + inner: Arc>>, +} + +impl SessionIdInterceptor { + pub const fn new(inner: Arc>>) -> Self { + Self { inner } + } +} + +impl Intercept for SessionIdInterceptor { + fn name(&self) -> &'static str { + "SessionIdInterceptor" + } + + fn read_after_deserialization( + &self, + context: &amzn_codewhisperer_client::config::interceptors::AfterDeserializationInterceptorContextRef<'_>, + _runtime_components: &RuntimeComponents, + _cfg: &mut ConfigBag, + ) -> Result<(), amzn_codewhisperer_client::error::BoxError> { + *self + .inner + .lock() + .expect("Failed to write to SessionIdInterceptor mutex") = context + .response() + .headers() + .get(X_AMZN_SESSIONID_HEADER) + .map(Into::into); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use amzn_consolas_client::config::RuntimeComponentsBuilder; + use amzn_consolas_client::config::interceptors::{ + AfterDeserializationInterceptorContextRef, + InterceptorContext, + }; + use aws_smithy_runtime_api::client::interceptors::context::Input; + use aws_smithy_runtime_api::http::StatusCode; + use aws_smithy_types::body::SdkBody; + + use super::*; + + #[test] + fn test_opt_out_interceptor() { + let rc = RuntimeComponentsBuilder::for_tests().build().unwrap(); + let mut cfg = ConfigBag::base(); + + let mut context = InterceptorContext::new(Input::erase(())); + let mut response = + aws_smithy_runtime_api::http::Response::new(StatusCode::try_from(200).unwrap(), SdkBody::empty()); + response + .headers_mut() + .insert(X_AMZN_SESSIONID_HEADER, "test-session-id"); + context.set_response(response); + let context = AfterDeserializationInterceptorContextRef::from(&context); + + let session_id_lock = Arc::new(Mutex::new(None)); + let interceptor = SessionIdInterceptor::new(session_id_lock.clone()); + println!("Interceptor: {}", interceptor.name()); + + interceptor + .read_after_deserialization(&context, &rc, &mut cfg) + .expect("success"); + assert_eq!(*session_id_lock.lock().unwrap(), Some("test-session-id".to_string())); + } +} diff --git a/crates/kiro-cli/src/fig_api_client/mod.rs b/crates/kiro-cli/src/fig_api_client/mod.rs new file mode 100644 index 0000000000..1a90635d0d --- /dev/null +++ b/crates/kiro-cli/src/fig_api_client/mod.rs @@ -0,0 +1,17 @@ +pub mod clients; +pub(crate) mod consts; +pub(crate) mod credentials; +mod customization; +mod endpoints; +mod error; +pub(crate) mod interceptor; +pub mod model; +pub mod profile; + +pub use clients::{ + Client, + StreamingClient, +}; +pub use endpoints::Endpoint; +pub use error::Error; +pub use profile::list_available_profiles; diff --git a/crates/kiro-cli/src/fig_api_client/model.rs b/crates/kiro-cli/src/fig_api_client/model.rs new file mode 100644 index 0000000000..ab44cdb119 --- /dev/null +++ b/crates/kiro-cli/src/fig_api_client/model.rs @@ -0,0 +1,924 @@ +use aws_smithy_types::Document; +use serde::{ + Deserialize, + Serialize, +}; + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FileContext { + pub left_file_content: String, + pub right_file_content: String, + pub filename: String, + pub programming_language: ProgrammingLanguage, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ProgrammingLanguage { + pub language_name: LanguageName, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, strum::AsRefStr)] +#[serde(rename_all = "lowercase")] +#[strum(serialize_all = "lowercase")] +pub enum LanguageName { + Python, + Javascript, + Java, + Csharp, + Typescript, + C, + Cpp, + Go, + Kotlin, + Php, + Ruby, + Rust, + Scala, + Shell, + Sql, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ReferenceTrackerConfiguration { + pub recommendations_with_references: RecommendationsWithReferences, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "UPPERCASE")] +pub enum RecommendationsWithReferences { + Block, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RecommendationsInput { + pub file_context: FileContext, + pub max_results: i32, + pub next_token: Option, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RecommendationsOutput { + pub recommendations: Vec, + pub next_token: Option, + pub session_id: Option, + pub request_id: Option, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Recommendation { + pub content: String, +} + +// ========= +// Streaming +// ========= + +#[derive(Debug, Clone)] +pub struct ConversationState { + pub conversation_id: Option, + pub user_input_message: UserInputMessage, + pub history: Option>, +} + +#[derive(Debug, Clone)] +pub enum ChatMessage { + AssistantResponseMessage(AssistantResponseMessage), + UserInputMessage(UserInputMessage), +} + +impl TryFrom for amzn_codewhisperer_streaming_client::types::ChatMessage { + type Error = aws_smithy_types::error::operation::BuildError; + + fn try_from(value: ChatMessage) -> Result { + Ok(match value { + ChatMessage::AssistantResponseMessage(message) => { + amzn_codewhisperer_streaming_client::types::ChatMessage::AssistantResponseMessage(message.try_into()?) + }, + ChatMessage::UserInputMessage(message) => { + amzn_codewhisperer_streaming_client::types::ChatMessage::UserInputMessage(message.into()) + }, + }) + } +} + +impl TryFrom for amzn_qdeveloper_streaming_client::types::ChatMessage { + type Error = aws_smithy_types::error::operation::BuildError; + + fn try_from(value: ChatMessage) -> Result { + Ok(match value { + ChatMessage::AssistantResponseMessage(message) => { + amzn_qdeveloper_streaming_client::types::ChatMessage::AssistantResponseMessage(message.try_into()?) + }, + ChatMessage::UserInputMessage(message) => { + amzn_qdeveloper_streaming_client::types::ChatMessage::UserInputMessage(message.into()) + }, + }) + } +} + +/// Information about a tool that can be used. +#[derive(Debug, Clone)] +pub enum Tool { + ToolSpecification(ToolSpecification), +} + +impl From for amzn_codewhisperer_streaming_client::types::Tool { + fn from(value: Tool) -> Self { + match value { + Tool::ToolSpecification(v) => amzn_codewhisperer_streaming_client::types::Tool::ToolSpecification(v.into()), + } + } +} + +impl From for amzn_qdeveloper_streaming_client::types::Tool { + fn from(value: Tool) -> Self { + match value { + Tool::ToolSpecification(v) => amzn_qdeveloper_streaming_client::types::Tool::ToolSpecification(v.into()), + } + } +} + +/// The specification for the tool. +#[derive(Debug, Clone)] +pub struct ToolSpecification { + /// The name for the tool. + pub name: String, + /// The description for the tool. + pub description: String, + /// The input schema for the tool in JSON format. + pub input_schema: ToolInputSchema, +} + +impl From for amzn_codewhisperer_streaming_client::types::ToolSpecification { + fn from(value: ToolSpecification) -> Self { + Self::builder() + .name(value.name) + .description(value.description) + .input_schema(value.input_schema.into()) + .build() + .expect("building ToolSpecification should not fail") + } +} + +impl From for amzn_qdeveloper_streaming_client::types::ToolSpecification { + fn from(value: ToolSpecification) -> Self { + Self::builder() + .name(value.name) + .description(value.description) + .input_schema(value.input_schema.into()) + .build() + .expect("building ToolSpecification should not fail") + } +} + +/// The input schema for the tool in JSON format. +#[derive(Debug, Clone)] +pub struct ToolInputSchema { + pub json: Option, +} + +impl From for amzn_codewhisperer_streaming_client::types::ToolInputSchema { + fn from(value: ToolInputSchema) -> Self { + Self::builder().set_json(value.json).build() + } +} + +impl From for amzn_qdeveloper_streaming_client::types::ToolInputSchema { + fn from(value: ToolInputSchema) -> Self { + Self::builder().set_json(value.json).build() + } +} + +/// Contains information about a tool that the model is requesting be run. The model uses the result +/// from the tool to generate a response. +#[derive(Debug, Clone)] +pub struct ToolUse { + /// The ID for the tool request. + pub tool_use_id: String, + /// The name for the tool. + pub name: String, + /// The input to pass to the tool. + pub input: Document, +} + +impl From for amzn_codewhisperer_streaming_client::types::ToolUse { + fn from(value: ToolUse) -> Self { + Self::builder() + .tool_use_id(value.tool_use_id) + .name(value.name) + .input(value.input) + .build() + .expect("building ToolUse should not fail") + } +} + +impl From for amzn_qdeveloper_streaming_client::types::ToolUse { + fn from(value: ToolUse) -> Self { + Self::builder() + .tool_use_id(value.tool_use_id) + .name(value.name) + .input(value.input) + .build() + .expect("building ToolUse should not fail") + } +} + +/// A tool result that contains the results for a tool request that was previously made. +#[derive(Debug, Clone)] +pub struct ToolResult { + /// The ID for the tool request. + pub tool_use_id: String, + /// Content of the tool result. + pub content: Vec, + /// Status of the tools result. + pub status: ToolResultStatus, +} + +impl From for amzn_codewhisperer_streaming_client::types::ToolResult { + fn from(value: ToolResult) -> Self { + Self::builder() + .tool_use_id(value.tool_use_id) + .set_content(Some(value.content.into_iter().map(Into::into).collect::<_>())) + .status(value.status.into()) + .build() + .expect("building ToolResult should not fail") + } +} + +impl From for amzn_qdeveloper_streaming_client::types::ToolResult { + fn from(value: ToolResult) -> Self { + Self::builder() + .tool_use_id(value.tool_use_id) + .set_content(Some(value.content.into_iter().map(Into::into).collect::<_>())) + .status(value.status.into()) + .build() + .expect("building ToolResult should not fail") + } +} + +#[derive(Debug, Clone)] +pub enum ToolResultContentBlock { + /// A tool result that is JSON format data. + Json(Document), + /// A tool result that is text. + Text(String), +} + +impl From for amzn_codewhisperer_streaming_client::types::ToolResultContentBlock { + fn from(value: ToolResultContentBlock) -> Self { + match value { + ToolResultContentBlock::Json(document) => Self::Json(document), + ToolResultContentBlock::Text(text) => Self::Text(text), + } + } +} + +impl From for amzn_qdeveloper_streaming_client::types::ToolResultContentBlock { + fn from(value: ToolResultContentBlock) -> Self { + match value { + ToolResultContentBlock::Json(document) => Self::Json(document), + ToolResultContentBlock::Text(text) => Self::Text(text), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ToolResultStatus { + Error, + Success, +} + +impl From for amzn_codewhisperer_streaming_client::types::ToolResultStatus { + fn from(value: ToolResultStatus) -> Self { + match value { + ToolResultStatus::Error => Self::Error, + ToolResultStatus::Success => Self::Success, + } + } +} + +impl From for amzn_qdeveloper_streaming_client::types::ToolResultStatus { + fn from(value: ToolResultStatus) -> Self { + match value { + ToolResultStatus::Error => Self::Error, + ToolResultStatus::Success => Self::Success, + } + } +} + +/// Markdown text message. +#[derive(Debug, Clone)] +pub struct AssistantResponseMessage { + /// Unique identifier for the chat message + pub message_id: Option, + /// The content of the text message in markdown format. + pub content: String, + /// ToolUse Request + pub tool_uses: Option>, +} + +impl TryFrom for amzn_codewhisperer_streaming_client::types::AssistantResponseMessage { + type Error = aws_smithy_types::error::operation::BuildError; + + fn try_from(value: AssistantResponseMessage) -> Result { + Self::builder() + .content(value.content) + .set_message_id(value.message_id) + .set_tool_uses(value.tool_uses.map(|uses| uses.into_iter().map(Into::into).collect())) + .build() + } +} + +impl TryFrom for amzn_qdeveloper_streaming_client::types::AssistantResponseMessage { + type Error = aws_smithy_types::error::operation::BuildError; + + fn try_from(value: AssistantResponseMessage) -> Result { + Self::builder() + .content(value.content) + .set_message_id(value.message_id) + .set_tool_uses(value.tool_uses.map(|uses| uses.into_iter().map(Into::into).collect())) + .build() + } +} + +#[non_exhaustive] +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ChatResponseStream { + AssistantResponseEvent { + content: String, + }, + /// Streaming response event for generated code text. + CodeEvent { + content: String, + }, + // TODO: finish events here + CodeReferenceEvent(()), + FollowupPromptEvent(()), + IntentsEvent(()), + InvalidStateEvent { + reason: String, + message: String, + }, + MessageMetadataEvent { + conversation_id: Option, + utterance_id: Option, + }, + SupplementaryWebLinksEvent(()), + ToolUseEvent { + tool_use_id: String, + name: String, + input: Option, + stop: Option, + }, + + #[non_exhaustive] + Unknown, +} + +impl From for ChatResponseStream { + fn from(value: amzn_codewhisperer_streaming_client::types::ChatResponseStream) -> Self { + match value { + amzn_codewhisperer_streaming_client::types::ChatResponseStream::AssistantResponseEvent( + amzn_codewhisperer_streaming_client::types::AssistantResponseEvent { content, .. }, + ) => ChatResponseStream::AssistantResponseEvent { content }, + amzn_codewhisperer_streaming_client::types::ChatResponseStream::CodeEvent( + amzn_codewhisperer_streaming_client::types::CodeEvent { content, .. }, + ) => ChatResponseStream::CodeEvent { content }, + amzn_codewhisperer_streaming_client::types::ChatResponseStream::CodeReferenceEvent(_) => { + ChatResponseStream::CodeReferenceEvent(()) + }, + amzn_codewhisperer_streaming_client::types::ChatResponseStream::FollowupPromptEvent(_) => { + ChatResponseStream::FollowupPromptEvent(()) + }, + amzn_codewhisperer_streaming_client::types::ChatResponseStream::IntentsEvent(_) => { + ChatResponseStream::IntentsEvent(()) + }, + amzn_codewhisperer_streaming_client::types::ChatResponseStream::InvalidStateEvent( + amzn_codewhisperer_streaming_client::types::InvalidStateEvent { reason, message, .. }, + ) => ChatResponseStream::InvalidStateEvent { + reason: reason.to_string(), + message, + }, + amzn_codewhisperer_streaming_client::types::ChatResponseStream::MessageMetadataEvent( + amzn_codewhisperer_streaming_client::types::MessageMetadataEvent { + conversation_id, + utterance_id, + .. + }, + ) => ChatResponseStream::MessageMetadataEvent { + conversation_id, + utterance_id, + }, + amzn_codewhisperer_streaming_client::types::ChatResponseStream::ToolUseEvent( + amzn_codewhisperer_streaming_client::types::ToolUseEvent { + tool_use_id, + name, + input, + stop, + .. + }, + ) => ChatResponseStream::ToolUseEvent { + tool_use_id, + name, + input, + stop, + }, + amzn_codewhisperer_streaming_client::types::ChatResponseStream::SupplementaryWebLinksEvent(_) => { + ChatResponseStream::SupplementaryWebLinksEvent(()) + }, + _ => ChatResponseStream::Unknown, + } + } +} + +impl From for ChatResponseStream { + fn from(value: amzn_qdeveloper_streaming_client::types::ChatResponseStream) -> Self { + match value { + amzn_qdeveloper_streaming_client::types::ChatResponseStream::AssistantResponseEvent( + amzn_qdeveloper_streaming_client::types::AssistantResponseEvent { content, .. }, + ) => ChatResponseStream::AssistantResponseEvent { content }, + amzn_qdeveloper_streaming_client::types::ChatResponseStream::CodeEvent( + amzn_qdeveloper_streaming_client::types::CodeEvent { content, .. }, + ) => ChatResponseStream::CodeEvent { content }, + amzn_qdeveloper_streaming_client::types::ChatResponseStream::CodeReferenceEvent(_) => { + ChatResponseStream::CodeReferenceEvent(()) + }, + amzn_qdeveloper_streaming_client::types::ChatResponseStream::FollowupPromptEvent(_) => { + ChatResponseStream::FollowupPromptEvent(()) + }, + amzn_qdeveloper_streaming_client::types::ChatResponseStream::IntentsEvent(_) => { + ChatResponseStream::IntentsEvent(()) + }, + amzn_qdeveloper_streaming_client::types::ChatResponseStream::InvalidStateEvent( + amzn_qdeveloper_streaming_client::types::InvalidStateEvent { reason, message, .. }, + ) => ChatResponseStream::InvalidStateEvent { + reason: reason.to_string(), + message, + }, + amzn_qdeveloper_streaming_client::types::ChatResponseStream::MessageMetadataEvent( + amzn_qdeveloper_streaming_client::types::MessageMetadataEvent { + conversation_id, + utterance_id, + .. + }, + ) => ChatResponseStream::MessageMetadataEvent { + conversation_id, + utterance_id, + }, + amzn_qdeveloper_streaming_client::types::ChatResponseStream::ToolUseEvent( + amzn_qdeveloper_streaming_client::types::ToolUseEvent { + tool_use_id, + name, + input, + stop, + .. + }, + ) => ChatResponseStream::ToolUseEvent { + tool_use_id, + name, + input, + stop, + }, + amzn_qdeveloper_streaming_client::types::ChatResponseStream::SupplementaryWebLinksEvent(_) => { + ChatResponseStream::SupplementaryWebLinksEvent(()) + }, + _ => ChatResponseStream::Unknown, + } + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct EnvState { + pub operating_system: Option, + pub current_working_directory: Option, + pub environment_variables: Vec, +} + +impl From for amzn_codewhisperer_streaming_client::types::EnvState { + fn from(value: EnvState) -> Self { + let environment_variables: Vec<_> = value.environment_variables.into_iter().map(Into::into).collect(); + Self::builder() + .set_operating_system(value.operating_system) + .set_current_working_directory(value.current_working_directory) + .set_environment_variables(if environment_variables.is_empty() { + None + } else { + Some(environment_variables) + }) + .build() + } +} + +impl From for amzn_qdeveloper_streaming_client::types::EnvState { + fn from(value: EnvState) -> Self { + let environment_variables: Vec<_> = value.environment_variables.into_iter().map(Into::into).collect(); + Self::builder() + .set_operating_system(value.operating_system) + .set_current_working_directory(value.current_working_directory) + .set_environment_variables(if environment_variables.is_empty() { + None + } else { + Some(environment_variables) + }) + .build() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EnvironmentVariable { + pub key: String, + pub value: String, +} + +impl From for amzn_codewhisperer_streaming_client::types::EnvironmentVariable { + fn from(value: EnvironmentVariable) -> Self { + Self::builder().key(value.key).value(value.value).build() + } +} + +impl From for amzn_qdeveloper_streaming_client::types::EnvironmentVariable { + fn from(value: EnvironmentVariable) -> Self { + Self::builder().key(value.key).value(value.value).build() + } +} + +#[derive(Debug, Clone)] +pub struct GitState { + pub status: String, +} + +impl From for amzn_codewhisperer_streaming_client::types::GitState { + fn from(value: GitState) -> Self { + Self::builder().status(value.status).build() + } +} + +impl From for amzn_qdeveloper_streaming_client::types::GitState { + fn from(value: GitState) -> Self { + Self::builder().status(value.status).build() + } +} + +#[derive(Debug, Clone)] +pub struct UserInputMessage { + pub content: String, + pub user_input_message_context: Option, + pub user_intent: Option, +} + +impl From for amzn_codewhisperer_streaming_client::types::UserInputMessage { + fn from(value: UserInputMessage) -> Self { + Self::builder() + .content(value.content) + .set_user_input_message_context(value.user_input_message_context.map(Into::into)) + .set_user_intent(value.user_intent.map(Into::into)) + .origin(amzn_codewhisperer_streaming_client::types::Origin::Cli) + .build() + .expect("Failed to build UserInputMessage") + } +} + +impl From for amzn_qdeveloper_streaming_client::types::UserInputMessage { + fn from(value: UserInputMessage) -> Self { + Self::builder() + .content(value.content) + .set_user_input_message_context(value.user_input_message_context.map(Into::into)) + .set_user_intent(value.user_intent.map(Into::into)) + .origin(amzn_qdeveloper_streaming_client::types::Origin::Cli) + .build() + .expect("Failed to build UserInputMessage") + } +} + +#[derive(Debug, Clone, Default)] +pub struct UserInputMessageContext { + pub env_state: Option, + pub git_state: Option, + pub tool_results: Option>, + pub tools: Option>, +} + +impl From for amzn_codewhisperer_streaming_client::types::UserInputMessageContext { + fn from(value: UserInputMessageContext) -> Self { + Self::builder() + .set_env_state(value.env_state.map(Into::into)) + .set_git_state(value.git_state.map(Into::into)) + .set_tool_results(value.tool_results.map(|t| t.into_iter().map(Into::into).collect())) + .set_tools(value.tools.map(|t| t.into_iter().map(Into::into).collect())) + .build() + } +} + +impl From for amzn_qdeveloper_streaming_client::types::UserInputMessageContext { + fn from(value: UserInputMessageContext) -> Self { + Self::builder() + .set_env_state(value.env_state.map(Into::into)) + .set_git_state(value.git_state.map(Into::into)) + .set_tool_results(value.tool_results.map(|t| t.into_iter().map(Into::into).collect())) + .set_tools(value.tools.map(|t| t.into_iter().map(Into::into).collect())) + .build() + } +} + +#[derive(Debug, Clone)] +pub enum UserIntent { + ApplyCommonBestPractices, +} + +impl From for amzn_codewhisperer_streaming_client::types::UserIntent { + fn from(value: UserIntent) -> Self { + match value { + UserIntent::ApplyCommonBestPractices => Self::ApplyCommonBestPractices, + } + } +} + +impl From for amzn_qdeveloper_streaming_client::types::UserIntent { + fn from(value: UserIntent) -> Self { + match value { + UserIntent::ApplyCommonBestPractices => Self::ApplyCommonBestPractices, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn build_user_input_message() { + let user_input_message = UserInputMessage { + content: "test content".to_string(), + user_input_message_context: Some(UserInputMessageContext { + env_state: Some(EnvState { + operating_system: Some("test os".to_string()), + current_working_directory: Some("test cwd".to_string()), + environment_variables: vec![EnvironmentVariable { + key: "test key".to_string(), + value: "test value".to_string(), + }], + }), + git_state: Some(GitState { + status: "test status".to_string(), + }), + tool_results: Some(vec![ToolResult { + tool_use_id: "test id".to_string(), + content: vec![ToolResultContentBlock::Text("test text".to_string())], + status: ToolResultStatus::Success, + }]), + tools: Some(vec![Tool::ToolSpecification(ToolSpecification { + name: "test tool name".to_string(), + description: "test tool description".to_string(), + input_schema: ToolInputSchema { + json: Some(Document::Null), + }, + })]), + }), + user_intent: Some(UserIntent::ApplyCommonBestPractices), + }; + + let codewhisper_input = + amzn_codewhisperer_streaming_client::types::UserInputMessage::from(user_input_message.clone()); + let qdeveloper_input = amzn_qdeveloper_streaming_client::types::UserInputMessage::from(user_input_message); + + assert_eq!(format!("{codewhisper_input:?}"), format!("{qdeveloper_input:?}")); + + let minimal_message = UserInputMessage { + content: "test content".to_string(), + user_input_message_context: None, + user_intent: None, + }; + + let codewhisper_minimal = + amzn_codewhisperer_streaming_client::types::UserInputMessage::from(minimal_message.clone()); + let qdeveloper_minimal = amzn_qdeveloper_streaming_client::types::UserInputMessage::from(minimal_message); + assert_eq!(format!("{codewhisper_minimal:?}"), format!("{qdeveloper_minimal:?}")); + } + + #[test] + fn build_assistant_response_message() { + let message = AssistantResponseMessage { + message_id: Some("testid".to_string()), + content: "test content".to_string(), + tool_uses: Some(vec![ToolUse { + tool_use_id: "tooluseid_test".to_string(), + name: "tool_name_test".to_string(), + input: Document::Object([("key1".to_string(), Document::Null)].into_iter().collect()), + }]), + }; + let codewhisper_input = + amzn_codewhisperer_streaming_client::types::AssistantResponseMessage::try_from(message.clone()).unwrap(); + let qdeveloper_input = + amzn_qdeveloper_streaming_client::types::AssistantResponseMessage::try_from(message).unwrap(); + assert_eq!(format!("{codewhisper_input:?}"), format!("{qdeveloper_input:?}")); + } + + #[test] + fn build_chat_response() { + let assistant_response_event = + amzn_codewhisperer_streaming_client::types::ChatResponseStream::AssistantResponseEvent( + amzn_codewhisperer_streaming_client::types::AssistantResponseEvent::builder() + .content("context") + .build() + .unwrap(), + ); + assert_eq!( + ChatResponseStream::from(assistant_response_event), + ChatResponseStream::AssistantResponseEvent { + content: "context".into(), + } + ); + + let assistant_response_event = + amzn_qdeveloper_streaming_client::types::ChatResponseStream::AssistantResponseEvent( + amzn_qdeveloper_streaming_client::types::AssistantResponseEvent::builder() + .content("context") + .build() + .unwrap(), + ); + assert_eq!( + ChatResponseStream::from(assistant_response_event), + ChatResponseStream::AssistantResponseEvent { + content: "context".into(), + } + ); + + let code_event = amzn_codewhisperer_streaming_client::types::ChatResponseStream::CodeEvent( + amzn_codewhisperer_streaming_client::types::CodeEvent::builder() + .content("context") + .build() + .unwrap(), + ); + assert_eq!(ChatResponseStream::from(code_event), ChatResponseStream::CodeEvent { + content: "context".into() + }); + + let code_event = amzn_qdeveloper_streaming_client::types::ChatResponseStream::CodeEvent( + amzn_qdeveloper_streaming_client::types::CodeEvent::builder() + .content("context") + .build() + .unwrap(), + ); + assert_eq!(ChatResponseStream::from(code_event), ChatResponseStream::CodeEvent { + content: "context".into() + }); + + let code_reference_event = amzn_codewhisperer_streaming_client::types::ChatResponseStream::CodeReferenceEvent( + amzn_codewhisperer_streaming_client::types::CodeReferenceEvent::builder().build(), + ); + assert_eq!( + ChatResponseStream::from(code_reference_event), + ChatResponseStream::CodeReferenceEvent(()) + ); + + let code_reference_event = amzn_qdeveloper_streaming_client::types::ChatResponseStream::CodeReferenceEvent( + amzn_qdeveloper_streaming_client::types::CodeReferenceEvent::builder().build(), + ); + assert_eq!( + ChatResponseStream::from(code_reference_event), + ChatResponseStream::CodeReferenceEvent(()) + ); + + let followup_prompt_event = amzn_codewhisperer_streaming_client::types::ChatResponseStream::FollowupPromptEvent( + amzn_codewhisperer_streaming_client::types::FollowupPromptEvent::builder().build(), + ); + assert_eq!( + ChatResponseStream::from(followup_prompt_event), + ChatResponseStream::FollowupPromptEvent(()) + ); + + let followup_prompt_event = amzn_qdeveloper_streaming_client::types::ChatResponseStream::FollowupPromptEvent( + amzn_qdeveloper_streaming_client::types::FollowupPromptEvent::builder().build(), + ); + assert_eq!( + ChatResponseStream::from(followup_prompt_event), + ChatResponseStream::FollowupPromptEvent(()) + ); + + let intents_event = amzn_codewhisperer_streaming_client::types::ChatResponseStream::IntentsEvent( + amzn_codewhisperer_streaming_client::types::IntentsEvent::builder().build(), + ); + assert_eq!( + ChatResponseStream::from(intents_event), + ChatResponseStream::IntentsEvent(()) + ); + + let intents_event = amzn_qdeveloper_streaming_client::types::ChatResponseStream::IntentsEvent( + amzn_qdeveloper_streaming_client::types::IntentsEvent::builder().build(), + ); + assert_eq!( + ChatResponseStream::from(intents_event), + ChatResponseStream::IntentsEvent(()) + ); + + let user_input_event = amzn_codewhisperer_streaming_client::types::ChatResponseStream::InvalidStateEvent( + amzn_codewhisperer_streaming_client::types::InvalidStateEvent::builder() + .reason(amzn_codewhisperer_streaming_client::types::InvalidStateReason::InvalidTaskAssistPlan) + .message("message") + .build() + .unwrap(), + ); + assert_eq!( + ChatResponseStream::from(user_input_event), + ChatResponseStream::InvalidStateEvent { + reason: amzn_codewhisperer_streaming_client::types::InvalidStateReason::InvalidTaskAssistPlan + .to_string(), + message: "message".into() + } + ); + + let user_input_event = amzn_qdeveloper_streaming_client::types::ChatResponseStream::InvalidStateEvent( + amzn_qdeveloper_streaming_client::types::InvalidStateEvent::builder() + .reason(amzn_qdeveloper_streaming_client::types::InvalidStateReason::InvalidTaskAssistPlan) + .message("message") + .build() + .unwrap(), + ); + assert_eq!( + ChatResponseStream::from(user_input_event), + ChatResponseStream::InvalidStateEvent { + reason: amzn_qdeveloper_streaming_client::types::InvalidStateReason::InvalidTaskAssistPlan.to_string(), + message: "message".into() + } + ); + + let user_input_event = amzn_codewhisperer_streaming_client::types::ChatResponseStream::MessageMetadataEvent( + amzn_codewhisperer_streaming_client::types::MessageMetadataEvent::builder().build(), + ); + assert_eq!( + ChatResponseStream::from(user_input_event), + ChatResponseStream::MessageMetadataEvent { + conversation_id: None, + utterance_id: None + } + ); + + let user_input_event = amzn_qdeveloper_streaming_client::types::ChatResponseStream::MessageMetadataEvent( + amzn_qdeveloper_streaming_client::types::MessageMetadataEvent::builder().build(), + ); + assert_eq!( + ChatResponseStream::from(user_input_event), + ChatResponseStream::MessageMetadataEvent { + conversation_id: None, + utterance_id: None + } + ); + + let user_input_event = + amzn_codewhisperer_streaming_client::types::ChatResponseStream::SupplementaryWebLinksEvent( + amzn_codewhisperer_streaming_client::types::SupplementaryWebLinksEvent::builder().build(), + ); + assert_eq!( + ChatResponseStream::from(user_input_event), + ChatResponseStream::SupplementaryWebLinksEvent(()) + ); + + let user_input_event = amzn_qdeveloper_streaming_client::types::ChatResponseStream::SupplementaryWebLinksEvent( + amzn_qdeveloper_streaming_client::types::SupplementaryWebLinksEvent::builder().build(), + ); + assert_eq!( + ChatResponseStream::from(user_input_event), + ChatResponseStream::SupplementaryWebLinksEvent(()) + ); + + let user_input_event = amzn_codewhisperer_streaming_client::types::ChatResponseStream::ToolUseEvent( + amzn_codewhisperer_streaming_client::types::ToolUseEvent::builder() + .tool_use_id("tool_use_id".to_string()) + .name("tool_name".to_string()) + .build() + .unwrap(), + ); + assert_eq!( + ChatResponseStream::from(user_input_event), + ChatResponseStream::ToolUseEvent { + tool_use_id: "tool_use_id".to_string(), + name: "tool_name".to_string(), + input: None, + stop: None, + } + ); + + let user_input_event = amzn_qdeveloper_streaming_client::types::ChatResponseStream::ToolUseEvent( + amzn_qdeveloper_streaming_client::types::ToolUseEvent::builder() + .tool_use_id("tool_use_id".to_string()) + .name("tool_name".to_string()) + .build() + .unwrap(), + ); + assert_eq!( + ChatResponseStream::from(user_input_event), + ChatResponseStream::ToolUseEvent { + tool_use_id: "tool_use_id".to_string(), + name: "tool_name".to_string(), + input: None, + stop: None, + } + ); + } +} diff --git a/crates/kiro-cli/src/fig_api_client/profile.rs b/crates/kiro-cli/src/fig_api_client/profile.rs new file mode 100644 index 0000000000..543554eaa7 --- /dev/null +++ b/crates/kiro-cli/src/fig_api_client/profile.rs @@ -0,0 +1,35 @@ +use serde::{ + Deserialize, + Serialize, +}; + +use crate::fig_api_client::Client; +use crate::fig_api_client::endpoints::Endpoint; + +#[derive(Debug, Deserialize, Serialize)] +pub struct Profile { + pub arn: String, + pub profile_name: String, +} + +impl From for Profile { + fn from(profile: amzn_codewhisperer_client::types::Profile) -> Self { + Self { + arn: profile.arn, + profile_name: profile.profile_name, + } + } +} + +pub async fn list_available_profiles() -> Vec { + let mut profiles = vec![]; + for endpoint in Endpoint::CODEWHISPERER_ENDPOINTS { + let client = Client::new_codewhisperer_client(&endpoint).await; + match client.list_available_profiles().await { + Ok(mut p) => profiles.append(&mut p), + Err(e) => tracing::error!("Failed to list profiles from endpoint {:?}: {:?}", endpoint, e), + } + } + + profiles +} diff --git a/crates/kiro-cli/src/fig_api_client/stage.rs b/crates/kiro-cli/src/fig_api_client/stage.rs new file mode 100644 index 0000000000..31b301786f --- /dev/null +++ b/crates/kiro-cli/src/fig_api_client/stage.rs @@ -0,0 +1,40 @@ +use std::str::FromStr; + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum Stage { + Prod, + Gamma, + Alpha, + Beta, +} + +impl Stage { + pub fn as_str(&self) -> &'static str { + match self { + Stage::Prod => "prod", + Stage::Gamma => "gamma", + Stage::Alpha => "alpha", + Stage::Beta => "beta", + } + } +} + +impl FromStr for Stage { + type Err = (); + + fn from_str(s: &str) -> Result { + match s.to_ascii_lowercase().trim() { + "prod" | "production" => Ok(Stage::Prod), + "gamma" => Ok(Stage::Gamma), + "alpha" => Ok(Stage::Alpha), + "beta" => Ok(Stage::Beta), + _ => Err(()), + } + } +} + +impl std::fmt::Display for Stage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} diff --git a/crates/kiro-cli/src/fig_auth/builder_id.rs b/crates/kiro-cli/src/fig_auth/builder_id.rs new file mode 100644 index 0000000000..24022b0093 --- /dev/null +++ b/crates/kiro-cli/src/fig_auth/builder_id.rs @@ -0,0 +1,708 @@ +//! # Builder ID +//! +//! SSO flow (RFC: ) +//! 1. Get a client id (SSO-OIDC identifier, formatted per RFC6749). +//! - Code: [DeviceRegistration::register] +//! - Calls [Client::register_client] +//! - RETURNS: [DeviceRegistration] +//! - Client registration is valid for potentially months and creates state server-side, so +//! the client SHOULD cache them to disk. +//! 2. Start device authorization. +//! - Code: [start_device_authorization] +//! - Calls [Client::start_device_authorization] +//! - RETURNS (RFC: ): +//! [StartDeviceAuthorizationResponse] +//! 3. Poll for the access token +//! - Code: [poll_create_token] +//! - Calls [Client::create_token] +//! - RETURNS: [PollCreateToken] +//! 4. (Repeat) Tokens SHOULD be refreshed if expired and a refresh token is available. +//! - Code: [BuilderIdToken::refresh_token] +//! - Calls [Client::create_token] +//! - RETURNS: [BuilderIdToken] + +use aws_sdk_ssooidc::client::Client; +use aws_sdk_ssooidc::config::retry::RetryConfig; +use aws_sdk_ssooidc::config::{ + BehaviorVersion, + ConfigBag, + RuntimeComponents, + SharedAsyncSleep, +}; +use aws_sdk_ssooidc::error::SdkError; +use aws_sdk_ssooidc::operation::create_token::CreateTokenOutput; +use aws_sdk_ssooidc::operation::register_client::RegisterClientOutput; +use aws_smithy_async::rt::sleep::TokioSleep; +use aws_smithy_runtime_api::client::identity::http::Token; +use aws_smithy_runtime_api::client::identity::{ + Identity, + IdentityFuture, + ResolveIdentity, +}; +use aws_smithy_types::error::display::DisplayErrorContext; +use aws_types::region::Region; +use aws_types::request_id::RequestId; +use time::OffsetDateTime; +use tracing::{ + debug, + error, + warn, +}; + +use crate::fig_auth::consts::*; +use crate::fig_auth::scope::is_scopes; +use crate::fig_auth::secret_store::{ + Secret, + SecretStore, +}; +use crate::fig_auth::{ + Error, + Result, +}; +use crate::fig_aws_common::app_name; +use crate::fig_telemetry_core::{ + Event, + EventType, + TelemetryResult, +}; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub enum OAuthFlow { + DeviceCode, + Pkce, +} + +impl std::fmt::Display for OAuthFlow { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match *self { + OAuthFlow::DeviceCode => write!(f, "DeviceCode"), + OAuthFlow::Pkce => write!(f, "PKCE"), + } + } +} + +/// Indicates if an expiration time has passed, there is a small 1 min window that is removed +/// so the token will not expire in transit +fn is_expired(expiration_time: &OffsetDateTime) -> bool { + let now = time::OffsetDateTime::now_utc(); + &(now + time::Duration::minutes(1)) > expiration_time +} + +pub(crate) fn oidc_url(region: &Region) -> String { + format!("https://oidc.{region}.amazonaws.com") +} + +pub(crate) fn client(region: Region) -> Client { + let retry_config = RetryConfig::standard().with_max_attempts(3); + let sdk_config = aws_types::SdkConfig::builder() + .http_client(crate::fig_aws_common::http_client::client()) + .behavior_version(BehaviorVersion::v2025_01_17()) + .endpoint_url(oidc_url(®ion)) + .region(region) + .retry_config(retry_config) + .sleep_impl(SharedAsyncSleep::new(TokioSleep::new())) + .app_name(app_name()) + .build(); + Client::new(&sdk_config) +} + +/// Represents an OIDC registered client, resulting from the "register client" API call. +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct DeviceRegistration { + pub client_id: String, + pub client_secret: Secret, + #[serde(with = "time::serde::rfc3339::option")] + pub client_secret_expires_at: Option, + pub region: String, + pub oauth_flow: OAuthFlow, + pub scopes: Option>, +} + +impl DeviceRegistration { + const SECRET_KEY: &'static str = "codewhisperer:odic:device-registration"; + + pub fn from_output( + output: RegisterClientOutput, + region: &Region, + oauth_flow: OAuthFlow, + scopes: Vec, + ) -> Self { + Self { + client_id: output.client_id.unwrap_or_default(), + client_secret: output.client_secret.unwrap_or_default().into(), + client_secret_expires_at: time::OffsetDateTime::from_unix_timestamp(output.client_secret_expires_at).ok(), + region: region.to_string(), + oauth_flow, + scopes: Some(scopes), + } + } + + /// Loads the OIDC registered client from the secret store, deleting it if it is expired. + async fn load_from_secret_store(secret_store: &SecretStore, region: &Region) -> Result> { + let device_registration = secret_store.get(Self::SECRET_KEY).await?; + + if let Some(device_registration) = device_registration { + // check that the data is not expired, assume it is invalid if not present + let device_registration: Self = serde_json::from_str(&device_registration.0)?; + + if let Some(client_secret_expires_at) = device_registration.client_secret_expires_at { + if !is_expired(&client_secret_expires_at) && device_registration.region == region.as_ref() { + return Ok(Some(device_registration)); + } + } + } + + // delete the data if its expired or invalid + if let Err(err) = secret_store.delete(Self::SECRET_KEY).await { + error!(?err, "Failed to delete device registration from keychain"); + } + + Ok(None) + } + + /// Loads the client saved in the secret store if available, otherwise registers a new client + /// and saves it in the secret store. + pub async fn init_device_code_registration( + client: &Client, + secret_store: &SecretStore, + region: &Region, + ) -> Result { + match Self::load_from_secret_store(secret_store, region).await { + Ok(Some(registration)) if registration.oauth_flow == OAuthFlow::DeviceCode => match ®istration.scopes { + Some(scopes) if is_scopes(scopes) => return Ok(registration), + _ => warn!("Invalid scopes in device registration, ignoring"), + }, + // If it doesn't exist or is for another OAuth flow, + // then continue with creating a new one. + Ok(None | Some(_)) => {}, + Err(err) => { + error!(?err, "Failed to read device registration from keychain"); + }, + }; + + let mut register = client + .register_client() + .client_name(CLIENT_NAME) + .client_type(CLIENT_TYPE); + for scope in SCOPES { + register = register.scopes(*scope); + } + let output = register.send().await?; + + let device_registration = Self::from_output( + output, + region, + OAuthFlow::DeviceCode, + SCOPES.iter().map(|s| (*s).to_owned()).collect(), + ); + + if let Err(err) = device_registration.save(secret_store).await { + error!(?err, "Failed to write device registration to keychain"); + } + + Ok(device_registration) + } + + /// Saves to the passed secret store. + pub async fn save(&self, secret_store: &SecretStore) -> Result<()> { + secret_store + .set(Self::SECRET_KEY, &serde_json::to_string(&self)?) + .await?; + Ok(()) + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct StartDeviceAuthorizationResponse { + /// Device verification code. + pub device_code: String, + /// User verification code. + pub user_code: String, + /// Verification URI on the authorization server. + pub verification_uri: String, + /// User verification URI on the authorization server. + pub verification_uri_complete: String, + /// Lifetime (seconds) of `device_code` and `user_code`. + pub expires_in: i32, + /// Minimum time (seconds) the client SHOULD wait between polling intervals. + pub interval: i32, + pub region: String, + pub start_url: String, +} + +/// Init a builder id request +pub async fn start_device_authorization( + secret_store: &SecretStore, + start_url: Option, + region: Option, +) -> Result { + let region = region.clone().map_or(OIDC_BUILDER_ID_REGION, Region::new); + let client = client(region.clone()); + + let DeviceRegistration { + client_id, + client_secret, + .. + } = DeviceRegistration::init_device_code_registration(&client, secret_store, ®ion).await?; + + let output = client + .start_device_authorization() + .client_id(&client_id) + .client_secret(&client_secret.0) + .start_url(start_url.as_deref().unwrap_or(START_URL)) + .send() + .await?; + + Ok(StartDeviceAuthorizationResponse { + device_code: output.device_code.unwrap_or_default(), + user_code: output.user_code.unwrap_or_default(), + verification_uri: output.verification_uri.unwrap_or_default(), + verification_uri_complete: output.verification_uri_complete.unwrap_or_default(), + expires_in: output.expires_in, + interval: output.interval, + region: region.to_string(), + start_url: start_url.unwrap_or_else(|| START_URL.to_owned()), + }) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TokenType { + BuilderId, + IamIdentityCenter, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct BuilderIdToken { + pub access_token: Secret, + #[serde(with = "time::serde::rfc3339")] + pub expires_at: time::OffsetDateTime, + pub refresh_token: Option, + pub region: Option, + pub start_url: Option, + pub oauth_flow: OAuthFlow, + pub scopes: Option>, +} + +impl BuilderIdToken { + const SECRET_KEY: &'static str = "codewhisperer:odic:token"; + + #[cfg(test)] + fn test() -> Self { + Self { + access_token: Secret("test_access_token".to_string()), + expires_at: time::OffsetDateTime::now_utc() + time::Duration::minutes(60), + refresh_token: Some(Secret("test_refresh_token".to_string())), + region: Some(OIDC_BUILDER_ID_REGION.to_string()), + start_url: Some(START_URL.to_string()), + oauth_flow: OAuthFlow::DeviceCode, + scopes: Some(SCOPES.iter().map(|s| (*s).to_owned()).collect()), + } + } + + /// Load the token from the keychain, refresh the token if it is expired and return it + pub async fn load(secret_store: &SecretStore, force_refresh: bool) -> Result> { + match secret_store.get(Self::SECRET_KEY).await { + Ok(Some(secret)) => { + let token: Option = serde_json::from_str(&secret.0)?; + match token { + Some(token) => { + let region = token.region.clone().map_or(OIDC_BUILDER_ID_REGION, Region::new); + + let client = client(region.clone()); + // if token is expired try to refresh + if token.is_expired() || force_refresh { + token.refresh_token(&client, secret_store, ®ion).await + } else { + Ok(Some(token)) + } + }, + None => Ok(None), + } + }, + Ok(None) => Ok(None), + Err(err) => { + error!(%err, "Error getting builder id token from keychain"); + Err(err) + }, + } + } + + /// Refresh the access token + pub async fn refresh_token( + &self, + client: &Client, + secret_store: &SecretStore, + region: &Region, + ) -> Result> { + let Some(refresh_token) = &self.refresh_token else { + // if the token is expired and has no refresh token, delete it + if let Err(err) = self.delete(secret_store).await { + error!(?err, "Failed to delete builder id token"); + } + + return Ok(None); + }; + + let registration = match DeviceRegistration::load_from_secret_store(secret_store, region).await? { + Some(registration) if registration.oauth_flow == self.oauth_flow => registration, + // If the OIDC client registration is for a different oauth flow or doesn't exist, then + // we can't refresh the token. + Some(registration) => { + warn!( + "Unable to refresh token: Stored client registration has oauth flow: {:?} but current access token has oauth flow: {:?}", + registration.oauth_flow, self.oauth_flow + ); + return Ok(None); + }, + None => { + warn!("Unable to refresh token: No registered client was found"); + return Ok(None); + }, + }; + + debug!("Refreshing access token"); + match client + .create_token() + .client_id(registration.client_id) + .client_secret(registration.client_secret.0) + .refresh_token(&refresh_token.0) + .grant_type(REFRESH_GRANT_TYPE) + .send() + .await + { + Ok(output) => { + crate::fig_telemetry_core::send_event( + Event::new(EventType::RefreshCredentials { + request_id: output.request_id().unwrap_or_default().into(), + result: TelemetryResult::Succeeded, + reason: None, + oauth_flow: registration.oauth_flow.to_string(), + }) + .with_credential_start_url(self.start_url.clone().unwrap_or_else(|| START_URL.to_owned())), + ) + .await; + + let token: BuilderIdToken = Self::from_output( + output, + region.clone(), + self.start_url.clone(), + self.oauth_flow, + self.scopes.clone(), + ); + debug!("Refreshed access token, new token: {:?}", token); + + if let Err(err) = token.save(secret_store).await { + error!(?err, "Failed to store builder id access token"); + }; + + Ok(Some(token)) + }, + Err(err) => { + let display_err = DisplayErrorContext(&err); + error!("Failed to refresh builder id access token: {}", display_err); + + // if the error is the client's fault, clear the token + if let SdkError::ServiceError(service_err) = &err { + crate::fig_telemetry_core::send_event( + Event::new(EventType::RefreshCredentials { + request_id: err.request_id().unwrap_or_default().into(), + result: TelemetryResult::Failed, + reason: Some(display_err.to_string()), + oauth_flow: registration.oauth_flow.to_string(), + }) + .with_credential_start_url(self.start_url.clone().unwrap_or_else(|| START_URL.to_owned())), + ) + .await; + if !service_err.err().is_slow_down_exception() { + if let Err(err) = self.delete(secret_store).await { + error!(?err, "Failed to delete builder id token"); + } + } + } + + Err(err.into()) + }, + } + } + + /// If the time has passed the `expires_at` time + /// + /// The token is marked as expired 1 min before it actually does to account for the potential a + /// token expires while in transit + pub fn is_expired(&self) -> bool { + is_expired(&self.expires_at) + } + + /// Save the token to the keychain + pub async fn save(&self, secret_store: &SecretStore) -> Result<()> { + secret_store + .set(Self::SECRET_KEY, &serde_json::to_string(self)?) + .await?; + Ok(()) + } + + /// Delete the token from the keychain + pub async fn delete(&self, secret_store: &SecretStore) -> Result<()> { + secret_store.delete(Self::SECRET_KEY).await?; + Ok(()) + } + + pub(crate) fn from_output( + output: CreateTokenOutput, + region: Region, + start_url: Option, + oauth_flow: OAuthFlow, + scopes: Option>, + ) -> Self { + Self { + access_token: output.access_token.unwrap_or_default().into(), + expires_at: time::OffsetDateTime::now_utc() + time::Duration::seconds(output.expires_in as i64), + refresh_token: output.refresh_token.map(|t| t.into()), + region: Some(region.to_string()), + start_url, + oauth_flow, + scopes, + } + } + + pub fn token_type(&self) -> TokenType { + match &self.start_url { + Some(url) if url == START_URL => TokenType::BuilderId, + None => TokenType::BuilderId, + Some(_) => TokenType::IamIdentityCenter, + } + } +} + +pub enum PollCreateToken { + Pending, + Complete(BuilderIdToken), + Error(Error), +} + +/// Poll for the create token response +pub async fn poll_create_token( + secret_store: &SecretStore, + device_code: String, + start_url: Option, + region: Option, +) -> PollCreateToken { + let region = region.clone().map_or(OIDC_BUILDER_ID_REGION, Region::new); + let client = client(region.clone()); + + let DeviceRegistration { + client_id, + client_secret, + scopes, + .. + } = match DeviceRegistration::init_device_code_registration(&client, secret_store, ®ion).await { + Ok(res) => res, + Err(err) => { + return PollCreateToken::Error(err); + }, + }; + + match client + .create_token() + .grant_type(DEVICE_GRANT_TYPE) + .device_code(device_code) + .client_id(client_id) + .client_secret(client_secret.0) + .send() + .await + { + Ok(output) => { + let token: BuilderIdToken = + BuilderIdToken::from_output(output, region, start_url, OAuthFlow::DeviceCode, scopes); + + if let Err(err) = token.save(secret_store).await { + error!(?err, "Failed to store builder id token"); + }; + + PollCreateToken::Complete(token) + }, + Err(SdkError::ServiceError(service_error)) if service_error.err().is_authorization_pending_exception() => { + PollCreateToken::Pending + }, + Err(err) => { + error!(?err, "Failed to poll for builder id token"); + PollCreateToken::Error(err.into()) + }, + } +} + +pub async fn builder_id_token() -> Result> { + let secret_store = SecretStore::new().await?; + BuilderIdToken::load(&secret_store, false).await +} + +pub async fn refresh_token() -> Result> { + let secret_store = SecretStore::new().await?; + BuilderIdToken::load(&secret_store, true).await +} + +pub async fn is_logged_in() -> bool { + matches!(builder_id_token().await, Ok(Some(_))) +} + +pub async fn logout() -> Result<()> { + let Ok(secret_store) = SecretStore::new().await else { + return Ok(()); + }; + + let (builder_res, device_res) = tokio::join!( + secret_store.delete(BuilderIdToken::SECRET_KEY), + secret_store.delete(DeviceRegistration::SECRET_KEY), + ); + + let profile_res = crate::fig_settings::state::remove_value("api.codewhisperer.profile"); + + builder_res?; + device_res?; + profile_res?; + + Ok(()) +} + +#[derive(Debug, Clone)] +pub struct BearerResolver; + +impl ResolveIdentity for BearerResolver { + fn resolve_identity<'a>( + &'a self, + _runtime_components: &'a RuntimeComponents, + _config_bag: &'a ConfigBag, + ) -> IdentityFuture<'a> { + IdentityFuture::new_boxed(Box::pin(async { + let secret_store = SecretStore::new().await?; + let token = BuilderIdToken::load(&secret_store, false).await?; + match token { + Some(token) => Ok(Identity::new( + Token::new(token.access_token.0, Some(token.expires_at.into())), + Some(token.expires_at.into()), + )), + None => Err(Error::NoToken.into()), + } + })) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + const US_EAST_1: Region = Region::from_static("us-east-1"); + const US_WEST_2: Region = Region::from_static("us-west-2"); + + macro_rules! test_ser_deser { + ($ty:ident, $variant:expr, $text:expr) => { + let quoted = format!("\"{}\"", $text); + assert_eq!(quoted, serde_json::to_string(&$variant).unwrap()); + assert_eq!($variant, serde_json::from_str("ed).unwrap()); + + assert_eq!($text, format!("{}", $variant)); + }; + } + + #[test] + fn test_oauth_flow_ser_deser() { + test_ser_deser!(OAuthFlow, OAuthFlow::DeviceCode, "DeviceCode"); + test_ser_deser!(OAuthFlow, OAuthFlow::Pkce, "PKCE"); + } + + #[test] + fn test_client() { + println!("{:?}", client(US_EAST_1)); + println!("{:?}", client(US_WEST_2)); + } + + #[test] + fn oidc_url_snapshot() { + insta::assert_snapshot!(oidc_url(&US_EAST_1), @"https://oidc.us-east-1.amazonaws.com"); + insta::assert_snapshot!(oidc_url(&US_WEST_2), @"https://oidc.us-west-2.amazonaws.com"); + } + + #[test] + fn test_is_expired() { + let mut token = BuilderIdToken::test(); + assert!(!token.is_expired()); + + token.expires_at = time::OffsetDateTime::now_utc() - time::Duration::seconds(60); + assert!(token.is_expired()); + } + + #[test] + fn test_token_type() { + let mut token = BuilderIdToken::test(); + assert_eq!(token.token_type(), TokenType::BuilderId); + + token.start_url = None; + assert_eq!(token.token_type(), TokenType::BuilderId); + + token.start_url = Some("https://amzn.awsapps.com/start".into()); + assert_eq!(token.token_type(), TokenType::IamIdentityCenter); + } + + #[ignore = "not in ci"] + #[tokio::test] + async fn logout_test() { + logout().await.unwrap(); + } + + #[ignore = "login flow"] + #[tokio::test] + async fn test_login() { + let start_url = Some("https://amzn.awsapps.com/start".into()); + let region = Some("us-east-1".into()); + + // let start_url = None; + // let region = None; + + let secret_store = SecretStore::new().await.unwrap(); + let res: StartDeviceAuthorizationResponse = + start_device_authorization(&secret_store, start_url.clone(), region.clone()) + .await + .unwrap(); + + println!("{:?}", res); + + loop { + match poll_create_token( + &secret_store, + res.device_code.clone(), + start_url.clone(), + region.clone(), + ) + .await + { + PollCreateToken::Pending => { + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + }, + PollCreateToken::Complete(token) => { + println!("{:?}", token); + break; + }, + PollCreateToken::Error(err) => { + println!("{}", err); + break; + }, + } + } + } + + #[ignore = "not in ci"] + #[tokio::test] + async fn test_load() { + let secret_store = SecretStore::new().await.unwrap(); + let token = BuilderIdToken::load(&secret_store, false).await; + println!("{:?}", token); + // println!("{:?}", token.unwrap().unwrap().access_token.0); + } + + #[ignore = "not in ci"] + #[tokio::test] + async fn test_refresh() { + let token = refresh_token().await.unwrap().unwrap(); + println!("{:?}", token); + } +} diff --git a/crates/kiro-cli/src/fig_auth/consts.rs b/crates/kiro-cli/src/fig_auth/consts.rs new file mode 100644 index 0000000000..a55174141b --- /dev/null +++ b/crates/kiro-cli/src/fig_auth/consts.rs @@ -0,0 +1,25 @@ +use aws_types::region::Region; + +pub(crate) const CLIENT_NAME: &str = "Amazon Q Developer for command line"; + +pub(crate) const OIDC_BUILDER_ID_REGION: Region = Region::from_static("us-east-1"); + +/// The scopes requested for OIDC +/// +/// Do not include `sso:account:access`, these permissions are not needed and were +/// previously included +pub(crate) const SCOPES: &[&str] = &[ + "codewhisperer:completions", + "codewhisperer:analysis", + "codewhisperer:conversations", + // "codewhisperer:taskassist", + // "codewhisperer:transformations", +]; + +pub(crate) const CLIENT_TYPE: &str = "public"; + +// The start URL for public builder ID users +pub const START_URL: &str = "https://view.awsapps.com/start"; + +pub(crate) const DEVICE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:device_code"; +pub(crate) const REFRESH_GRANT_TYPE: &str = "refresh_token"; diff --git a/crates/kiro-cli/src/fig_auth/error.rs b/crates/kiro-cli/src/fig_auth/error.rs new file mode 100644 index 0000000000..8739d33cfa --- /dev/null +++ b/crates/kiro-cli/src/fig_auth/error.rs @@ -0,0 +1,47 @@ +use aws_sdk_ssooidc::error::SdkError; +use aws_sdk_ssooidc::operation::create_token::CreateTokenError; +use aws_sdk_ssooidc::operation::register_client::RegisterClientError; +use aws_sdk_ssooidc::operation::start_device_authorization::StartDeviceAuthorizationError; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum Error { + #[error(transparent)] + Ssooidc(#[from] Box), + #[error(transparent)] + SdkRegisterClient(#[from] SdkError), + #[error(transparent)] + SdkCreateToken(#[from] SdkError), + #[error(transparent)] + SdkStartDeviceAuthorization(#[from] SdkError), + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + TimeComponentRange(#[from] time::error::ComponentRange), + #[error(transparent)] + Directories(#[from] crate::fig_util::directories::DirectoryError), + #[error(transparent)] + SerdeJson(#[from] serde_json::Error), + #[error("Security error: {}", .0)] + Security(String), + #[error(transparent)] + StringFromUtf8(#[from] std::string::FromUtf8Error), + #[error(transparent)] + StrFromUtf8(#[from] std::str::Utf8Error), + #[error(transparent)] + DbOpenError(#[from] crate::fig_settings::error::DbOpenError), + #[error(transparent)] + Setting(#[from] crate::fig_settings::Error), + #[error("No token")] + NoToken, + #[error("OAuth state mismatch. Actual: {} | Expected: {}", .actual, .expected)] + OAuthStateMismatch { actual: String, expected: String }, + #[error("Timeout waiting for authentication to complete")] + OAuthTimeout, + #[error("No code received on redirect")] + OAuthMissingCode, + #[error("OAuth error: {0}")] + OAuthCustomError(String), +} + +pub(crate) type Result = std::result::Result; diff --git a/crates/kiro-cli/src/fig_auth/index.html b/crates/kiro-cli/src/fig_auth/index.html new file mode 100644 index 0000000000..c68c852af9 --- /dev/null +++ b/crates/kiro-cli/src/fig_auth/index.html @@ -0,0 +1,181 @@ + + + + + AWS Authentication + + + + + +
+
+ + + + + +
+
+ +
+
+ +
+

Request approved

+

+
+
+

+
+ + + +
+
+ + + + diff --git a/crates/kiro-cli/src/fig_auth/mod.rs b/crates/kiro-cli/src/fig_auth/mod.rs new file mode 100644 index 0000000000..fd9ee6059b --- /dev/null +++ b/crates/kiro-cli/src/fig_auth/mod.rs @@ -0,0 +1,16 @@ +pub mod builder_id; +mod consts; +mod error; +pub mod pkce; +mod scope; +pub mod secret_store; + +pub use builder_id::{ + builder_id_token, + is_logged_in, + logout, + refresh_token, +}; +pub use consts::START_URL; +pub use error::Error; +pub(crate) use error::Result; diff --git a/crates/kiro-cli/src/fig_auth/pkce.rs b/crates/kiro-cli/src/fig_auth/pkce.rs new file mode 100644 index 0000000000..e65546969c --- /dev/null +++ b/crates/kiro-cli/src/fig_auth/pkce.rs @@ -0,0 +1,627 @@ +//! # OAuth 2.0 Proof Key for Code Exchange +//! +//! This module implements the PKCE integration with AWS OIDC according to their +//! developer guide. +//! +//! The benefit of PKCE over device code is to simplify the user experience by not +//! requiring the user to validate the generated code across the browser and the +//! device. +//! +//! SSO flow (RFC: ) +//! 1. Register an OIDC client +//! - Code: [PkceRegistration::register] +//! 2. Host a local HTTP server to handle the redirect +//! - Code: [PkceRegistration::finish] +//! 3. Open the [PkceRegistration::url] in the browser, and approve the request. +//! 4. Exchange the code for access and refresh tokens. +//! - This completes the future returned by [PkceRegistration::finish]. +//! +//! Once access/refresh tokens are received, there is no difference between PKCE +//! and device code (as already implemented in [crate::builder_id]). + +use std::future::Future; +use std::pin::Pin; +use std::time::Duration; + +pub use aws_sdk_ssooidc::client::Client; +pub use aws_sdk_ssooidc::operation::create_token::CreateTokenOutput; +pub use aws_sdk_ssooidc::operation::register_client::RegisterClientOutput; +pub use aws_types::region::Region; +use base64::Engine; +use base64::engine::general_purpose::URL_SAFE; +use bytes::Bytes; +use http_body_util::Full; +use hyper::body::Incoming; +use hyper::server::conn::http1; +use hyper::service::Service; +use hyper::{ + Request, + Response, +}; +use hyper_util::rt::TokioIo; +use percent_encoding::{ + NON_ALPHANUMERIC, + utf8_percent_encode, +}; +use rand::Rng; +use tokio::net::TcpListener; +use tracing::{ + debug, + error, +}; + +use crate::fig_auth::builder_id::*; +use crate::fig_auth::consts::*; +use crate::fig_auth::secret_store::SecretStore; +use crate::fig_auth::{ + Error, + Result, + START_URL, +}; + +const DEFAULT_AUTHORIZATION_TIMEOUT: Duration = Duration::from_secs(60 * 3); + +/// Starts the PKCE authorization flow, using [`START_URL`] and [`OIDC_BUILDER_ID_REGION`] as the +/// default issuer URL and region. Returns the [`PkceClient`] to use to finish the flow. +pub async fn start_pkce_authorization( + start_url: Option, + region: Option, +) -> Result<(Client, PkceRegistration)> { + let issuer_url = start_url.as_deref().unwrap_or(START_URL); + let region = region.clone().map_or(OIDC_BUILDER_ID_REGION, Region::new); + let client = client(region.clone()); + let registration = PkceRegistration::register(&client, region, issuer_url.to_string(), None).await?; + Ok((client, registration)) +} + +/// Represents a client used for registering with AWS IAM OIDC. +#[async_trait::async_trait] +pub trait PkceClient { + /// The scopes that the client will request + fn scopes() -> Vec; + + async fn register_client(&self, redirect_uri: String, issuer_url: String) -> Result; + + async fn create_token(&self, args: CreateTokenArgs) -> Result; +} + +#[derive(Debug, Clone)] +pub struct RegisterClientResponse { + pub output: RegisterClientOutput, +} + +impl RegisterClientResponse { + pub fn client_id(&self) -> &str { + self.output.client_id().unwrap_or_default() + } + + pub fn client_secret(&self) -> &str { + self.output.client_secret().unwrap_or_default() + } +} + +#[derive(Debug)] +pub struct CreateTokenResponse { + pub output: CreateTokenOutput, +} + +#[derive(Debug)] +pub struct CreateTokenArgs { + pub client_id: String, + pub client_secret: String, + pub redirect_uri: String, + pub code_verifier: String, + pub code: String, +} + +#[async_trait::async_trait] +impl PkceClient for Client { + fn scopes() -> Vec { + SCOPES.iter().map(|s| (*s).to_owned()).collect() + } + + async fn register_client(&self, redirect_uri: String, issuer_url: String) -> Result { + let mut register = self + .register_client() + .client_name(CLIENT_NAME) + .client_type(CLIENT_TYPE) + .issuer_url(issuer_url.clone()) + .redirect_uris(redirect_uri.clone()) + .grant_types("authorization_code") + .grant_types("refresh_token"); + for scope in Self::scopes() { + register = register.scopes(scope); + } + let output = register.send().await?; + Ok(RegisterClientResponse { output }) + } + + async fn create_token(&self, args: CreateTokenArgs) -> Result { + let output = self + .create_token() + .client_id(args.client_id.clone()) + .client_secret(args.client_secret.clone()) + .grant_type("authorization_code") + .redirect_uri(args.redirect_uri) + .code_verifier(args.code_verifier) + .code(args.code) + .send() + .await?; + Ok(CreateTokenResponse { output }) + } +} + +/// Represents an active PKCE registration flow. To execute the flow, you should (in order): +/// 1. Call [`PkceRegistration::register`] to register an AWS OIDC client and receive the URL to be +/// opened by the browser. +/// 2. Call [`PkceRegistration::finish`] to host a local server to handle redirects, and trade the +/// authorization code for an access token. +#[derive(Debug)] +pub struct PkceRegistration { + /// URL to be opened by the user's browser. + pub url: String, + registered_client: RegisterClientResponse, + /// Configured URI that the authorization server will redirect the client to. + pub redirect_uri: String, + code_verifier: String, + /// Random value generated for every authentication attempt. + /// + /// + pub state: String, + /// Listener for hosting the local HTTP server. + listener: TcpListener, + region: Region, + /// Interchangeable with the "start URL" concept in the device code flow. + issuer_url: String, + /// Time to wait for [`Self::finish`] to complete. Default is [`DEFAULT_AUTHORIZATION_TIMEOUT`]. + timeout: Duration, +} + +impl PkceRegistration { + pub async fn register( + client: &impl PkceClient, + region: Region, + issuer_url: String, + timeout: Option, + ) -> Result { + let listener = TcpListener::bind("127.0.0.1:0").await?; + let redirect_uri = format!("http://{}/oauth/callback", listener.local_addr()?); + let code_verifier = generate_code_verifier(); + let code_challenge = generate_code_challenge(&code_verifier); + let state = rand::rng() + .sample_iter(rand::distr::Alphanumeric) + .take(10) + .collect::>(); + let state = String::from_utf8(state).unwrap_or("state".to_string()); + + let response = client.register_client(redirect_uri.clone(), issuer_url.clone()).await?; + + let query = PkceQueryParams { + client_id: response.client_id().to_string(), + redirect_uri: redirect_uri.clone(), + // Scopes must be space delimited. + scopes: SCOPES.join(" "), + state: state.clone(), + code_challenge: code_challenge.clone(), + code_challenge_method: "S256".to_string(), + }; + let url = format!("{}/authorize?{}", oidc_url(®ion), query.as_query_params()); + + Ok(Self { + url, + registered_client: response, + code_verifier, + state, + listener, + redirect_uri, + region, + issuer_url, + timeout: timeout.unwrap_or(DEFAULT_AUTHORIZATION_TIMEOUT), + }) + } + + /// Hosts a local HTTP server to listen for browser redirects. If a [`SecretStore`] is passed, + /// then the access and refresh tokens will be saved. + /// + /// Only the first connection will be served. + pub async fn finish(self, client: &C, secret_store: Option<&SecretStore>) -> Result<()> { + let code = tokio::select! { + code = Self::recv_code(self.listener, self.state) => { + code? + }, + _ = tokio::time::sleep(self.timeout) => { + return Err(Error::OAuthTimeout); + } + }; + + let response = client + .create_token(CreateTokenArgs { + client_id: self.registered_client.client_id().to_string(), + client_secret: self.registered_client.client_secret().to_string(), + redirect_uri: self.redirect_uri, + code_verifier: self.code_verifier, + code, + }) + .await?; + + // Tokens are redacted in the log output. + debug!(?response, "Received create_token response"); + + let token = BuilderIdToken::from_output( + response.output, + self.region.clone(), + Some(self.issuer_url), + OAuthFlow::Pkce, + Some(C::scopes()), + ); + + let device_registration = DeviceRegistration::from_output( + self.registered_client.output, + &self.region, + OAuthFlow::Pkce, + C::scopes(), + ); + + let Some(secret_store) = secret_store else { + return Ok(()); + }; + + if let Err(err) = device_registration.save(secret_store).await { + error!(?err, "Failed to store pkce registration to secret store"); + } + + if let Err(err) = token.save(secret_store).await { + error!(?err, "Failed to store builder id token"); + }; + + Ok(()) + } + + async fn recv_code(listener: TcpListener, expected_state: String) -> Result { + let (code_tx, mut code_rx) = tokio::sync::mpsc::channel::>(1); + let (stream, _) = listener.accept().await?; + let stream = TokioIo::new(stream); // Wrapper to implement Hyper IO traits for Tokio types. + let host = listener.local_addr()?.to_string(); + tokio::spawn(async move { + if let Err(err) = http1::Builder::new() + .serve_connection(stream, PkceHttpService { + code_tx: std::sync::Arc::new(code_tx), + host, + }) + .await + { + error!(?err, "Error occurred serving the connection"); + } + }); + match code_rx.recv().await { + Some(Ok((code, state))) => { + debug!(code = "", state, "Received code and state"); + if state != expected_state { + return Err(Error::OAuthStateMismatch { + actual: state, + expected: expected_state, + }); + } + // Give time for the user to be redirected to index.html. + tokio::time::sleep(Duration::from_millis(200)).await; + Ok(code) + }, + Some(Err(err)) => { + // Give time for the user to be redirected to index.html. + tokio::time::sleep(Duration::from_millis(200)).await; + Err(err) + }, + None => Err(Error::OAuthMissingCode), + } + } +} + +type CodeSender = std::sync::Arc>>; +type ServiceError = Error; +type ServiceResponse = Response>; +type ServiceFuture = Pin> + Send>>; + +#[derive(Debug, Clone)] +struct PkceHttpService { + /// [`tokio::sync::mpsc::Sender`] for a (code, state) pair. + code_tx: CodeSender, + + /// The host being served - ie, the hostname and port. + /// Used for responding with redirects. + host: String, +} + +impl PkceHttpService { + /// Handles the browser redirect to `"http://{host}/oauth/callback"` which contains either the + /// code and state query params, or an error query param. Redirects to "/index.html". + /// + /// The [`Request`] doesn't actually contain the host, hence the `host` argument. + async fn handle_oauth_callback( + code_tx: CodeSender, + host: String, + req: Request, + ) -> Result { + let query_params = req + .uri() + .query() + .map(|query| { + query + .split('&') + .filter_map(|kv| kv.split_once('=')) + .collect::>() + }) + .ok_or(Error::OAuthCustomError("query parameters are missing".into()))?; + + // Error handling: if something goes wrong at the authorization endpoint, the + // client will be redirected to the redirect url with "error" and + // "error_description" query parameters. + if let Some(error) = query_params.get("error") { + let error_description = query_params.get("error_description").unwrap_or(&""); + let _ = code_tx + .send(Err(Error::OAuthCustomError(format!( + "error occurred during authorization: {:?}, {:?}", + error, error_description + )))) + .await; + return Self::redirect_to_index(&host, &format!("?error={}", error)); + } else { + let code = query_params.get("code"); + let state = query_params.get("state"); + if let (Some(code), Some(state)) = (code, state) { + let _ = code_tx.send(Ok(((*code).to_string(), (*state).to_string()))).await; + } else { + let _ = code_tx + .send(Err(Error::OAuthCustomError( + "missing code and/or state in the query parameters".into(), + ))) + .await; + return Self::redirect_to_index(&host, "?error=missing%20required%20query%20parameters"); + } + } + + Self::redirect_to_index(&host, "") + } + + fn redirect_to_index(host: &str, query_params: &str) -> Result { + Ok(Response::builder() + .status(302) + .header("Location", format!("http://{}/index.html{}", host, query_params)) + .body("".into()) + .expect("is valid builder, should not panic")) + } +} + +impl Service> for PkceHttpService { + type Error = ServiceError; + type Future = ServiceFuture; + type Response = ServiceResponse; + + fn call(&self, req: Request) -> Self::Future { + let code_tx: CodeSender = std::sync::Arc::clone(&self.code_tx); + let host = self.host.clone(); + Box::pin(async move { + debug!(?req, "Handling connection"); + match req.uri().path() { + "/oauth/callback" | "/oauth/callback/" => Self::handle_oauth_callback(code_tx, host, req).await, + "/index.html" => Ok(Response::builder() + .status(200) + .header("Content-Type", "text/html") + .header("Connection", "close") + .body(include_str!("./index.html").into()) + .expect("valid builder will not panic")), + _ => Ok(Response::builder() + .status(404) + .body("".into()) + .expect("valid builder will not panic")), + } + }) + } +} + +/// Query params for the initial GET request that starts the PKCE flow. Use +/// [`PkceQueryParams::as_query_params`] to get a URL-safe string. +#[derive(Debug, Clone, serde::Serialize)] +struct PkceQueryParams { + client_id: String, + redirect_uri: String, + scopes: String, + state: String, + code_challenge: String, + code_challenge_method: String, +} + +macro_rules! encode { + ($expr:expr) => { + utf8_percent_encode(&$expr, NON_ALPHANUMERIC) + }; +} + +impl PkceQueryParams { + fn as_query_params(&self) -> String { + [ + "response_type=code".to_string(), + format!("client_id={}", encode!(self.client_id)), + format!("redirect_uri={}", encode!(self.redirect_uri)), + format!("scopes={}", encode!(self.scopes)), + format!("state={}", encode!(self.state)), + format!("code_challenge={}", encode!(self.code_challenge)), + format!("code_challenge_method={}", encode!(self.code_challenge_method)), + ] + .join("&") + } +} + +/// Generates a random 43-octet URL safe string according to the RFC recommendation. +/// +/// Reference: https://datatracker.ietf.org/doc/html/rfc7636#section-4.1 +fn generate_code_verifier() -> String { + URL_SAFE.encode(rand::random::<[u8; 32]>()).replace('=', "") +} + +/// Base64 URL encoded sha256 hash of the code verifier. +/// +/// Reference: https://datatracker.ietf.org/doc/html/rfc7636#section-4.2 +fn generate_code_challenge(code_verifier: &str) -> String { + use sha2::{ + Digest, + Sha256, + }; + let mut hasher = Sha256::new(); + hasher.update(code_verifier); + URL_SAFE.encode(hasher.finalize()).replace('=', "") +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::fig_auth::scope::is_scopes; + + #[derive(Debug, Clone)] + struct TestPkceClient; + + #[async_trait::async_trait] + impl PkceClient for TestPkceClient { + fn scopes() -> Vec { + vec!["scope:1".to_string(), "scope:2".to_string()] + } + + async fn register_client(&self, _: String, _: String) -> Result { + Ok(RegisterClientResponse { + output: RegisterClientOutput::builder() + .client_id("test_client_id") + .client_secret("test_client_secret") + .build(), + }) + } + + async fn create_token(&self, _: CreateTokenArgs) -> Result { + Ok(CreateTokenResponse { + output: CreateTokenOutput::builder().build(), + }) + } + } + + #[ignore = "not in ci"] + #[tokio::test] + async fn test_pkce_flow_e2e() { + tracing_subscriber::fmt::init(); + let start_url = "https://amzn.awsapps.com/start".to_string(); + let region = Region::new("us-east-1"); + let client = client(region.clone()); + let registration = PkceRegistration::register(&client, region.clone(), start_url, None) + .await + .unwrap(); + println!("{:?}", registration); + if crate::fig_util::open::open_url_async(®istration.url).await.is_err() { + panic!("unable to open the URL"); + } + println!("Waiting for authorization to complete..."); + let secret_store = SecretStore::new().await.unwrap(); + registration.finish(&client, Some(&secret_store)).await.unwrap(); + println!("Authorization successful"); + } + + #[tokio::test] + async fn test_pkce_flow_completes_successfully() { + // tracing_subscriber::fmt::init(); + let region = Region::new("us-east-1"); + let issuer_url = START_URL.into(); + let client = TestPkceClient {}; + let registration = PkceRegistration::register(&client, region, issuer_url, None) + .await + .unwrap(); + + let redirect_uri = registration.redirect_uri.clone(); + let state = registration.state.clone(); + tokio::spawn(async move { + // Let registration.finish be called to handle the request. + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + reqwest::get(format!("{}/?code={}&state={}", redirect_uri, "code", state)) + .await + .unwrap(); + }); + + registration.finish(&client, None).await.unwrap(); + } + + #[tokio::test] + async fn test_pkce_flow_with_state_mismatch_throws_err() { + let region = Region::new("us-east-1"); + let issuer_url = START_URL.into(); + let client = TestPkceClient {}; + let registration = PkceRegistration::register(&client, region, issuer_url, None) + .await + .unwrap(); + + let redirect_uri = registration.redirect_uri.clone(); + tokio::spawn(async move { + // Let registration.finish be called to handle the request. + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + reqwest::get(format!("{}/?code={}&state={}", redirect_uri, "code", "not_my_state")) + .await + .unwrap(); + }); + + assert!(matches!( + registration.finish(&client, None).await, + Err(Error::OAuthStateMismatch { actual: _, expected: _ }) + )); + } + + #[tokio::test] + async fn test_pkce_flow_with_authorization_redirect_error() { + let region = Region::new("us-east-1"); + let issuer_url = START_URL.into(); + let client = TestPkceClient {}; + let registration = PkceRegistration::register(&client, region, issuer_url, None) + .await + .unwrap(); + + let redirect_uri = registration.redirect_uri.clone(); + tokio::spawn(async move { + // Let registration.finish be called to handle the request. + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + reqwest::get(format!( + "{}/?error={}&error_description={}", + redirect_uri, "error code", "something bad happened?" + )) + .await + .unwrap(); + }); + + assert!(matches!( + registration.finish(&client, None).await, + Err(Error::OAuthCustomError(_)) + )); + } + + #[tokio::test] + async fn test_pkce_flow_with_timeout() { + let region = Region::new("us-east-1"); + let issuer_url = START_URL.into(); + let client = TestPkceClient {}; + let registration = PkceRegistration::register(&client, region, issuer_url, Some(Duration::from_millis(100))) + .await + .unwrap(); + + assert!(matches!( + registration.finish(&client, None).await, + Err(Error::OAuthTimeout) + )); + } + + #[tokio::test] + async fn verify_gen_code_challenge() { + let code_verifier = generate_code_verifier(); + println!("{:?}", code_verifier); + + let code_challenge = generate_code_challenge(&code_verifier); + println!("{:?}", code_challenge); + assert!(code_challenge.len() >= 43); + } + + #[test] + fn verify_client_scopes() { + assert!(is_scopes(&Client::scopes())); + } +} diff --git a/crates/kiro-cli/src/fig_auth/scope.rs b/crates/kiro-cli/src/fig_auth/scope.rs new file mode 100644 index 0000000000..1a72a69687 --- /dev/null +++ b/crates/kiro-cli/src/fig_auth/scope.rs @@ -0,0 +1,33 @@ +use crate::fig_auth::consts::SCOPES; + +pub fn scopes_match, B: AsRef>(a: &[A], b: &[B]) -> bool { + if a.len() != b.len() { + return false; + } + + let mut a = a.iter().map(|s| s.as_ref()).collect::>(); + let mut b = b.iter().map(|s| s.as_ref()).collect::>(); + a.sort(); + b.sort(); + a == b +} + +/// Checks if the given scopes match the predefined scopes. +pub(crate) fn is_scopes>(scopes: &[S]) -> bool { + scopes_match(SCOPES, scopes) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_scopes_match() { + assert!(scopes_match(&["a", "b", "c"], &["a", "b", "c"])); + assert!(scopes_match(&["a", "b", "c"], &["a", "c", "b"])); + assert!(!scopes_match(&["a", "b", "c"], &["a", "b"])); + assert!(!scopes_match(&["a", "b"], &["a", "b", "c"])); + + assert!(is_scopes(SCOPES)); + } +} diff --git a/crates/kiro-cli/src/fig_auth/secret_store/linux.rs b/crates/kiro-cli/src/fig_auth/secret_store/linux.rs new file mode 100644 index 0000000000..28fa153398 --- /dev/null +++ b/crates/kiro-cli/src/fig_auth/secret_store/linux.rs @@ -0,0 +1,27 @@ +use super::Secret; +use super::sqlite::SqliteSecretStore; +use crate::Result; + +pub struct SecretStoreImpl { + inner: SqliteSecretStore, +} + +impl SecretStoreImpl { + pub async fn new() -> Result { + Ok(Self { + inner: SqliteSecretStore::new().await?, + }) + } + + pub async fn set(&self, key: &str, password: &str) -> Result<()> { + self.inner.set(key, password).await + } + + pub async fn get(&self, key: &str) -> Result> { + self.inner.get(key).await + } + + pub async fn delete(&self, key: &str) -> Result<()> { + self.inner.delete(key).await + } +} diff --git a/crates/kiro-cli/src/fig_auth/secret_store/macos.rs b/crates/kiro-cli/src/fig_auth/secret_store/macos.rs new file mode 100644 index 0000000000..5c28fe8386 --- /dev/null +++ b/crates/kiro-cli/src/fig_auth/secret_store/macos.rs @@ -0,0 +1,80 @@ +use super::Secret; +use crate::fig_auth::{ + Error, + Result, +}; + +/// Path to the `security` binary +const SECURITY_BIN: &str = "/usr/bin/security"; + +/// The account name is not used. +const ACCOUNT: &str = ""; + +pub struct SecretStoreImpl { + _private: (), +} + +impl SecretStoreImpl { + pub async fn new() -> Result { + Ok(Self { _private: () }) + } + + /// Sets the `key` to `password` on the keychain, this will override any existing value + pub async fn set(&self, key: &str, password: &str) -> Result<()> { + let output = tokio::process::Command::new(SECURITY_BIN) + .args(["add-generic-password", "-U", "-s", key, "-a", ACCOUNT, "-w", password]) + .output() + .await?; + + if !output.status.success() { + let stderr = std::str::from_utf8(&output.stderr)?; + return Err(Error::Security(stderr.into())); + } + + Ok(()) + } + + /// Returns the password for the `key` + /// + /// If not found the result will be `Ok(None)`, other errors will be returned + pub async fn get(&self, key: &str) -> Result> { + let output = tokio::process::Command::new(SECURITY_BIN) + .args(["find-generic-password", "-s", key, "-a", ACCOUNT, "-w"]) + .output() + .await?; + + if !output.status.success() { + let stderr = std::str::from_utf8(&output.stderr)?; + if stderr.contains("could not be found") { + return Ok(None); + } else { + return Err(Error::Security(stderr.into())); + } + } + + let stdout = std::str::from_utf8(&output.stdout)?; + + // strip newline + let stdout = match stdout.strip_suffix('\n') { + Some(stdout) => stdout, + None => stdout, + }; + + Ok(Some(stdout.into())) + } + + /// Deletes the `key` from the keychain + pub async fn delete(&self, key: &str) -> Result<()> { + let output = tokio::process::Command::new(SECURITY_BIN) + .args(["delete-generic-password", "-s", key, "-a", ACCOUNT]) + .output() + .await?; + + if !output.status.success() { + let stderr = std::str::from_utf8(&output.stderr)?; + return Err(Error::Security(stderr.into())); + } + + Ok(()) + } +} diff --git a/crates/kiro-cli/src/fig_auth/secret_store/mod.rs b/crates/kiro-cli/src/fig_auth/secret_store/mod.rs new file mode 100644 index 0000000000..480f011395 --- /dev/null +++ b/crates/kiro-cli/src/fig_auth/secret_store/mod.rs @@ -0,0 +1,102 @@ +#[cfg(target_os = "linux")] +mod linux; +#[cfg(target_os = "macos")] +mod macos; +mod sqlite; +#[cfg(target_os = "linux")] +use linux::SecretStoreImpl; +#[cfg(target_os = "macos")] +use macos::SecretStoreImpl; + +use crate::fig_auth::Result; + +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, serde::Serialize, serde::Deserialize)] +#[serde(transparent)] +pub struct Secret(pub String); + +impl std::fmt::Debug for Secret { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Secret").finish() + } +} + +impl From for Secret +where + T: Into, +{ + fn from(value: T) -> Self { + Self(value.into()) + } +} + +pub struct SecretStore { + inner: SecretStoreImpl, +} + +impl SecretStore { + pub async fn new() -> Result { + SecretStoreImpl::new().await.map(|inner| Self { inner }) + } + + pub async fn set(&self, key: &str, password: &str) -> Result<()> { + self.inner.set(key, password).await + } + + pub async fn get(&self, key: &str) -> Result> { + self.inner.get(key).await + } + + pub async fn delete(&self, key: &str) -> Result<()> { + self.inner.delete(key).await + } +} + +impl std::fmt::Debug for SecretStore { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SecretStore").finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + #[ignore = "not on ci"] + async fn test_set_password() { + let key = "test_set_password"; + let store = SecretStore::new().await.unwrap(); + store.set(key, "test").await.unwrap(); + assert_eq!(store.get(key).await.unwrap().unwrap().0, "test"); + store.delete(key).await.unwrap(); + } + + #[tokio::test] + #[ignore = "not on ci"] + async fn secret_get_time() { + let key = "test_secret_get_time"; + let store = SecretStore::new().await.unwrap(); + store.set(key, "1234").await.unwrap(); + + let now = std::time::Instant::now(); + for _ in 0..100 { + store.get(key).await.unwrap(); + } + + println!("duration: {:?}", now.elapsed() / 100); + + store.delete(key).await.unwrap(); + } + + #[tokio::test] + #[ignore = "not on ci"] + async fn secret_delete() { + let key = "test_secret_delete"; + + let store = SecretStore::new().await.unwrap(); + store.set(key, "1234").await.unwrap(); + assert_eq!(store.get(key).await.unwrap().unwrap().0, "1234"); + store.delete(key).await.unwrap(); + assert_eq!(store.get(key).await.unwrap(), None); + } +} diff --git a/crates/kiro-cli/src/fig_auth/secret_store/sqlite.rs b/crates/kiro-cli/src/fig_auth/secret_store/sqlite.rs new file mode 100644 index 0000000000..d42fcdcc08 --- /dev/null +++ b/crates/kiro-cli/src/fig_auth/secret_store/sqlite.rs @@ -0,0 +1,50 @@ +#![allow(dead_code)] +use super::Secret; +use crate::Result; +use crate::fig_settings::sqlite::{ + Db, + database, +}; + +pub struct SqliteSecretStore { + db: &'static Db, +} + +impl SqliteSecretStore { + pub async fn new() -> Result { + Ok(Self { db: database()? }) + } + + pub async fn set(&self, key: &str, password: &str) -> Result<()> { + Ok(self.db.set_auth_value(key, password)?) + } + + pub async fn get(&self, key: &str) -> Result> { + Ok(self.db.get_auth_value(key)?.map(Secret)) + } + + pub async fn delete(&self, key: &str) -> Result<()> { + Ok(self.db.unset_auth_value(key)?) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_set_get_delete() { + let store = SqliteSecretStore::new().await.unwrap(); + let key = "test_key"; + let password = "test_password"; + + store.set(key, password).await.unwrap(); + + let secret = store.get(key).await.unwrap(); + assert_eq!(secret, Some(Secret(password.to_string()))); + + store.delete(key).await.unwrap(); + let secret = store.get(key).await.unwrap(); + assert_eq!(secret, None); + } +} diff --git a/crates/kiro-cli/src/fig_aws_common/http_client.rs b/crates/kiro-cli/src/fig_aws_common/http_client.rs new file mode 100644 index 0000000000..57a64f1682 --- /dev/null +++ b/crates/kiro-cli/src/fig_aws_common/http_client.rs @@ -0,0 +1,198 @@ +use std::time::Duration; + +use aws_smithy_runtime_api::client::http::{ + HttpClient, + HttpConnector, + HttpConnectorFuture, + HttpConnectorSettings, + SharedHttpConnector, +}; +use aws_smithy_runtime_api::client::result::ConnectorError; +use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents; +use aws_smithy_runtime_api::http::Request; +use aws_smithy_types::body::SdkBody; +use reqwest::Client as ReqwestClient; + +/// Returns a wrapper around the global [fig_request::client] that implements +/// [HttpClient]. +pub fn client() -> Client { + let client = crate::request::client().expect("failed to create http client"); + Client::new(client.clone()) +} + +/// A wrapper around [reqwest::Client] that implements [HttpClient]. +/// +/// This is required to support using proxy servers with the AWS SDK. +#[derive(Debug, Clone)] +pub struct Client { + inner: ReqwestClient, +} + +impl Client { + pub fn new(client: ReqwestClient) -> Self { + Self { inner: client } + } +} + +#[derive(Debug)] +struct CallError { + kind: CallErrorKind, + message: &'static str, + source: Option>, +} + +impl CallError { + fn user(message: &'static str) -> Self { + Self { + kind: CallErrorKind::User, + message, + source: None, + } + } + + fn user_with_source(message: &'static str, source: E) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Self { + kind: CallErrorKind::User, + message, + source: Some(Box::new(source)), + } + } + + fn timeout(source: E) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Self { + kind: CallErrorKind::Timeout, + message: "request timed out", + source: Some(Box::new(source)), + } + } + + fn io(source: E) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Self { + kind: CallErrorKind::Io, + message: "an i/o error occurred", + source: Some(Box::new(source)), + } + } + + fn other(message: &'static str, source: E) -> Self + where + E: std::error::Error + Send + Sync + 'static, + { + Self { + kind: CallErrorKind::Other, + message, + source: Some(Box::new(source)), + } + } +} + +impl std::error::Error for CallError {} + +impl std::fmt::Display for CallError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message)?; + if let Some(err) = self.source.as_ref() { + write!(f, ": {}", err)?; + } + Ok(()) + } +} + +impl From for ConnectorError { + fn from(value: CallError) -> Self { + match &value.kind { + CallErrorKind::User => Self::user(Box::new(value)), + CallErrorKind::Timeout => Self::timeout(Box::new(value)), + CallErrorKind::Io => Self::io(Box::new(value)), + CallErrorKind::Other => Self::other(Box::new(value), None), + } + } +} + +impl From for CallError { + fn from(err: reqwest::Error) -> Self { + if err.is_timeout() { + CallError::timeout(err) + } else if err.is_connect() { + CallError::io(err) + } else { + CallError::other("an unknown error occurred", err) + } + } +} + +#[derive(Debug, Clone)] +enum CallErrorKind { + User, + Timeout, + Io, + Other, +} + +#[derive(Debug)] +struct ReqwestConnector { + client: ReqwestClient, + timeout: Option, +} + +impl HttpConnector for ReqwestConnector { + fn call(&self, request: Request) -> HttpConnectorFuture { + let client = self.client.clone(); + let timeout = self.timeout; + + HttpConnectorFuture::new(async move { + // Convert the aws_smithy_runtime_api request to a reqwest request. + // TODO: There surely has to be a better way to convert an aws_smith_runtime_api + // Request to a reqwest Request. + let mut req_builder = client.request( + reqwest::Method::from_bytes(request.method().as_bytes()) + .map_err(|err| CallError::user_with_source("failed to create method name", err))?, + request.uri().to_owned(), + ); + // Copy the header, body, and timeout. + let parts = request.into_parts(); + for (name, value) in parts.headers.iter() { + let name = name.to_owned(); + let value = value.as_bytes().to_owned(); + req_builder = req_builder.header(name, value); + } + let body_bytes = parts + .body + .bytes() + .ok_or(CallError::user("streaming request body is not supported"))? + .to_owned(); + req_builder = req_builder.body(body_bytes); + if let Some(timeout) = timeout { + req_builder = req_builder.timeout(timeout); + } + + let reqwest_response = req_builder.send().await.map_err(CallError::from)?; + + // Converts from a reqwest Response into an http::Response. + let (parts, body) = http::Response::from(reqwest_response).into_parts(); + let http_response = http::Response::from_parts(parts, SdkBody::from_body_1_x(body)); + + Ok(aws_smithy_runtime_api::http::Response::try_from(http_response) + .map_err(|err| CallError::other("failed to convert to a proper response", err))?) + }) + } +} + +impl HttpClient for Client { + fn http_connector(&self, settings: &HttpConnectorSettings, _components: &RuntimeComponents) -> SharedHttpConnector { + let connector = ReqwestConnector { + client: self.inner.clone(), + timeout: settings.read_timeout(), + }; + SharedHttpConnector::new(connector) + } +} diff --git a/crates/kiro-cli/src/fig_aws_common/mod.rs b/crates/kiro-cli/src/fig_aws_common/mod.rs new file mode 100644 index 0000000000..b9739f9109 --- /dev/null +++ b/crates/kiro-cli/src/fig_aws_common/mod.rs @@ -0,0 +1,36 @@ +pub mod http_client; +mod sdk_error_display; +mod user_agent_override_interceptor; + +use std::sync::LazyLock; + +use aws_smithy_runtime_api::client::behavior_version::BehaviorVersion; +use aws_types::app_name::AppName; +pub use sdk_error_display::SdkErrorDisplay; +pub use user_agent_override_interceptor::UserAgentOverrideInterceptor; + +const APP_NAME_STR: &str = "AmazonQ-For-CLI"; + +pub fn app_name() -> AppName { + static APP_NAME: LazyLock = LazyLock::new(|| AppName::new(APP_NAME_STR).expect("invalid app name")); + APP_NAME.clone() +} + +pub fn behavior_version() -> BehaviorVersion { + BehaviorVersion::v2025_01_17() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_app_name() { + println!("{}", app_name()); + } + + #[test] + fn test_behavior_version() { + assert!(behavior_version() == BehaviorVersion::latest()); + } +} diff --git a/crates/kiro-cli/src/fig_aws_common/sdk_error_display.rs b/crates/kiro-cli/src/fig_aws_common/sdk_error_display.rs new file mode 100644 index 0000000000..6bd8b544c4 --- /dev/null +++ b/crates/kiro-cli/src/fig_aws_common/sdk_error_display.rs @@ -0,0 +1,96 @@ +use std::error::Error; +use std::fmt::{ + self, + Debug, + Display, +}; + +use aws_smithy_runtime_api::client::result::SdkError; + +#[derive(Debug)] +pub struct SdkErrorDisplay<'a, E, R>(pub &'a SdkError); + +impl Display for SdkErrorDisplay<'_, E, R> +where + E: Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match &self.0 { + SdkError::ConstructionFailure(_) => { + write!(f, "failed to construct request") + }, + SdkError::TimeoutError(_) => write!(f, "request has timed out"), + SdkError::DispatchFailure(e) => { + write!(f, "dispatch failure")?; + if let Some(connector_error) = e.as_connector_error() { + if let Some(source) = connector_error.source() { + write!(f, " ({connector_error}): {source}")?; + } else { + write!(f, ": {connector_error}")?; + } + } + Ok(()) + }, + SdkError::ResponseError(_) => write!(f, "response error"), + SdkError::ServiceError(e) => { + write!(f, "{}", e.err()) + }, + other => write!(f, "{other}"), + } + } +} + +impl Error for SdkErrorDisplay<'_, E, R> +where + E: Error + 'static, + R: Debug, +{ + fn source(&self) -> Option<&(dyn Error + 'static)> { + self.0.source() + } +} + +#[cfg(test)] +mod tests { + use aws_smithy_runtime_api::client::result::{ + ConnectorError, + ConstructionFailure, + DispatchFailure, + ResponseError, + SdkError, + ServiceError, + TimeoutError, + }; + + use super::SdkErrorDisplay; + + #[test] + fn test_displays_sdk_error() { + let construction_failure = ConstructionFailure::builder().source("").build(); + let sdk_error: SdkError = SdkError::ConstructionFailure(construction_failure); + let sdk_error_display = SdkErrorDisplay(&sdk_error); + assert_eq!("failed to construct request", sdk_error_display.to_string()); + + let timeout_error = TimeoutError::builder().source("").build(); + let sdk_error: SdkError = SdkError::TimeoutError(timeout_error); + let sdk_error_display = SdkErrorDisplay(&sdk_error); + assert_eq!("request has timed out", sdk_error_display.to_string()); + + let dispatch_failure = DispatchFailure::builder() + .source(ConnectorError::io("".into())) + .build(); + let sdk_error: SdkError = SdkError::DispatchFailure(dispatch_failure); + let sdk_error_display = SdkErrorDisplay(&sdk_error); + assert_eq!("dispatch failure (io error): ", sdk_error_display.to_string()); + + let response_error = ResponseError::builder().source("").raw("".into()).build(); + let sdk_error: SdkError = SdkError::ResponseError(response_error); + let sdk_error_display = SdkErrorDisplay(&sdk_error); + assert_eq!("response error", sdk_error_display.to_string()); + + let service_error = ServiceError::builder().source("").raw("".into()).build(); + let sdk_error: SdkError = SdkError::ServiceError(service_error); + let sdk_error_display = SdkErrorDisplay(&sdk_error); + assert_eq!("", sdk_error_display.to_string()); + } +} diff --git a/crates/kiro-cli/src/fig_aws_common/user_agent_override_interceptor.rs b/crates/kiro-cli/src/fig_aws_common/user_agent_override_interceptor.rs new file mode 100644 index 0000000000..d53e4960d8 --- /dev/null +++ b/crates/kiro-cli/src/fig_aws_common/user_agent_override_interceptor.rs @@ -0,0 +1,227 @@ +use std::borrow::Cow; +use std::error::Error; +use std::fmt; + +use aws_runtime::user_agent::{ + AdditionalMetadata, + ApiMetadata, + AwsUserAgent, +}; +use aws_smithy_runtime_api::box_error::BoxError; +use aws_smithy_runtime_api::client::interceptors::Intercept; +use aws_smithy_runtime_api::client::interceptors::context::BeforeTransmitInterceptorContextMut; +use aws_smithy_runtime_api::client::runtime_components::RuntimeComponents; +use aws_smithy_types::config_bag::ConfigBag; +use aws_types::app_name::AppName; +use aws_types::os_shim_internal::Env; +use http::header::{ + InvalidHeaderValue, + USER_AGENT, +}; +use tracing::warn; + +/// The environment variable name of additional user agent metadata we include in the user agent +/// string. This is used in AWS CloudShell where they want to track usage by version. +const AWS_TOOLING_USER_AGENT: &str = "AWS_TOOLING_USER_AGENT"; + +const VERSION_HEADER: &str = "Version"; +const VERSION_VALUE: &str = env!("CARGO_PKG_VERSION"); + +#[derive(Debug)] +enum UserAgentOverrideInterceptorError { + MissingApiMetadata, + InvalidHeaderValue(InvalidHeaderValue), +} + +impl Error for UserAgentOverrideInterceptorError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + Self::InvalidHeaderValue(source) => Some(source), + Self::MissingApiMetadata => None, + } + } +} + +impl fmt::Display for UserAgentOverrideInterceptorError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Self::InvalidHeaderValue(_) => "AwsUserAgent generated an invalid HTTP header value. This is a bug. Please file an issue.", + Self::MissingApiMetadata => "The UserAgentInterceptor requires ApiMetadata to be set before the request is made. This is a bug. Please file an issue.", + }) + } +} + +impl From for UserAgentOverrideInterceptorError { + fn from(err: InvalidHeaderValue) -> Self { + UserAgentOverrideInterceptorError::InvalidHeaderValue(err) + } +} +/// Generates and attaches the AWS SDK's user agent to a HTTP request +#[non_exhaustive] +#[derive(Debug, Default)] +pub struct UserAgentOverrideInterceptor { + env: Env, +} + +impl UserAgentOverrideInterceptor { + /// Creates a new `UserAgentInterceptor` + pub fn new() -> Self { + Self { env: Env::real() } + } + + #[cfg(test)] + pub fn from_env(env: Env) -> Self { + Self { env } + } +} + +impl Intercept for UserAgentOverrideInterceptor { + fn name(&self) -> &'static str { + "UserAgentOverrideInterceptor" + } + + fn modify_before_signing( + &self, + context: &mut BeforeTransmitInterceptorContextMut<'_>, + _runtime_components: &RuntimeComponents, + cfg: &mut ConfigBag, + ) -> Result<(), BoxError> { + let env = self.env.clone(); + + // Allow for overriding the user agent by an earlier interceptor (so, for example, + // tests can use `AwsUserAgent::for_tests()`) by attempting to grab one out of the + // config bag before creating one. + let ua: Cow<'_, AwsUserAgent> = cfg.load::().map(Cow::Borrowed).map_or_else( + || { + let api_metadata = cfg + .load::() + .ok_or(UserAgentOverrideInterceptorError::MissingApiMetadata)?; + + let aws_tooling_user_agent = env.get(AWS_TOOLING_USER_AGENT); + let mut ua = AwsUserAgent::new_from_environment(env, api_metadata.clone()); + + let ver = format!("{VERSION_HEADER}/{VERSION_VALUE}"); + match AdditionalMetadata::new(clean_metadata(&ver)) { + Ok(md) => { + ua.add_additional_metadata(md); + }, + Err(err) => panic!("Failed to parse version: {err}"), + }; + + let maybe_app_name = cfg.load::(); + if let Some(app_name) = maybe_app_name { + ua.set_app_name(app_name.clone()); + } + if let Ok(val) = aws_tooling_user_agent { + match AdditionalMetadata::new(clean_metadata(&val)) { + Ok(md) => { + ua.add_additional_metadata(md); + }, + Err(err) => warn!(%err, %val, "Failed to parse {AWS_TOOLING_USER_AGENT}"), + }; + } + + Ok(Cow::Owned(ua)) + }, + Result::<_, UserAgentOverrideInterceptorError>::Ok, + )?; + + let headers = context.request_mut().headers_mut(); + headers.insert(USER_AGENT.as_str(), ua.aws_ua_header()); + Ok(()) + } +} + +fn clean_metadata(s: &str) -> String { + let valid_character = |c: char| -> bool { + match c { + _ if c.is_ascii_alphanumeric() => true, + '!' | '#' | '$' | '%' | '&' | '\'' | '*' | '+' | '-' | '.' | '^' | '_' | '`' | '|' | '~' => true, + _ => false, + } + }; + s.chars().map(|c| if valid_character(c) { c } else { '-' }).collect() +} + +#[cfg(test)] +mod tests { + use aws_smithy_runtime_api::client::interceptors::context::{ + Input, + InterceptorContext, + }; + use aws_smithy_runtime_api::client::runtime_components::RuntimeComponentsBuilder; + use aws_smithy_types::config_bag::Layer; + use http::HeaderValue; + + use super::*; + use crate::fig_aws_common::{ + APP_NAME_STR, + app_name, + }; + + #[test] + fn error_test() { + let err = UserAgentOverrideInterceptorError::InvalidHeaderValue(HeaderValue::from_bytes(b"\0").unwrap_err()); + assert!(err.source().is_some()); + println!("{err}"); + + let err = UserAgentOverrideInterceptorError::MissingApiMetadata; + assert!(err.source().is_none()); + println!("{err}"); + } + + fn user_agent_base() -> (RuntimeComponents, ConfigBag, InterceptorContext) { + let rc = RuntimeComponentsBuilder::for_tests().build().unwrap(); + let mut cfg = ConfigBag::base(); + + let mut layer = Layer::new("layer"); + layer.store_put(ApiMetadata::new("q", "123")); + layer.store_put(app_name()); + cfg.push_layer(layer); + + let mut context = InterceptorContext::new(Input::erase(())); + context.set_request(aws_smithy_runtime_api::http::Request::empty()); + + (rc, cfg, context) + } + + #[test] + fn user_agent_override_test() { + let (rc, mut cfg, mut context) = user_agent_base(); + let mut context = BeforeTransmitInterceptorContextMut::from(&mut context); + let interceptor = UserAgentOverrideInterceptor::new(); + println!("Interceptor: {}", interceptor.name()); + interceptor + .modify_before_signing(&mut context, &rc, &mut cfg) + .expect("success"); + + let ua = context.request().headers().get(USER_AGENT).unwrap(); + println!("User-Agent: {ua}"); + assert!(ua.contains(&format!("app/{APP_NAME_STR}"))); + assert!(ua.contains(VERSION_HEADER)); + assert!(ua.contains(VERSION_VALUE)); + } + + #[test] + fn user_agent_override_cloudshell_test() { + let (rc, mut cfg, mut context) = user_agent_base(); + let mut context = BeforeTransmitInterceptorContextMut::from(&mut context); + let env = Env::from_slice(&[ + ("AWS_EXECUTION_ENV", "CloudShell"), + (AWS_TOOLING_USER_AGENT, "AWS-CloudShell/2024.08.29"), + ]); + let interceptor = UserAgentOverrideInterceptor::from_env(env); + println!("Interceptor: {}", interceptor.name()); + interceptor + .modify_before_signing(&mut context, &rc, &mut cfg) + .expect("success"); + + let ua = context.request().headers().get(USER_AGENT).unwrap(); + println!("User-Agent: {ua}"); + assert!(ua.contains(&format!("app/{APP_NAME_STR}"))); + assert!(ua.contains("exec-env/CloudShell")); + assert!(ua.contains("md/AWS-CloudShell-2024.08.29")); + assert!(ua.contains(VERSION_HEADER)); + assert!(ua.contains(VERSION_VALUE)); + } +} diff --git a/crates/kiro-cli/src/fig_install.rs b/crates/kiro-cli/src/fig_install.rs new file mode 100644 index 0000000000..82819bcfac --- /dev/null +++ b/crates/kiro-cli/src/fig_install.rs @@ -0,0 +1,119 @@ +use std::str::FromStr; +use std::time::SystemTimeError; + +use thiserror::Error; +use tracing::error; + +use crate::fig_util::manifest::{ + Channel, + manifest, +}; + +#[derive(Debug, Error)] +pub enum Error { + #[error(transparent)] + Io(#[from] std::io::Error), + #[error("unsupported platform")] + UnsupportedPlatform, + #[error(transparent)] + Util(#[from] crate::fig_util::Error), + #[error(transparent)] + Settings(#[from] crate::fig_settings::Error), + #[error(transparent)] + Reqwest(#[from] reqwest::Error), + #[error(transparent)] + Semver(#[from] semver::Error), + #[error(transparent)] + SystemTime(#[from] SystemTimeError), + #[error(transparent)] + Strum(#[from] strum::ParseError), + #[error("failed to update: `{0}`")] + UpdateFailed(String), + #[cfg(target_os = "macos")] + #[error("failed to update due to auth error: `{0}`")] + SecurityFramework(#[from] security_framework::base::Error), + #[error("your system is not supported on this channel")] + SystemNotOnChannel, + #[error("Update in progress")] + UpdateInProgress, + #[error("could not convert path to cstring")] + Nul(#[from] std::ffi::NulError), + #[error("failed to get system id")] + SystemIdNotFound, + #[error("unable to find the bundled metadata")] + BundleMetadataNotFound, +} + +use std::path::PathBuf; + +use crate::fig_util::{ + CLI_BINARY_NAME, + OLD_CLI_BINARY_NAMES, + OLD_PTY_BINARY_NAMES, + PTY_BINARY_NAME, + directories, +}; + +pub async fn uninstall() -> Result<(), Error> { + let remove_binary = |path: PathBuf| async move { + match tokio::fs::remove_file(&path).await { + Ok(_) => tracing::info!("Removed binary: {path:?}"), + Err(err) if err.kind() == std::io::ErrorKind::NotFound => {}, + Err(err) => tracing::warn!(%err, "Failed to remove binary: {path:?}"), + } + }; + + // let folders = [directories::home_local_bin()?, Path::new("/usr/local/bin").into()]; + let folders = [directories::home_local_bin()?]; + + let mut all_binary_names = vec![CLI_BINARY_NAME, PTY_BINARY_NAME]; + all_binary_names.extend(OLD_CLI_BINARY_NAMES); + all_binary_names.extend(OLD_PTY_BINARY_NAMES); + + let mut pty_names = vec![PTY_BINARY_NAME]; + pty_names.extend(OLD_PTY_BINARY_NAMES); + + for folder in folders { + for binary_name in &all_binary_names { + let binary_path = folder.join(binary_name); + remove_binary(binary_path).await; + } + } + + Ok(()) +} + +fn update() -> Result<(), Error> { + // let status = self_update::backends::s3::Update::configure() + // .bucket_name("self_update_releases") + // .asset_prefix("something/self_update") + // .region("eu-west-2") + // .bin_name("self_update_example") + // .show_download_progress(true) + // .current_version(cargo_crate_version!()) + // .build()? + // .update()?; + // println!("S3 Update status: `{}`!", status.version()); + todo!(); +} + +impl From for Error { + fn from(err: crate::fig_util::directories::DirectoryError) -> Self { + crate::fig_util::Error::Directory(err).into() + } +} + +// The current selected channel +pub fn get_channel() -> Result { + Ok(match crate::fig_settings::state::get_string("updates.channel")? { + Some(channel) => Channel::from_str(&channel)?, + None => { + let manifest_channel = manifest().default_channel; + if crate::fig_settings::settings::get_bool_or("app.beta", false) { + manifest_channel.max(Channel::Beta) + } else { + manifest_channel + } + }, + }) +} diff --git a/crates/kiro-cli/src/fig_log.rs b/crates/kiro-cli/src/fig_log.rs new file mode 100644 index 0000000000..e7051cd6c0 --- /dev/null +++ b/crates/kiro-cli/src/fig_log.rs @@ -0,0 +1,313 @@ +use std::fs::File; +use std::path::Path; +use std::sync::Mutex; + +use thiserror::Error; +use tracing::info; +use tracing::level_filters::LevelFilter; +use tracing_appender::non_blocking::WorkerGuard; +use tracing_subscriber::filter::Directive; +use tracing_subscriber::prelude::*; +use tracing_subscriber::{ + EnvFilter, + Registry, + fmt, +}; + +use crate::fig_util::env_var::Q_LOG_LEVEL; + +const MAX_FILE_SIZE: u64 = 10 * 1024 * 1024; +const DEFAULT_FILTER: LevelFilter = LevelFilter::ERROR; + +static Q_LOG_LEVEL_GLOBAL: Mutex> = Mutex::new(None); +static MAX_LEVEL: Mutex> = Mutex::new(None); +static ENV_FILTER_RELOADABLE_HANDLE: Mutex>> = + Mutex::new(None); + +// A logging error +#[derive(Debug, Error)] +pub enum Error { + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + TracingReload(#[from] tracing_subscriber::reload::Error), +} + +/// Arguments to the initialize_logging function +#[derive(Debug)] +pub struct LogArgs> { + /// The log level to use. When not set, the default log level is used. + pub log_level: Option, + /// Whether or not we log to stdout. + pub log_to_stdout: bool, + /// The log file path which we write logs to. When not set, we do not write to a file. + pub log_file_path: Option, + /// Whether we should delete the log file at each launch. + pub delete_old_log_file: bool, +} + +/// The log guard maintains tracing guards which send log information to other threads. +/// +/// This must be kept alive for logging to function as expected. +#[must_use] +#[derive(Debug)] +pub struct LogGuard { + _file_guard: Option, + _stdout_guard: Option, + _mcp_file_guard: Option, +} + +/// Initialize our application level logging using the given LogArgs. +/// +/// # Returns +/// +/// On success, this returns a guard which must be kept alive. +#[inline] +pub fn initialize_logging>(args: LogArgs) -> Result { + let filter_layer = create_filter_layer(); + let (reloadable_filter_layer, reloadable_handle) = tracing_subscriber::reload::Layer::new(filter_layer); + ENV_FILTER_RELOADABLE_HANDLE.lock().unwrap().replace(reloadable_handle); + let mut mcp_path = None; + + // First we construct the file logging layer if a file name was provided. + let (file_layer, _file_guard) = match args.log_file_path { + Some(log_file_path) => { + let log_path = log_file_path.as_ref(); + + // Make the log path parent directory if it doesn't exist. + if let Some(parent) = log_path.parent() { + if log_path.ends_with("chat.log") { + mcp_path = Some(parent.to_path_buf()); + } + std::fs::create_dir_all(parent)?; + } + + // We delete the old log file when requested each time the logger is initialized, otherwise we only + // delete the file when it has grown too large. + if args.delete_old_log_file { + std::fs::remove_file(log_path).ok(); + } else if log_path.exists() && std::fs::metadata(log_path)?.len() > MAX_FILE_SIZE { + std::fs::remove_file(log_path)?; + } + + // Create the new log file or append to the existing one. + let file = if args.delete_old_log_file { + File::create(log_path)? + } else { + File::options().append(true).create(true).open(log_path)? + }; + + // On posix-like systems, we modify permissions so that only the owner has access. + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + if let Ok(metadata) = file.metadata() { + let mut permissions = metadata.permissions(); + permissions.set_mode(0o600); + file.set_permissions(permissions).ok(); + } + } + + let (non_blocking, guard) = tracing_appender::non_blocking(file); + let file_layer = fmt::layer().with_line_number(true).with_writer(non_blocking); + + (Some(file_layer), Some(guard)) + }, + None => (None, None), + }; + + // If we log to stdout, we need to add this layer to our logger. + let (stdout_layer, _stdout_guard) = if args.log_to_stdout { + let (non_blocking, guard) = tracing_appender::non_blocking(std::io::stdout()); + let stdout_layer = fmt::layer().with_line_number(true).with_writer(non_blocking); + (Some(stdout_layer), Some(guard)) + } else { + (None, None) + }; + + // Set up for mcp servers layer if we are in chat + let (mcp_server_layer, _mcp_file_guard) = if let Some(parent) = mcp_path { + let mcp_path = parent.join("mcp.log"); + if args.delete_old_log_file { + std::fs::remove_file(&mcp_path).ok(); + } else if mcp_path.exists() && std::fs::metadata(&mcp_path)?.len() > MAX_FILE_SIZE { + std::fs::remove_file(&mcp_path)?; + } + let file = if args.delete_old_log_file { + File::create(&mcp_path)? + } else { + File::options().append(true).create(true).open(&mcp_path)? + }; + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + if let Ok(metadata) = file.metadata() { + let mut permissions = metadata.permissions(); + permissions.set_mode(0o600); + file.set_permissions(permissions).ok(); + } + } + let (non_blocking, guard) = tracing_appender::non_blocking(file); + let file_layer = fmt::layer() + .with_line_number(true) + .with_writer(non_blocking) + .with_filter(EnvFilter::new("mcp=trace")); + (Some(file_layer), Some(guard)) + } else { + (None, None) + }; + + if let Some(level) = args.log_level { + set_log_level(level)?; + } + + // Finally, initialize our logging + let subscriber = tracing_subscriber::registry() + .with(reloadable_filter_layer) + .with(file_layer) + .with(stdout_layer); + + if let Some(mcp_server_layer) = mcp_server_layer { + subscriber.with(mcp_server_layer).init(); + return Ok(LogGuard { + _file_guard, + _stdout_guard, + _mcp_file_guard, + }); + } + + subscriber.init(); + + Ok(LogGuard { + _file_guard, + _stdout_guard, + _mcp_file_guard, + }) +} + +/// Get the current log level by first seeing if it is set in application, then environment, then +/// otherwise using the default +/// +/// # Returns +/// +/// Returns a string identifying the current log level. +pub fn get_log_level() -> String { + Q_LOG_LEVEL_GLOBAL + .lock() + .unwrap() + .clone() + .unwrap_or_else(|| std::env::var(Q_LOG_LEVEL).unwrap_or_else(|_| DEFAULT_FILTER.to_string())) +} + +/// Set the log level to the given level. +/// +/// # Returns +/// +/// On success, returns the old log level. +pub fn set_log_level(level: String) -> Result { + info!("Setting log level to {level:?}"); + + let old_level = get_log_level(); + *Q_LOG_LEVEL_GLOBAL.lock().unwrap() = Some(level); + + let filter_layer = create_filter_layer(); + *MAX_LEVEL.lock().unwrap() = filter_layer.max_level_hint(); + + ENV_FILTER_RELOADABLE_HANDLE + .lock() + .unwrap() + .as_ref() + .expect("set_log_level must not be called before logging is initialized") + .reload(filter_layer)?; + + Ok(old_level) +} + +/// Get the current max log level +/// +/// # Returns +/// +/// The max log level which is set every time the log level is set. +pub fn get_log_level_max() -> LevelFilter { + let max_level = *MAX_LEVEL.lock().unwrap(); + match max_level { + Some(level) => level, + None => { + let filter_layer = create_filter_layer(); + *MAX_LEVEL.lock().unwrap() = filter_layer.max_level_hint(); + filter_layer.max_level_hint().unwrap_or(DEFAULT_FILTER) + }, + } +} + +fn create_filter_layer() -> EnvFilter { + let directive = Directive::from(DEFAULT_FILTER); + + let log_level = Q_LOG_LEVEL_GLOBAL + .lock() + .unwrap() + .clone() + .or_else(|| std::env::var(Q_LOG_LEVEL).ok()); + + match log_level { + Some(level) => EnvFilter::builder() + .with_default_directive(directive) + .parse_lossy(level), + None => EnvFilter::default().add_directive(directive), + } +} + +#[cfg(test)] +mod tests { + use std::fs::read_to_string; + use std::time::Duration; + + use tracing::{ + debug, + error, + trace, + warn, + }; + + use super::*; + + #[test] + fn test_logging() { + // Create a temp path for where we write logs to. + let tempdir = tempfile::TempDir::new().unwrap(); + let log_path = tempdir.path().join("test.log"); + + // Assert that initialize logging simply doesn't panic. + let _guard = initialize_logging(LogArgs { + log_level: Some("trace".to_owned()), + log_to_stdout: true, + log_file_path: Some(&log_path), + delete_old_log_file: true, + }) + .unwrap(); + + // Test that get log level functions as expected. + assert_eq!(get_log_level(), "trace"); + + // Write some log messages out to file. (and stderr) + trace!("abc"); + debug!("def"); + info!("ghi"); + warn!("jkl"); + error!("mno"); + + // Test that set log level functions as expected. + // This also restores the default log level. + set_log_level(DEFAULT_FILTER.to_string()).unwrap(); + assert_eq!(get_log_level(), DEFAULT_FILTER.to_string()); + + // Sleep in order to ensure logs get written to file, then assert on the contents + std::thread::sleep(Duration::from_millis(100)); + let logs = read_to_string(&log_path).unwrap(); + for i in [ + "TRACE", "DEBUG", "INFO", "WARN", "ERROR", "abc", "def", "ghi", "jkl", "mno", + ] { + assert!(logs.contains(i)); + } + } +} diff --git a/crates/kiro-cli/src/fig_os_shim/env.rs b/crates/kiro-cli/src/fig_os_shim/env.rs new file mode 100644 index 0000000000..ee705fd07f --- /dev/null +++ b/crates/kiro-cli/src/fig_os_shim/env.rs @@ -0,0 +1,227 @@ +use std::collections::HashMap; +use std::env::{ + self, + VarError, +}; +use std::ffi::{ + OsStr, + OsString, +}; +use std::io; +use std::path::PathBuf; +use std::sync::{ + Arc, + Mutex, +}; + +use crate::fig_os_shim::Shim; +#[derive(Debug, Clone, Default)] +pub struct Env(inner::Inner); + +mod inner { + use std::collections::HashMap; + use std::path::PathBuf; + use std::sync::{ + Arc, + Mutex, + }; + + #[derive(Debug, Clone, Default)] + pub(super) enum Inner { + #[default] + Real, + Fake(Arc>), + } + + #[derive(Debug, Clone)] + pub(super) struct Fake { + pub vars: HashMap, + pub cwd: PathBuf, + pub current_exe: PathBuf, + } + + impl Default for Fake { + fn default() -> Self { + Self { + vars: HashMap::default(), + cwd: PathBuf::from("/"), + current_exe: PathBuf::from("/current_exe"), + } + } + } +} + +impl Env { + pub fn new() -> Self { + Self::default() + } + + pub fn new_fake() -> Self { + Self(inner::Inner::Fake(Arc::new(Mutex::new(inner::Fake::default())))) + } + + /// Create a fake process environment from a slice of tuples. + pub fn from_slice(vars: &[(&str, &str)]) -> Self { + use inner::Inner; + let map: HashMap<_, _> = vars.iter().map(|(k, v)| ((*k).to_owned(), (*v).to_owned())).collect(); + Self(Inner::Fake(Arc::new(Mutex::new(inner::Fake { + vars: map, + ..Default::default() + })))) + } + + pub fn get>(&self, key: K) -> Result { + use inner::Inner; + match &self.0 { + Inner::Real => env::var(key.as_ref()), + Inner::Fake(fake) => fake + .lock() + .unwrap() + .vars + .get(key.as_ref()) + .cloned() + .ok_or(VarError::NotPresent), + } + } + + pub fn get_os>(&self, key: K) -> Option { + use inner::Inner; + match &self.0 { + Inner::Real => env::var_os(key.as_ref()), + Inner::Fake(fake) => fake + .lock() + .unwrap() + .vars + .get(key.as_ref().to_str()?) + .cloned() + .map(OsString::from), + } + } + + /// Sets the environment variable `key` to the value `value` for the currently running + /// process. + /// + /// # Safety + /// + /// See [std::env::set_var] for the safety requirements. + pub unsafe fn set_var(&self, key: impl AsRef, value: impl AsRef) { + use inner::Inner; + match &self.0 { + Inner::Real => std::env::set_var(key, value), + Inner::Fake(fake) => { + fake.lock().unwrap().vars.insert( + key.as_ref().to_str().expect("key must be valid str").to_string(), + value.as_ref().to_str().expect("key must be valid str").to_string(), + ); + }, + } + } + + pub fn home(&self) -> Option { + match &self.0 { + inner::Inner::Real => dirs::home_dir(), + inner::Inner::Fake(fake) => fake.lock().unwrap().vars.get("HOME").map(PathBuf::from), + } + } + + pub fn current_dir(&self) -> Result { + use inner::Inner; + match &self.0 { + Inner::Real => std::env::current_dir(), + Inner::Fake(fake) => Ok(fake.lock().unwrap().cwd.clone()), + } + } + + pub fn current_exe(&self) -> Result { + use inner::Inner; + match &self.0 { + Inner::Real => std::env::current_exe(), + Inner::Fake(fake) => Ok(fake.lock().unwrap().current_exe.clone()), + } + } + + pub fn in_cloudshell(&self) -> bool { + self.get("AWS_EXECUTION_ENV") + .is_ok_and(|v| v.trim().eq_ignore_ascii_case("cloudshell")) + } + + pub fn in_ssh(&self) -> bool { + self.get("SSH_CLIENT").is_ok() || self.get("SSH_CONNECTION").is_ok() || self.get("SSH_TTY").is_ok() + } + + pub fn in_codespaces(&self) -> bool { + self.get_os("CODESPACES").is_some() || self.get_os("Q_CODESPACES").is_some() + } + + pub fn in_ci(&self) -> bool { + self.get_os("CI").is_some() || self.get_os("Q_CI").is_some() + } + + /// Whether or not the current executable is run from an AppImage. + /// + /// See: https://docs.appimage.org/packaging-guide/environment-variables.html + pub fn in_appimage(&self) -> bool { + self.get_os("APPIMAGE").is_some() + } +} + +impl Shim for Env { + fn is_real(&self) -> bool { + matches!(self.0, inner::Inner::Real) + } +} + +#[cfg(test)] +mod tests { + use std::path::Path; + + use super::*; + + #[test] + fn test_new() { + let env = Env::new(); + assert!(matches!(env, Env(inner::Inner::Real))); + + let env = Env::default(); + assert!(matches!(env, Env(inner::Inner::Real))); + } + + #[test] + fn test_get() { + let env = Env::new(); + assert!(env.home().is_some()); + assert!(env.get("PATH").is_ok()); + assert!(env.get_os("PATH").is_some()); + assert!(env.get("NON_EXISTENT").is_err()); + + let env = Env::from_slice(&[("HOME", "/home/user"), ("PATH", "/bin:/usr/bin")]); + assert_eq!(env.home().unwrap(), Path::new("/home/user")); + assert_eq!(env.get("PATH").unwrap(), "/bin:/usr/bin"); + assert!(env.get_os("PATH").is_some()); + assert!(env.get("NON_EXISTENT").is_err()); + } + + #[test] + fn test_in_envs() { + let env = Env::from_slice(&[]); + assert!(!env.in_cloudshell()); + assert!(!env.in_ssh()); + + let env = Env::from_slice(&[("AWS_EXECUTION_ENV", "CloudShell"), ("SSH_CLIENT", "1")]); + assert!(env.in_cloudshell()); + assert!(env.in_ssh()); + + let env = Env::from_slice(&[("AWS_EXECUTION_ENV", "CLOUDSHELL\n")]); + assert!(env.in_cloudshell()); + assert!(!env.in_ssh()); + + let env = Env::from_slice(&[("APPIMAGE", "/tmp/.mount-asdf/usr")]); + assert!(env.in_appimage()); + } + + #[test] + fn test_default_current_dir() { + let env = Env::new_fake(); + assert_eq!(env.current_dir().unwrap(), PathBuf::from("/")); + } +} diff --git a/crates/kiro-cli/src/fig_os_shim/fs.rs b/crates/kiro-cli/src/fig_os_shim/fs.rs new file mode 100644 index 0000000000..64e27dbe4c --- /dev/null +++ b/crates/kiro-cli/src/fig_os_shim/fs.rs @@ -0,0 +1,611 @@ +use std::collections::HashMap; +use std::fs::Permissions; +use std::io; +use std::os::unix::ffi::OsStrExt; +use std::path::{ + Path, + PathBuf, +}; +use std::sync::{ + Arc, + Mutex, +}; + +use tempfile::TempDir; +use tokio::fs; + +use crate::fig_os_shim::Shim; + +#[derive(Debug, Clone, Default)] +pub struct Fs(inner::Inner); + +mod inner { + use std::collections::HashMap; + use std::path::PathBuf; + use std::sync::{ + Arc, + Mutex, + }; + + use tempfile::TempDir; + + #[derive(Debug, Clone, Default)] + pub(super) enum Inner { + #[default] + Real, + /// Uses the real filesystem except acts as if the process has + /// a different root directory by using [TempDir] + Chroot(Arc), + Fake(Arc>>>), + } +} + +impl Fs { + pub fn new() -> Self { + Self::default() + } + + pub fn new_fake() -> Self { + Self(inner::Inner::Fake(Arc::new(Mutex::new(HashMap::new())))) + } + + pub fn new_chroot() -> Self { + let tempdir = tempfile::tempdir().expect("failed creating temporary directory"); + Self(inner::Inner::Chroot(tempdir.into())) + } + + pub fn is_chroot(&self) -> bool { + matches!(self.0, inner::Inner::Chroot(_)) + } + + pub fn from_slice(vars: &[(&str, &str)]) -> Self { + use inner::Inner; + let map: HashMap<_, _> = vars + .iter() + .map(|(k, v)| (PathBuf::from(k), v.as_bytes().to_vec())) + .collect(); + Self(Inner::Fake(Arc::new(Mutex::new(map)))) + } + + pub async fn create_new(&self, path: impl AsRef) -> io::Result { + use inner::Inner; + match &self.0 { + Inner::Real => fs::File::create_new(path).await, + Inner::Chroot(root) => fs::File::create_new(append(root.path(), path)).await, + Inner::Fake(_) => Err(io::Error::new(io::ErrorKind::Other, "unimplemented")), + } + } + + pub async fn create_dir(&self, path: impl AsRef) -> io::Result<()> { + use inner::Inner; + match &self.0 { + Inner::Real => fs::create_dir(path).await, + Inner::Chroot(root) => fs::create_dir(append(root.path(), path)).await, + Inner::Fake(_) => Err(io::Error::new(io::ErrorKind::Other, "unimplemented")), + } + } + + pub async fn create_dir_all(&self, path: impl AsRef) -> io::Result<()> { + use inner::Inner; + match &self.0 { + Inner::Real => fs::create_dir_all(path).await, + Inner::Chroot(root) => fs::create_dir_all(append(root.path(), path)).await, + Inner::Fake(_) => Err(io::Error::new(io::ErrorKind::Other, "unimplemented")), + } + } + + /// Attempts to open a file in read-only mode. + /// + /// This is a proxy to [`tokio::fs::File::open`]. + pub async fn open(&self, path: impl AsRef) -> io::Result { + use inner::Inner; + match &self.0 { + Inner::Real => fs::File::open(path).await, + Inner::Chroot(root) => fs::File::open(append(root.path(), path)).await, + Inner::Fake(_) => Err(io::Error::new(io::ErrorKind::Other, "unimplemented")), + } + } + + pub async fn read(&self, path: impl AsRef) -> io::Result> { + use inner::Inner; + match &self.0 { + Inner::Real => fs::read(path).await, + Inner::Chroot(root) => fs::read(append(root.path(), path)).await, + Inner::Fake(map) => { + let Ok(lock) = map.lock() else { + return Err(io::Error::new(io::ErrorKind::Other, "poisoned lock")); + }; + let Some(data) = lock.get(path.as_ref()) else { + return Err(io::Error::new(io::ErrorKind::NotFound, "not found")); + }; + Ok(data.clone()) + }, + } + } + + pub async fn read_to_string(&self, path: impl AsRef) -> io::Result { + use inner::Inner; + match &self.0 { + Inner::Real => fs::read_to_string(path).await, + Inner::Chroot(root) => fs::read_to_string(append(root.path(), path)).await, + Inner::Fake(map) => { + let Ok(lock) = map.lock() else { + return Err(io::Error::new(io::ErrorKind::Other, "poisoned lock")); + }; + let Some(data) = lock.get(path.as_ref()) else { + return Err(io::Error::new(io::ErrorKind::NotFound, "not found")); + }; + match String::from_utf8(data.clone()) { + Ok(string) => Ok(string), + Err(err) => Err(io::Error::new(io::ErrorKind::InvalidData, err)), + } + }, + } + } + + pub fn read_to_string_sync(&self, path: impl AsRef) -> io::Result { + use inner::Inner; + match &self.0 { + Inner::Real => std::fs::read_to_string(path), + Inner::Chroot(root) => std::fs::read_to_string(append(root.path(), path)), + Inner::Fake(map) => { + let Ok(lock) = map.lock() else { + return Err(io::Error::new(io::ErrorKind::Other, "poisoned lock")); + }; + let Some(data) = lock.get(path.as_ref()) else { + return Err(io::Error::new(io::ErrorKind::NotFound, "not found")); + }; + match String::from_utf8(data.clone()) { + Ok(string) => Ok(string), + Err(err) => Err(io::Error::new(io::ErrorKind::InvalidData, err)), + } + }, + } + } + + /// Creates a future that will open a file for writing and write the entire + /// contents of `contents` to it. + /// + /// This is a proxy to [`tokio::fs::write`]. + pub async fn write(&self, path: impl AsRef, contents: impl AsRef<[u8]>) -> io::Result<()> { + use inner::Inner; + match &self.0 { + Inner::Real => fs::write(path, contents).await, + Inner::Chroot(root) => fs::write(append(root.path(), path), contents).await, + Inner::Fake(map) => { + let Ok(mut lock) = map.lock() else { + return Err(io::Error::new(io::ErrorKind::Other, "poisoned lock")); + }; + lock.insert(path.as_ref().to_owned(), contents.as_ref().to_owned()); + Ok(()) + }, + } + } + + /// Removes a file from the filesystem. + /// + /// Note that there is no guarantee that the file is immediately deleted (e.g. + /// depending on platform, other open file descriptors may prevent immediate + /// removal). + /// + /// This is a proxy to [`tokio::fs::remove_file`]. + pub async fn remove_file(&self, path: impl AsRef) -> io::Result<()> { + use inner::Inner; + match &self.0 { + Inner::Real => fs::remove_file(path).await, + Inner::Chroot(root) => fs::remove_file(append(root.path(), path)).await, + Inner::Fake(_) => panic!("unimplemented"), + } + } + + /// Removes a directory at this path, after removing all its contents. Use carefully! + /// + /// This is a proxy to [`tokio::fs::remove_dir_all`]. + pub async fn remove_dir_all(&self, path: impl AsRef) -> io::Result<()> { + use inner::Inner; + match &self.0 { + Inner::Real => fs::remove_dir_all(path).await, + Inner::Chroot(root) => fs::remove_dir_all(append(root.path(), path)).await, + Inner::Fake(_) => panic!("unimplemented"), + } + } + + /// Renames a file or directory to a new name, replacing the original file if + /// `to` already exists. + /// + /// This will not work if the new name is on a different mount point. + /// + /// This is a proxy to [`tokio::fs::rename`]. + pub async fn rename(&self, from: impl AsRef, to: impl AsRef) -> io::Result<()> { + use inner::Inner; + match &self.0 { + Inner::Real => fs::rename(from, to).await, + Inner::Chroot(root) => fs::rename(append(root.path(), from), append(root.path(), to)).await, + Inner::Fake(_) => panic!("unimplemented"), + } + } + + /// Copies the contents of one file to another. This function will also copy the permission bits + /// of the original file to the destination file. + /// This function will overwrite the contents of to. + /// + /// This is a proxy to [`tokio::fs::copy`]. + pub async fn copy(&self, from: impl AsRef, to: impl AsRef) -> io::Result { + use inner::Inner; + match &self.0 { + Inner::Real => fs::copy(from, to).await, + Inner::Chroot(root) => fs::copy(append(root.path(), from), append(root.path(), to)).await, + Inner::Fake(_) => panic!("unimplemented"), + } + } + + /// Returns `Ok(true)` if the path points at an existing entity. + /// + /// This function will traverse symbolic links to query information about the + /// destination file. In case of broken symbolic links this will return `Ok(false)`. + /// + /// This is a proxy to [`tokio::fs::try_exists`]. + pub async fn try_exists(&self, path: impl AsRef) -> Result { + use inner::Inner; + match &self.0 { + Inner::Real => fs::try_exists(path).await, + Inner::Chroot(root) => fs::try_exists(append(root.path(), path)).await, + Inner::Fake(_) => panic!("unimplemented"), + } + } + + /// Returns `true` if the path points at an existing entity. + /// + /// This is a proxy to [std::path::Path::exists]. See the related doc comment in std + /// on the pitfalls of using this versus [std::path::Path::try_exists]. + pub fn exists(&self, path: impl AsRef) -> bool { + use inner::Inner; + match &self.0 { + Inner::Real => path.as_ref().exists(), + Inner::Chroot(root) => append(root.path(), path).exists(), + Inner::Fake(_) => panic!("unimplemented"), + } + } + + /// Returns `true` if the path points at an existing entity without following symlinks. + /// + /// This does *not* guarantee that the path doesn't point to a symlink. For example, `false` + /// will be returned if the user doesn't have permission to perform a metadata operation on + /// `path`. + pub async fn symlink_exists(&self, path: impl AsRef) -> bool { + match self.symlink_metadata(path).await { + Ok(_) => true, + Err(err) if err.kind() != std::io::ErrorKind::NotFound => true, + Err(_) => false, + } + } + + pub async fn create_tempdir(&self) -> io::Result { + use inner::Inner; + match &self.0 { + Inner::Real => TempDir::new(), + Inner::Chroot(root) => TempDir::new_in(root.path()), + Inner::Fake(_) => panic!("unimplemented"), + } + } + + /// Creates a new symbolic link on the filesystem. + /// + /// The `link` path will be a symbolic link pointing to the `original` path. + /// + /// This is a proxy to [`tokio::fs::symlink`]. + #[cfg(unix)] + pub async fn symlink(&self, original: impl AsRef, link: impl AsRef) -> io::Result<()> { + use inner::Inner; + match &self.0 { + Inner::Real => fs::symlink(original, link).await, + Inner::Chroot(root) => fs::symlink(append(root.path(), original), append(root.path(), link)).await, + Inner::Fake(_) => panic!("unimplemented"), + } + } + + /// Creates a new symbolic link on the filesystem. + /// + /// The `link` path will be a symbolic link pointing to the `original` path. + /// + /// This is a proxy to [`std::os::unix::fs::symlink`]. + #[cfg(unix)] + pub fn symlink_sync(&self, original: impl AsRef, link: impl AsRef) -> io::Result<()> { + use inner::Inner; + match &self.0 { + Inner::Real => std::os::unix::fs::symlink(original, link), + Inner::Chroot(root) => std::os::unix::fs::symlink(append(root.path(), original), append(root.path(), link)), + Inner::Fake(_) => panic!("unimplemented"), + } + } + + /// Query the metadata about a file without following symlinks. + /// + /// This is a proxy to [`tokio::fs::symlink_metadata`] + /// + /// # Errors + /// + /// This function will return an error in the following situations, but is not + /// limited to just these cases: + /// + /// * The user lacks permissions to perform `metadata` call on `path`. + /// * `path` does not exist. + #[cfg(unix)] + pub async fn symlink_metadata(&self, path: impl AsRef) -> io::Result { + use inner::Inner; + match &self.0 { + Inner::Real => fs::symlink_metadata(path).await, + Inner::Chroot(root) => fs::symlink_metadata(append(root.path(), path)).await, + Inner::Fake(_) => panic!("unimplemented"), + } + } + + /// Reads a symbolic link, returning the file that the link points to. + /// + /// This is a proxy to [`tokio::fs::read_link`]. + pub async fn read_link(&self, path: impl AsRef) -> io::Result { + use inner::Inner; + match &self.0 { + Inner::Real => fs::read_link(path).await, + Inner::Chroot(root) => Ok(append(root.path(), fs::read_link(append(root.path(), path)).await?)), + Inner::Fake(_) => panic!("unimplemented"), + } + } + + /// Returns a stream over the entries within a directory. + /// + /// This is a proxy to [`tokio::fs::read_dir`]. + pub async fn read_dir(&self, path: impl AsRef) -> Result { + use inner::Inner; + match &self.0 { + Inner::Real => fs::read_dir(path).await, + Inner::Chroot(root) => fs::read_dir(append(root.path(), path)).await, + Inner::Fake(_) => panic!("unimplemented"), + } + } + + /// Returns the canonical, absolute form of a path with all intermediate + /// components normalized and symbolic links resolved. + /// + /// This is a proxy to [`tokio::fs::canonicalize`]. + pub async fn canonicalize(&self, path: impl AsRef) -> Result { + use inner::Inner; + match &self.0 { + Inner::Real => fs::canonicalize(path).await, + Inner::Chroot(root) => fs::canonicalize(append(root.path(), path)).await, + Inner::Fake(_) => panic!("unimplemented"), + } + } + + /// Changes the permissions found on a file or a directory. + /// + /// This is a proxy to [`tokio::fs::set_permissions`] + pub async fn set_permissions(&self, path: impl AsRef, perm: Permissions) -> Result<(), io::Error> { + use inner::Inner; + match &self.0 { + Inner::Real => fs::set_permissions(path, perm).await, + Inner::Chroot(root) => fs::set_permissions(append(root.path(), path), perm).await, + Inner::Fake(_) => panic!("unimplemented"), + } + } + + /// For test [Fs]'s that use a different root, returns an absolute path. + /// + /// This must be used for any paths indirectly used by code using a chroot + /// [Fs]. + pub fn chroot_path(&self, path: impl AsRef) -> PathBuf { + use inner::Inner; + match &self.0 { + Inner::Chroot(root) => append(root.path(), path), + _ => path.as_ref().to_path_buf(), + } + } + + /// See [Fs::chroot_path]. + pub fn chroot_path_str(&self, path: impl AsRef) -> String { + use inner::Inner; + match &self.0 { + Inner::Chroot(root) => append(root.path(), path).to_string_lossy().to_string(), + _ => path.as_ref().to_path_buf().to_string_lossy().to_string(), + } + } +} + +impl Shim for Fs { + fn is_real(&self) -> bool { + matches!(self.0, inner::Inner::Real) + } +} + +/// Performs `a.join(b)`, except: +/// - if `b` is an absolute path, then the resulting path will equal `/a/b` +/// - if the prefix of `b` contains some `n` copies of a, then the resulting path will equal `/a/b` +fn append(a: impl AsRef, b: impl AsRef) -> PathBuf { + use std::ffi::OsString; + use std::os::unix::ffi::OsStringExt; + + // Have to use byte slices since rust seems to always append + // a forward slash at the end of a path... + let a = a.as_ref().as_os_str().as_bytes(); + let mut b = b.as_ref().as_os_str().as_bytes(); + while b.starts_with(a) { + b = b.strip_prefix(a).unwrap(); + } + while b.starts_with(b"/") { + b = b.strip_prefix(b"/").unwrap(); + } + PathBuf::from(OsString::from_vec(a.to_vec())).join(PathBuf::from(OsString::from_vec(b.to_vec()))) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_impl_is_real() { + let fs = Fs::default(); + assert!(matches!(fs.0, inner::Inner::Real)); + } + + #[tokio::test] + async fn test_fake() { + let dir = PathBuf::from("/dir"); + let fs = Fs::from_slice(&[("/test", "test")]); + + fs.create_dir(dir.join("create_dir")).await.unwrap_err(); + fs.create_dir_all(dir.join("create/dir/all")).await.unwrap_err(); + fs.write(dir.join("write"), b"write").await.unwrap(); + assert_eq!(fs.read(dir.join("write")).await.unwrap(), b"write"); + assert_eq!(fs.read_to_string(dir.join("write")).await.unwrap(), "write"); + } + + #[tokio::test] + async fn test_real() { + let dir = tempfile::tempdir().unwrap(); + let fs = Fs::new(); + + fs.create_dir(dir.path().join("create_dir")).await.unwrap(); + fs.create_dir_all(dir.path().join("create/dir/all")).await.unwrap(); + fs.write(dir.path().join("write"), b"write").await.unwrap(); + assert_eq!(fs.read(dir.path().join("write")).await.unwrap(), b"write"); + assert_eq!(fs.read_to_string(dir.path().join("write")).await.unwrap(), "write"); + } + + #[test] + fn test_append() { + macro_rules! assert_append { + ($a:expr, $b:expr, $expected:expr) => { + assert_eq!(append($a, $b), PathBuf::from($expected)); + }; + } + assert_append!("/abc/test", "/test", "/abc/test/test"); + assert_append!("/tmp/.dir", "/tmp/.dir/home/myuser", "/tmp/.dir/home/myuser"); + assert_append!("/tmp/.dir", "/tmp/hello", "/tmp/.dir/tmp/hello"); + assert_append!("/tmp/.dir", "/tmp/.dir/tmp/.dir/home/user", "/tmp/.dir/home/user"); + } + + #[tokio::test] + async fn test_read_to_string() { + let fs = Fs::new_fake(); + fs.write("fake", "contents").await.unwrap(); + fs.write("invalid_utf8", &[255]).await.unwrap(); + + // async tests + assert_eq!( + fs.read_to_string("fake").await.unwrap(), + "contents", + "should read fake file" + ); + assert!( + fs.read_to_string("unknown") + .await + .is_err_and(|err| err.kind() == io::ErrorKind::NotFound), + "unknown path should return NotFound" + ); + assert!( + fs.read_to_string("invalid_utf8") + .await + .is_err_and(|err| err.kind() == io::ErrorKind::InvalidData), + "invalid utf8 should return InvalidData" + ); + + // sync tests + assert_eq!( + fs.read_to_string_sync("fake").unwrap(), + "contents", + "should read fake file" + ); + assert!( + fs.read_to_string_sync("unknown") + .is_err_and(|err| err.kind() == io::ErrorKind::NotFound), + "unknown path should return NotFound" + ); + assert!( + fs.read_to_string_sync("invalid_utf8") + .is_err_and(|err| err.kind() == io::ErrorKind::InvalidData), + "invalid utf8 should return InvalidData" + ); + } + + #[tokio::test] + #[cfg(unix)] + async fn test_chroot_file_operations_for_unix() { + if nix::unistd::Uid::effective().is_root() { + println!("currently running as root, skipping."); + return; + } + + let fs = Fs::new_chroot(); + assert!(fs.is_chroot()); + + fs.write("/fake", "contents").await.unwrap(); + assert_eq!(fs.read_to_string("/fake").await.unwrap(), "contents"); + assert_eq!(fs.read_to_string_sync("/fake").unwrap(), "contents"); + + assert!(!fs.try_exists("/etc").await.unwrap()); + + fs.create_dir_all("/etc/b/c").await.unwrap(); + assert!(fs.try_exists("/etc").await.unwrap()); + let mut read_dir = fs.read_dir("/etc").await.unwrap(); + let e = read_dir.next_entry().await.unwrap(); + assert!(e.unwrap().metadata().await.unwrap().is_dir()); + assert!(read_dir.next_entry().await.unwrap().is_none()); + + fs.remove_dir_all("/etc").await.unwrap(); + assert!(!fs.try_exists("/etc").await.unwrap()); + + fs.copy("/fake", "/fake_copy").await.unwrap(); + assert_eq!(fs.read_to_string("/fake_copy").await.unwrap(), "contents"); + assert_eq!(fs.read_to_string_sync("/fake_copy").unwrap(), "contents"); + + fs.remove_file("/fake_copy").await.unwrap(); + assert!(!fs.try_exists("/fake_copy").await.unwrap()); + + fs.symlink("/fake", "/fake_symlink").await.unwrap(); + fs.symlink_sync("/fake", "/fake_symlink_sync").unwrap(); + assert_eq!(fs.read_to_string("/fake_symlink").await.unwrap(), "contents"); + assert_eq!( + fs.read_to_string(fs.read_link("/fake_symlink").await.unwrap()) + .await + .unwrap(), + "contents" + ); + assert_eq!(fs.read_to_string("/fake_symlink_sync").await.unwrap(), "contents"); + assert_eq!(fs.read_to_string_sync("/fake_symlink").unwrap(), "contents"); + + // Checking symlink exist + assert!(fs.symlink_exists("/fake_symlink").await); + assert!(fs.exists("/fake_symlink")); + fs.remove_file("/fake").await.unwrap(); + assert!(fs.symlink_exists("/fake_symlink").await); + assert!(!fs.exists("/fake_symlink")); + + // Checking rename + fs.write("/rename_1", "abc").await.unwrap(); + fs.write("/rename_2", "123").await.unwrap(); + fs.rename("/rename_2", "/rename_1").await.unwrap(); + assert_eq!(fs.read_to_string("/rename_1").await.unwrap(), "123"); + + // Checking open + assert!(fs.open("/does_not_exist").await.is_err()); + assert!(fs.open("/rename_1").await.is_ok()); + } + + #[tokio::test] + async fn test_chroot_tempdir() { + let fs = Fs::new_chroot(); + let tempdir = fs.create_tempdir().await.unwrap(); + if let Fs(inner::Inner::Chroot(root)) = fs { + assert_eq!(tempdir.path().parent().unwrap(), root.path()); + } else { + panic!("tempdir should be created under root"); + } + } + + #[tokio::test] + async fn test_create_new() { + let fs = Fs::new_chroot(); + fs.create_new("my_file.txt").await.unwrap(); + assert!(fs.create_new("my_file.txt").await.is_err()); + } +} diff --git a/crates/kiro-cli/src/fig_os_shim/mod.rs b/crates/kiro-cli/src/fig_os_shim/mod.rs new file mode 100644 index 0000000000..200efeff5f --- /dev/null +++ b/crates/kiro-cli/src/fig_os_shim/mod.rs @@ -0,0 +1,203 @@ +mod env; +mod fs; +mod platform; +mod providers; +mod sysinfo; + +use std::sync::Arc; + +pub use env::Env; +pub use fs::Fs; +pub use platform::{ + Os, + Platform, +}; +pub use providers::{ + EnvProvider, + FsProvider, + PlatformProvider, + SysInfoProvider, +}; +pub use sysinfo::SysInfo; + +pub trait Shim { + /// Returns whether or not the shim is a real implementation. + fn is_real(&self) -> bool; +} + +/// Struct that contains the interface to every system related IO operation. +/// +/// Every operation that accesses the file system, environment, or other related platform +/// primitives should be done through a [Context] as this enables testing otherwise untestable +/// code paths in unit tests. +#[derive(Debug, Clone)] +pub struct Context { + #[allow(dead_code)] + fs: Fs, + env: Env, + platform: Platform, + sysinfo: SysInfo, +} + +impl Context { + /// Returns a new [Context] with real implementations of each OS shim. + pub fn new() -> Arc { + Arc::new_cyclic(|_| Self { + fs: Default::default(), + env: Default::default(), + platform: Default::default(), + sysinfo: SysInfo::default(), + }) + } + + pub fn new_fake() -> Arc { + Arc::new(Self { + fs: Fs::new_fake(), + env: Env::new_fake(), + platform: Platform::new_fake(Os::current()), + sysinfo: SysInfo::new_fake(), + }) + } + + pub fn builder() -> ContextBuilder { + ContextBuilder::new() + } + + pub fn fs(&self) -> &Fs { + &self.fs + } + + pub fn env(&self) -> &Env { + &self.env + } + + pub fn platform(&self) -> &Platform { + &self.platform + } + + pub fn sysinfo(&self) -> &SysInfo { + &self.sysinfo + } +} + +#[derive(Default, Debug)] +pub struct ContextBuilder { + fs: Option, + env: Option, + platform: Option, + sysinfo: Option, +} + +impl ContextBuilder { + pub fn new() -> Self { + Self::default() + } + + /// Builds an immutable [Context] using real implementations for each field by default. + pub fn build(self) -> Arc { + let fs = self.fs.unwrap_or_default(); + let env = self.env.unwrap_or_default(); + let platform = self.platform.unwrap_or_default(); + let sysinfo = self.sysinfo.unwrap_or_default(); + Arc::new_cyclic(|_| Context { + fs, + env, + platform, + sysinfo, + }) + } + + /// Builds an immutable [Context] using fake implementations for each field by default. + pub fn build_fake(self) -> Arc { + let fs = self.fs.unwrap_or(Fs::new_fake()); + let env = self.env.unwrap_or(Env::new_fake()); + let platform = self.platform.unwrap_or(Platform::new_fake(Os::Mac)); + let sysinfo = self.sysinfo.unwrap_or(SysInfo::new_fake()); + Arc::new_cyclic(|_| Context { + fs, + env, + platform, + sysinfo, + }) + } + + pub fn with_env(mut self, env: Env) -> Self { + self.env = Some(env); + self + } + + pub fn with_fs(mut self, fs: Fs) -> Self { + self.fs = Some(fs); + self + } + + pub fn with_platform(mut self, platform: Platform) -> Self { + self.platform = Some(platform); + self + } + + /// Creates a chroot filesystem and fake environment so that `$HOME` + /// points to `/home/testuser`. Note that this replaces the + /// [Fs] and [Env] currently set with the builder. + pub async fn with_test_home(mut self) -> Result { + let home = "/home/testuser"; + let fs = Fs::new_chroot(); + fs.create_dir_all(home).await?; + self.fs = Some(fs); + self.env = Some(Env::from_slice(&[("HOME", "/home/testuser"), ("USER", "testuser")])); + Ok(self) + } + + pub fn with_env_var(mut self, key: &str, value: &str) -> Self { + self.env = match self.env { + Some(env) if !env.is_real() => { + unsafe { env.set_var(key, value) }; + Some(env) + }, + _ => Some(Env::from_slice(&[(key, value)])), + }; + self + } + + pub fn with_os(mut self, os: Os) -> Self { + self.platform = Some(Platform::new_fake(os)); + self + } + + pub fn with_running_processes(mut self, process_names: &[&str]) -> Self { + let sysinfo = match self.sysinfo { + Some(sysinfo) if !sysinfo.is_real() => sysinfo, + _ => SysInfo::new_fake(), + }; + sysinfo.add_running_processes(process_names); + self.sysinfo = Some(sysinfo); + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn context_builder_returns_real_impls_by_default() { + let ctx = ContextBuilder::new().build(); + assert!(ctx.fs().is_real()); + assert!(ctx.env().is_real()); + assert!(ctx.platform().is_real()); + assert!(ctx.sysinfo().is_real()); + } + + #[tokio::test] + async fn test_context_builder_with_test_home() { + let ctx = ContextBuilder::new() + .with_test_home() + .await + .unwrap() + .with_env_var("hello", "world") + .build(); + assert!(ctx.fs().try_exists("/home/testuser").await.unwrap()); + assert_eq!(ctx.env().get("HOME").unwrap(), "/home/testuser"); + assert_eq!(ctx.env().get("hello").unwrap(), "world"); + } +} diff --git a/crates/kiro-cli/src/fig_os_shim/platform.rs b/crates/kiro-cli/src/fig_os_shim/platform.rs new file mode 100644 index 0000000000..29d5dc06ee --- /dev/null +++ b/crates/kiro-cli/src/fig_os_shim/platform.rs @@ -0,0 +1,105 @@ +use std::fmt; + +use cfg_if::cfg_if; +use serde::Serialize; + +use crate::fig_os_shim::Shim; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] +#[non_exhaustive] +pub enum Os { + Mac, + Linux, +} + +impl Os { + pub fn current() -> Self { + cfg_if! { + if #[cfg(target_os = "macos")] { + Self::Mac + } else if #[cfg(target_os = "linux")] { + Self::Linux + } else { + compile_error!("unsupported platform"); + } + } + } + + pub fn all() -> &'static [Self] { + &[Self::Mac, Self::Linux] + } + + pub fn as_str(&self) -> &'static str { + match self { + Self::Mac => "macos", + Self::Linux => "linux", + } + } +} + +impl fmt::Display for Os { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +#[derive(Default, Debug, Clone)] +pub struct Platform(inner::Inner); + +mod inner { + use super::*; + + #[derive(Default, Debug, Clone)] + pub(super) enum Inner { + #[default] + Real, + Fake(Os), + } +} + +impl Platform { + /// Returns a real implementation of [Platform]. + pub fn new() -> Self { + Self(inner::Inner::Real) + } + + /// Returns a new fake [Platform]. + pub fn new_fake(os: Os) -> Self { + Self(inner::Inner::Fake(os)) + } + + /// Returns the current [Os]. + pub fn os(&self) -> Os { + use inner::Inner; + match &self.0 { + Inner::Real => Os::current(), + Inner::Fake(os) => *os, + } + } +} + +impl Shim for Platform { + fn is_real(&self) -> bool { + matches!(self.0, inner::Inner::Real) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_platform() { + let platform = Platform::default(); + assert!(platform.is_real()); + + for os in Os::all() { + let platform = Platform::new_fake(*os); + assert!(!platform.is_real()); + assert_eq!(&platform.os(), os); + + let _ = os.as_str(); + println!("{os:?} {os}"); + } + } +} diff --git a/crates/kiro-cli/src/fig_os_shim/providers.rs b/crates/kiro-cli/src/fig_os_shim/providers.rs new file mode 100644 index 0000000000..c430526802 --- /dev/null +++ b/crates/kiro-cli/src/fig_os_shim/providers.rs @@ -0,0 +1,133 @@ +use std::sync::Arc; + +use crate::fig_os_shim::{ + Context, + Env, + Fs, + Platform, + SysInfo, +}; + +pub trait ContextProvider { + fn context(&self) -> &Context; +} + +pub trait ContextArcProvider { + fn context_arc(&self) -> Arc; +} + +impl ContextArcProvider for Arc { + fn context_arc(&self) -> Arc { + Arc::clone(self) + } +} + +macro_rules! impl_context_provider { + ($a:ty) => { + impl ContextProvider for $a { + fn context(&self) -> &Context { + self + } + } + }; +} + +impl_context_provider!(Arc); +impl_context_provider!(&Arc); +impl_context_provider!(Context); +impl_context_provider!(&Context); + +pub trait EnvProvider { + fn env(&self) -> &Env; +} + +impl EnvProvider for Env { + fn env(&self) -> &Env { + self + } +} + +impl EnvProvider for T +where + T: ContextProvider, +{ + fn env(&self) -> &Env { + self.context().env() + } +} + +pub trait FsProvider { + fn fs(&self) -> &Fs; +} + +impl FsProvider for Fs { + fn fs(&self) -> &Fs { + self + } +} + +impl FsProvider for T +where + T: ContextProvider, +{ + fn fs(&self) -> &Fs { + self.context().fs() + } +} + +pub trait PlatformProvider { + fn platform(&self) -> &Platform; +} + +impl PlatformProvider for Platform { + fn platform(&self) -> &Platform { + self + } +} + +impl PlatformProvider for T +where + T: ContextProvider, +{ + fn platform(&self) -> &Platform { + self.context().platform() + } +} + +pub trait SysInfoProvider { + fn sysinfo(&self) -> &SysInfo; +} + +impl SysInfoProvider for SysInfo { + fn sysinfo(&self) -> &SysInfo { + self + } +} + +impl SysInfoProvider for T +where + T: ContextProvider, +{ + fn sysinfo(&self) -> &SysInfo { + self.context().sysinfo() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_env_provider() { + let env = Env::default(); + let env_provider = &env as &dyn EnvProvider; + env_provider.env(); + } + + #[test] + fn test_fs_provider() { + let fs = Fs::default(); + let fs_provider = &fs as &dyn FsProvider; + fs_provider.fs(); + } +} diff --git a/crates/kiro-cli/src/fig_os_shim/sysinfo.rs b/crates/kiro-cli/src/fig_os_shim/sysinfo.rs new file mode 100644 index 0000000000..0d340e97a1 --- /dev/null +++ b/crates/kiro-cli/src/fig_os_shim/sysinfo.rs @@ -0,0 +1,68 @@ +use std::ffi::OsString; +use std::sync::{ + Arc, + Mutex, +}; + +use crate::fig_os_shim::Shim; + +#[derive(Debug, Clone, Default)] +pub struct SysInfo(inner::Inner); + +mod inner { + use std::collections::HashSet; + use std::sync::{ + Arc, + Mutex, + }; + + #[derive(Debug, Clone, Default)] + pub enum Inner { + #[default] + Real, + Fake(Arc>), + } + + #[derive(Debug, Clone, Default)] + pub struct Fake { + pub process_names: HashSet, + } +} + +impl SysInfo { + pub fn new_fake() -> Self { + Self(inner::Inner::Fake(Arc::new(Mutex::new(inner::Fake::default())))) + } + + /// Returns whether the process containing `name` is running. + pub fn is_process_running(&self, name: &str) -> bool { + use inner::Inner; + match &self.0 { + Inner::Real => { + let system = sysinfo::System::new_all(); + let is_running = system.processes_by_name(&OsString::from(name)).next().is_some(); + is_running + }, + Inner::Fake(fake) => fake.lock().unwrap().process_names.contains(name), + } + } + + pub fn add_running_processes(&self, process_names: &[&str]) { + use inner::Inner; + match &self.0 { + Inner::Real => panic!("unimplemented"), + Inner::Fake(fake) => { + let curr_names = &mut fake.lock().unwrap().process_names; + for name in process_names { + curr_names.insert((*name).to_string()); + } + }, + } + } +} + +impl Shim for SysInfo { + fn is_real(&self) -> bool { + matches!(self.0, inner::Inner::Real) + } +} diff --git a/crates/kiro-cli/src/fig_settings/actions.json b/crates/kiro-cli/src/fig_settings/actions.json new file mode 100644 index 0000000000..c5a886f665 --- /dev/null +++ b/crates/kiro-cli/src/fig_settings/actions.json @@ -0,0 +1,216 @@ +[ + { + "identifier": "insertSelected", + "name": "Insert selected", + "description": "Insert selected suggestion", + "category": "Insertion", + "availability": "WHEN_FOCUSED", + "defaultBindings": ["enter"] + }, + { + "identifier": "insertCommonPrefix", + "name": "Insert common prefix", + "description": "Insert shared prefix of available suggestions. Shake if there's no common prefix.", + "category": "Insertion", + "availability": "WHEN_FOCUSED", + "defaultBindings": ["tab"] + }, + { + "identifier": "insertCommonPrefixOrNavigateDown", + "name": "Insert common prefix or navigate", + "description": "Insert shared prefix of available suggestions. Navigate if there's no common prefix.", + "category": "Insertion", + "availability": "WHEN_FOCUSED" + }, + { + "identifier": "insertCommonPrefixOrInsertSelected", + "name": "Insert common prefix or insert selected", + "description": "Insert shared prefix of available suggestions. Insert currently selected suggestion if there's not common prefix.", + "category": "Insertion", + "availability": "WHEN_FOCUSED" + }, + { + "identifier": "insertSelectedAndExecute", + "name": "Insert selected and execute", + "description": "Insert selected suggestion and then execute the current command.", + "category": "Insertion", + "availability": "WHEN_FOCUSED" + }, + { + "identifier": "execute", + "name": "Execute", + "description": "Execute the current command.", + "category": "Insertion", + "availability": "WHEN_FOCUSED" + }, + { + "identifier": "hideAutocomplete", + "name": "Hide autocomplete", + "description": "Hide the autocomplete window", + "category": "General", + "availability": "ALWAYS", + "defaultBindings": ["esc"] + }, + { + "identifier": "showAutocomplete", + "name": "Show autocomplete", + "description": "Show the autocomplete window", + "category": "General", + "availability": "ALWAYS" + }, + { + "identifier": "toggleAutocomplete", + "name": "Toggle autocomplete", + "description": "Toggle the visibility of the autocomplete window", + "availability": "ALWAYS" + }, + { + "identifier": "navigateUp", + "name": "Navigate up", + "description": "Scroll up one entry in the list of suggestions", + "category": "Navigation", + "availability": "WHEN_FOCUSED", + "defaultBindings": ["shift+tab", "up", "control+k", "control+p"] + }, + { + "identifier": "navigateDown", + "name": "Navigate down", + "description": "Scroll down one entry in the list of suggestions", + "category": "Navigation", + "availability": "WHEN_FOCUSED", + "defaultBindings": ["down", "control+j", "control+n"] + }, + { + "identifier": "selectSuggestion1", + "name": "Select 1st suggestion", + "description": "Select the 1st suggestion of the list", + "category": "Navigation", + "availability": "WHEN_FOCUSED", + "defaultBindings": ["control+1"] + }, + { + "identifier": "selectSuggestion2", + "name": "Select 2nd suggestion", + "description": "Select the 2nd suggestion of the list", + "category": "Navigation", + "availability": "WHEN_FOCUSED", + "defaultBindings": ["control+2"] + }, + { + "identifier": "selectSuggestion3", + "name": "Select 3rd suggestion", + "description": "Select the 3rd suggestion of the list", + "category": "Navigation", + "availability": "WHEN_FOCUSED", + "defaultBindings": ["control+3"] + }, + { + "identifier": "selectSuggestion4", + "name": "Select 4th suggestion", + "description": "Select the 4th suggestion of the list", + "category": "Navigation", + "availability": "WHEN_FOCUSED", + "defaultBindings": ["control+4"] + }, + { + "identifier": "selectSuggestion5", + "name": "Select 5th suggestion", + "description": "Select the 5th suggestion of the list", + "category": "Navigation", + "availability": "WHEN_FOCUSED", + "defaultBindings": ["control+5"] + }, + { + "identifier": "selectSuggestion6", + "name": "Select 6th suggestion", + "description": "Select the 6th suggestion of the list", + "category": "Navigation", + "availability": "WHEN_FOCUSED", + "defaultBindings": ["control+6"] + }, + { + "identifier": "selectSuggestion7", + "name": "Select 7th suggestion", + "description": "Select the 7th suggestion of the list", + "category": "Navigation", + "availability": "WHEN_FOCUSED", + "defaultBindings": ["control+7"] + }, + { + "identifier": "selectSuggestion8", + "name": "Select 8th suggestion", + "description": "Select the 8th suggestion of the list", + "category": "Navigation", + "availability": "WHEN_FOCUSED", + "defaultBindings": ["control+8"] + }, + { + "identifier": "selectSuggestion9", + "name": "Select 9th suggestion", + "description": "Select the 9th suggestion of the list", + "category": "Navigation", + "availability": "WHEN_FOCUSED", + "defaultBindings": ["control+9"] + }, + { + "identifier": "selectSuggestion10", + "name": "Select 10th suggestion", + "description": "Select the 10th suggestion of the list", + "category": "Navigation", + "availability": "WHEN_FOCUSED", + "defaultBindings": ["control+0"] + }, + { + "identifier": "hideDescription", + "name": "Hide description popout", + "description": "Hide autocomplete description popout", + "category": "Appearance", + "availability": "WHEN_FOCUSED" + }, + { + "identifier": "showDescription", + "name": "Show description popout", + "description": "Show autocomplete description popout", + "category": "Appearance", + "availability": "WHEN_FOCUSED" + }, + { + "identifier": "toggleDescription", + "name": "Toggle description popout", + "description": "Toggle visibility of autocomplete description popout", + "category": "Appearance", + "availability": "WHEN_FOCUSED", + "defaultBindings": ["command+i"] + }, + { + "identifier": "toggleHistoryMode", + "name": "Toggle history mode", + "description": "Toggle between history suggestions and autocomplete spec suggestions", + "category": "General", + "availability": "WHEN_FOCUSED", + "defaultBindings": ["control+r"] + }, + { + "identifier": "toggleFuzzySearch", + "name": "Toggle fuzzy search", + "description": "Toggle between normal prefix search and fuzzy search", + "category": "General", + "availability": "WHEN_FOCUSED" + }, + { + "identifier": "increaseSize", + "name": "Increase window size", + "description": "Increase the size of the autocomplete window", + "category": "Appearance", + "availability": "WHEN_FOCUSED", + "defaultBindings": ["command+="] + }, + { + "identifier": "decreaseSize", + "name": "Decrease window size", + "description": "Decrease the size of the autocomplete window", + "category": "Appearance", + "availability": "WHEN_FOCUSED", + "defaultBindings": ["command+-"] + } +] diff --git a/crates/kiro-cli/src/fig_settings/error.rs b/crates/kiro-cli/src/fig_settings/error.rs new file mode 100644 index 0000000000..c9ecfe8ddb --- /dev/null +++ b/crates/kiro-cli/src/fig_settings/error.rs @@ -0,0 +1,70 @@ +use std::sync::PoisonError; + +use thiserror::Error; + +use crate::fig_util::directories::DirectoryError; + +// A cloneable error +#[derive(Debug, Clone, thiserror::Error)] +#[error("Failed to open database: {}", .0)] +pub struct DbOpenError(pub(crate) String); + +#[derive(Debug, Error)] +pub enum Error { + #[error(transparent)] + IoError(#[from] std::io::Error), + #[error(transparent)] + JsonError(#[from] serde_json::Error), + #[error(transparent)] + FigUtilError(#[from] crate::fig_util::Error), + #[error("settings file is not a json object")] + SettingsNotObject, + #[error(transparent)] + DirectoryError(#[from] DirectoryError), + #[error("memory backend is not used")] + MemoryBackendNotUsed, + #[error(transparent)] + Rusqlite(#[from] rusqlite::Error), + #[error(transparent)] + R2d2(#[from] r2d2::Error), + #[error(transparent)] + DbOpenError(#[from] DbOpenError), + #[error("{}", .0)] + PoisonError(String), +} + +impl From> for Error { + fn from(value: PoisonError) -> Self { + Self::PoisonError(value.to_string()) + } +} + +pub type Result = std::result::Result; + +#[cfg(test)] +mod tests { + use super::*; + + fn all_errors() -> Vec { + vec![ + std::io::Error::new(std::io::ErrorKind::InvalidData, "oops").into(), + serde_json::from_str::<()>("oops").unwrap_err().into(), + crate::fig_util::Error::UnsupportedPlatform.into(), + Error::SettingsNotObject, + crate::fig_util::directories::DirectoryError::NoHomeDirectory.into(), + Error::MemoryBackendNotUsed, + rusqlite::Error::SqliteSingleThreadedMode.into(), + // r2d2::Error + DbOpenError("oops".into()).into(), + PoisonError::<()>::new(()).into(), + ] + } + + #[test] + fn test_error_display_debug() { + for error in all_errors() { + eprintln!("{}", error); + eprintln!("{:?}", error); + } + } +} diff --git a/crates/kiro-cli/src/fig_settings/keybindings.rs b/crates/kiro-cli/src/fig_settings/keybindings.rs new file mode 100644 index 0000000000..7718198a7c --- /dev/null +++ b/crates/kiro-cli/src/fig_settings/keybindings.rs @@ -0,0 +1,144 @@ +use std::fmt::Display; + +use serde::{ + Deserialize, + Serialize, +}; + +use super::{ + Error, + JsonStore, + OldSettings, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum Availability { + WhenFocused, + Always, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct KeyBindingDescription { + pub identifier: String, + pub name: Option, + pub description: Option, + pub category: Option, + pub availability: Option, + pub default_bindings: Option>, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct KeyBinding { + pub identifier: String, + pub binding: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(transparent)] +pub struct KeyBindings(pub Vec); + +impl KeyBindings { + pub fn load_hardcoded() -> Self { + let hardcoded_descriptions: Vec = + serde_json::from_str(include_str!("actions.json")).expect("Unable to load hardcoded actions"); + + let key_bindings = hardcoded_descriptions + .into_iter() + .flat_map(|description| { + description + .default_bindings + .unwrap_or_default() + .into_iter() + .map(move |binding| KeyBinding { + identifier: description.identifier.clone(), + binding, + }) + }) + .collect(); + + Self(key_bindings) + } + + fn load_from_json_map( + json_map: &serde_json::Map, + product_namespace: impl Display, + ) -> Self { + let key_bindings = json_map + .into_iter() + .filter_map(|(key, value)| { + if let Some(key) = key.strip_prefix(&format!("{product_namespace}.keybindings.",)) { + Some(KeyBinding { + identifier: value.as_str()?.into(), + binding: key.into(), + }) + } else { + None + } + }) + .collect(); + Self(key_bindings) + } + + pub fn load_from_settings(product_namespace: impl Display) -> Result { + let settings = OldSettings::load()?; + let map = settings.map(); + Ok(Self::load_from_json_map(&map, product_namespace)) + } +} + +impl IntoIterator for KeyBindings { + type IntoIter = std::vec::IntoIter; + type Item = KeyBinding; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_load_json() { + let json = KeyBindings::load_hardcoded(); + assert_eq!(json.0.len(), 24); + + assert_eq!(json.0[0].identifier, "insertSelected"); + assert_eq!(json.0[0].binding, "enter"); + } + + #[test] + fn test_load_from_json_map() { + let json_map = serde_json::json!({ + "autocomplete.keybindings.command+i": "toggleDescription", + "autocomplete.keybindings.control+-": "increaseSize", + "autocomplete.keybindings.control+/": "toggleDescription", + "autocomplete.keybindings.control+=": "decreaseSize", + "autocomplete.other": "other", + "other": "other", + }) + .as_object() + .unwrap() + .clone(); + + let json = KeyBindings::load_from_json_map(&json_map, "autocomplete"); + + assert_eq!(json.0.len(), 4); + + assert_eq!(json.0[0].identifier, "toggleDescription"); + assert_eq!(json.0[0].binding, "command+i"); + + assert_eq!(json.0[1].identifier, "increaseSize"); + assert_eq!(json.0[1].binding, "control+-"); + + assert_eq!(json.0[2].identifier, "toggleDescription"); + assert_eq!(json.0[2].binding, "control+/"); + + assert_eq!(json.0[3].identifier, "decreaseSize"); + assert_eq!(json.0[3].binding, "control+="); + } +} diff --git a/crates/kiro-cli/src/fig_settings/mod.rs b/crates/kiro-cli/src/fig_settings/mod.rs new file mode 100644 index 0000000000..415f2646ef --- /dev/null +++ b/crates/kiro-cli/src/fig_settings/mod.rs @@ -0,0 +1,349 @@ +pub mod error; +pub mod keybindings; +pub mod settings; +pub mod sqlite; +pub mod state; + +use std::fs::{ + self, + File, +}; +use std::io::{ + Read, + Seek, + SeekFrom, + Write, +}; +use std::path::PathBuf; + +pub use error::{ + Error, + Result, +}; +use fd_lock::RwLock as FileRwLock; +use parking_lot::{ + MappedRwLockReadGuard, + MappedRwLockWriteGuard, + RwLock, + RwLockReadGuard, + RwLockWriteGuard, +}; +use serde_json::Value; +pub use settings::Settings; +pub use state::State; +use thiserror::Error; +use tracing::error; + +use crate::fig_util::directories; + +pub type Map = serde_json::Map; + +static SETTINGS_FILE_LOCK: RwLock<()> = RwLock::new(()); + +static SETTINGS_DATA: RwLock> = RwLock::new(None); + +#[derive(Debug, Clone)] +pub enum Backend { + Global, + Memory(Map), +} + +pub enum ReadGuard<'a, T> { + Global(RwLockReadGuard<'a, Option>), + Memory(&'a T), +} + +impl<'a, T> ReadGuard<'a, T> { + pub fn map &U>(self, f: F) -> MappedReadGuard<'a, U> { + match self { + ReadGuard::Global(guard) => { + MappedReadGuard::Global(RwLockReadGuard::<'a, Option>::map(guard, |data: &Option| { + f(data.as_ref().expect("global backend is not used")) + })) + }, + ReadGuard::Memory(data) => MappedReadGuard::Memory(f(data)), + } + } + + pub fn try_map Option<&U>>(self, f: F) -> Option> { + match self { + ReadGuard::Global(guard) => RwLockReadGuard::<'a, Option>::try_map(guard, |data: &Option| { + f(data.as_ref().expect("global backend is not used")) + }) + .ok() + .map(MappedReadGuard::Global), + ReadGuard::Memory(data) => f(data).map(MappedReadGuard::Memory), + } + } +} + +impl std::ops::Deref for ReadGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + match self { + ReadGuard::Global(guard) => guard.as_ref().expect("global backend is not used"), + ReadGuard::Memory(data) => data, + } + } +} + +pub enum MappedReadGuard<'a, T> { + Global(MappedRwLockReadGuard<'a, T>), + Memory(&'a T), +} + +impl std::ops::Deref for MappedReadGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + match self { + MappedReadGuard::Global(guard) => guard, + MappedReadGuard::Memory(data) => data, + } + } +} + +pub enum WriteGuard<'a, T> { + Global(RwLockWriteGuard<'a, Option>), + Memory(&'a mut T), +} + +impl<'a, T> WriteGuard<'a, T> { + pub fn map &mut U>(self, f: F) -> MappedWriteGuard<'a, U> { + match self { + WriteGuard::Global(guard) => { + MappedWriteGuard::Global(RwLockWriteGuard::<'a, Option>::map(guard, |data: &mut Option| { + f(data.as_mut().expect("global backend is not used")) + })) + }, + WriteGuard::Memory(data) => MappedWriteGuard::Memory(f(data)), + } + } + + pub fn try_map Option<&mut U>>(self, f: F) -> Option> { + match self { + WriteGuard::Global(guard) => RwLockWriteGuard::<'a, Option>::try_map(guard, |data: &mut Option| { + f(data.as_mut().expect("global backend is not used")) + }) + .ok() + .map(MappedWriteGuard::Global), + WriteGuard::Memory(data) => f(data).map(MappedWriteGuard::Memory), + } + } +} + +impl std::ops::Deref for WriteGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + match self { + WriteGuard::Global(guard) => guard.as_ref().expect("global backend is not used"), + WriteGuard::Memory(data) => data, + } + } +} + +impl std::ops::DerefMut for WriteGuard<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + WriteGuard::Global(guard) => guard.as_mut().expect("global backend is not used"), + WriteGuard::Memory(data) => data, + } + } +} + +pub enum MappedWriteGuard<'a, T> { + Global(MappedRwLockWriteGuard<'a, T>), + Memory(&'a mut T), +} + +impl std::ops::Deref for MappedWriteGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + match self { + MappedWriteGuard::Global(guard) => guard, + MappedWriteGuard::Memory(data) => data, + } + } +} + +impl std::ops::DerefMut for MappedWriteGuard<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + match self { + MappedWriteGuard::Global(guard) => guard, + MappedWriteGuard::Memory(data) => data, + } + } +} + +pub trait JsonStore: Sized { + /// Path to the file + fn path() -> Result; + + /// In mem lock on the file + fn file_lock() -> &'static RwLock<()>; + + /// [RwLock] on the data, [None] if not using the global backend + fn data_lock() -> &'static RwLock>; + + fn new_from_backend(backend: Backend) -> Self; + + fn map(&self) -> ReadGuard<'_, Map>; + + fn map_mut(&mut self) -> WriteGuard<'_, Map>; + + fn load() -> Result { + let is_global = Self::data_lock().read().as_ref().is_some(); + if is_global { + Ok(Self::new_from_backend(Backend::Global)) + } else { + Ok(Self::new_from_backend(Backend::Memory(Self::load_from_file()?))) + } + } + + fn load_from_file() -> Result { + let path = Self::path()?; + + // If the folder doesn't exist, create it. + if let Some(parent) = path.parent() { + if !parent.exists() { + fs::create_dir_all(parent)?; + } + } + + let json: Map = { + let _lock_guard = Self::file_lock().write(); + + // If the file doesn't exist, create it. + if !path.exists() { + let mut file = FileRwLock::new(File::create(path)?); + file.write()?.write_all(b"{}")?; + serde_json::Map::new() + } else { + let mut file = FileRwLock::new(File::open(&path)?); + let mut read = file.write()?; + serde_json::from_reader(&mut *read)? + } + }; + + Ok(json) + } + + fn save_to_file(&self) -> Result<()> { + let path = Self::path()?; + + // If the folder doesn't exist, create it. + if let Some(parent) = path.parent() { + if !parent.exists() { + fs::create_dir_all(parent)?; + } + } + + let _lock_guard = Self::file_lock().write(); + + let mut file_opts = File::options(); + file_opts.create(true).write(true).truncate(true); + + #[cfg(unix)] + { + use std::os::unix::fs::OpenOptionsExt; + file_opts.mode(0o600); + } + + let mut file = FileRwLock::new(file_opts.open(&path)?); + let mut lock = file.write()?; + + if let Err(_err) = serde_json::to_writer_pretty(&mut *lock, &*self.map()) { + // Write {} to the file if the serialization failed + lock.seek(SeekFrom::Start(0))?; + lock.set_len(0)?; + lock.write_all(b"{}")?; + }; + lock.flush()?; + + Ok(()) + } + + fn set(&mut self, key: impl Into, value: impl Into) { + self.map_mut().insert(key.into(), value.into()); + } + + fn get(&self, key: impl AsRef) -> Option> { + self.map().try_map(|data| data.get(key.as_ref())) + } + + fn remove(&mut self, key: impl AsRef) -> Option { + self.map_mut().remove(key.as_ref()) + } + + fn get_mut(&mut self, key: impl Into) -> Option> { + self.map_mut().try_map(|data| data.get_mut(&key.into())) + } + + fn get_bool(&self, key: impl AsRef) -> Option { + self.get(key).and_then(|value| value.as_bool()) + } + + fn get_bool_or(&self, key: impl AsRef, default: bool) -> bool { + self.get_bool(key).unwrap_or(default) + } + + fn get_string(&self, key: impl AsRef) -> Option { + self.get(key).and_then(|value| value.as_str().map(|s| s.into())) + } + + fn get_string_or(&self, key: impl AsRef, default: String) -> String { + self.get_string(key).unwrap_or(default) + } + + fn get_int(&self, key: impl AsRef) -> Option { + self.get(key).and_then(|value| value.as_i64()) + } + + fn get_int_or(&self, key: impl AsRef, default: i64) -> i64 { + self.get_int(key).unwrap_or(default) + } +} + +pub struct OldSettings { + pub(crate) inner: Backend, +} + +impl JsonStore for OldSettings { + fn path() -> Result { + Ok(directories::settings_path()?) + } + + fn file_lock() -> &'static RwLock<()> { + &SETTINGS_FILE_LOCK + } + + fn data_lock() -> &'static RwLock> { + &SETTINGS_DATA + } + + fn new_from_backend(backend: Backend) -> Self { + match backend { + Backend::Global => Self { inner: Backend::Global }, + Backend::Memory(map) => Self { + inner: Backend::Memory(map), + }, + } + } + + fn map(&self) -> ReadGuard<'_, Map> { + match &self.inner { + Backend::Global => ReadGuard::Global(Self::data_lock().read()), + Backend::Memory(map) => ReadGuard::Memory(map), + } + } + + fn map_mut(&mut self) -> WriteGuard<'_, Map> { + match &mut self.inner { + Backend::Global => WriteGuard::Global(Self::data_lock().write()), + Backend::Memory(map) => WriteGuard::Memory(map), + } + } +} diff --git a/crates/kiro-cli/src/fig_settings/settings.rs b/crates/kiro-cli/src/fig_settings/settings.rs new file mode 100644 index 0000000000..0a0be53c01 --- /dev/null +++ b/crates/kiro-cli/src/fig_settings/settings.rs @@ -0,0 +1,243 @@ +use std::sync::{ + Arc, + Mutex, +}; + +use serde::de::DeserializeOwned; +use serde_json::{ + Map, + Value, +}; + +use super::{ + JsonStore, + OldSettings, + Result, +}; + +#[derive(Debug, Clone, Default)] +pub struct Settings(inner::Inner); + +mod inner { + use std::sync::{ + Arc, + Mutex, + }; + + use serde_json::{ + Map, + Value, + }; + + #[derive(Debug, Clone, Default)] + pub enum Inner { + #[default] + Real, + Fake(Arc>>), + } +} + +impl Settings { + pub fn new() -> Self { + Self(inner::Inner::Real) + } + + pub fn new_fake() -> Self { + Self(inner::Inner::Fake(Arc::new(Mutex::new(Map::new())))) + } + + pub fn from_slice(slice: &[(&str, Value)]) -> Self { + Self(inner::Inner::Fake(Arc::new(Mutex::new( + slice.iter().map(|(k, v)| ((*k).to_owned(), v.clone())).collect(), + )))) + } + + pub fn set_value(&self, key: impl Into, value: impl Into) -> Result<()> { + match &self.0 { + inner::Inner::Real => { + let mut settings = OldSettings::load()?; + settings.set(key, value); + settings.save_to_file()?; + Ok(()) + }, + inner::Inner::Fake(map) => { + map.lock()?.insert(key.into(), value.into()); + Ok(()) + }, + } + } + + pub fn remove_value(&self, key: impl AsRef) -> Result<()> { + match &self.0 { + inner::Inner::Real => { + let mut settings = OldSettings::load()?; + settings.remove(key); + settings.save_to_file()?; + Ok(()) + }, + inner::Inner::Fake(map) => { + map.lock()?.remove(key.as_ref()); + Ok(()) + }, + } + } + + pub fn get_value(&self, key: impl AsRef) -> Result> { + match &self.0 { + inner::Inner::Real => Ok(OldSettings::load()?.get(key.as_ref()).map(|v| v.clone())), + inner::Inner::Fake(map) => Ok(map.lock()?.get(key.as_ref()).cloned()), + } + } + + pub fn get(&self, key: impl AsRef) -> Result> { + match &self.0 { + inner::Inner::Real => { + let settings = OldSettings::load()?; + let v = settings.get(key); + match v.as_deref() { + Some(value) => Ok(Some(serde_json::from_value(value.clone())?)), + None => Ok(None), + } + }, + inner::Inner::Fake(map) => { + let value = map.lock()?.get(key.as_ref()).cloned(); + match value { + Some(value) => Ok(Some(serde_json::from_value(value)?)), + None => Ok(None), + } + }, + } + } + + pub fn get_bool(&self, key: impl AsRef) -> Result> { + match &self.0 { + inner::Inner::Real => Ok(OldSettings::load()?.get_bool(key.as_ref())), + inner::Inner::Fake(map) => Ok(map.lock()?.get(key.as_ref()).cloned().and_then(|v| v.as_bool())), + } + } + + pub fn get_bool_or(&self, key: impl AsRef, default: bool) -> bool { + self.get_bool(key).ok().flatten().unwrap_or(default) + } + + pub fn get_string(&self, key: impl AsRef) -> Result> { + match &self.0 { + inner::Inner::Real => Ok(OldSettings::load()?.get_string(key.as_ref())), + inner::Inner::Fake(map) => Ok(map + .lock()? + .get(key.as_ref()) + .cloned() + .and_then(|v| v.as_str().map(|s| s.to_owned()))), + } + } + + pub fn get_string_opt(&self, key: impl AsRef) -> Option { + self.get_string(key).ok().flatten() + } + + pub fn get_string_or(&self, key: impl AsRef, default: String) -> String { + self.get_string(key).ok().flatten().unwrap_or(default) + } + + pub fn get_int(&self, key: impl AsRef) -> Result> { + match &self.0 { + inner::Inner::Real => Ok(OldSettings::load()?.get_int(key.as_ref())), + inner::Inner::Fake(map) => Ok(map.lock()?.get(key.as_ref()).cloned().and_then(|v| v.as_i64())), + } + } + + pub fn get_int_or(&self, key: impl AsRef, default: i64) -> i64 { + self.get_int(key).ok().flatten().unwrap_or(default) + } +} + +pub trait SettingsProvider { + fn settings(&self) -> &Settings; +} + +impl SettingsProvider for Settings { + fn settings(&self) -> &Settings { + self + } +} + +pub fn set_value(key: impl Into, value: impl Into) -> Result<()> { + Settings::new().set_value(key, value) +} + +pub fn remove_value(key: impl AsRef) -> Result<()> { + Settings::new().remove_value(key) +} + +pub fn get_value(key: impl AsRef) -> Result> { + Settings::new().get_value(key) +} + +pub fn get(key: impl AsRef) -> Result> { + Settings::new().get(key) +} + +pub fn get_bool(key: impl AsRef) -> Result> { + Settings::new().get_bool(key) +} + +pub fn get_bool_or(key: impl AsRef, default: bool) -> bool { + Settings::new().get_bool_or(key, default) +} + +pub fn get_string(key: impl AsRef) -> Result> { + Settings::new().get_string(key) +} + +pub fn get_string_opt(key: impl AsRef) -> Option { + Settings::new().get_string_opt(key) +} + +pub fn get_string_or(key: impl AsRef, default: String) -> String { + Settings::new().get_string_or(key, default) +} + +pub fn get_int(key: impl AsRef) -> Result> { + Settings::new().get_int(key) +} + +pub fn get_int_or(key: impl AsRef, default: i64) -> i64 { + Settings::new().get_int_or(key, default) +} + +#[cfg(test)] +mod test { + use super::{ + Result, + Settings, + }; + + /// General read/write settings test + #[test] + fn test_settings() -> Result<()> { + let settings = Settings::from_slice(&[]); + + assert!(settings.get_value("test").unwrap().is_none()); + assert!(settings.get::("test").unwrap().is_none()); + settings.set_value("test", "hello :)")?; + assert!(settings.get_value("test").unwrap().is_some()); + assert!(settings.get::("test").unwrap().is_some()); + settings.remove_value("test")?; + assert!(settings.get_value("test").unwrap().is_none()); + assert!(settings.get::("test").unwrap().is_none()); + + assert!(!settings.get_bool_or("bool", false)); + settings.set_value("bool", true).unwrap(); + assert!(settings.get_bool("bool").unwrap().unwrap()); + + assert_eq!(settings.get_string_or("string", "hi".into()), "hi"); + settings.set_value("string", "hi").unwrap(); + assert_eq!(settings.get_string("string").unwrap().unwrap(), "hi"); + + assert_eq!(settings.get_int_or("int", 32), 32); + settings.set_value("int", 32).unwrap(); + assert_eq!(settings.get_int("int").unwrap().unwrap(), 32); + + Ok(()) + } +} diff --git a/crates/kiro-cli/src/fig_settings/sqlite.rs b/crates/kiro-cli/src/fig_settings/sqlite.rs new file mode 100644 index 0000000000..f8569f8c77 --- /dev/null +++ b/crates/kiro-cli/src/fig_settings/sqlite.rs @@ -0,0 +1,437 @@ +use std::ops::Deref; +use std::path::{ + Path, + PathBuf, +}; +use std::sync::LazyLock; + +use r2d2::Pool; +use r2d2_sqlite::SqliteConnectionManager; +use rusqlite::types::FromSql; +use rusqlite::{ + Connection, + Error, + ToSql, + params, +}; +use serde_json::Map; +use tracing::{ + debug, + info, +}; + +use super::error::DbOpenError; +use crate::fig_settings::Result; +use crate::fig_util::directories::fig_data_dir; + +const STATE_TABLE_NAME: &str = "state"; +const AUTH_TABLE_NAME: &str = "auth_kv"; + +pub static DATABASE: LazyLock> = LazyLock::new(|| { + let db = Db::new().map_err(|e| DbOpenError(e.to_string()))?; + db.migrate().map_err(|e| DbOpenError(e.to_string()))?; + Ok(db) +}); + +pub fn database() -> Result<&'static Db, DbOpenError> { + match DATABASE.as_ref() { + Ok(db) => Ok(db), + Err(err) => Err(err.clone()), + } +} + +#[derive(Debug)] +struct Migration { + name: &'static str, + sql: &'static str, +} + +macro_rules! migrations { + ($($name:expr),*) => {{ + &[ + $( + Migration { + name: $name, + sql: include_str!(concat!("sqlite_migrations/", $name, ".sql")), + } + ),* + ] + }}; +} + +const MIGRATIONS: &[Migration] = migrations![ + "000_migration_table", + "001_history_table", + "002_drop_history_in_ssh_docker", + "003_improved_history_timing", + "004_state_table", + "005_auth_table" +]; + +#[derive(Debug, Clone)] +pub struct Db { + pub(crate) pool: Pool, +} + +impl Db { + fn path() -> Result { + Ok(fig_data_dir()?.join("data.sqlite3")) + } + + pub fn new() -> Result { + Self::open(&Self::path()?) + } + + fn open(path: &Path) -> Result { + // make the parent dir if it doesnt exist + if let Some(parent) = path.parent() { + if !parent.exists() { + std::fs::create_dir_all(parent)?; + } + } + + let conn = SqliteConnectionManager::file(path); + let pool = Pool::builder().build(conn)?; + + // Check the unix permissions of the database file, set them to 0600 if they are not + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let metadata = std::fs::metadata(path)?; + let mut permissions = metadata.permissions(); + if permissions.mode() & 0o777 != 0o600 { + debug!(?path, "Setting database file permissions to 0600"); + permissions.set_mode(0o600); + std::fs::set_permissions(path, permissions)?; + } + } + + Ok(Self { pool }) + } + + pub(crate) fn mock() -> Self { + let conn = SqliteConnectionManager::memory(); + let pool = Pool::builder().build(conn).unwrap(); + Self { pool } + } + + pub fn migrate(&self) -> Result<()> { + let mut conn = self.pool.get()?; + let transaction = conn.transaction()?; + + // select the max migration id + let max_id = max_migration(&transaction); + + for (version, migration) in MIGRATIONS.iter().enumerate() { + // skip migrations that already exist + match max_id { + Some(max_id) if max_id >= version as i64 => continue, + _ => (), + }; + + // execute the migration + transaction.execute_batch(migration.sql)?; + + info!(%version, name =% migration.name, "Applying migration"); + + // insert the migration entry + transaction.execute( + "INSERT INTO migrations (version, migration_time) VALUES (?1, strftime('%s', 'now'));", + params![version], + )?; + } + + // commit the transaction + transaction.commit()?; + + Ok(()) + } + + fn get_value(&self, table: &'static str, key: impl AsRef) -> Result> { + let conn = self.pool.get()?; + let mut stmt = conn.prepare(&format!("SELECT value FROM {table} WHERE key = ?1"))?; + match stmt.query_row([key.as_ref()], |row| row.get(0)) { + Ok(data) => Ok(Some(data)), + Err(Error::QueryReturnedNoRows) => Ok(None), + Err(err) => Err(err.into()), + } + } + + pub fn get_state_value(&self, key: impl AsRef) -> Result> { + self.get_value(STATE_TABLE_NAME, key) + } + + pub fn get_auth_value(&self, key: impl AsRef) -> Result> { + self.get_value(AUTH_TABLE_NAME, key) + } + + fn set_value(&self, table: &'static str, key: impl AsRef, value: T) -> Result<()> { + self.pool.get()?.execute( + &format!("INSERT OR REPLACE INTO {table} (key, value) VALUES (?1, ?2)"), + params![key.as_ref(), value], + )?; + Ok(()) + } + + pub fn set_state_value(&self, key: impl AsRef, value: impl Into) -> Result<()> { + self.set_value(STATE_TABLE_NAME, key, value.into()) + } + + pub fn set_auth_value(&self, key: impl AsRef, value: impl Into) -> Result<()> { + self.set_value(AUTH_TABLE_NAME, key, value.into()) + } + + fn unset_value(&self, table: &'static str, key: impl AsRef) -> Result<()> { + self.pool + .get()? + .execute(&format!("DELETE FROM {table} WHERE key = ?1"), [key.as_ref()])?; + Ok(()) + } + + pub fn unset_state_value(&self, key: impl AsRef) -> Result<()> { + self.unset_value(STATE_TABLE_NAME, key) + } + + pub fn unset_auth_value(&self, key: impl AsRef) -> Result<()> { + self.unset_value(AUTH_TABLE_NAME, key) + } + + fn is_value_set(&self, table: &'static str, key: impl AsRef) -> Result { + let conn = self.pool.get()?; + let mut stmt = conn.prepare(&format!("SELECT value FROM {table} WHERE key = ?1"))?; + match stmt.query_row([key.as_ref()], |_| Ok(())) { + Ok(()) => Ok(true), + Err(Error::QueryReturnedNoRows) => Ok(false), + Err(err) => Err(err.into()), + } + } + + pub fn is_state_value_set(&self, key: impl AsRef) -> Result { + self.is_value_set(STATE_TABLE_NAME, key) + } + + pub fn is_auth_value_set(&self, key: impl AsRef) -> Result { + self.is_value_set(AUTH_TABLE_NAME, key) + } + + fn all_values(&self, table: &'static str) -> Result> { + let conn = self.pool.get()?; + let mut stmt = conn.prepare(&format!("SELECT key, value FROM {table}"))?; + let rows = stmt.query_map([], |row| { + let key = row.get(0)?; + let value = row.get(1)?; + Ok((key, value)) + })?; + + let mut map = Map::new(); + for row in rows { + let (key, value) = row?; + map.insert(key, value); + } + + Ok(map) + } + + pub fn all_state_values(&self) -> Result> { + self.all_values(STATE_TABLE_NAME) + } + + // atomic style operations + + fn atomic_op( + &self, + key: impl AsRef, + op: impl FnOnce(&Option) -> Option, + ) -> Result> { + let mut conn = self.pool.get()?; + let tx = conn.transaction()?; + + let value = tx.query_row::, _, _>( + &format!("SELECT value FROM {STATE_TABLE_NAME} WHERE key = ?1"), + [key.as_ref()], + |row| row.get(0), + ); + + let value_0: Option = match value { + Ok(value) => value, + Err(Error::QueryReturnedNoRows) => None, + Err(err) => return Err(err.into()), + }; + + let value_1 = op(&value_0); + + if let Some(value) = value_1 { + tx.execute( + &format!("INSERT OR REPLACE INTO {STATE_TABLE_NAME} (key, value) VALUES (?1, ?2)"), + params![key.as_ref(), value], + )?; + } else { + tx.execute( + &format!("DELETE FROM {STATE_TABLE_NAME} WHERE key = ?1"), + [key.as_ref()], + )?; + } + + tx.commit()?; + + Ok(value_0) + } + + /// Atomically get the value of a key, then perform an or operation on it + /// and set the new value. If the key does not exist, set it to the or value. + pub fn atomic_bool_or(&self, key: impl AsRef, or: bool) -> Result { + self.atomic_op::(key, |val| match val { + // Some(val) => Some(serde_json::Value::Bool( || or)), + Some(serde_json::Value::Bool(b)) => Some(serde_json::Value::Bool(*b || or)), + Some(_) | None => Some(serde_json::Value::Bool(or)), + }) + .map(|val| val.and_then(|val| val.as_bool()).unwrap_or(false)) + } +} + +fn max_migration>(conn: &C) -> Option { + let mut stmt = conn.prepare("SELECT MAX(id) FROM migrations").ok()?; + stmt.query_row([], |row| row.get(0)).ok() +} + +#[cfg(test)] +mod tests { + use super::*; + + fn mock() -> Db { + let db = Db::mock(); + db.migrate().unwrap(); + db + } + + #[test] + fn test_migrate() { + let db = mock(); + + // assert migration count is correct + let max_migration = max_migration(&&*db.pool.get().unwrap()); + assert_eq!(max_migration, Some(MIGRATIONS.len() as i64)); + } + + #[test] + fn list_migrations() { + // Assert the migrations are in order + assert!(MIGRATIONS.windows(2).all(|w| w[0].name <= w[1].name)); + + // Assert the migrations start with their index + assert!( + MIGRATIONS + .iter() + .enumerate() + .all(|(i, m)| m.name.starts_with(&format!("{:03}_", i))) + ); + + // Assert all the files in migrations/ are in the list + let migration_folder = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("src/sqlite/migrations"); + let migration_count = std::fs::read_dir(migration_folder).unwrap().count(); + assert_eq!(MIGRATIONS.len(), migration_count); + } + + #[test] + fn state_table_tests() { + let db = mock(); + + // set + db.set_state_value("test", "test").unwrap(); + db.set_state_value("int", 1).unwrap(); + db.set_state_value("float", 1.0).unwrap(); + db.set_state_value("bool", true).unwrap(); + db.set_state_value("null", ()).unwrap(); + db.set_state_value("array", vec![1, 2, 3]).unwrap(); + db.set_state_value("object", serde_json::json!({ "test": "test" })) + .unwrap(); + db.set_state_value("binary", b"test".to_vec()).unwrap(); + + // get + assert_eq!(db.get_state_value("test").unwrap().unwrap(), "test"); + assert_eq!(db.get_state_value("int").unwrap().unwrap(), 1); + assert_eq!(db.get_state_value("float").unwrap().unwrap(), 1.0); + assert_eq!(db.get_state_value("bool").unwrap().unwrap(), true); + assert_eq!(db.get_state_value("null").unwrap().unwrap(), serde_json::Value::Null); + assert_eq!( + db.get_state_value("array").unwrap().unwrap(), + serde_json::json!([1, 2, 3]) + ); + assert_eq!( + db.get_state_value("object").unwrap().unwrap(), + serde_json::json!({ "test": "test" }) + ); + assert_eq!( + db.get_state_value("binary").unwrap().unwrap(), + serde_json::json!(b"test".to_vec()) + ); + + // unset + db.unset_state_value("test").unwrap(); + db.unset_state_value("int").unwrap(); + + // is_set + assert!(!db.is_state_value_set("test").unwrap()); + assert!(!db.is_state_value_set("int").unwrap()); + assert!(db.is_state_value_set("float").unwrap()); + assert!(db.is_state_value_set("bool").unwrap()); + } + + #[test] + fn auth_table_tests() { + let db = mock(); + + db.set_auth_value("test", "test").unwrap(); + assert_eq!(db.get_auth_value("test").unwrap().unwrap(), "test"); + assert!(db.is_auth_value_set("test").unwrap()); + db.unset_auth_value("test").unwrap(); + assert!(!db.is_auth_value_set("test").unwrap()); + + assert_eq!(db.get_auth_value("test2").unwrap(), None); + assert!(!db.is_auth_value_set("test2").unwrap()); + } + + #[test] + fn db_open_time() { + let tempdir = tempfile::tempdir().unwrap(); + let path = tempdir.path().join("data.sqlite3"); + + // init the db + let db = Db::open(&path).unwrap(); + db.migrate().unwrap(); + drop(db); + + let test_count = 100; + + let instant = std::time::Instant::now(); + let db = Db::open(&path).unwrap(); + for _ in 0..test_count { + db.set_state_value("test", "test").unwrap(); + db.get_state_value("test").unwrap().unwrap(); + } + let elapsed = instant.elapsed() / test_count; + println!("time: {:?}", elapsed); + } + + #[test] + fn test_atomic_bool() { + let key = "test"; + let db = mock(); + + let cases = [ + (None, false, false, false), + (None, true, false, true), + (Some(false), false, false, false), + (Some(false), true, false, true), + (Some(true), false, true, true), + (Some(true), true, true, true), + ]; + + for (a, b, c, d) in cases { + db.set_state_value(key, a).unwrap(); + assert_eq!(db.atomic_bool_or(key, b).unwrap(), c); + assert_eq!(db.get_state_value(key).unwrap().unwrap(), d); + db.unset_state_value(key).unwrap(); + } + } +} diff --git a/crates/kiro-cli/src/fig_settings/sqlite_migrations/000_migration_table.sql b/crates/kiro-cli/src/fig_settings/sqlite_migrations/000_migration_table.sql new file mode 100644 index 0000000000..1437deb0d9 --- /dev/null +++ b/crates/kiro-cli/src/fig_settings/sqlite_migrations/000_migration_table.sql @@ -0,0 +1,5 @@ +CREATE TABLE IF NOT EXISTS migrations ( + id INTEGER PRIMARY KEY, + version INTEGER NOT NULL, + migration_time INTEGER NOT NULL +); \ No newline at end of file diff --git a/crates/kiro-cli/src/fig_settings/sqlite_migrations/001_history_table.sql b/crates/kiro-cli/src/fig_settings/sqlite_migrations/001_history_table.sql new file mode 100644 index 0000000000..7d25913387 --- /dev/null +++ b/crates/kiro-cli/src/fig_settings/sqlite_migrations/001_history_table.sql @@ -0,0 +1,13 @@ +CREATE TABLE IF NOT EXISTS history ( + id INTEGER PRIMARY KEY, + command TEXT, + shell TEXT, + pid INTEGER, + session_id TEXT, + cwd TEXT, + time INTEGER, + in_ssh INTEGER, + in_docker INTEGER, + hostname TEXT, + exit_code INTEGER +); diff --git a/crates/kiro-cli/src/fig_settings/sqlite_migrations/002_drop_history_in_ssh_docker.sql b/crates/kiro-cli/src/fig_settings/sqlite_migrations/002_drop_history_in_ssh_docker.sql new file mode 100644 index 0000000000..45e518e024 --- /dev/null +++ b/crates/kiro-cli/src/fig_settings/sqlite_migrations/002_drop_history_in_ssh_docker.sql @@ -0,0 +1,3 @@ +ALTER TABLE history DROP COLUMN in_ssh; +ALTER TABLE history DROP COLUMN in_docker; + \ No newline at end of file diff --git a/crates/kiro-cli/src/fig_settings/sqlite_migrations/003_improved_history_timing.sql b/crates/kiro-cli/src/fig_settings/sqlite_migrations/003_improved_history_timing.sql new file mode 100644 index 0000000000..58e3bb1c3c --- /dev/null +++ b/crates/kiro-cli/src/fig_settings/sqlite_migrations/003_improved_history_timing.sql @@ -0,0 +1,3 @@ +ALTER TABLE history RENAME COLUMN time TO start_time; +ALTER TABLE history ADD COLUMN end_time INTEGER; +ALTER TABLE history ADD COLUMN duration INTEGER; diff --git a/crates/kiro-cli/src/fig_settings/sqlite_migrations/004_state_table.sql b/crates/kiro-cli/src/fig_settings/sqlite_migrations/004_state_table.sql new file mode 100644 index 0000000000..3a7b43c00e --- /dev/null +++ b/crates/kiro-cli/src/fig_settings/sqlite_migrations/004_state_table.sql @@ -0,0 +1,4 @@ +CREATE TABLE state ( + key TEXT PRIMARY KEY, + value TEXT +); diff --git a/crates/kiro-cli/src/fig_settings/sqlite_migrations/005_auth_table.sql b/crates/kiro-cli/src/fig_settings/sqlite_migrations/005_auth_table.sql new file mode 100644 index 0000000000..17b28fb8e1 --- /dev/null +++ b/crates/kiro-cli/src/fig_settings/sqlite_migrations/005_auth_table.sql @@ -0,0 +1,6 @@ +-- We create a separate auth_kv to ensure the data is not available in all the same +-- places that the state is available in +CREATE TABLE auth_kv ( + key TEXT PRIMARY KEY, + value TEXT +); diff --git a/crates/kiro-cli/src/fig_settings/state.rs b/crates/kiro-cli/src/fig_settings/state.rs new file mode 100644 index 0000000000..dea24596f9 --- /dev/null +++ b/crates/kiro-cli/src/fig_settings/state.rs @@ -0,0 +1,202 @@ +use serde::de::DeserializeOwned; +use serde_json::{ + Map, + Value, +}; + +use super::sqlite::{ + Db, + database, +}; +use crate::fig_settings::Result; + +#[derive(Debug, Clone, Default)] +pub struct State(inner::Inner); + +mod inner { + use super::*; + + #[derive(Debug, Clone, Default)] + pub enum Inner { + #[default] + Real, + Fake(Db), + } +} + +impl State { + pub fn new() -> Self { + Self::default() + } + + pub fn new_fake() -> Self { + let db = Db::mock(); + db.migrate().unwrap(); + Self(inner::Inner::Fake(db)) + } + + pub fn from_slice(slice: &[(&str, Value)]) -> Self { + let fake = Self::new_fake(); + for (key, value) in slice { + fake.set_value(key, value.clone()).unwrap(); + } + fake + } + + fn database(&self) -> Result<&Db> { + match &self.0 { + inner::Inner::Real => Ok(database()?), + inner::Inner::Fake(db) => Ok(db), + } + } + + pub fn all(&self) -> Result> { + self.database()?.all_state_values() + } + + pub fn set_value(&self, key: impl AsRef, value: impl Into) -> Result<()> { + self.database()?.set_state_value(key, value)?; + Ok(()) + } + + pub fn remove_value(&self, key: impl AsRef) -> Result<()> { + self.database()?.unset_state_value(key)?; + Ok(()) + } + + pub fn get_value(&self, key: impl AsRef) -> Result> { + self.database()?.get_state_value(key) + } + + pub fn get(&self, key: impl AsRef) -> Result> { + Ok(self + .database()? + .get_state_value(key)? + .map(|value| serde_json::from_value(value.clone())) + .transpose()?) + } + + pub fn get_bool(&self, key: impl AsRef) -> Result> { + Ok(self.database()?.get_state_value(key)?.and_then(|value| value.as_bool())) + } + + pub fn get_bool_or(&self, key: impl AsRef, default: bool) -> bool { + self.get_bool(key).ok().flatten().unwrap_or(default) + } + + pub fn get_string(&self, key: impl AsRef) -> Result> { + Ok(self.database()?.get_state_value(key)?.and_then(|value| match value { + Value::String(s) => Some(s), + _ => None, + })) + } + + pub fn get_string_or(&self, key: impl AsRef, default: impl Into) -> String { + self.get_string(key).ok().flatten().unwrap_or_else(|| default.into()) + } + + pub fn get_int(&self, key: impl AsRef) -> Result> { + Ok(self.database()?.get_state_value(key)?.and_then(|value| value.as_i64())) + } + + pub fn get_int_or(&self, key: impl AsRef, default: i64) -> i64 { + self.get_int(key).ok().flatten().unwrap_or(default) + } + + // Atomic style operations + + pub fn atomic_bool_or(&self, key: impl AsRef, or: bool) -> Result { + self.database()?.atomic_bool_or(key, or) + } +} + +pub trait StateProvider { + fn state(&self) -> &State; +} + +impl StateProvider for State { + fn state(&self) -> &State { + self + } +} + +pub fn all() -> Result> { + State::new().all() +} + +pub fn set_value(key: impl AsRef, value: impl Into) -> Result<()> { + State::new().set_value(key, value) +} + +pub fn remove_value(key: impl AsRef) -> Result<()> { + State::new().remove_value(key) +} + +pub fn get_value(key: impl AsRef) -> Result> { + State::new().get_value(key) +} + +pub fn get(key: impl AsRef) -> Result> { + State::new().get(key) +} + +pub fn get_bool(key: impl AsRef) -> Result> { + State::new().get_bool(key) +} + +pub fn get_bool_or(key: impl AsRef, default: bool) -> bool { + State::new().get_bool_or(key, default) +} + +pub fn get_string(key: impl AsRef) -> Result> { + State::new().get_string(key) +} + +pub fn get_string_or(key: impl AsRef, default: impl Into) -> String { + State::new().get_string_or(key, default) +} + +pub fn get_int(key: impl AsRef) -> Result> { + State::new().get_int(key) +} + +pub fn get_int_or(key: impl AsRef, default: i64) -> i64 { + State::new().get_int_or(key, default) +} + +#[cfg(test)] +mod tests { + use super::{ + Result, + State, + }; + + /// General read/write state test + #[test] + fn test_state() -> Result<()> { + let state = State::new_fake(); + + assert!(state.get_value("test").unwrap().is_none()); + assert!(state.get::("test").unwrap().is_none()); + state.set_value("test", "hello :)")?; + assert!(state.get_value("test").unwrap().is_some()); + assert!(state.get::("test").unwrap().is_some()); + state.remove_value("test")?; + assert!(state.get_value("test").unwrap().is_none()); + assert!(state.get::("test").unwrap().is_none()); + + assert!(!state.get_bool_or("bool", false)); + state.set_value("bool", true).unwrap(); + assert!(state.get_bool("bool").unwrap().unwrap()); + + assert_eq!(state.get_string_or("string", "hi"), "hi"); + state.set_value("string", "hi").unwrap(); + assert_eq!(state.get_string("string").unwrap().unwrap(), "hi"); + + assert_eq!(state.get_int_or("int", 32), 32); + state.set_value("int", 32).unwrap(); + assert_eq!(state.get_int("int").unwrap().unwrap(), 32); + + Ok(()) + } +} diff --git a/crates/kiro-cli/src/fig_telemetry/cognito.rs b/crates/kiro-cli/src/fig_telemetry/cognito.rs new file mode 100644 index 0000000000..21e286a570 --- /dev/null +++ b/crates/kiro-cli/src/fig_telemetry/cognito.rs @@ -0,0 +1,145 @@ +use amzn_toolkit_telemetry_client::config::BehaviorVersion; +use aws_credential_types::provider::error::CredentialsError; +use aws_credential_types::{ + Credentials, + provider, +}; +use aws_sdk_cognitoidentity::primitives::{ + DateTime, + DateTimeFormat, +}; + +use crate::fig_aws_common::app_name; +use crate::fig_telemetry::TelemetryStage; + +const CREDENTIALS_KEY: &str = "telemetry-cognito-credentials"; + +const DATE_TIME_FORMAT: DateTimeFormat = DateTimeFormat::DateTime; + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +struct CredentialsJson { + pub access_key_id: Option, + pub secret_key: Option, + pub session_token: Option, + pub expiration: Option, +} + +pub(crate) async fn get_cognito_credentials_send( + telemetry_stage: &TelemetryStage, +) -> Result { + let conf = aws_sdk_cognitoidentity::Config::builder() + .behavior_version(BehaviorVersion::v2025_01_17()) + .region(telemetry_stage.region.clone()) + .app_name(app_name()) + .build(); + let client = aws_sdk_cognitoidentity::Client::from_conf(conf); + + let identity_id = client + .get_id() + .identity_pool_id(telemetry_stage.cognito_pool_id) + .send() + .await + .map_err(CredentialsError::provider_error)? + .identity_id + .ok_or(CredentialsError::provider_error("no identity_id from get_id"))?; + + let credentials = client + .get_credentials_for_identity() + .identity_id(identity_id) + .send() + .await + .map_err(CredentialsError::provider_error)? + .credentials + .ok_or(CredentialsError::provider_error( + "no credentials from get_credentials_for_identity", + ))?; + + if let Ok(json) = serde_json::to_value(CredentialsJson { + access_key_id: credentials.access_key_id.clone(), + secret_key: credentials.secret_key.clone(), + session_token: credentials.session_token.clone(), + expiration: credentials.expiration.and_then(|t| t.fmt(DATE_TIME_FORMAT).ok()), + }) { + crate::fig_settings::state::set_value(CREDENTIALS_KEY, json).ok(); + } + + let Some(access_key_id) = credentials.access_key_id else { + return Err(CredentialsError::provider_error("access key id not found")); + }; + + let Some(secret_key) = credentials.secret_key else { + return Err(CredentialsError::provider_error("secret access key not found")); + }; + + Ok(Credentials::new( + access_key_id, + secret_key, + credentials.session_token, + credentials.expiration.and_then(|dt| dt.try_into().ok()), + "", + )) +} + +pub(crate) async fn get_cognito_credentials(telemetry_stage: &TelemetryStage) -> Result { + match crate::fig_settings::state::get_string(CREDENTIALS_KEY).ok().flatten() { + Some(creds) => { + let CredentialsJson { + access_key_id, + secret_key, + session_token, + expiration, + }: CredentialsJson = serde_json::from_str(&creds).map_err(CredentialsError::provider_error)?; + + let Some(access_key_id) = access_key_id else { + return get_cognito_credentials_send(telemetry_stage).await; + }; + + let Some(secret_key) = secret_key else { + return get_cognito_credentials_send(telemetry_stage).await; + }; + + Ok(Credentials::new( + access_key_id, + secret_key, + session_token, + expiration + .and_then(|s| DateTime::from_str(&s, DATE_TIME_FORMAT).ok()) + .and_then(|dt| dt.try_into().ok()), + "", + )) + }, + None => get_cognito_credentials_send(telemetry_stage).await, + } +} + +#[derive(Debug)] +pub(crate) struct CognitoProvider { + telemetry_stage: TelemetryStage, +} + +impl CognitoProvider { + pub(crate) fn new(telemetry_stage: TelemetryStage) -> CognitoProvider { + CognitoProvider { telemetry_stage } + } +} + +impl provider::ProvideCredentials for CognitoProvider { + fn provide_credentials<'a>(&'a self) -> provider::future::ProvideCredentials<'a> + where + Self: 'a, + { + provider::future::ProvideCredentials::new(get_cognito_credentials(&self.telemetry_stage)) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[tokio::test] + async fn pools() { + for telemetry_stage in [TelemetryStage::BETA, TelemetryStage::EXTERNAL_PROD] { + get_cognito_credentials_send(&telemetry_stage).await.unwrap(); + } + } +} diff --git a/crates/kiro-cli/src/fig_telemetry/definitions.rs b/crates/kiro-cli/src/fig_telemetry/definitions.rs new file mode 100644 index 0000000000..0421534f75 --- /dev/null +++ b/crates/kiro-cli/src/fig_telemetry/definitions.rs @@ -0,0 +1,35 @@ +// https://github.com/aws/aws-toolkit-common/blob/main/telemetry/telemetryformat.md + +pub trait IntoMetricDatum: Send { + fn into_metric_datum(self) -> amzn_toolkit_telemetry_client::types::MetricDatum; +} + +include!(concat!(env!("OUT_DIR"), "/mod.rs")); + +#[cfg(test)] +mod tests { + use std::time::SystemTime; + + use super::*; + use crate::fig_telemetry::definitions::metrics::CodewhispererterminalAddChatMessage; + + #[test] + fn test_serde() { + let metric_datum_init = Metric::CodewhispererterminalAddChatMessage(CodewhispererterminalAddChatMessage { + amazonq_conversation_id: None, + codewhispererterminal_context_file_length: None, + create_time: Some(SystemTime::now()), + value: None, + credential_start_url: Some("https://example.com".to_owned().into()), + codewhispererterminal_in_cloudshell: Some(false.into()), + }); + + let s = serde_json::to_string_pretty(&metric_datum_init).unwrap(); + println!("{s}"); + + let metric_datum_out: Metric = serde_json::from_str(&s).unwrap(); + println!("{metric_datum_out:#?}"); + + assert_eq!(metric_datum_init, metric_datum_out); + } +} diff --git a/crates/kiro-cli/src/fig_telemetry/endpoint.rs b/crates/kiro-cli/src/fig_telemetry/endpoint.rs new file mode 100644 index 0000000000..681d19af76 --- /dev/null +++ b/crates/kiro-cli/src/fig_telemetry/endpoint.rs @@ -0,0 +1,32 @@ +use amzn_toolkit_telemetry_client::config::endpoint::{ + Endpoint, + EndpointFuture, + Params, + ResolveEndpoint, +}; + +#[derive(Debug, Clone, Copy)] +pub(crate) struct StaticEndpoint(pub &'static str); + +impl ResolveEndpoint for StaticEndpoint { + fn resolve_endpoint<'a>(&'a self, _params: &'a Params) -> EndpointFuture<'a> { + let endpoint = Endpoint::builder().url(self.0).build(); + tracing::info!(?endpoint, "Resolving endpoint"); + EndpointFuture::ready(Ok(endpoint)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_static_endpoint() { + let endpoint = StaticEndpoint("https://example.com"); + let params = Params::builder().build().unwrap(); + let endpoint = endpoint.resolve_endpoint(¶ms).await.unwrap(); + assert_eq!(endpoint.url(), "https://example.com"); + assert!(endpoint.properties().is_empty()); + assert!(endpoint.headers().count() == 0); + } +} diff --git a/crates/kiro-cli/src/fig_telemetry/event.rs b/crates/kiro-cli/src/fig_telemetry/event.rs new file mode 100644 index 0000000000..91832e6d34 --- /dev/null +++ b/crates/kiro-cli/src/fig_telemetry/event.rs @@ -0,0 +1,150 @@ +use std::time::SystemTime; + +use crate::fig_telemetry_core::{ + Event, + EventType, + MetricDatum, +}; + +/// Wrapper around the default telemetry [Event]. Used to initialize other metadata fields +/// within the global telemetry emitter implementation. +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +pub struct AppTelemetryEvent(Event); + +impl std::ops::Deref for AppTelemetryEvent { + type Target = Event; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl AppTelemetryEvent { + pub async fn new(ty: EventType) -> Self { + Self(Event { + ty, + credential_start_url: crate::fig_auth::builder_id_token() + .await + .ok() + .flatten() + .and_then(|t| t.start_url), + created_time: Some(SystemTime::now()), + }) + } + + pub async fn from_event(event: Event) -> Self { + let credential_start_url = match event.credential_start_url { + Some(v) => Some(v), + None => crate::fig_auth::builder_id_token() + .await + .ok() + .flatten() + .and_then(|t| t.start_url), + }; + Self(Event { + ty: event.ty, + credential_start_url, + created_time: event.created_time.or_else(|| Some(SystemTime::now())), + }) + } + + pub fn into_metric_datum(self) -> Option { + self.0.into_metric_datum() + } +} + +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InlineShellCompletionActionedOptions {} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + use crate::fig_telemetry_core::TelemetryResult; + + async fn user_logged_in() -> AppTelemetryEvent { + AppTelemetryEvent::new(EventType::UserLoggedIn {}).await + } + + async fn refresh_credentials() -> AppTelemetryEvent { + AppTelemetryEvent::new(EventType::RefreshCredentials { + request_id: "request_id".into(), + result: TelemetryResult::Failed, + reason: Some("some failure".into()), + oauth_flow: "pkce".into(), + }) + .await + } + + async fn cli_subcommand_executed() -> AppTelemetryEvent { + AppTelemetryEvent::new(EventType::CliSubcommandExecuted { + subcommand: "test".into(), + }) + .await + } + + async fn chat_start() -> AppTelemetryEvent { + AppTelemetryEvent::new(EventType::ChatStart { + conversation_id: "XXX".into(), + }) + .await + } + + async fn chat_end() -> AppTelemetryEvent { + AppTelemetryEvent::new(EventType::ChatEnd { + conversation_id: "XXX".into(), + }) + .await + } + + async fn chat_added_message() -> AppTelemetryEvent { + AppTelemetryEvent::new(EventType::ChatAddedMessage { + conversation_id: "XXX".into(), + message_id: "YYY".into(), + context_file_length: Some(5), + }) + .await + } + + pub(crate) async fn all_events() -> Vec { + vec![ + user_logged_in().await, + refresh_credentials().await, + cli_subcommand_executed().await, + chat_start().await, + chat_end().await, + chat_added_message().await, + ] + } + + #[tokio::test] + async fn from_event_test() { + let event = Event { + ty: EventType::UserLoggedIn {}, + credential_start_url: Some("https://example.com".into()), + created_time: None, + }; + let app_event = AppTelemetryEvent::from_event(event).await; + assert_eq!(app_event.ty, EventType::UserLoggedIn {}); + assert_eq!(app_event.credential_start_url, Some("https://example.com".into())); + assert!(app_event.created_time.is_some()); + } + + #[tokio::test] + async fn test_event_ser() { + for event in all_events().await { + let json = serde_json::to_string_pretty(&event).unwrap(); + println!("\n{json}\n"); + } + } + + #[tokio::test] + async fn test_into_metric_datum() { + for event in all_events().await { + let metric_datum = event.into_metric_datum(); + if let Some(metric_datum) = metric_datum { + println!("\n{}: {metric_datum:?}\n", metric_datum.metric_name()); + } + } + } +} diff --git a/crates/kiro-cli/src/fig_telemetry/install_method.rs b/crates/kiro-cli/src/fig_telemetry/install_method.rs new file mode 100644 index 0000000000..2a541252af --- /dev/null +++ b/crates/kiro-cli/src/fig_telemetry/install_method.rs @@ -0,0 +1,45 @@ +use std::process::Command; +use std::sync::LazyLock; + +use serde::{ + Deserialize, + Serialize, +}; + +static INSTALL_METHOD: LazyLock = LazyLock::new(|| { + if let Ok(output) = Command::new("brew").args(["list", "amazon-q", "-1"]).output() { + if output.status.success() { + return InstallMethod::Brew; + } + } + + if let Ok(current_exe) = std::env::current_exe() { + if current_exe.components().any(|c| c.as_os_str() == ".toolbox") { + return InstallMethod::Toolbox; + } + } + + InstallMethod::Unknown +}); + +/// The method of installation that Fig was installed with +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum InstallMethod { + Brew, + Toolbox, + Unknown, +} + +impl std::fmt::Display for InstallMethod { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + InstallMethod::Brew => "brew", + InstallMethod::Toolbox => "toolbox", + InstallMethod::Unknown => "unknown", + }) + } +} + +pub fn get_install_method() -> InstallMethod { + *INSTALL_METHOD +} diff --git a/crates/kiro-cli/src/fig_telemetry/mod.rs b/crates/kiro-cli/src/fig_telemetry/mod.rs new file mode 100644 index 0000000000..f7292925a3 --- /dev/null +++ b/crates/kiro-cli/src/fig_telemetry/mod.rs @@ -0,0 +1,693 @@ +pub mod cognito; +pub mod definitions; +pub mod endpoint; +mod event; +mod install_method; +mod util; + +use std::any::Any; +use std::sync::LazyLock; +use std::time::{ + Duration, + SystemTime, +}; + +use amzn_codewhisperer_client::types::{ + ChatAddMessageEvent, + CompletionType, + IdeCategory, + OperatingSystem, + OptOutPreference, + ProgrammingLanguage, + TelemetryEvent, + TerminalUserInteractionEvent, + TerminalUserInteractionEventType, + UserContext, + UserTriggerDecisionEvent, +}; +use amzn_toolkit_telemetry_client::config::{ + BehaviorVersion, + Region, +}; +use amzn_toolkit_telemetry_client::error::DisplayErrorContext; +use amzn_toolkit_telemetry_client::types::AwsProduct; +use amzn_toolkit_telemetry_client::{ + Client as ToolkitTelemetryClient, + Config, +}; +use aws_credential_types::provider::SharedCredentialsProvider; +use aws_smithy_types::DateTime; +use cognito::CognitoProvider; +use endpoint::StaticEndpoint; +pub use event::AppTelemetryEvent; +pub use install_method::{ + InstallMethod, + get_install_method, +}; +use tokio::sync::{ + Mutex, + OnceCell, +}; +use tokio::task::JoinSet; +use tracing::{ + debug, + error, +}; +use util::telemetry_is_disabled; +use uuid::Uuid; + +use crate::fig_api_client::Client as CodewhispererClient; +use crate::fig_aws_common::app_name; +use crate::fig_settings::State; +pub use crate::fig_telemetry_core::{ + EventType, + QProfileSwitchIntent, + SuggestionState, + TelemetryEmitter, + TelemetryResult, +}; +use crate::fig_util::system_info::os_version; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("Telemetry is disabled")] + TelemetryDisabled, + #[error(transparent)] + ClientError(#[from] amzn_toolkit_telemetry_client::operation::post_metrics::PostMetricsError), +} + +const PRODUCT: &str = "CodeWhisperer"; +const PRODUCT_VERSION: &str = env!("CARGO_PKG_VERSION"); + +async fn client() -> &'static Client { + static CLIENT: OnceCell = OnceCell::const_new(); + CLIENT + .get_or_init(|| async { Client::new(TelemetryStage::EXTERNAL_PROD).await }) + .await +} + +/// A telemetry emitter that first tries sending the event to figterm so that the CLI commands can +/// execute much quicker. Only falls back to sending it directly on the current task if sending to +/// figterm fails. +struct DispatchingTelemetryEmitter; + +#[async_trait::async_trait] +impl TelemetryEmitter for DispatchingTelemetryEmitter { + async fn send(&self, event: crate::fig_telemetry_core::Event) { + let event = AppTelemetryEvent::from_event(event).await; + send_event(event).await; + } + + fn as_any(&self) -> &dyn Any { + self + } +} + +pub fn init_global_telemetry_emitter() { + crate::fig_telemetry_core::init_global_telemetry_emitter(DispatchingTelemetryEmitter {}); +} + +/// A IDE toolkit telemetry stage +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct TelemetryStage { + pub name: &'static str, + pub endpoint: &'static str, + pub cognito_pool_id: &'static str, + pub region: Region, +} + +impl TelemetryStage { + #[allow(dead_code)] + const BETA: Self = Self::new( + "beta", + "https://7zftft3lj2.execute-api.us-east-1.amazonaws.com/Beta", + "us-east-1:db7bfc9f-8ecd-4fbb-bea7-280c16069a99", + "us-east-1", + ); + const EXTERNAL_PROD: Self = Self::new( + "prod", + "https://client-telemetry.us-east-1.amazonaws.com", + "us-east-1:820fd6d1-95c0-4ca4-bffb-3f01d32da842", + "us-east-1", + ); + + const fn new( + name: &'static str, + endpoint: &'static str, + cognito_pool_id: &'static str, + region: &'static str, + ) -> Self { + Self { + name, + endpoint, + cognito_pool_id, + region: Region::from_static(region), + } + } +} + +static JOIN_SET: LazyLock>> = LazyLock::new(|| Mutex::new(JoinSet::new())); + +/// Joins all current telemetry events +pub async fn finish_telemetry() { + let mut set = JOIN_SET.lock().await; + while let Some(res) = set.join_next().await { + if let Err(err) = res { + error!(%err, "Failed to join telemetry event"); + } + } +} + +/// Joins all current telemetry events and panics if any fail to join +pub async fn finish_telemetry_unwrap() { + let mut set = JOIN_SET.lock().await; + while let Some(res) = set.join_next().await { + res.unwrap(); + } +} + +fn opt_out_preference() -> OptOutPreference { + if telemetry_is_disabled() { + OptOutPreference::OptOut + } else { + OptOutPreference::OptIn + } +} + +#[derive(Debug, Clone)] +pub struct Client { + client_id: Uuid, + toolkit_telemetry_client: Option, + codewhisperer_client: Option, + state: State, +} + +impl Client { + pub async fn new(telemetry_stage: TelemetryStage) -> Self { + let client_id = util::get_client_id(); + let toolkit_telemetry_client = Some(amzn_toolkit_telemetry_client::Client::from_conf( + Config::builder() + .http_client(crate::fig_aws_common::http_client::client()) + .behavior_version(BehaviorVersion::v2025_01_17()) + .endpoint_resolver(StaticEndpoint(telemetry_stage.endpoint)) + .app_name(app_name()) + .region(telemetry_stage.region.clone()) + .credentials_provider(SharedCredentialsProvider::new(CognitoProvider::new(telemetry_stage))) + .build(), + )); + let codewhisperer_client = CodewhispererClient::new().await.ok(); + let state = State::new(); + + Self { + client_id, + toolkit_telemetry_client, + codewhisperer_client, + state, + } + } + + pub fn mock() -> Self { + let client_id = util::get_client_id(); + let toolkit_telemetry_client = None; + let codewhisperer_client = Some(CodewhispererClient::mock()); + let state = State::new_fake(); + + Self { + client_id, + toolkit_telemetry_client, + codewhisperer_client, + state, + } + } + + async fn send_event(&self, event: AppTelemetryEvent) { + self.send_cw_telemetry_event(&event).await; + self.send_telemetry_toolkit_metric(event).await; + } + + async fn send_telemetry_toolkit_metric(&self, event: AppTelemetryEvent) { + if telemetry_is_disabled() { + return; + } + let Some(toolkit_telemetry_client) = self.toolkit_telemetry_client.clone() else { + return; + }; + let client_id = self.client_id; + let Some(metric_datum) = event.into_metric_datum() else { + return; + }; + + let mut set = JOIN_SET.lock().await; + set.spawn({ + async move { + let product = AwsProduct::CodewhispererTerminal; + let product_version = env!("CARGO_PKG_VERSION"); + let os = std::env::consts::OS; + let os_architecture = std::env::consts::ARCH; + let os_version = os_version().map(|v| v.to_string()).unwrap_or_default(); + let metric_name = metric_datum.metric_name().to_owned(); + + debug!(?product, ?metric_datum, "Posting metrics"); + if let Err(err) = toolkit_telemetry_client + .post_metrics() + .aws_product(product) + .aws_product_version(product_version) + .client_id(client_id) + .os(os) + .os_architecture(os_architecture) + .os_version(os_version) + .metric_data(metric_datum) + .send() + .await + .map_err(DisplayErrorContext) + { + error!(%err, ?metric_name, "Failed to post metric"); + } + } + }); + } + + async fn send_cw_telemetry_event(&self, event: &AppTelemetryEvent) { + match &event.ty { + EventType::ChatAddedMessage { + conversation_id, + message_id, + .. + } => { + self.send_cw_telemetry_chat_add_message_event(conversation_id.clone(), message_id.clone()) + .await; + }, + _ => {}, + } + } + + fn user_context(&self) -> Option { + let operating_system = match std::env::consts::OS { + "linux" => OperatingSystem::Linux, + "macos" => OperatingSystem::Mac, + "windows" => OperatingSystem::Windows, + os => { + error!(%os, "Unsupported operating system"); + return None; + }, + }; + + match UserContext::builder() + .client_id(self.client_id.hyphenated().to_string()) + .operating_system(operating_system) + .product(PRODUCT) + .ide_category(IdeCategory::Cli) + .ide_version(PRODUCT_VERSION) + .build() + { + Ok(user_context) => Some(user_context), + Err(err) => { + error!(%err, "Failed to build user context"); + None + }, + } + } + + async fn send_cw_telemetry_translation_action( + &self, + latency: Duration, + suggestion_state: SuggestionState, + terminal: Option, + terminal_version: Option, + shell: Option, + shell_version: Option, + ) { + let Some(codewhisperer_client) = self.codewhisperer_client.clone() else { + return; + }; + let user_context = self.user_context().unwrap(); + let opt_out_preference = opt_out_preference(); + + let mut set = JOIN_SET.lock().await; + set.spawn(async move { + let mut terminal_user_interaction_event_builder = TerminalUserInteractionEvent::builder() + .terminal_user_interaction_event_type( + TerminalUserInteractionEventType::CodewhispererTerminalTranslationAction, + ) + .time_to_suggestion(latency.as_millis() as i32) + .is_completion_accepted(suggestion_state == SuggestionState::Accept); + + if let Some(terminal) = terminal { + terminal_user_interaction_event_builder = terminal_user_interaction_event_builder.terminal(terminal); + } + + if let Some(terminal_version) = terminal_version { + terminal_user_interaction_event_builder = + terminal_user_interaction_event_builder.terminal_version(terminal_version); + } + + if let Some(shell) = shell { + terminal_user_interaction_event_builder = terminal_user_interaction_event_builder.shell(shell); + } + + if let Some(shell_version) = shell_version { + terminal_user_interaction_event_builder = + terminal_user_interaction_event_builder.shell_version(shell_version); + } + + let terminal_user_interaction_event = terminal_user_interaction_event_builder.build(); + + if let Err(err) = codewhisperer_client + .send_telemetry_event( + TelemetryEvent::TerminalUserInteractionEvent(terminal_user_interaction_event), + user_context, + opt_out_preference, + ) + .await + { + error!(err =% DisplayErrorContext(err), "Failed to send telemetry event"); + } + }); + } + + async fn send_cw_telemetry_completion_inserted( + &self, + command: String, + terminal: Option, + shell: Option, + ) { + let Some(codewhisperer_client) = self.codewhisperer_client.clone() else { + return; + }; + let user_context = self.user_context().unwrap(); + let opt_out_preference = opt_out_preference(); + + let mut set = JOIN_SET.lock().await; + set.spawn(async move { + let mut terminal_user_interaction_event_builder = TerminalUserInteractionEvent::builder() + .terminal_user_interaction_event_type( + TerminalUserInteractionEventType::CodewhispererTerminalCompletionInserted, + ) + .cli_tool_command(command); + + if let Some(terminal) = terminal { + terminal_user_interaction_event_builder = terminal_user_interaction_event_builder.terminal(terminal); + } + + if let Some(shell) = shell { + terminal_user_interaction_event_builder = terminal_user_interaction_event_builder.shell(shell); + } + + let terminal_user_interaction_event = terminal_user_interaction_event_builder.build(); + + if let Err(err) = codewhisperer_client + .send_telemetry_event( + TelemetryEvent::TerminalUserInteractionEvent(terminal_user_interaction_event), + user_context, + opt_out_preference, + ) + .await + { + error!(err =% DisplayErrorContext(err), "Failed to send telemetry event"); + } + }); + } + + async fn send_cw_telemetry_chat_add_message_event(&self, conversation_id: String, message_id: String) { + let Some(codewhisperer_client) = self.codewhisperer_client.clone() else { + return; + }; + let user_context = self.user_context().unwrap(); + let opt_out_preference = opt_out_preference(); + + let chat_add_message_event = match ChatAddMessageEvent::builder() + .conversation_id(conversation_id) + .message_id(message_id) + .build() + { + Ok(event) => event, + Err(err) => { + error!(err =% DisplayErrorContext(err), "Failed to send telemetry event"); + return; + }, + }; + + let mut set = JOIN_SET.lock().await; + set.spawn(async move { + if let Err(err) = codewhisperer_client + .send_telemetry_event( + TelemetryEvent::ChatAddMessageEvent(chat_add_message_event), + user_context, + opt_out_preference, + ) + .await + { + error!(err =% DisplayErrorContext(err), "Failed to send telemetry event"); + } + }); + } + + /// This is the user decision to accept a suggestion for inline suggestions + async fn send_cw_telemetry_user_trigger_decision_event( + &self, + session_id: String, + request_id: String, + latency: Duration, + accepted: bool, + suggested_chars_len: i32, + number_of_recommendations: i32, + ) { + let Some(codewhisperer_client) = self.codewhisperer_client.clone() else { + return; + }; + let user_context = self.user_context().unwrap(); + let opt_out_preference = opt_out_preference(); + + let programming_language = match ProgrammingLanguage::builder().language_name("shell").build() { + Ok(language) => language, + Err(err) => { + error!(err =% DisplayErrorContext(err), "Failed to build programming language"); + return; + }, + }; + + let suggestion_state = if accepted { + SuggestionState::Accept + } else { + SuggestionState::Reject + }; + + let user_trigger_decision_event = match UserTriggerDecisionEvent::builder() + .session_id(session_id) + .request_id(request_id) + .programming_language(programming_language) + .completion_type(CompletionType::Line) + .suggestion_state(suggestion_state.into()) + .accepted_character_count(if accepted { suggested_chars_len } else { 0 }) + .number_of_recommendations(number_of_recommendations) + .generated_line(1) + .recommendation_latency_milliseconds(latency.as_secs_f64() * 1000.0) + .timestamp(DateTime::from(SystemTime::now())) + .build() + { + Ok(event) => event, + Err(err) => { + error!(err =% DisplayErrorContext(err), "Failed to build user trigger decision event"); + return; + }, + }; + + let mut set = JOIN_SET.lock().await; + set.spawn(async move { + if let Err(err) = codewhisperer_client + .send_telemetry_event( + TelemetryEvent::UserTriggerDecisionEvent(user_trigger_decision_event), + user_context, + opt_out_preference, + ) + .await + { + error!(err =% DisplayErrorContext(err), "Failed to send telemetry event"); + } + }); + } +} + +pub async fn send_event(event: AppTelemetryEvent) { + client().await.send_event(event).await; +} + +pub async fn send_user_logged_in() { + let event = AppTelemetryEvent::new(EventType::UserLoggedIn {}).await; + send_event(event).await; +} + +pub async fn send_cli_subcommand_executed(subcommand: impl Into) { + let event = AppTelemetryEvent::new(EventType::CliSubcommandExecuted { + subcommand: subcommand.into(), + }) + .await; + send_event(event).await; +} + +pub async fn send_start_chat(conversation_id: String) { + let event = AppTelemetryEvent::new(EventType::ChatStart { conversation_id }).await; + send_event(event).await; +} + +pub async fn send_end_chat(conversation_id: String) { + let event = AppTelemetryEvent::new(EventType::ChatEnd { conversation_id }).await; + send_event(event).await; +} + +pub async fn send_chat_added_message(conversation_id: String, message_id: String, context_file_length: Option) { + let event = AppTelemetryEvent::new(EventType::ChatAddedMessage { + conversation_id, + message_id, + context_file_length, + }) + .await; + send_event(event).await; +} + +pub async fn send_did_select_profile( + source: QProfileSwitchIntent, + amazonq_profile_region: String, + result: TelemetryResult, + sso_region: Option, + profile_count: Option, +) { + let event = AppTelemetryEvent::new(EventType::DidSelectProfile { + source, + amazonq_profile_region, + result, + sso_region, + profile_count, + }) + .await; + send_event(event).await; +} + +pub async fn send_profile_state( + source: QProfileSwitchIntent, + amazonq_profile_region: String, + result: TelemetryResult, + sso_region: Option, +) { + let event = AppTelemetryEvent::new(EventType::ProfileState { + source, + amazonq_profile_region, + result, + sso_region, + }) + .await; + send_event(event).await; +} + +#[cfg(test)] +mod test { + use event::tests::all_events; + use uuid::uuid; + + use super::*; + + #[tokio::test] + async fn client_context() { + let client = client().await; + let context = client.user_context().unwrap(); + + assert_eq!(context.ide_category, IdeCategory::Cli); + assert!(matches!( + context.operating_system, + OperatingSystem::Linux | OperatingSystem::Mac | OperatingSystem::Windows + )); + assert_eq!(context.product, PRODUCT); + assert_eq!( + context.client_id, + Some(uuid!("ffffffff-ffff-ffff-ffff-ffffffffffff").hyphenated().to_string()) + ); + assert_eq!(context.ide_version.as_deref(), Some(PRODUCT_VERSION)); + } + + #[tokio::test] + async fn client_send_event_test() { + let client = Client::mock(); + for event in all_events().await { + client.send_event(event).await; + } + } + + #[tracing_test::traced_test] + #[tokio::test] + #[ignore = "needs auth which is not in CI"] + async fn test_send() { + // let (shell, shell_version) = Shell::current_shell_version() + // .await + // .map(|(shell, shell_version)| (Some(shell), Some(shell_version))) + // .unwrap_or((None, None)); + + // let client = Client::new(TelemetryStage::BETA).await; + + // client + // .post_metric(metrics::CodewhispererterminalCliSubcommandExecuted { + // create_time: None, + // value: None, + // codewhispererterminal_subcommand: Some(CodewhispererterminalSubcommand("doctor".into())), + // codewhispererterminal_terminal: CURRENT_TERMINAL + // .clone() + // .map(|terminal| CodewhispererterminalTerminal(terminal.internal_id().to_string())), + // codewhispererterminal_terminal_version: CURRENT_TERMINAL_VERSION + // .clone() + // .map(CodewhispererterminalTerminalVersion), + // codewhispererterminal_shell: shell.map(|shell| + // CodewhispererterminalShell(shell.to_string())), + // codewhispererterminal_shell_version: + // shell_version.map(CodewhispererterminalShellVersion), credential_start_url: + // start_url().await, }) + // .await; + + finish_telemetry_unwrap().await; + + assert!(!logs_contain("ERROR")); + assert!(!logs_contain("error")); + assert!(!logs_contain("WARN")); + assert!(!logs_contain("warn")); + assert!(!logs_contain("Failed to post metric")); + } + + #[tracing_test::traced_test] + #[tokio::test] + #[ignore = "needs auth which is not in CI"] + async fn test_all_telemetry() { + send_user_logged_in().await; + send_cli_subcommand_executed("doctor").await; + send_chat_added_message("debug".to_owned(), "debug".to_owned(), Some(123)).await; + + finish_telemetry_unwrap().await; + + assert!(!logs_contain("ERROR")); + assert!(!logs_contain("error")); + assert!(!logs_contain("WARN")); + assert!(!logs_contain("warn")); + assert!(!logs_contain("Failed to post metric")); + } + + #[tokio::test] + #[ignore = "needs auth which is not in CI"] + async fn test_without_optout() { + let client = Client::new(TelemetryStage::BETA).await; + client + .codewhisperer_client + .as_ref() + .unwrap() + .send_telemetry_event( + TelemetryEvent::ChatAddMessageEvent( + ChatAddMessageEvent::builder() + .conversation_id("debug".to_owned()) + .message_id("debug".to_owned()) + .build() + .unwrap(), + ), + client.user_context().unwrap(), + OptOutPreference::OptIn, + ) + .await + .unwrap(); + } +} diff --git a/crates/kiro-cli/src/fig_telemetry/util.rs b/crates/kiro-cli/src/fig_telemetry/util.rs new file mode 100644 index 0000000000..6a725abbb6 --- /dev/null +++ b/crates/kiro-cli/src/fig_telemetry/util.rs @@ -0,0 +1,162 @@ +use std::str::FromStr; + +use tracing::error; +use uuid::{ + Uuid, + uuid, +}; + +use crate::fig_os_shim::Env; +use crate::fig_settings::{ + Settings, + State, +}; + +const CLIENT_ID_STATE_KEY: &str = "telemetryClientId"; +const CLIENT_ID_ENV_VAR: &str = "Q_TELEMETRY_CLIENT_ID"; + +pub(crate) fn telemetry_is_disabled() -> bool { + let is_test = cfg!(test); + telemetry_is_disabled_inner(is_test, &Env::new(), &Settings::new()) +} + +/// Returns whether or not the user has disabled telemetry through settings or environment +fn telemetry_is_disabled_inner(is_test: bool, env: &Env, settings: &Settings) -> bool { + let env_var = env.get_os("Q_DISABLE_TELEMETRY").is_some(); + let setting = !settings + .get_value("telemetry.enabled") + .ok() + .flatten() + .and_then(|v| v.as_bool()) + .unwrap_or(true); + !is_test && (env_var || setting) +} + +pub(crate) fn get_client_id() -> Uuid { + get_client_id_inner(cfg!(test), &Env::new(), &State::new(), &Settings::new()) +} + +/// Generates or gets the client id and caches the result +/// +/// Based on: +pub(crate) fn get_client_id_inner(is_test: bool, env: &Env, state: &State, settings: &Settings) -> Uuid { + if is_test { + return uuid!("ffffffff-ffff-ffff-ffff-ffffffffffff"); + } + + if telemetry_is_disabled_inner(is_test, env, settings) { + return uuid!("11111111-1111-1111-1111-111111111111"); + } + + if let Ok(client_id) = env.get(CLIENT_ID_ENV_VAR) { + if let Ok(uuid) = Uuid::from_str(&client_id) { + return uuid; + } + } + + let state_uuid = state + .get_string(CLIENT_ID_STATE_KEY) + .ok() + .flatten() + .and_then(|s| Uuid::from_str(&s).ok()); + + match state_uuid { + Some(uuid) => uuid, + None => { + let uuid = old_client_id_inner(settings).unwrap_or_else(Uuid::new_v4); + if let Err(err) = state.set_value(CLIENT_ID_STATE_KEY, uuid.to_string()) { + error!(%err, "Failed to set client id in state"); + } + uuid + }, + } +} + +/// We accidently generates some clientIds in the settings file, we want to include those in the +/// telemetry events so we corolate those users with the correct clientIds +fn old_client_id_inner(settings: &Settings) -> Option { + settings + .get_string(CLIENT_ID_STATE_KEY) + .ok() + .flatten() + .and_then(|s| Uuid::from_str(&s).ok()) +} + +pub(crate) fn old_client_id() -> Option { + old_client_id_inner(&Settings::new()) +} + +#[cfg(test)] +mod tests { + use super::*; + + const TEST_UUID_STR: &str = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"; + const TEST_UUID: Uuid = uuid!(TEST_UUID_STR); + + #[test] + fn test_is_telemetry_disabled() { + // disabled by default in tests + // let is_disabled = telemetry_is_disabled(); + // assert!(!is_disabled); + + // let settings = Settings::new_fake(); + + // let env = Env::from_slice(&[("Q_DISABLE_TELEMETRY", "1")]); + // assert!(telemetry_is_disabled_inner(true, &env, &settings)); + // assert!(telemetry_is_disabled_inner(false, &env, &settings)); + + // let env = Env::new_fake(); + // assert!(telemetry_is_disabled_inner(true, &env, &settings)); + // assert!(!telemetry_is_disabled_inner(false, &env, &settings)); + + // settings.set_value("telemetry.enabled", false).unwrap(); + // assert!(telemetry_is_disabled_inner(false, &env, &settings)); + // assert!(!telemetry_is_disabled_inner(true, &env, &settings)); + + // settings.set_value("telemetry.enabled", true).unwrap(); + // assert!(!telemetry_is_disabled_inner(false, &env, &settings)); + // assert!(!telemetry_is_disabled_inner(true, &env, &settings)); + } + + #[test] + fn test_get_client_id() { + // max by default in tests + let id = get_client_id(); + assert!(id.is_max()); + + let state = State::new_fake(); + let settings = Settings::new_fake(); + + let env = Env::from_slice(&[(CLIENT_ID_ENV_VAR, TEST_UUID_STR)]); + assert_eq!(get_client_id_inner(false, &env, &state, &settings), TEST_UUID); + + let env = Env::new_fake(); + + // in tests returns the test uuid + assert!(get_client_id_inner(true, &env, &state, &settings).is_max()); + + // returns the currently set client id if one is found + state.set_value(CLIENT_ID_STATE_KEY, TEST_UUID_STR).unwrap(); + assert_eq!(get_client_id_inner(false, &env, &state, &settings), TEST_UUID); + + // generates a new client id if none is found + state.remove_value(CLIENT_ID_STATE_KEY).unwrap(); + assert_eq!( + get_client_id_inner(false, &env, &state, &settings).to_string(), + state.get_string(CLIENT_ID_STATE_KEY).unwrap().unwrap() + ); + + // migrates the client id in settings + state.remove_value(CLIENT_ID_STATE_KEY).unwrap(); + settings.set_value(CLIENT_ID_STATE_KEY, TEST_UUID_STR).unwrap(); + assert_eq!(get_client_id_inner(false, &env, &state, &settings), TEST_UUID); + } + + #[test] + fn test_get_client_id_old() { + let settings = Settings::new_fake(); + assert!(old_client_id_inner(&settings).is_none()); + settings.set_value(CLIENT_ID_STATE_KEY, TEST_UUID_STR).unwrap(); + assert_eq!(old_client_id_inner(&settings), Some(TEST_UUID)); + } +} diff --git a/crates/kiro-cli/src/fig_telemetry_core.rs b/crates/kiro-cli/src/fig_telemetry_core.rs new file mode 100644 index 0000000000..63cdb98bf0 --- /dev/null +++ b/crates/kiro-cli/src/fig_telemetry_core.rs @@ -0,0 +1,418 @@ +use std::any::Any; +use std::sync::OnceLock; +use std::time::SystemTime; + +pub use amzn_toolkit_telemetry_client::types::MetricDatum; +use strum::{ + Display, + EnumString, +}; + +use crate::fig_telemetry::definitions::IntoMetricDatum; +use crate::fig_telemetry::definitions::metrics::{ + AmazonqDidSelectProfile, + AmazonqEndChat, + AmazonqProfileState, + AmazonqStartChat, + CodewhispererterminalAddChatMessage, + CodewhispererterminalCliSubcommandExecuted, + CodewhispererterminalMcpServerInit, + CodewhispererterminalRefreshCredentials, + CodewhispererterminalToolUseSuggested, + CodewhispererterminalUserLoggedIn, +}; +use crate::fig_telemetry::definitions::types::{ + CodewhispererterminalCustomToolInputTokenSize, + CodewhispererterminalCustomToolLatency, + CodewhispererterminalCustomToolOutputTokenSize, + CodewhispererterminalInCloudshell, + CodewhispererterminalIsToolValid, + CodewhispererterminalMcpServerInitFailureReason, + CodewhispererterminalToolName, + CodewhispererterminalToolUseId, + CodewhispererterminalToolUseIsSuccess, + CodewhispererterminalToolsPerMcpServer, + CodewhispererterminalUserInputId, + CodewhispererterminalUtteranceId, +}; + +type GlobalTelemetryEmitter = dyn TelemetryEmitter + Send + Sync + 'static; + +/// Global telemetry emitter for the current process. +static EMITTER: OnceLock> = OnceLock::new(); + +pub fn init_global_telemetry_emitter(telemetry_emitter: T) +where + T: TelemetryEmitter + Send + Sync + 'static, +{ + match EMITTER.set(Box::new(telemetry_emitter)) { + Ok(_) => (), + Err(_) => panic!("The global telemetry emitter can only be initialized once"), + } +} + +/// Sends the telemetry event through the global [TelemetryEmitter] as set by +/// [init_global_telemetry_emitter], returning [None] if no telemetry emitter was set. +pub async fn send_event(event: Event) -> Option<()> { + if let Some(emitter) = EMITTER.get() { + emitter.send(event).await; + Some(()) + } else { + None + } +} + +/// Trait to handle sending telemetry events. This is intended to be used globally within the +/// application, and can be set using [init_global_telemetry_emitter]. Only one global +/// [TelemetryEmitter] impl should exist. +/// +/// TODO: Update all telemetry calls to go through the global [TelemetryEmitter] impl instead. +#[async_trait::async_trait] +pub trait TelemetryEmitter { + async fn send(&self, event: Event); + + fn as_any(&self) -> &dyn Any; +} + +/// A serializable telemetry event that can be sent or queued. +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Event { + pub created_time: Option, + pub credential_start_url: Option, + #[serde(flatten)] + pub ty: EventType, +} + +impl Event { + pub fn new(ty: EventType) -> Self { + Self { + ty, + created_time: Some(SystemTime::now()), + credential_start_url: None, + } + } + + pub fn with_credential_start_url(mut self, credential_start_url: String) -> Self { + self.credential_start_url = Some(credential_start_url); + self + } + + pub fn into_metric_datum(self) -> Option { + match self.ty { + EventType::UserLoggedIn {} => Some( + CodewhispererterminalUserLoggedIn { + create_time: self.created_time, + value: None, + credential_start_url: self.credential_start_url.map(Into::into), + codewhispererterminal_in_cloudshell: in_cloudshell(), + } + .into_metric_datum(), + ), + EventType::RefreshCredentials { + request_id, + result, + reason, + oauth_flow, + } => Some( + CodewhispererterminalRefreshCredentials { + create_time: self.created_time, + value: None, + credential_start_url: self.credential_start_url.map(Into::into), + request_id: Some(request_id.into()), + result: Some(result.to_string().into()), + reason: reason.map(Into::into), + oauth_flow: Some(oauth_flow.into()), + codewhispererterminal_in_cloudshell: in_cloudshell(), + } + .into_metric_datum(), + ), + EventType::CliSubcommandExecuted { subcommand } => Some( + CodewhispererterminalCliSubcommandExecuted { + create_time: self.created_time, + value: None, + credential_start_url: self.credential_start_url.map(Into::into), + codewhispererterminal_subcommand: Some(subcommand.into()), + codewhispererterminal_in_cloudshell: in_cloudshell(), + } + .into_metric_datum(), + ), + EventType::ChatStart { conversation_id } => Some( + AmazonqStartChat { + create_time: self.created_time, + value: None, + credential_start_url: self.credential_start_url.map(Into::into), + amazonq_conversation_id: Some(conversation_id.into()), + codewhispererterminal_in_cloudshell: in_cloudshell(), + } + .into_metric_datum(), + ), + EventType::ChatEnd { conversation_id } => Some( + AmazonqEndChat { + create_time: self.created_time, + value: None, + credential_start_url: self.credential_start_url.map(Into::into), + amazonq_conversation_id: Some(conversation_id.into()), + codewhispererterminal_in_cloudshell: in_cloudshell(), + } + .into_metric_datum(), + ), + EventType::ChatAddedMessage { + conversation_id, + context_file_length, + .. + } => Some( + CodewhispererterminalAddChatMessage { + create_time: self.created_time, + value: None, + amazonq_conversation_id: Some(conversation_id.into()), + credential_start_url: self.credential_start_url.map(Into::into), + codewhispererterminal_in_cloudshell: in_cloudshell(), + codewhispererterminal_context_file_length: context_file_length.map(|l| l as i64).map(Into::into), + } + .into_metric_datum(), + ), + EventType::ToolUseSuggested { + conversation_id, + utterance_id, + user_input_id, + tool_use_id, + tool_name, + is_accepted, + is_valid, + is_success, + is_custom_tool, + input_token_size, + output_token_size, + custom_tool_call_latency, + } => Some( + CodewhispererterminalToolUseSuggested { + create_time: self.created_time, + credential_start_url: self.credential_start_url.map(Into::into), + value: None, + amazonq_conversation_id: Some(conversation_id.into()), + codewhispererterminal_utterance_id: utterance_id.map(CodewhispererterminalUtteranceId), + codewhispererterminal_user_input_id: user_input_id.map(CodewhispererterminalUserInputId), + codewhispererterminal_tool_use_id: tool_use_id.map(CodewhispererterminalToolUseId), + codewhispererterminal_tool_name: tool_name.map(CodewhispererterminalToolName), + codewhispererterminal_is_tool_use_accepted: Some(is_accepted.into()), + codewhispererterminal_is_tool_valid: is_valid.map(CodewhispererterminalIsToolValid), + codewhispererterminal_tool_use_is_success: is_success.map(CodewhispererterminalToolUseIsSuccess), + codewhispererterminal_is_custom_tool: Some(is_custom_tool.into()), + codewhispererterminal_custom_tool_input_token_size: input_token_size + .map(|s| CodewhispererterminalCustomToolInputTokenSize(s as i64)), + codewhispererterminal_custom_tool_output_token_size: output_token_size + .map(|s| CodewhispererterminalCustomToolOutputTokenSize(s as i64)), + codewhispererterminal_custom_tool_latency: custom_tool_call_latency + .map(|l| CodewhispererterminalCustomToolLatency(l as i64)), + } + .into_metric_datum(), + ), + EventType::McpServerInit { + conversation_id, + init_failure_reason, + number_of_tools, + } => Some( + CodewhispererterminalMcpServerInit { + create_time: self.created_time, + value: None, + amazonq_conversation_id: Some(conversation_id.into()), + codewhispererterminal_mcp_server_init_failure_reason: init_failure_reason + .map(CodewhispererterminalMcpServerInitFailureReason), + codewhispererterminal_tools_per_mcp_server: Some(CodewhispererterminalToolsPerMcpServer( + number_of_tools as i64, + )), + } + .into_metric_datum(), + ), + EventType::DidSelectProfile { + source, + amazonq_profile_region, + result, + sso_region, + profile_count, + } => Some( + AmazonqDidSelectProfile { + create_time: self.created_time, + value: None, + source: Some(source.to_string().into()), + amazon_q_profile_region: Some(amazonq_profile_region.into()), + result: Some(result.to_string().into()), + sso_region: sso_region.map(Into::into), + credential_start_url: self.credential_start_url.map(Into::into), + profile_count: profile_count.map(Into::into), + } + .into_metric_datum(), + ), + EventType::ProfileState { + source, + amazonq_profile_region, + result, + sso_region, + } => Some( + AmazonqProfileState { + create_time: self.created_time, + value: None, + source: Some(source.to_string().into()), + amazon_q_profile_region: Some(amazonq_profile_region.into()), + result: Some(result.to_string().into()), + sso_region: sso_region.map(Into::into), + credential_start_url: self.credential_start_url.map(Into::into), + } + .into_metric_datum(), + ), + } + } +} + +#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +#[serde(tag = "type")] +pub enum EventType { + UserLoggedIn {}, + RefreshCredentials { + request_id: String, + result: TelemetryResult, + reason: Option, + oauth_flow: String, + }, + CliSubcommandExecuted { + subcommand: String, + }, + ChatStart { + conversation_id: String, + }, + ChatEnd { + conversation_id: String, + }, + ChatAddedMessage { + conversation_id: String, + message_id: String, + context_file_length: Option, + }, + ToolUseSuggested { + conversation_id: String, + utterance_id: Option, + user_input_id: Option, + tool_use_id: Option, + tool_name: Option, + is_accepted: bool, + is_success: Option, + is_valid: Option, + is_custom_tool: bool, + input_token_size: Option, + output_token_size: Option, + custom_tool_call_latency: Option, + }, + McpServerInit { + conversation_id: String, + init_failure_reason: Option, + number_of_tools: usize, + }, + DidSelectProfile { + source: QProfileSwitchIntent, + amazonq_profile_region: String, + result: TelemetryResult, + sso_region: Option, + profile_count: Option, + }, + ProfileState { + source: QProfileSwitchIntent, + amazonq_profile_region: String, + result: TelemetryResult, + sso_region: Option, + }, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)] +pub enum SuggestionState { + Accept, + Discard, + Empty, + Reject, +} + +impl SuggestionState { + pub fn is_accepted(&self) -> bool { + matches!(self, SuggestionState::Accept) + } +} + +impl From for amzn_codewhisperer_client::types::SuggestionState { + fn from(value: SuggestionState) -> Self { + match value { + SuggestionState::Accept => amzn_codewhisperer_client::types::SuggestionState::Accept, + SuggestionState::Discard => amzn_codewhisperer_client::types::SuggestionState::Discard, + SuggestionState::Empty => amzn_codewhisperer_client::types::SuggestionState::Empty, + SuggestionState::Reject => amzn_codewhisperer_client::types::SuggestionState::Reject, + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, EnumString, Display, serde::Serialize, serde::Deserialize)] +pub enum TelemetryResult { + Succeeded, + Failed, + Cancelled, +} + +/// 'user' -> users change the profile through Q CLI user profile command +/// 'auth' -> users change the profile through dashboard +/// 'update' -> CLI auto select the profile on users' behalf as there is only 1 profile +/// 'reload' -> CLI will try to reload previous selected profile upon CLI is running +#[derive(Debug, Copy, Clone, PartialEq, Eq, EnumString, Display, serde::Serialize, serde::Deserialize)] +pub enum QProfileSwitchIntent { + User, + Auth, + Update, + Reload, +} + +fn in_cloudshell() -> Option { + Some(crate::fig_util::system_info::in_cloudshell().into()) +} + +#[cfg(test)] +mod tests { + use std::sync::Mutex; + + use super::*; + + #[derive(Debug, Default)] + struct DummyEmitter(Mutex>); + + #[async_trait::async_trait] + impl TelemetryEmitter for DummyEmitter { + async fn send(&self, event: Event) { + self.0.lock().unwrap().push(event); + } + + fn as_any(&self) -> &dyn Any { + self + } + } + + #[tokio::test] + async fn test_init_global_telemetry_emitter_receives_event() { + init_global_telemetry_emitter(DummyEmitter::default()); + send_event(Event::new(EventType::UserLoggedIn {})).await; + + let events = EMITTER + .get() + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .0 + .lock() + .unwrap(); + assert!(events.len() == 1); + assert!(matches!(events.first().unwrap().ty, EventType::UserLoggedIn {})); + } + + #[ignore = "depends on test_init_global_telemetry_emitter_receives_event not being ran"] + #[tokio::test] + async fn test_no_global_telemetry_emitter() { + assert!(send_event(Event::new(EventType::UserLoggedIn {})).await.is_none()); + } +} diff --git a/crates/kiro-cli/src/fig_util/cli_context.rs b/crates/kiro-cli/src/fig_util/cli_context.rs new file mode 100644 index 0000000000..f13382d389 --- /dev/null +++ b/crates/kiro-cli/src/fig_util/cli_context.rs @@ -0,0 +1,58 @@ +use std::sync::Arc; + +use crate::fig_os_shim::Context; +use crate::fig_settings::{ + Settings, + State, +}; + +#[derive(Debug, Clone)] +pub struct CliContext { + settings: Settings, + state: State, + context: Arc, +} + +impl Default for CliContext { + fn default() -> Self { + Self::new() + } +} + +impl CliContext { + pub fn new() -> Self { + let settings = Settings::new(); + let state = State::new(); + let context = Context::new(); + + Self { + settings, + state, + context, + } + } + + pub fn new_fake() -> Self { + let settings = Settings::new_fake(); + let state = State::new_fake(); + let context = Context::new_fake(); + + Self { + settings, + state, + context, + } + } + + pub fn settings(&self) -> &Settings { + &self.settings + } + + pub fn state(&self) -> &State { + &self.state + } + + pub fn context(&self) -> &Context { + &self.context + } +} diff --git a/crates/kiro-cli/src/fig_util/consts.rs b/crates/kiro-cli/src/fig_util/consts.rs new file mode 100644 index 0000000000..86d26f5d53 --- /dev/null +++ b/crates/kiro-cli/src/fig_util/consts.rs @@ -0,0 +1,139 @@ +#[cfg(windows)] +pub const APP_PROCESS_NAME: &str = "q_desktop.exe"; + +/// The name configured under `"package.productName"` in the tauri.conf.json file. +pub const TAURI_PRODUCT_NAME: &str = "q_desktop"; + +pub const CLI_BINARY_NAME: &str = "q"; +pub const CLI_BINARY_NAME_MINIMAL: &str = "q-minimal"; +pub const PTY_BINARY_NAME: &str = "qterm"; + +pub const CLI_CRATE_NAME: &str = "q_cli"; + +pub const URL_SCHEMA: &str = "q"; + +pub const PRODUCT_NAME: &str = "Amazon Q"; + +pub const RUNTIME_DIR_NAME: &str = "cwrun"; + +// These are the old "CodeWhisperer" branding, used anywhere we will not update to Amazon Q +pub const OLD_CLI_BINARY_NAMES: &[&str] = &["cw"]; +pub const OLD_PTY_BINARY_NAMES: &[&str] = &["cwterm"]; + +pub const GITHUB_REPO_NAME: &str = "aws/amazon-q-developer-cli"; + +/// Build time env vars +pub mod build { + /// The target of the current build, e.g. "aarch64-unknown-linux-musl" + pub const TARGET_TRIPLE: Option<&str> = option_env!("AMAZON_Q_BUILD_TARGET_TRIPLE"); + + /// The variant of the current build + pub const VARIANT: Option<&str> = option_env!("AMAZON_Q_BUILD_VARIANT"); + + /// A git full sha hash of the current build + pub const HASH: Option<&str> = option_env!("AMAZON_Q_BUILD_HASH"); + + /// The datetime in rfc3339 format of the current build + pub const DATETIME: Option<&str> = option_env!("AMAZON_Q_BUILD_DATETIME"); + + /// If `fish` tests should be skipped + pub const SKIP_FISH_TESTS: bool = option_env!("AMAZON_Q_BUILD_SKIP_FISH_TESTS").is_some(); + + /// If `shellcheck` tests should be skipped + pub const SKIP_SHELLCHECK_TESTS: bool = option_env!("AMAZON_Q_BUILD_SKIP_SHELLCHECK_TESTS").is_some(); +} + +/// macOS specific constants +pub mod macos { + pub const BUNDLE_CONTENTS_MACOS_PATH: &str = "Contents/MacOS"; + pub const BUNDLE_CONTENTS_RESOURCE_PATH: &str = "Contents/Resources"; + pub const BUNDLE_CONTENTS_HELPERS_PATH: &str = "Contents/Helpers"; + pub const BUNDLE_CONTENTS_INFO_PLIST_PATH: &str = "Contents/Info.plist"; +} + +pub mod linux { + pub const DESKTOP_ENTRY_NAME: &str = "amazon-q.desktop"; + + /// Name of the deb package. + pub const PACKAGE_NAME: &str = "amazon-q"; + + /// The wm_class used for the application windows. + pub const DESKTOP_APP_WM_CLASS: &str = "Amazon-q"; +} + +pub mod env_var { + macro_rules! define_env_vars { + ($($(#[$meta:meta])* $ident:ident = $name:expr),*) => { + $( + $(#[$meta])* + pub const $ident: &str = $name; + )* + + pub const ALL: &[&str] = &[$($ident),*]; + } + } + + define_env_vars! { + /// The UUID of the current parent qterm instance + QTERM_SESSION_ID = "QTERM_SESSION_ID", + + /// The current parent socket to connect to + Q_PARENT = "Q_PARENT", + + /// Set the [`Q_PARENT`] parent socket to connect to + Q_SET_PARENT = "Q_SET_PARENT", + + /// Guard for the [`Q_SET_PARENT`] check + Q_SET_PARENT_CHECK = "Q_SET_PARENT_CHECK", + + /// Set if qterm is running, contains the version + Q_TERM = "Q_TERM", + + /// Sets the current log level + Q_LOG_LEVEL = "Q_LOG_LEVEL", + + /// Overrides the ZDOTDIR environment variable + Q_ZDOTDIR = "Q_ZDOTDIR", + + /// Indicates a process was launched by Amazon Q + PROCESS_LAUNCHED_BY_Q = "PROCESS_LAUNCHED_BY_Q", + + /// The shell to use in qterm + Q_SHELL = "Q_SHELL", + + /// Indicates the user is debugging the shell + Q_DEBUG_SHELL = "Q_DEBUG_SHELL", + + /// Indicates the user is using zsh autosuggestions which disables Inline + Q_USING_ZSH_AUTOSUGGESTIONS = "Q_USING_ZSH_AUTOSUGGESTIONS", + + /// Overrides the path to the bundle metadata released with certain desktop builds. + Q_BUNDLE_METADATA_PATH = "Q_BUNDLE_METADATA_PATH" + } +} + +#[cfg(test)] +mod tests { + use time::OffsetDateTime; + use time::format_description::well_known::Rfc3339; + + use super::*; + + #[test] + fn test_build_envs() { + if let Some(build_variant) = build::VARIANT { + println!("build_variant: {build_variant}"); + assert!(["full", "minimal"].contains(&&*build_variant.to_ascii_lowercase())); + } + + if let Some(build_hash) = build::HASH { + println!("build_hash: {build_hash}"); + assert!(!build_hash.is_empty()); + } + + if let Some(build_datetime) = build::DATETIME { + println!("build_datetime: {build_datetime}"); + println!("{}", OffsetDateTime::parse(build_datetime, &Rfc3339).unwrap()); + } + } +} diff --git a/crates/kiro-cli/src/fig_util/directories.rs b/crates/kiro-cli/src/fig_util/directories.rs new file mode 100644 index 0000000000..ff15c5bb56 --- /dev/null +++ b/crates/kiro-cli/src/fig_util/directories.rs @@ -0,0 +1,304 @@ +use std::path::PathBuf; + +use thiserror::Error; + +use crate::fig_os_shim::{ + EnvProvider, + FsProvider, + Os, + Shim, +}; +use crate::fig_util::env_var::Q_PARENT; + +#[derive(Debug, Error)] +pub enum DirectoryError { + #[error("home directory not found")] + NoHomeDirectory, + #[error("runtime directory not found: neither XDG_RUNTIME_DIR nor TMPDIR were found")] + NoRuntimeDirectory, + #[error("non absolute path: {0:?}")] + NonAbsolutePath(PathBuf), + #[error("unsupported platform: {0:?}")] + UnsupportedOs(Os), + #[error("IO Error: {0}")] + Io(#[from] std::io::Error), + #[error(transparent)] + TimeFormat(#[from] time::error::Format), + #[error(transparent)] + Utf8FromPath(#[from] camino::FromPathError), + #[error(transparent)] + Utf8FromPathBuf(#[from] camino::FromPathBufError), + #[error(transparent)] + FromVecWithNul(#[from] std::ffi::FromVecWithNulError), + #[error(transparent)] + IntoString(#[from] std::ffi::IntoStringError), + #[error("{Q_PARENT} env variable not set")] + QParentNotSet, + #[error("must be ran from an appimage executable")] + NotAppImage, +} + +type Result = std::result::Result; + +/// The directory of the users home +/// +/// - Linux: /home/Alice +/// - MacOS: /Users/Alice +/// - Windows: C:\Users\Alice +pub fn home_dir() -> Result { + dirs::home_dir().ok_or(DirectoryError::NoHomeDirectory) +} + +pub fn home_dir_ctx(ctx: &Ctx) -> Result { + if ctx.env().is_real() { + home_dir() + } else { + ctx.env() + .get("HOME") + .map_err(|_err| DirectoryError::NoHomeDirectory) + .and_then(|h| { + if h.is_empty() { + Err(DirectoryError::NoHomeDirectory) + } else { + Ok(h) + } + }) + .map(PathBuf::from) + .map(|p| ctx.fs().chroot_path(p)) + } +} + +/// The directory of the users `$HOME/.local/bin` directory +/// +/// MacOS and Linux path: `$HOME/.local/bin`` +#[cfg(unix)] +pub fn home_local_bin() -> Result { + let mut path = home_dir()?; + path.push(".local/bin"); + Ok(path) +} + +#[cfg(target_os = "linux")] +pub fn home_local_bin_ctx(ctx: &Context) -> Result { + let mut path = home_dir_ctx(ctx)?; + path.push(".local/bin"); + Ok(path) +} + +/// The q data directory +/// +/// - Linux: `$XDG_DATA_HOME/amazon-q` or `$HOME/.local/share/amazon-q` +/// - MacOS: `$HOME/Library/Application Support/amazon-q` +pub fn fig_data_dir() -> Result { + cfg_if::cfg_if! { + if #[cfg(unix)] { + Ok(dirs::data_local_dir() + .ok_or(DirectoryError::NoHomeDirectory)? + .join("amazon-q")) + } else if #[cfg(windows)] { + Ok(fig_dir()?.join("userdata")) + } + } +} + +pub fn fig_data_dir_ctx(fs: &impl FsProvider) -> Result { + Ok(fs.fs().chroot_path(fig_data_dir()?)) +} + +/// Get the macos tempdir from the `confstr` function +/// +/// See: +#[cfg(target_os = "macos")] +fn macos_tempdir() -> Result { + let len = unsafe { libc::confstr(libc::_CS_DARWIN_USER_TEMP_DIR, std::ptr::null::().cast_mut(), 0) }; + let mut buf: Vec = vec![0; len]; + unsafe { libc::confstr(libc::_CS_DARWIN_USER_TEMP_DIR, buf.as_mut_ptr().cast(), buf.len()) }; + let c_string = std::ffi::CString::from_vec_with_nul(buf)?; + let str = c_string.into_string()?; + Ok(PathBuf::from(str)) +} + +/// Runtime dir is used for runtime data that should not be persisted for a long time, e.g. socket +/// files and logs +/// +/// The XDG_RUNTIME_DIR is set by systemd , +/// if this is not set such as on macOS it will fallback to TMPDIR which is secure on macOS +#[cfg(unix)] +pub fn runtime_dir() -> Result { + let mut dir = dirs::runtime_dir(); + dir = dir.or_else(|| std::env::var_os("TMPDIR").map(PathBuf::from)); + + cfg_if::cfg_if! { + if #[cfg(target_os = "macos")] { + let macos_tempdir = macos_tempdir()?; + dir = dir.or(Some(macos_tempdir)); + } else { + dir = dir.or_else(|| Some(std::env::temp_dir())); + } + } + + dir.ok_or(DirectoryError::NoRuntimeDirectory) +} + +/// The directory to all the fig logs +/// - Linux: `/tmp/fig/$USER/logs` +/// - MacOS: `$TMPDIR/logs` +/// - Windows: `%TEMP%\fig\logs` +pub fn logs_dir() -> Result { + cfg_if::cfg_if! { + if #[cfg(unix)] { + use crate::CLI_BINARY_NAME; + Ok(runtime_dir()?.join(format!("{CLI_BINARY_NAME}log"))) + } else if #[cfg(windows)] { + Ok(std::env::temp_dir().join("amazon-q").join("logs")) + } + } +} + +/// The directory to the directory containing config for the `/context` feature in `q chat`. +pub fn chat_global_context_path(ctx: &Ctx) -> Result { + Ok(home_dir_ctx(ctx)? + .join(".aws") + .join("amazonq") + .join("global_context.json")) +} + +/// The directory to the directory containing config for the `/context` feature in `q chat`. +pub fn chat_profiles_dir(ctx: &Ctx) -> Result { + Ok(home_dir_ctx(ctx)?.join(".aws").join("amazonq").join("profiles")) +} + +/// The path to the fig settings file +pub fn settings_path() -> Result { + Ok(fig_data_dir()?.join("settings.json")) +} + +/// The path to the lock file used to indicate that the app is updating +pub fn update_lock_path(ctx: &impl FsProvider) -> Result { + Ok(fig_data_dir_ctx(ctx)?.join("update.lock")) +} + +#[cfg(test)] +mod linux_tests { + use super::*; + + #[test] + fn all_paths() { + let ctx = crate::fig_os_shim::Context::new(); + assert!(logs_dir().is_ok()); + assert!(settings_path().is_ok()); + assert!(update_lock_path(&ctx).is_ok()); + } +} + +// TODO(grant): Add back path tests on linux +#[cfg(all(test, not(target_os = "linux")))] +mod tests { + use insta; + + use super::*; + + macro_rules! assert_directory { + ($value:expr, @$snapshot:literal) => { + insta::assert_snapshot!( + sanitized_directory_path($value), + @$snapshot, + ) + }; + } + + macro_rules! macos { + ($value:expr, @$snapshot:literal) => { + #[cfg(target_os = "macos")] + assert_directory!($value, @$snapshot) + }; + } + + macro_rules! linux { + ($value:expr, @$snapshot:literal) => { + #[cfg(target_os = "linux")] + assert_directory!($value, @$snapshot) + }; + } + + macro_rules! windows { + ($value:expr, @$snapshot:literal) => { + #[cfg(target_os = "windows")] + assert_directory!($value, @$snapshot) + }; + } + + fn sanitized_directory_path(path: Result) -> String { + let mut path = path.unwrap().into_os_string().into_string().unwrap(); + + if let Ok(home) = std::env::var("HOME") { + let home = home.strip_suffix('/').unwrap_or(&home); + path = path.replace(home, "$HOME"); + } + + let user = whoami::username(); + path = path.replace(&user, "$USER"); + + if let Ok(tmpdir) = std::env::var("TMPDIR") { + let tmpdir = tmpdir.strip_suffix('/').unwrap_or(&tmpdir); + path = path.replace(tmpdir, "$TMPDIR"); + } + + #[cfg(target_os = "macos")] + { + if let Ok(tmpdir) = macos_tempdir() { + let tmpdir = tmpdir.to_str().unwrap(); + let tmpdir = tmpdir.strip_suffix('/').unwrap_or(tmpdir); + path = path.replace(tmpdir, "$TMPDIR"); + }; + } + + if let Ok(xdg_runtime_dir) = std::env::var("XDG_RUNTIME_DIR") { + let xdg_runtime_dir = xdg_runtime_dir.strip_suffix('/').unwrap_or(&xdg_runtime_dir); + path = path.replace(xdg_runtime_dir, "$XDG_RUNTIME_DIR"); + } + + #[cfg(target_os = "linux")] + { + path = path.replace("/tmp", "$TMPDIR"); + } + + path + } + + #[cfg(unix)] + #[test] + fn snapshot_home_local_bin() { + linux!(home_local_bin(), @"$HOME/.local/bin"); + macos!(home_local_bin(), @"$HOME/.local/bin"); + } + + #[test] + fn snapshot_fig_data_dir() { + linux!(fig_data_dir(), @"$HOME/.local/share/amazon-q"); + macos!(fig_data_dir(), @"$HOME/Library/Application Support/amazon-q"); + windows!(fig_data_dir(), @r"C:\Users\$USER\AppData\Local\Fig\userdata"); + } + + #[test] + fn snapshot_settings_path() { + linux!(settings_path(), @"$HOME/.local/share/amazon-q/settings.json"); + macos!(settings_path(), @"$HOME/Library/Application Support/amazon-q/settings.json"); + windows!(settings_path(), @r"C:\Users\$USER\AppData\Lcoal\Fig\settings.json"); + } + + #[test] + fn snapshot_update_lock_path() { + let ctx = crate::fig_os_shim::Context::new(); + linux!(update_lock_path(&ctx), @"$HOME/.local/share/amazon-q/update.lock"); + macos!(update_lock_path(&ctx), @"$HOME/Library/Application Support/amazon-q/update.lock"); + windows!(update_lock_path(&ctx), @r"C:\Users\$USER\AppData\Local\Fig\userdata\update.lock"); + } + + #[test] + #[cfg(target_os = "macos")] + fn macos_tempdir_test() { + let tmpdir = macos_tempdir().unwrap(); + println!("{:?}", tmpdir); + } +} diff --git a/crates/kiro-cli/src/fig_util/error.rs b/crates/kiro-cli/src/fig_util/error.rs new file mode 100644 index 0000000000..910e76bcdc --- /dev/null +++ b/crates/kiro-cli/src/fig_util/error.rs @@ -0,0 +1,25 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum Error { + #[error("io operation error")] + IoError(#[from] std::io::Error), + #[error("unsupported platform")] + UnsupportedPlatform, + #[error("unsupported architecture")] + UnsupportedArch, + #[error(transparent)] + Directory(#[from] crate::directories::DirectoryError), + #[error("process has no parent")] + NoParentProcess, + #[error("could not find the os hwid")] + HwidNotFound, + #[error("the shell, `{0}`, isn't supported yet")] + UnknownShell(String), + #[error("missing environment variable `{0}`")] + MissingEnv(&'static str), + #[error("unknown display server `{0}`")] + UnknownDisplayServer(String), + #[error("unknown desktop `{0}`")] + UnknownDesktop(String), +} diff --git a/crates/kiro-cli/src/fig_util/manifest.rs b/crates/kiro-cli/src/fig_util/manifest.rs new file mode 100644 index 0000000000..119627ef97 --- /dev/null +++ b/crates/kiro-cli/src/fig_util/manifest.rs @@ -0,0 +1,343 @@ +use std::fmt::Display; +use std::str::FromStr; +use std::sync::OnceLock; + +use cfg_if::cfg_if; +use serde::{ + Deserialize, + Deserializer, + Serialize, +}; +use strum::{ + Display, + EnumString, +}; + +use crate::fig_util::build::TARGET_TRIPLE; +use crate::fig_util::consts::build::VARIANT; + +#[derive(Deserialize)] +pub struct Manifest { + #[serde(deserialize_with = "deser_enum_other")] + pub managed_by: ManagedBy, + #[serde(deserialize_with = "deser_enum_other")] + pub target_triple: TargetTriple, + #[serde(deserialize_with = "deser_enum_other")] + pub variant: Variant, + #[serde(deserialize_with = "deser_enum_other")] + pub default_channel: Channel, + pub packaged_at: String, + pub packaged_by: String, +} + +#[derive(EnumString, Display, Deserialize, Serialize, PartialEq, Eq, Clone, Debug)] +#[serde(rename_all = "camelCase")] +#[strum(serialize_all = "camelCase")] +pub enum ManagedBy { + None, + #[strum(default)] + Other(String), +} + +/// The target triplet, describes a platform on which the project is build for. Note that this also +/// includes "fake" targets like `universal-apple-darwin` as provided by [Tauri](https://tauri.app/v1/guides/building/macos/#binary-targets) +#[derive(Deserialize, Serialize, PartialEq, Eq, EnumString, Debug, Display)] +pub enum TargetTriple { + #[serde(rename = "universal-apple-darwin")] + #[strum(serialize = "universal-apple-darwin")] + UniversalAppleDarwin, + #[serde(rename = "x86_64-unknown-linux-gnu")] + #[strum(serialize = "x86_64-unknown-linux-gnu")] + X86_64UnknownLinuxGnu, + #[serde(rename = "x86_64-unknown-linux-musl")] + #[strum(serialize = "x86_64-unknown-linux-musl")] + X86_64UnknownLinuxMusl, + #[serde(rename = "aarch64-unknown-linux-gnu")] + #[strum(serialize = "aarch64-unknown-linux-gnu")] + AArch64UnknownLinuxGnu, + #[serde(rename = "aarch64-unknown-linux-musl")] + #[strum(serialize = "aarch64-unknown-linux-musl")] + AArch64UnknownLinuxMusl, + #[strum(default)] + Other(String), +} + +impl TargetTriple { + const fn from_system() -> Self { + cfg_if! { + if #[cfg(target_os = "macos")] { + TargetTriple::UniversalAppleDarwin + } else if #[cfg(all(target_os = "linux", target_env = "gnu", target_arch = "x86_64"))] { + TargetTriple::X86_64UnknownLinuxGnu + } else if #[cfg(all(target_os = "linux", target_env = "gnu", target_arch = "aarch64"))] { + TargetTriple::AArch64UnknownLinuxGnu + } else if #[cfg(all(target_os = "linux", target_env = "musl", target_arch = "x86_64"))] { + TargetTriple::X86_64UnknownLinuxMusl + } else if #[cfg(all(target_os = "linux", target_env = "musl", target_arch = "aarch64"))] { + TargetTriple::AArch64UnknownLinuxMusl + } else { + compile_error!("unknown target") + } + } + } +} + +#[derive(EnumString, Display, Deserialize, Serialize, PartialEq, Eq, Clone, Debug)] +#[serde(rename_all = "camelCase")] +#[strum(serialize_all = "camelCase")] +pub enum Variant { + Full, + #[serde(alias = "headless")] + #[strum(to_string = "minimal", serialize = "headless")] + Minimal, + #[strum(default)] + Other(String), +} + +#[derive(EnumString, Display, Deserialize, Serialize, PartialEq, Eq, Clone, Debug)] +#[serde(rename_all = "camelCase")] +#[strum(serialize_all = "camelCase")] +pub enum Os { + Macos, + Linux, + #[strum(default)] + Other(String), +} + +impl Os { + pub fn current() -> Self { + match std::env::consts::OS { + "macos" => Os::Macos, + "linux" => Os::Linux, + _ => panic!("Unsupported OS: {}", std::env::consts::OS), + } + } + + pub fn is_current_os(&self) -> bool { + self == &Os::current() + } +} + +#[derive(EnumString, Display, Deserialize, Serialize, PartialEq, Eq, Clone, Debug)] +#[serde(rename_all = "camelCase")] +#[strum(serialize_all = "camelCase")] +pub enum FileType { + Dmg, + TarGz, + TarXz, + TarZst, + Zip, + AppImage, + Deb, + #[strum(default)] + Other(String), +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, EnumString, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +#[strum(serialize_all = "camelCase")] +pub enum Channel { + Stable, + Beta, + Qa, + Nightly, +} + +impl Channel { + pub fn all() -> &'static [Self] { + &[Channel::Stable, Channel::Beta, Channel::Qa, Channel::Nightly] + } + + pub fn id(&self) -> &'static str { + match self { + Channel::Stable => "stable", + Channel::Beta => "beta", + Channel::Qa => "qa", + Channel::Nightly => "nightly", + } + } + + pub fn name(&self) -> &'static str { + match self { + Channel::Stable => "Stable", + Channel::Beta => "Beta", + Channel::Qa => "QA", + Channel::Nightly => "Nightly", + } + } +} + +impl Display for Channel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if f.alternate() { + f.write_str(self.name()) + } else { + f.write_str(self.id()) + } + } +} + +#[derive(Debug, Clone, Deserialize)] +pub struct BundleMetadata { + pub packaged_as: FileType, +} + +fn deser_enum_other<'de, D, T>(deserializer: D) -> Result +where + D: Deserializer<'de>, + T: FromStr, + T::Err: Display, +{ + match T::from_str(<&str as Deserialize<'de>>::deserialize(deserializer)?) { + Ok(s) => Ok(s), + Err(err) => Err(serde::de::Error::custom(err)), + } +} + +/// Returns the manifest, reading and parsing it if necessary +pub fn manifest() -> &'static Manifest { + static CACHED: OnceLock = OnceLock::new(); + CACHED.get_or_init(|| Manifest { + managed_by: ManagedBy::None, + target_triple: match TARGET_TRIPLE { + Some(target) => TargetTriple::from_str(target).expect("parsing target triple should not fail"), + _ => TargetTriple::from_system(), + }, + variant: match VARIANT.map(|s| s.to_ascii_lowercase()).as_deref() { + Some("minimal") => Variant::Minimal, + _ => Variant::Full, + }, + default_channel: Channel::Stable, + packaged_at: "unknown".into(), + packaged_by: "unknown".into(), + }) +} + +/// Checks if this is a full build according to the manifest. +/// Note that this does not guarantee the value of is_minimal +pub fn is_full() -> bool { + cfg_if! { + if #[cfg(target_os = "macos")] { + true + } else if #[cfg(unix)] { + matches!( + manifest(), + Manifest { + variant: Variant::Full, + .. + } + ) + } else if #[cfg(windows)] { + true + } + } +} + +/// Checks if this is a minimal build according to the manifest. +/// Note that this does not guarantee the value of is_full +pub fn is_minimal() -> bool { + cfg_if! { + if #[cfg(target_os = "macos")] { + false + } else if #[cfg(unix)] { + matches!( + manifest(), + Manifest { + variant: Variant::Minimal, + .. + } + ) + } else if #[cfg(windows)] { + false + } + } +} + +/// Gets the version from the manifest +#[deprecated = "versions are unified, use env!(\"CARGO_PKG_VERSION\")"] +pub fn version() -> Option<&'static str> { + Some(env!("CARGO_PKG_VERSION")) +} + +#[cfg(test)] +mod tests { + use serde_json::{ + from_str, + to_string, + }; + + use super::*; + + macro_rules! test_ser_deser { + ($ty:ident, $variant:expr, $text:expr) => { + let quoted = format!("\"{}\"", $text); + assert_eq!(quoted, to_string(&$variant).unwrap()); + assert_eq!($variant, from_str("ed).unwrap()); + assert_eq!($variant, $ty::from_str($text).unwrap()); + assert_eq!($text, $variant.to_string()); + }; + } + + #[test] + fn test_target_triple_serialize_deserialize() { + test_ser_deser!( + TargetTriple, + TargetTriple::UniversalAppleDarwin, + "universal-apple-darwin" + ); + test_ser_deser!( + TargetTriple, + TargetTriple::X86_64UnknownLinuxGnu, + "x86_64-unknown-linux-gnu" + ); + test_ser_deser!( + TargetTriple, + TargetTriple::AArch64UnknownLinuxGnu, + "aarch64-unknown-linux-gnu" + ); + test_ser_deser!( + TargetTriple, + TargetTriple::X86_64UnknownLinuxMusl, + "x86_64-unknown-linux-musl" + ); + test_ser_deser!( + TargetTriple, + TargetTriple::AArch64UnknownLinuxMusl, + "aarch64-unknown-linux-musl" + ); + } + + #[test] + fn test_file_type_serialize_deserialize() { + test_ser_deser!(FileType, FileType::Dmg, "dmg"); + test_ser_deser!(FileType, FileType::TarGz, "tarGz"); + test_ser_deser!(FileType, FileType::TarXz, "tarXz"); + test_ser_deser!(FileType, FileType::TarZst, "tarZst"); + test_ser_deser!(FileType, FileType::Zip, "zip"); + test_ser_deser!(FileType, FileType::AppImage, "appImage"); + test_ser_deser!(FileType, FileType::Deb, "deb"); + } + + #[test] + fn test_managed_by_serialize_deserialize() { + test_ser_deser!(ManagedBy, ManagedBy::None, "none"); + } + + #[test] + fn test_variant_serialize_deserialize() { + test_ser_deser!(Variant, Variant::Full, "full"); + test_ser_deser!(Variant, Variant::Minimal, "minimal"); + + // headless is a special case that should deserialize to Minimal + assert_eq!(Variant::Minimal, from_str("\"headless\"").unwrap()); + assert_eq!(Variant::Minimal, Variant::from_str("headless").unwrap()); + } + + #[test] + fn test_channel_serialize_deserialize() { + test_ser_deser!(Channel, Channel::Stable, "stable"); + test_ser_deser!(Channel, Channel::Beta, "beta"); + test_ser_deser!(Channel, Channel::Qa, "qa"); + test_ser_deser!(Channel, Channel::Nightly, "nightly"); + } +} diff --git a/crates/kiro-cli/src/fig_util/mod.rs b/crates/kiro-cli/src/fig_util/mod.rs new file mode 100644 index 0000000000..9c3044f6fd --- /dev/null +++ b/crates/kiro-cli/src/fig_util/mod.rs @@ -0,0 +1,379 @@ +mod cli_context; +pub mod directories; +pub mod manifest; +pub mod open; +pub mod pid_file; +pub mod process_info; +mod region_check; +pub mod spinner; +pub mod system_info; + +pub mod consts; + +use std::cmp::Ordering; +use std::env; +use std::ffi::OsStr; +use std::fmt::Display; +use std::io::{ + ErrorKind, + stdout, +}; +use std::path::{ + Path, + PathBuf, +}; +use std::process::Command; + +use anstream::stream::IsTerminal; +use cfg_if::cfg_if; +pub use cli_context::CliContext; +pub use consts::*; +use crossterm::style::Stylize; +use dialoguer::Select; +use dialoguer::theme::ColorfulTheme; +use eyre::{ + Context, + ContextCompat, + Result, + bail, +}; +use globset::{ + Glob, + GlobSet, + GlobSetBuilder, +}; +use rand::Rng; +use regex::Regex; +use thiserror::Error; +use tracing::warn; + +#[derive(Debug, Error)] +pub enum Error { + #[error("io operation error")] + IoError(#[from] std::io::Error), + #[error("unsupported platform")] + UnsupportedPlatform, + #[error("unsupported architecture")] + UnsupportedArch, + #[error(transparent)] + Directory(#[from] directories::DirectoryError), + #[error("process has no parent")] + NoParentProcess, + #[error("could not find the os hwid")] + HwidNotFound, + #[error("the shell, `{0}`, isn't supported yet")] + UnknownShell(String), + #[error("missing environment variable `{0}`")] + MissingEnv(&'static str), + #[error("unknown display server `{0}`")] + UnknownDisplayServer(String), + #[error("unknown desktop, checked environment variables: {0}")] + UnknownDesktop(UnknownDesktopErrContext), + #[error(transparent)] + StrUtf8Error(#[from] std::str::Utf8Error), + #[error(transparent)] + Json(#[from] serde_json::Error), +} + +#[derive(Debug, Clone)] +pub struct UnknownDesktopErrContext { + xdg_current_desktop: String, + xdg_session_desktop: String, + gdm_session: String, +} + +impl std::fmt::Display for UnknownDesktopErrContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "XDG_CURRENT_DESKTOP: `{}`, ", self.xdg_current_desktop)?; + write!(f, "XDG_SESSION_DESKTOP: `{}`, ", self.xdg_session_desktop)?; + write!(f, "GDMSESSION: `{}`", self.gdm_session) + } +} + +/// Returns a random 64 character hex string +/// +/// # Example +/// +/// ``` +/// use crate::fig_util::gen_hex_string; +/// +/// let hex = gen_hex_string(); +/// assert_eq!(hex.len(), 64); +/// ``` +pub fn gen_hex_string() -> String { + let mut buf = [0u8; 32]; + rand::rng().fill(&mut buf); + hex::encode(buf) +} + +pub fn search_xdg_data_dirs(ext: impl AsRef) -> Option { + let ext = ext.as_ref(); + if let Ok(xdg_data_dirs) = std::env::var("XDG_DATA_DIRS") { + for base in xdg_data_dirs.split(':') { + let check = Path::new(base).join(ext); + if check.exists() { + return Some(check); + } + } + } + None +} + +/// Returns the path to the original executable, not the symlink +pub fn current_exe_origin() -> Result { + Ok(std::env::current_exe()?.canonicalize()?) +} + +pub fn partitioned_compare(lhs: &str, rhs: &str, by: char) -> Ordering { + let sides = lhs + .split(by) + .filter(|x| !x.is_empty()) + .zip(rhs.split(by).filter(|x| !x.is_empty())); + + for (lhs, rhs) in sides { + match if lhs.chars().all(|x| x.is_numeric()) && rhs.chars().all(|x| x.is_numeric()) { + // perform a numerical comparison + let lhs: u64 = lhs.parse().unwrap(); + let rhs: u64 = rhs.parse().unwrap(); + lhs.cmp(&rhs) + } else { + // perform a lexical comparison + lhs.cmp(rhs) + } { + Ordering::Equal => continue, + s => return s, + } + } + + lhs.len().cmp(&rhs.len()) +} + +/// Glob patterns against full paths +pub fn glob_dir(glob: &GlobSet, directory: impl AsRef) -> Result> { + let mut files = Vec::new(); + + // List files in the directory + let dir = std::fs::read_dir(directory)?; + + for entry in dir { + let path = entry?.path(); + + // Check if the file matches the glob pattern + if glob.is_match(&path) { + files.push(path); + } + } + + Ok(files) +} + +/// Glob patterns against the file name +pub fn glob_files(glob: &GlobSet, directory: impl AsRef) -> Result> { + let mut files = Vec::new(); + + // List files in the directory + let dir = std::fs::read_dir(directory)?; + + for entry in dir { + let entry = entry?; + let path = entry.path(); + let file_name = path.file_name(); + + // Check if the file matches the glob pattern + if let Some(file_name) = file_name { + if glob.is_match(file_name) { + files.push(path); + } + } + } + + Ok(files) +} + +pub fn glob(patterns: I) -> Result +where + I: IntoIterator, + S: AsRef, +{ + let mut builder = GlobSetBuilder::new(); + for pattern in patterns { + builder.add(Glob::new(pattern.as_ref())?); + } + Ok(builder.build()?) +} + +pub fn app_path_from_bundle_id(bundle_id: impl AsRef) -> Option { + cfg_if! { + if #[cfg(target_os = "macos")] { + let installed_apps = std::process::Command::new("mdfind") + .arg("kMDItemCFBundleIdentifier") + .arg("=") + .arg(bundle_id) + .output() + .ok()?; + + let path = String::from_utf8_lossy(&installed_apps.stdout); + Some(path.trim().split('\n').next()?.into()) + } else { + let _bundle_id = bundle_id; + None + } + } +} + +pub fn is_executable_in_path(program: impl AsRef) -> bool { + match env::var_os("PATH") { + Some(path) => env::split_paths(&path).any(|p| p.join(&program).is_file()), + _ => false, + } +} + +pub fn app_not_running_message() -> String { + format!( + "\n{}\n{PRODUCT_NAME} app might not be running, to launch {PRODUCT_NAME} run: {}\n", + format!("Unable to connect to {PRODUCT_NAME} app").bold(), + format!("{CLI_BINARY_NAME} launch").magenta() + ) +} + +pub fn login_message() -> String { + format!( + "{}\nLooks like you aren't logged in to {PRODUCT_NAME}, to login run: {}", + "Not logged in".bold(), + format!("{CLI_BINARY_NAME} login").magenta() + ) +} + +pub fn match_regex(regex: impl AsRef, input: impl AsRef) -> Option { + Some( + Regex::new(regex.as_ref()) + .unwrap() + .captures(input.as_ref())? + .get(1)? + .as_str() + .into(), + ) +} + +pub fn choose(prompt: impl Display, options: &[impl ToString]) -> Result> { + if options.is_empty() { + bail!("no options passed to choose") + } + + if !stdout().is_terminal() { + warn!("called choose while stdout is not a terminal"); + return Ok(Some(0)); + } + + match Select::with_theme(&dialoguer_theme()) + .items(options) + .default(0) + .with_prompt(prompt.to_string()) + .interact_opt() + { + Ok(ok) => Ok(ok), + Err(dialoguer::Error::IO(io)) if io.kind() == ErrorKind::Interrupted => Ok(None), + Err(e) => Err(e).wrap_err("Failed to choose"), + } +} + +pub fn input(prompt: &str, initial_text: Option<&str>) -> Result { + if !stdout().is_terminal() { + warn!("called input while stdout is not a terminal"); + return Ok(String::new()); + } + + let theme = dialoguer_theme(); + let mut input = dialoguer::Input::with_theme(&theme).with_prompt(prompt); + + if let Some(initial_text) = initial_text { + input = input.with_initial_text(initial_text); + } + + Ok(input.interact_text()?) +} + +pub fn get_running_app_info(bundle_id: impl AsRef, field: impl AsRef) -> Result { + let info = Command::new("lsappinfo") + .args(["info", "-only", field.as_ref(), "-app", bundle_id.as_ref()]) + .output()?; + let info = String::from_utf8(info.stdout)?; + let value = info + .split('=') + .nth(1) + .context(eyre::eyre!("Could not get field value for {}", field.as_ref()))? + .replace('"', ""); + Ok(value.trim().into()) +} + +pub fn dialoguer_theme() -> ColorfulTheme { + ColorfulTheme { + prompt_prefix: dialoguer::console::style("?".into()).for_stderr().magenta(), + ..ColorfulTheme::default() + } +} + +#[cfg(target_os = "macos")] +pub async fn is_brew_reinstall() -> bool { + let regex = regex::bytes::Regex::new(r"brew(\.\w+)?\s+(upgrade|reinstall|install)").unwrap(); + + tokio::process::Command::new("ps") + .args(["aux", "-o", "args"]) + .output() + .await + .is_ok_and(|output| regex.is_match(&output.stdout)) +} + +#[cfg(test)] +mod tests { + use std::cmp::Ordering; + + use super::*; + + #[test] + fn regex() { + let regex_test = |regex: &str, input: &str, expected: Option<&str>| { + assert_eq!(match_regex(regex, input), expected.map(|s| s.into())); + }; + + regex_test(r"foo=(\S+)", "foo=bar", Some("bar")); + regex_test(r"foo=(\S+)", "bar=foo", None); + regex_test(r"foo=(\S+)", "foo=bar baz", Some("bar")); + regex_test(r"foo=(\S+)", "foo=", None); + } + + #[test] + fn exe_path() { + #[cfg(unix)] + assert!(is_executable_in_path("cargo")); + + #[cfg(windows)] + assert!(is_executable_in_path("cargo.exe")); + } + + #[test] + fn globs() { + let set = glob(["*.txt", "*.md"]).unwrap(); + assert!(set.is_match("README.md")); + assert!(set.is_match("LICENSE.txt")); + } + + #[test] + fn test_partitioned_compare() { + assert_eq!(partitioned_compare("1.2.3", "1.2.3", '.'), Ordering::Equal); + assert_eq!(partitioned_compare("1.2.3", "1.2.2", '.'), Ordering::Greater); + assert_eq!(partitioned_compare("4-a-b", "4-a-c", '-'), Ordering::Less); + assert_eq!(partitioned_compare("0?0?0", "0?0", '?'), Ordering::Greater); + } + + #[test] + fn test_gen_hex_string() { + let hex = gen_hex_string(); + assert_eq!(hex.len(), 64); + } + + #[test] + fn test_current_exe_origin() { + current_exe_origin().unwrap(); + } +} diff --git a/crates/kiro-cli/src/fig_util/open.rs b/crates/kiro-cli/src/fig_util/open.rs new file mode 100644 index 0000000000..6315adc118 --- /dev/null +++ b/crates/kiro-cli/src/fig_util/open.rs @@ -0,0 +1,101 @@ +use cfg_if::cfg_if; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error(transparent)] + Io(#[from] std::io::Error), + #[error("Failed to open URL")] + Failed, +} + +#[cfg(target_os = "macos")] +#[allow(unexpected_cfgs)] +fn open_macos(url_str: impl AsRef) -> Result<(), Error> { + use objc2::ClassType; + use objc2_foundation::{ + NSString, + NSURL, + }; + + let url_nsstring = NSString::from_str(url_str.as_ref()); + let nsurl = unsafe { NSURL::initWithString(NSURL::alloc(), &url_nsstring) }.ok_or(Error::Failed)?; + let res = unsafe { objc2_app_kit::NSWorkspace::sharedWorkspace().openURL(&nsurl) }; + res.then_some(()).ok_or(Error::Failed) +} + +#[cfg(target_os = "windows")] +fn open_command(url: impl AsRef) -> std::process::Command { + use std::os::windows::process::CommandExt; + + let detached = 0x8; + let mut command = std::process::Command::new("cmd"); + command.creation_flags(detached); + command.args(["/c", "start", url.as_ref()]); + command +} + +#[cfg(any(target_os = "linux", target_os = "freebsd"))] +fn open_command(url: impl AsRef) -> std::process::Command { + let executable = if crate::system_info::in_wsl() { + "wslview" + } else { + "xdg-open" + }; + + let mut command = std::process::Command::new(executable); + command.arg(url.as_ref()); + command +} + +/// Returns bool indicating whether the URL was opened successfully +pub fn open_url(url: impl AsRef) -> Result<(), Error> { + cfg_if! { + if #[cfg(target_os = "macos")] { + open_macos(url) + } else { + match open_command(url).output() { + Ok(output) => { + tracing::trace!(?output, "open_url output"); + if output.status.success() { + Ok(()) + } else { + Err(Error::Failed) + } + }, + Err(err) => Err(err.into()), + } + } + } +} + +/// Returns bool indicating whether the URL was opened successfully +pub async fn open_url_async(url: impl AsRef) -> Result<(), Error> { + cfg_if! { + if #[cfg(target_os = "macos")] { + open_macos(url) + } else { + match tokio::process::Command::from(open_command(url)).output().await { + Ok(output) => { + tracing::trace!(?output, "open_url_async output"); + if output.status.success() { + Ok(()) + } else { + Err(Error::Failed) + } + }, + Err(err) => Err(err.into()), + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[ignore] + #[test] + fn test_open_url() { + open_url("https://fig.io").unwrap(); + } +} diff --git a/crates/kiro-cli/src/fig_util/pid_file.rs b/crates/kiro-cli/src/fig_util/pid_file.rs new file mode 100644 index 0000000000..0f86ef5c66 --- /dev/null +++ b/crates/kiro-cli/src/fig_util/pid_file.rs @@ -0,0 +1,167 @@ +use std::fs::{ + File, + OpenOptions, +}; +use std::io::{ + Error, + ErrorKind, + Seek, + SeekFrom, + Write, +}; +use std::os::unix::fs::OpenOptionsExt; +use std::path::PathBuf; + +use eyre::Result; +use nix::fcntl::{ + Flock, + FlockArg, +}; +use nix::sys::signal::{ + Signal, + kill, +}; +use nix::unistd::Pid; +use tokio::fs::read_to_string; +use tokio::time::sleep; +use tracing::{ + debug, + error, + info, + instrument, + warn, +}; + +/// A file-based process lock that ensures only one instance of a process is running. +/// +/// `PidLock` works by: +/// 1. Creating/opening a PID file at the specified path +/// 2. Attempting to acquire an exclusive lock on the file +/// 3. Writing the current process ID to the file +/// 4. If another process holds the lock, attempts to terminate that process first +/// +/// The lock is automatically released when the `PidLock` instance is dropped. +#[derive(Debug)] +pub struct PidLock { + lock: Flock, + pid_path: PathBuf, +} + +impl PidLock { + #[instrument(name = "PidLock::new")] + pub async fn new(pid_path: PathBuf) -> Result { + let file = OpenOptions::new() + .read(true) + .write(true) + .create(true) + .truncate(false) + .mode(0o644) + .open(&pid_path) + .inspect_err(|err| error!(%err, "Failed to open pid file"))?; + + // Try to get exclusive lock + let mut lock = match Flock::lock(file, FlockArg::LockExclusiveNonblock) { + Ok(lock) => lock, + Err((file, err)) => { + debug!(%err, "Failed to acquire lock, trying to handle existing process"); + + // Read existing PID + match read_to_string(&pid_path).await { + Ok(content) => match content.trim().parse::() { + Ok(pid) => { + debug!(%pid, "Found existing process ID"); + if let Err(err) = kill_process(pid).await { + error!(%err, %pid, "Failed to kill existing process"); + } else { + info!(%pid, "Successfully killed existing process"); + } + }, + Err(err) => { + warn!(%err, %content, "Failed to parse PID from lockfile"); + }, + }, + Err(err) => warn!(%err, "Failed to read PID from lockfile"), + } + + Flock::lock(file, FlockArg::LockExclusiveNonblock).map_err(|(_, err)| { + error!(%err, "Failed to acquire lock after handling existing process"); + err + })? + }, + }; + + // Write current PID + let current_pid = std::process::id(); + lock.set_len(0) + .inspect_err(|err| error!(%err, "Failed to truncate lock file"))?; + lock.seek(SeekFrom::Start(0)) + .inspect_err(|err| error!(%err, "Failed to seek to start of file"))?; + lock.write_all(current_pid.to_string().as_bytes()) + .inspect_err(|err| error!(%err, "Failed to write PID to file"))?; + lock.flush() + .inspect_err(|err| error!(%err, "Failed to flush lock file"))?; + + info!(%current_pid, "Successfully created and locked PID file"); + Ok(PidLock { lock, pid_path }) + } + + #[instrument(name = "PidLock::release", skip(self), fields(pid_path =? self.pid_path))] + pub fn release(mut self) -> Result<(), Error> { + debug!("Releasing PID lock"); + self.lock + .set_len(0) + .inspect_err(|err| error!(%err, "Failed to truncate lock file during release"))?; + self.lock + .flush() + .inspect_err(|err| error!(%err, "Failed to flush lock file during release"))?; + self.lock.unlock().map_err(|(_, err)| { + error!(%err, "Failed to unlock file during release"); + err + })?; + Ok(()) + } +} + +#[instrument(level = "debug")] +fn process_exists(pid: i32) -> bool { + let exists = kill(Pid::from_raw(pid), None).is_ok(); + debug!(%pid, %exists, "Checked if process exists"); + exists +} + +#[instrument(level = "debug")] +async fn kill_process(pid: i32) -> Result<()> { + if !process_exists(pid) { + error!(%pid, "Process not found"); + return Err(Error::new(ErrorKind::NotFound, format!("Process already running with PID {pid}")).into()); + } + + info!(%pid, "Attempting to terminate process"); + match kill(Pid::from_raw(pid), Signal::SIGINT) { + Ok(_) => { + debug!(%pid, "Sent SIGINT signal"); + + // Wait for the process to terminate + for i in 0..50 { + if !process_exists(pid) { + info!(%pid, "Process terminated successfully"); + return Ok(()); + } + debug!(%pid, attempt = i, "Process still running, waiting"); + sleep(std::time::Duration::from_millis(100)).await; + } + + if process_exists(pid) { + warn!(%pid, "Process didn't terminate gracefully, sending SIGKILL"); + let _ = kill(Pid::from_raw(pid), Signal::SIGKILL) + .inspect_err(|err| error!(%err, %pid, "Failed to send SIGKILL")); + sleep(std::time::Duration::from_millis(100)).await; + } + Ok(()) + }, + Err(err) => { + error!(%err, %pid, "Failed to send SIGINT"); + Err(Error::new(ErrorKind::Other, format!("Failed to terminate existing process: {err}")).into()) + }, + } +} diff --git a/crates/kiro-cli/src/fig_util/process_info/freebsd.rs b/crates/kiro-cli/src/fig_util/process_info/freebsd.rs new file mode 100644 index 0000000000..384d9ded68 --- /dev/null +++ b/crates/kiro-cli/src/fig_util/process_info/freebsd.rs @@ -0,0 +1,20 @@ +use std::path::PathBuf; + +use super::{ + Pid, + PidExt, +}; + +impl PidExt for Pid { + fn current() -> Self { + nix::unistd::getpid().into() + } + + fn parent(&self) -> Option { + None + } + + fn exe(&self) -> Option { + None + } +} diff --git a/crates/kiro-cli/src/fig_util/process_info/linux.rs b/crates/kiro-cli/src/fig_util/process_info/linux.rs new file mode 100644 index 0000000000..a013f31580 --- /dev/null +++ b/crates/kiro-cli/src/fig_util/process_info/linux.rs @@ -0,0 +1,41 @@ +use std::path::PathBuf; +use std::str::FromStr; + +pub trait LinuxExt { + fn cmdline(&self) -> Option; +} + +use super::{ + Pid, + PidExt, +}; + +impl PidExt for Pid { + fn current() -> Self { + nix::unistd::getpid().into() + } + + fn parent(&self) -> Option { + std::fs::read_to_string(format!("/proc/{self}/status")) + .ok() + .and_then(|s| { + s.lines() + .find(|line| line.starts_with("PPid:")) + .and_then(|line| line.strip_prefix("PPid:")) + .map(|line| line.trim()) + .and_then(|pid_str| Pid::from_str(pid_str).ok()) + }) + } + + fn exe(&self) -> Option { + std::path::PathBuf::from(format!("/proc/{self}/exe")).read_link().ok() + } +} + +impl LinuxExt for Pid { + fn cmdline(&self) -> Option { + std::fs::read_to_string(format!("/proc/{self}/cmdline")) + .ok() + .map(|s| s.replace('\0', "")) + } +} diff --git a/crates/kiro-cli/src/fig_util/process_info/macos.rs b/crates/kiro-cli/src/fig_util/process_info/macos.rs new file mode 100644 index 0000000000..d260c36456 --- /dev/null +++ b/crates/kiro-cli/src/fig_util/process_info/macos.rs @@ -0,0 +1,49 @@ +use std::ffi::OsStr; +use std::mem::MaybeUninit; +use std::os::unix::prelude::OsStrExt; +use std::path::PathBuf; + +use super::{ + Pid, + PidExt, +}; + +impl PidExt for Pid { + fn current() -> Self { + nix::unistd::getpid().into() + } + + fn parent(&self) -> Option { + let pid = self.0; + let mut info = MaybeUninit::::zeroed(); + let ret = unsafe { + nix::libc::proc_pidinfo( + pid, + nix::libc::PROC_PIDTBSDINFO, + 0, + info.as_mut_ptr().cast(), + std::mem::size_of::() as _, + ) + }; + if ret as usize != std::mem::size_of::() { + return None; + } + let info = unsafe { info.assume_init() }; + match info.pbi_ppid { + 0 => None, + ppid => Some(Pid(ppid.try_into().ok()?)), + } + } + + fn exe(&self) -> Option { + let mut buffer = [0u8; 4096]; + let pid = self.0; + let buffer_ptr = buffer.as_mut_ptr().cast::(); + let buffer_size = buffer.len() as u32; + let ret = unsafe { nix::libc::proc_pidpath(pid, buffer_ptr, buffer_size) }; + match ret { + 0 => None, + len => Some(PathBuf::from(OsStr::from_bytes(&buffer[..len as usize]))), + } + } +} diff --git a/crates/kiro-cli/src/fig_util/process_info/mod.rs b/crates/kiro-cli/src/fig_util/process_info/mod.rs new file mode 100644 index 0000000000..cc415a785c --- /dev/null +++ b/crates/kiro-cli/src/fig_util/process_info/mod.rs @@ -0,0 +1,118 @@ +use std::path::PathBuf; +use std::{ + fmt, + str, +}; + +use cfg_if::cfg_if; + +#[cfg(target_os = "linux")] +mod linux; +#[cfg(target_os = "linux")] +pub use linux::*; + +#[cfg(target_os = "macos")] +mod macos; +// #[cfg(target_os = "macos")] +// pub use macos::*; + +#[cfg(target_os = "windows")] +mod windows; +#[cfg(target_os = "windows")] +pub use self::windows::*; + +#[cfg(target_os = "freebsd")] +mod freebsd; +#[cfg(target_os = "freebsd")] +pub use self::freebsd::*; + +macro_rules! pid_decl { + ($typ:ty) => { + #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] + #[repr(transparent)] + pub struct Pid(pub(crate) $typ); + + impl From<$typ> for Pid { + fn from(v: $typ) -> Self { + Self(v) + } + } + impl From for $typ { + fn from(v: Pid) -> Self { + v.0 + } + } + impl str::FromStr for Pid { + type Err = <$typ as str::FromStr>::Err; + + fn from_str(s: &str) -> Result { + Ok(Self(<$typ>::from_str(s)?)) + } + } + impl fmt::Display for Pid { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } + } + }; +} + +cfg_if! { + if #[cfg(unix)] { + use nix::libc::pid_t; + + pid_decl!(pid_t); + + impl From for Pid { + fn from(pid: nix::unistd::Pid) -> Self { + Pid(pid.as_raw()) + } + } + + impl From for nix::unistd::Pid { + fn from(pid: Pid) -> Self { + nix::unistd::Pid::from_raw(pid.0) + } + } + } else if #[cfg(windows)] { + pid_decl!(u32); + } +} + +pub trait PidExt { + fn current() -> Self; + fn parent(&self) -> Option; + fn exe(&self) -> Option; +} + +pub fn get_parent_process_exe() -> Option { + let mut pid = Pid::current(); + loop { + pid = pid.parent()?; + match pid.exe() { + // We ignore toolbox-exec since we never want to know if that is the parent process + Some(pid) if pid.file_name().is_some_and(|s| s == "toolbox-exec") => {}, + other => return other, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parent_name() { + let process_pid = Pid::current(); + let parent_pid = process_pid.parent().unwrap(); + let parent_exe = parent_pid.exe().unwrap(); + let parent_name = parent_exe.file_name().unwrap().to_str().unwrap(); + + assert!(parent_name.contains("cargo")); + } + + #[test] + fn test_get_parent_process_exe() { + get_parent_process_exe(); + } +} diff --git a/crates/kiro-cli/src/fig_util/process_info/windows.rs b/crates/kiro-cli/src/fig_util/process_info/windows.rs new file mode 100644 index 0000000000..5ecfa6615e --- /dev/null +++ b/crates/kiro-cli/src/fig_util/process_info/windows.rs @@ -0,0 +1,136 @@ +use std::ffi::CStr; +use std::mem::{ + MaybeUninit, + size_of, +}; +use std::ops::Deref; +use std::path::PathBuf; + +use windows::Win32::Foundation::{ + CloseHandle, + HANDLE, + MAX_PATH, +}; +use windows::Win32::System::Threading::{ + GetCurrentProcessId, + NtQueryInformationProcess, + OpenProcess, + PROCESS_BASIC_INFORMATION, + PROCESS_NAME_FORMAT, + PROCESS_QUERY_INFORMATION, + PROCESS_QUERY_LIMITED_INFORMATION, + PROCESS_VM_READ, + ProcessBasicInformation, + QueryFullProcessImageNameA, +}; +use windows::core::PSTR; + +use super::{ + Pid, + PidExt, +}; + +struct SafeHandle(HANDLE); + +impl SafeHandle { + fn new(handle: HANDLE) -> Option { + if !handle.is_invalid() { Some(Self(handle)) } else { None } + } +} + +impl Drop for SafeHandle { + fn drop(&mut self) { + unsafe { + CloseHandle(self.0); + } + } +} + +impl Deref for SafeHandle { + type Target = HANDLE; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +fn get_process_handle(pid: &Pid) -> Option { + if pid.0 == 0 { + return None; + } + + let handle = unsafe { + match OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, false, pid.0) { + Ok(handle) => handle, + Err(_) => match OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, false, pid.0) { + Ok(handle) => handle, + Err(_) => return None, + }, + } + }; + + SafeHandle::new(handle) +} + +impl PidExt for Pid { + fn current() -> Self { + unsafe { Pid(GetCurrentProcessId()) } + } + + fn parent(&self) -> Option { + let handle = get_process_handle(self)?; + + unsafe { + let mut info: MaybeUninit = MaybeUninit::uninit(); + let mut len = 0; + if NtQueryInformationProcess( + *handle, + ProcessBasicInformation, + info.as_mut_ptr() as *mut _, + size_of::() as _, + &mut len, + ) + .is_err() + { + return None; + } + + let info = info.assume_init(); + + if info.InheritedFromUniqueProcessId as usize != 0 { + Some(Pid(info.InheritedFromUniqueProcessId as u32)) + } else { + None + } + } + } + + fn exe(&self) -> Option { + unsafe { + let handle = OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, false, self.0).ok()?; + + // Get the terminal name + let mut len = MAX_PATH; + let mut process_name = [0; MAX_PATH as usize + 1]; + process_name[MAX_PATH as usize] = u8::try_from('\0').unwrap(); + + if !QueryFullProcessImageNameA( + handle, + PROCESS_NAME_FORMAT(0), + PSTR(process_name.as_mut_ptr()), + &mut len, + ) + .as_bool() + { + return None; + } + + let title = CStr::from_bytes_with_nul(&process_name[0..=len as usize]) + .ok()? + .to_str() + .ok()?; + + Some(PathBuf::from(title)) + } + } +} diff --git a/crates/kiro-cli/src/fig_util/region_check.rs b/crates/kiro-cli/src/fig_util/region_check.rs new file mode 100644 index 0000000000..0a9f18d8a2 --- /dev/null +++ b/crates/kiro-cli/src/fig_util/region_check.rs @@ -0,0 +1,15 @@ +use super::system_info::in_cloudshell; + +const GOV_REGIONS: &[&str] = &["us-gov-east-1", "us-gov-west-1"]; + +pub fn region_check(capability: &'static str) -> eyre::Result<()> { + let Ok(region) = std::env::var("AWS_REGION") else { + return Ok(()); + }; + + if in_cloudshell() && GOV_REGIONS.contains(®ion.as_str()) { + eyre::bail!("AWS GovCloud ({region}) is not supported for {capability}."); + } + + Ok(()) +} diff --git a/crates/kiro-cli/src/fig_util/spinner.rs b/crates/kiro-cli/src/fig_util/spinner.rs new file mode 100644 index 0000000000..2a38859dac --- /dev/null +++ b/crates/kiro-cli/src/fig_util/spinner.rs @@ -0,0 +1,126 @@ +use std::io::{ + Write, + stdout, +}; +use std::sync::mpsc::{ + Sender, + TryRecvError, + channel, +}; +use std::thread; +use std::thread::JoinHandle; +use std::time::Duration; + +use anstream::{ + print, + println, +}; +use crossterm::ExecutableCommand; + +const FRAMES: &[&str] = &[ + "▰▱▱▱▱▱▱", + "▰▰▱▱▱▱▱", + "▰▰▰▱▱▱▱", + "▰▰▰▰▱▱▱", + "▰▰▰▰▰▱▱", + "▰▰▰▰▰▰▱", + "▰▰▰▰▰▰▰", + "▰▱▱▱▱▱▱", +]; +const INTERVAL: Duration = Duration::from_millis(100); + +pub struct Spinner { + sender: Sender>, + join: Option>, +} + +impl Drop for Spinner { + fn drop(&mut self) { + if self.join.is_some() { + self.sender.send(Some("\x1b[2K\r".into())).unwrap(); + self.join.take().unwrap().join().unwrap(); + } + } +} + +#[derive(Debug, Clone)] +pub enum SpinnerComponent { + Text(String), + Spinner, +} + +impl Spinner { + pub fn new(components: Vec) -> Self { + let (sender, recv) = channel::>(); + + stdout().execute(crossterm::cursor::Hide).ok(); + + let join = thread::spawn(move || { + 'outer: loop { + let mut stdout = stdout(); + for frame in FRAMES.iter() { + let (do_stop, stop_symbol) = match recv.try_recv() { + Ok(stop_symbol) => (true, stop_symbol), + Err(TryRecvError::Disconnected) => (true, None), + Err(TryRecvError::Empty) => (false, None), + }; + + let frame = stop_symbol.unwrap_or_else(|| (*frame).to_string()); + + let line = components.iter().fold(String::new(), |mut acc, elem| { + acc.push_str(match elem { + SpinnerComponent::Text(ref t) => t, + SpinnerComponent::Spinner => &frame, + }); + acc + }); + + print!("\r{line}"); + + stdout.flush().unwrap(); + + if do_stop { + stdout.execute(crossterm::cursor::Show).ok(); + break 'outer; + } + + thread::sleep(INTERVAL); + } + } + }); + + Self { + sender, + join: Some(join), + } + } + + fn stop_inner(&mut self, stop_symbol: Option) { + self.sender.send(stop_symbol).expect("Could not stop spinner thread."); + self.join.take().unwrap().join().unwrap(); + } + + pub fn stop(&mut self) { + self.stop_inner(None); + } + + pub fn stop_with_message(&mut self, msg: String) { + self.stop(); + println!("\x1b[2K\r{msg}"); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_spinner() { + let mut spinner = Spinner::new(vec![ + SpinnerComponent::Spinner, + SpinnerComponent::Text("Loading".into()), + ]); + thread::sleep(Duration::from_secs(1)); + spinner.stop_with_message("Done".into()); + } +} diff --git a/crates/kiro-cli/src/fig_util/system_info/linux.rs b/crates/kiro-cli/src/fig_util/system_info/linux.rs new file mode 100644 index 0000000000..d611fe9293 --- /dev/null +++ b/crates/kiro-cli/src/fig_util/system_info/linux.rs @@ -0,0 +1,285 @@ +use std::io; +use std::path::Path; +use std::sync::OnceLock; + +use regex::Regex; +use serde::{ + Deserialize, + Serialize, +}; + +use crate::fig_os_shim::EnvProvider; +use crate::fig_util::{ + Error, + UnknownDesktopErrContext, +}; + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum DisplayServer { + X11, + Wayland, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum DesktopEnvironment { + Gnome, + Plasma, + I3, + Sway, +} + +pub fn get_display_server(env: &impl EnvProvider) -> Result { + match env.env().get("XDG_SESSION_TYPE") { + Ok(session) => match session.as_str() { + "x11" => Ok(DisplayServer::X11), + "wayland" => Ok(DisplayServer::Wayland), + _ => Err(Error::UnknownDisplayServer(session)), + }, + // x11 is not guarantee this var is set, so we just assume x11 if it is not set + _ => Ok(DisplayServer::X11), + } +} + +pub fn get_desktop_environment(env: &impl EnvProvider) -> Result { + let env = env.env(); + + // Prioritize XDG_CURRENT_DESKTOP and check other common env vars as fallback. + // https://superuser.com/a/1643180 + let xdg_current_desktop = match env.get("XDG_CURRENT_DESKTOP") { + Ok(current) => { + let current_lower = current.to_lowercase(); + let (_, desktop) = current_lower.split_once(':').unwrap_or(("", current_lower.as_str())); + match desktop.to_lowercase().as_str() { + "gnome" | "gnome-xorg" | "ubuntu" | "pop" => return Ok(DesktopEnvironment::Gnome), + "kde" | "plasma" => return Ok(DesktopEnvironment::Plasma), + "i3" => return Ok(DesktopEnvironment::I3), + "sway" => return Ok(DesktopEnvironment::Sway), + _ => current, + } + }, + _ => "".into(), + }; + + let xdg_session_desktop = match env.get("XDG_SESSION_DESKTOP") { + Ok(session) => { + let session_lower = session.to_lowercase(); + match session_lower.as_str() { + "gnome" | "ubuntu" => return Ok(DesktopEnvironment::Gnome), + "kde" => return Ok(DesktopEnvironment::Plasma), + _ => session, + } + }, + _ => "".into(), + }; + + let gdm_session = match env.get("GDMSESSION") { + Ok(session) if session.to_lowercase().starts_with("ubuntu") => return Ok(DesktopEnvironment::Gnome), + Ok(session) => session, + _ => "".into(), + }; + + Err(Error::UnknownDesktop(UnknownDesktopErrContext { + xdg_current_desktop, + xdg_session_desktop, + gdm_session, + })) +} + +pub fn get_os_release() -> Option<&'static OsRelease> { + static OS_RELEASE: OnceLock> = OnceLock::new(); + OS_RELEASE.get_or_init(|| OsRelease::load().ok()).as_ref() +} + +/// Fields from +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct OsRelease { + pub id: Option, + + pub name: Option, + pub pretty_name: Option, + + pub version_id: Option, + pub version: Option, + + pub build_id: Option, + + pub variant_id: Option, + pub variant: Option, +} + +impl OsRelease { + fn path() -> &'static Path { + Path::new("/etc/os-release") + } + + pub(crate) fn load() -> io::Result { + let os_release_str = std::fs::read_to_string(Self::path())?; + Ok(OsRelease::from_str(&os_release_str)) + } + + pub(crate) fn from_str(s: &str) -> OsRelease { + // Remove the starting and ending quotes from a string if they match + let strip_quotes = |s: &str| -> Option { + if s.starts_with('"') && s.ends_with('"') { + Some(s[1..s.len() - 1].into()) + } else { + Some(s.into()) + } + }; + + let mut os_release = OsRelease::default(); + for line in s.lines() { + if let Some((key, value)) = line.split_once('=') { + match key { + "ID" => os_release.id = strip_quotes(value), + "NAME" => os_release.name = strip_quotes(value), + "PRETTY_NAME" => os_release.pretty_name = strip_quotes(value), + "VERSION" => os_release.version = strip_quotes(value), + "VERSION_ID" => os_release.version_id = strip_quotes(value), + "BUILD_ID" => os_release.build_id = strip_quotes(value), + "VARIANT" => os_release.variant = strip_quotes(value), + "VARIANT_ID" => os_release.variant_id = strip_quotes(value), + _ => {}, + } + } + } + os_release + } +} + +fn containerenv_engine_re() -> &'static Regex { + static CONTAINERENV_ENGINE_RE: OnceLock = OnceLock::new(); + CONTAINERENV_ENGINE_RE.get_or_init(|| Regex::new(r#"engine="([^"\s]+)""#).unwrap()) +} + +pub enum SandboxKind { + None, + Flatpak, + Snap, + Docker, + Container(Option), +} + +pub fn detect_sandbox() -> SandboxKind { + if Path::new("/.flatpak-info").exists() { + return SandboxKind::Flatpak; + } + if std::env::var("SNAP").is_ok() { + return SandboxKind::Snap; + } + if Path::new("/.dockerenv").exists() { + return SandboxKind::Docker; + } + if let Ok(env) = std::fs::read_to_string("/var/run/.containerenv") { + return SandboxKind::Container( + containerenv_engine_re() + .captures(&env) + .and_then(|x| x.get(1)) + .map(|x| x.as_str().to_string()), + ); + } + + SandboxKind::None +} + +impl SandboxKind { + pub fn is_container(&self) -> bool { + matches!(self, SandboxKind::Docker | SandboxKind::Container(_)) + } + + pub fn is_app_runtime(&self) -> bool { + matches!(self, SandboxKind::Flatpak | SandboxKind::Snap) + } + + pub fn is_none(&self) -> bool { + matches!(self, SandboxKind::None) + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::fig_os_shim::Env; + + #[cfg(target_os = "linux")] + #[test] + fn os_release() { + if OsRelease::path().exists() { + OsRelease::load().unwrap(); + } else { + println!("Skipping os-release test as /etc/os-release does not exist"); + } + } + + #[test] + fn os_release_parse() { + let os_release_str = indoc::indoc! {r#" + NAME="Amazon Linux" + VERSION="2023" + ID="amzn" + ID_LIKE="fedora" + VERSION_ID="2023" + PLATFORM_ID="platform:al2023" + PRETTY_NAME="Amazon Linux 2023.4.20240416" + ANSI_COLOR="0;33" + CPE_NAME="cpe:2.3:o:amazon:amazon_linux:2023" + HOME_URL="https://aws.amazon.com/linux/amazon-linux-2023/" + DOCUMENTATION_URL="https://docs.aws.amazon.com/linux/" + SUPPORT_URL="https://aws.amazon.com/premiumsupport/" + BUG_REPORT_URL="https://github.com/amazonlinux/amazon-linux-2023" + VENDOR_NAME="AWS" + VENDOR_URL="https://aws.amazon.com/" + SUPPORT_END="2028-03-15" + "#}; + + let os_release = OsRelease::from_str(os_release_str); + + assert_eq!(os_release.id, Some("amzn".into())); + + assert_eq!(os_release.name, Some("Amazon Linux".into())); + assert_eq!(os_release.pretty_name, Some("Amazon Linux 2023.4.20240416".into())); + + assert_eq!(os_release.version_id, Some("2023".into())); + assert_eq!(os_release.version, Some("2023".into())); + + assert_eq!(os_release.build_id, None); + + assert_eq!(os_release.variant_id, None); + assert_eq!(os_release.variant, None); + } + + #[test] + fn test_get_desktop_environment() { + let tests = [ + (vec![("XDG_CURRENT_DESKTOP", "UBUNTU:gnome")], DesktopEnvironment::Gnome), + ( + vec![("XDG_CURRENT_DESKTOP", "Unity"), ("XDG_SESSION_DESKTOP", "ubuntu")], + DesktopEnvironment::Gnome, + ), + ( + vec![("XDG_CURRENT_DESKTOP", "Unity"), ("XDG_SESSION_DESKTOP", "GNOME")], + DesktopEnvironment::Gnome, + ), + (vec![("GDMSESSION", "ubuntu")], DesktopEnvironment::Gnome), + ]; + + for (env, expected_desktop_env) in tests { + let env = Env::from_slice(&env); + assert_eq!( + get_desktop_environment(&env).unwrap(), + expected_desktop_env, + "expected: {:?} from env: {:?}", + expected_desktop_env, + env + ); + } + } + + #[test] + fn test_get_desktop_environment_err() { + let env = Env::from_slice(&[("XDG_CURRENT_DESKTOP", "Unity"), ("XDG_SESSION_DESKTOP", "")]); + let res = get_desktop_environment(&env); + println!("{}", res.as_ref().unwrap_err()); + assert!(matches!(res, Err(Error::UnknownDesktop(_)))); + } +} diff --git a/crates/kiro-cli/src/fig_util/system_info/mod.rs b/crates/kiro-cli/src/fig_util/system_info/mod.rs new file mode 100644 index 0000000000..ae6cb65da0 --- /dev/null +++ b/crates/kiro-cli/src/fig_util/system_info/mod.rs @@ -0,0 +1,382 @@ +pub mod linux; + +use std::borrow::Cow; +use std::sync::OnceLock; + +use cfg_if::cfg_if; +use serde::{ + Deserialize, + Serialize, +}; +use sha2::{ + Digest, + Sha256, +}; + +use crate::fig_os_shim::Env; +use crate::fig_util::Error; +use crate::fig_util::env_var::Q_PARENT; +use crate::fig_util::manifest::is_minimal; + +/// The support level for different platforms +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum SupportLevel { + /// A fully supported platform + Supported, + /// Supported, but with a caveat + SupportedWithCaveat { info: Cow<'static, str> }, + /// A platform that is currently in development + InDevelopment { info: Option> }, + /// A platform that is not supported + Unsupported, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum OSVersion { + MacOS { + major: i32, + minor: i32, + patch: Option, + build: String, + }, + Linux { + kernel_version: String, + #[serde(flatten)] + os_release: Option, + }, + Windows { + name: String, + build: u32, + }, + FreeBsd { + version: String, + }, +} + +impl OSVersion { + pub fn support_level(&self) -> SupportLevel { + match self { + OSVersion::MacOS { major, minor, .. } => { + // Minimum supported macOS version is 10.14.0 + if *major > 10 || (*major == 10 && *minor >= 14) { + SupportLevel::Supported + } else { + SupportLevel::Unsupported + } + }, + OSVersion::Linux { .. } => match (is_remote(), is_minimal()) { + (true, true) => SupportLevel::Supported, + (false, true) => SupportLevel::SupportedWithCaveat { + info: "Autocomplete is not yet available on Linux, but other products should work as expected." + .into(), + }, + (_, _) => SupportLevel::Supported, + }, + OSVersion::Windows { build, .. } => match build { + // Only Windows 11 is fully supported at the moment + build if *build >= 22000 => SupportLevel::Supported, + // Windows 10 development has known issues + build if *build >= 10240 => SupportLevel::InDevelopment { + info: Some( + "Since support for Windows 10 is still in progress,\ +Autocomplete only works in Git Bash with the default prompt.\ +Please upgrade to Windows 11 or wait for a fix while we work this issue out." + .into(), + ), + }, + // Earlier versions of Windows are not supported + _ => SupportLevel::Unsupported, + }, + OSVersion::FreeBsd { .. } => SupportLevel::InDevelopment { info: None }, + } + } + + pub fn user_readable(&self) -> Vec { + match self { + OSVersion::Linux { + kernel_version, + os_release, + } => { + let mut v = vec![format!("kernel: {kernel_version}")]; + + if let Some(os_release) = os_release { + if let Some(name) = &os_release.name { + v.push(format!("distro: {name}")); + } + + if let Some(version) = &os_release.version { + v.push(format!("distro-version: {version}")); + } else if let Some(version) = &os_release.version_id { + v.push(format!("distro-version: {version}")); + } + + if let Some(variant) = &os_release.variant { + v.push(format!("distro-variant: {variant}")); + } else if let Some(variant) = &os_release.variant_id { + v.push(format!("distro-variant: {variant}")); + } + + if let Some(build) = &os_release.build_id { + v.push(format!("distro-build: {build}")); + } + } + + v + }, + other => vec![format!("{other}")], + } + } +} + +impl std::fmt::Display for OSVersion { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + OSVersion::MacOS { + major, + minor, + patch, + build, + } => { + let patch = patch.unwrap_or(0); + write!(f, "macOS {major}.{minor}.{patch} ({build})") + }, + OSVersion::Linux { + kernel_version, + os_release, + } => match os_release + .as_ref() + .and_then(|r| r.pretty_name.as_ref().or(r.name.as_ref())) + { + Some(distro_name) => write!(f, "Linux {kernel_version} - {distro_name}"), + None => write!(f, "Linux {kernel_version}"), + }, + OSVersion::Windows { name, build } => write!(f, "{name} (or newer) - build {build}"), + OSVersion::FreeBsd { version } => write!(f, "FreeBSD {version}"), + } + } +} + +pub fn os_version() -> Option<&'static OSVersion> { + static OS_VERSION: OnceLock> = OnceLock::new(); + OS_VERSION.get_or_init(|| { + cfg_if! { + if #[cfg(target_os = "macos")] { + use std::process::Command; + use regex::Regex; + + let version_info = Command::new("sw_vers") + .output() + .ok()?; + + let version_info: String = String::from_utf8_lossy(&version_info.stdout).trim().into(); + + let version_regex = Regex::new(r"ProductVersion:\s*(\S+)").unwrap(); + let build_regex = Regex::new(r"BuildVersion:\s*(\S+)").unwrap(); + + let version: String = version_regex + .captures(&version_info) + .and_then(|c| c.get(1)) + .map(|v| v.as_str().into())?; + + let major = version + .split('.') + .next()? + .parse().ok()?; + + let minor = version + .split('.') + .nth(1)? + .parse().ok()?; + + let patch = version.split('.').nth(2).and_then(|p| p.parse().ok()); + + let build = build_regex + .captures(&version_info) + .and_then(|c| c.get(1))? + .as_str() + .into(); + + Some(OSVersion::MacOS { + major, + minor, + patch, + build, + }) + } else if #[cfg(target_os = "linux")] { + use nix::sys::utsname::uname; + + let kernel_version = uname().ok()?.release().to_string_lossy().into(); + let os_release = linux::get_os_release().cloned(); + + Some(OSVersion::Linux { + kernel_version, + os_release, + }) + } else if #[cfg(target_os = "windows")] { + use winreg::enums::HKEY_LOCAL_MACHINE; + use winreg::RegKey; + + let rkey = RegKey::predef(HKEY_LOCAL_MACHINE).open_subkey(r"SOFTWARE\Microsoft\Windows NT\CurrentVersion").ok()?; + let build: String = rkey.get_value("CurrentBuild").ok()?; + + Some(OSVersion::Windows { + name: rkey.get_value("ProductName").ok()?, + build: build.parse::().ok()?, + }) + } else if #[cfg(target_os = "freebsd")] { + use nix::sys::utsname::uname; + + let version = uname().ok()?.release().to_string_lossy().into(); + + Some(OSVersion::FreeBsd { + version, + }) + + } + } + }).as_ref() +} + +pub fn in_ssh() -> bool { + static IN_SSH: OnceLock = OnceLock::new(); + *IN_SSH.get_or_init(|| Env::new().in_ssh()) +} + +/// Test if the program is running under WSL +pub fn in_wsl() -> bool { + cfg_if! { + if #[cfg(target_os = "linux")] { + static IN_WSL: OnceLock = OnceLock::new(); + *IN_WSL.get_or_init(|| { + if let Ok(b) = std::fs::read("/proc/sys/kernel/osrelease") { + if let Ok(s) = std::str::from_utf8(&b) { + let a = s.to_ascii_lowercase(); + return a.contains("microsoft") || a.contains("wsl"); + } + } + false + }) + } else { + false + } + } +} + +/// Is the calling binary running on a remote instance +pub fn is_remote() -> bool { + // TODO(chay): Add detection for inside docker container + in_ssh() || in_cloudshell() || in_wsl() || std::env::var_os("Q_FAKE_IS_REMOTE").is_some() +} + +/// Determines if we have an IPC path to a Desktop app from a remote environment +pub fn has_parent() -> bool { + static HAS_PARENT: OnceLock = OnceLock::new(); + *HAS_PARENT.get_or_init(|| std::env::var_os(Q_PARENT).is_some()) +} + +/// This true if the env var `AWS_EXECUTION_ENV=CloudShell` +pub fn in_cloudshell() -> bool { + static IN_CLOUDSHELL: OnceLock = OnceLock::new(); + *IN_CLOUDSHELL.get_or_init(|| Env::new().in_cloudshell()) +} + +pub fn in_codespaces() -> bool { + static IN_CODESPACES: OnceLock = OnceLock::new(); + *IN_CODESPACES + .get_or_init(|| std::env::var_os("CODESPACES").is_some() || std::env::var_os("Q_CODESPACES").is_some()) +} + +pub fn in_ci() -> bool { + static IN_CI: OnceLock = OnceLock::new(); + *IN_CI.get_or_init(|| std::env::var_os("CI").is_some() || std::env::var_os("Q_CI").is_some()) +} + +#[cfg(target_os = "macos")] +fn raw_system_id() -> Result { + let output = std::process::Command::new("ioreg") + .args(["-rd1", "-c", "IOPlatformExpertDevice"]) + .output()?; + + let output = String::from_utf8_lossy(&output.stdout); + + let machine_id: String = output + .lines() + .find(|line| line.contains("IOPlatformUUID")) + .ok_or(Error::HwidNotFound)? + .split('=') + .nth(1) + .ok_or(Error::HwidNotFound)? + .trim() + .trim_start_matches('"') + .trim_end_matches('"') + .into(); + + Ok(machine_id) +} + +#[cfg(target_os = "linux")] +fn raw_system_id() -> Result { + for path in ["/var/lib/dbus/machine-id", "/etc/machine-id"] { + if std::path::Path::new(path).exists() { + return Ok(std::fs::read_to_string(path)?); + } + } + Err(Error::HwidNotFound) +} + +#[cfg(target_os = "windows")] +fn raw_system_id() -> Result { + use winreg::RegKey; + use winreg::enums::HKEY_LOCAL_MACHINE; + + let rkey = RegKey::predef(HKEY_LOCAL_MACHINE).open_subkey(r"SOFTWARE\Microsoft\Cryptography")?; + let id: String = rkey.get_value("MachineGuid")?; + + Ok(id) +} + +#[cfg(target_os = "freebsd")] +fn raw_system_id() -> Result { + Err(Error::HwidNotFound) +} + +pub fn get_system_id() -> Option<&'static str> { + static SYSTEM_ID: OnceLock> = OnceLock::new(); + SYSTEM_ID + .get_or_init(|| { + let hwid = raw_system_id().ok()?; + let mut hasher = Sha256::new(); + hasher.update(hwid); + Some(format!("{:x}", hasher.finalize())) + }) + .as_deref() +} + +pub fn get_platform() -> &'static str { + if let Some(over_ride) = option_env!("Q_OVERRIDE_PLATFORM") { + over_ride + } else { + std::env::consts::OS + } +} + +pub fn get_arch() -> &'static str { + if let Some(over_ride) = option_env!("Q_OVERRIDE_ARCH") { + over_ride + } else { + std::env::consts::ARCH + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_get_system_id() { + let id = get_system_id(); + assert!(id.is_some()); + assert_eq!(id.unwrap().len(), 64); + } +} diff --git a/crates/kiro-cli/src/main.rs b/crates/kiro-cli/src/main.rs new file mode 100644 index 0000000000..ff4ee3a97e --- /dev/null +++ b/crates/kiro-cli/src/main.rs @@ -0,0 +1,102 @@ +mod cli; +mod diagnostics; +mod fig_api_client; +mod fig_auth; +mod fig_aws_common; +mod fig_install; +mod fig_log; +mod fig_os_shim; +mod fig_settings; +mod fig_telemetry; +mod fig_telemetry_core; +mod fig_util; +mod mcp_client; +mod request; + +use std::process::ExitCode; + +use anstream::eprintln; +use clap::Parser; +use clap::error::{ + ContextKind, + ErrorKind, +}; +use crossterm::style::Stylize; +use eyre::Result; +use fig_log::get_log_level_max; +use tracing::metadata::LevelFilter; + +use crate::fig_telemetry::{ + finish_telemetry, + init_global_telemetry_emitter, +}; +use crate::fig_util::{ + CLI_BINARY_NAME, + PRODUCT_NAME, +}; + +#[global_allocator] +static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; + +fn main() -> Result { + color_eyre::install()?; + init_global_telemetry_emitter(); + + let multithread = matches!( + std::env::args().nth(1).as_deref(), + Some("init" | "_" | "internal" | "completion" | "hook") + ); + + let parsed = match cli::Cli::try_parse() { + Ok(cli) => cli, + Err(err) => { + let _ = err.print(); + + let unknown_arg = matches!(err.kind(), ErrorKind::UnknownArgument | ErrorKind::InvalidSubcommand) + && !err.context().any(|(context_kind, _)| { + matches!( + context_kind, + ContextKind::SuggestedSubcommand | ContextKind::SuggestedArg + ) + }); + + if unknown_arg { + eprintln!( + "\nThis command may be valid in newer versions of the {PRODUCT_NAME} CLI. Try running {} {}.", + CLI_BINARY_NAME.magenta(), + "update".magenta() + ); + } + + return Ok(ExitCode::from(err.exit_code().try_into().unwrap_or(2))); + }, + }; + + let verbose = parsed.verbose > 0; + + let runtime = if multithread { + tokio::runtime::Builder::new_multi_thread() + } else { + tokio::runtime::Builder::new_current_thread() + } + .enable_all() + .build()?; + + let result = runtime.block_on(async { + let result = parsed.execute().await; + finish_telemetry().await; + result + }); + + match result { + Ok(exit_code) => Ok(exit_code), + Err(err) => { + if verbose || get_log_level_max() > LevelFilter::INFO { + eprintln!("{} {err:?}", "error:".bold().red()); + } else { + eprintln!("{} {err}", "error:".bold().red()); + } + Ok(ExitCode::FAILURE) + }, + } +} diff --git a/crates/kiro-cli/src/mcp_client/client.rs b/crates/kiro-cli/src/mcp_client/client.rs new file mode 100644 index 0000000000..5f87cc25f9 --- /dev/null +++ b/crates/kiro-cli/src/mcp_client/client.rs @@ -0,0 +1,764 @@ +use std::collections::HashMap; +use std::process::Stdio; +use std::sync::atomic::{ + AtomicBool, + AtomicU64, + Ordering, +}; +use std::sync::{ + Arc, + RwLock as SyncRwLock, +}; +use std::time::Duration; + +use nix::sys::signal::Signal; +use nix::unistd::Pid; +use serde::{ + Deserialize, + Serialize, +}; +use thiserror::Error; +use tokio::time; +use tokio::time::error::Elapsed; + +use crate::mcp_client::transport::base_protocol::{ + JsonRpcMessage, + JsonRpcNotification, + JsonRpcRequest, + JsonRpcVersion, +}; +use crate::mcp_client::transport::stdio::JsonRpcStdioTransport; +use crate::mcp_client::transport::{ + self, + Transport, + TransportError, +}; +use crate::mcp_client::{ + JsonRpcResponse, + Listener as _, + LogListener, + PaginationSupportedOps, + PromptGet, + PromptsListResult, + ResourceTemplatesListResult, + ResourcesListResult, + ToolsListResult, +}; + +pub type ServerCapabilities = serde_json::Value; +pub type ClientInfo = serde_json::Value; +pub type StdioTransport = JsonRpcStdioTransport; + +/// Represents the capabilities of a client in the Model Context Protocol. +/// This structure is sent to the server during initialization to communicate +/// what features the client supports and provide information about the client. +/// When features are added to the client, these should be declared in the [From] trait implemented +/// for the struct. +#[derive(Default, Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct ClientCapabilities { + protocol_version: JsonRpcVersion, + capabilities: HashMap, + client_info: serde_json::Value, +} + +impl From for ClientCapabilities { + fn from(client_info: ClientInfo) -> Self { + ClientCapabilities { + client_info, + ..Default::default() + } + } +} + +#[derive(Debug, Deserialize)] +pub struct ClientConfig { + pub server_name: String, + pub bin_path: String, + pub args: Vec, + pub timeout: u64, + pub client_info: serde_json::Value, + pub env: Option>, +} + +#[derive(Debug, Error)] +pub enum ClientError { + #[error(transparent)] + TransportError(#[from] TransportError), + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + Serialization(#[from] serde_json::Error), + #[error("Operation timed out: {context}")] + RuntimeError { + #[source] + source: tokio::time::error::Elapsed, + context: String, + }, + #[error("{0}")] + NegotiationError(String), + #[error("Failed to obtain process id")] + MissingProcessId, +} + +impl From<(tokio::time::error::Elapsed, String)> for ClientError { + fn from((error, context): (tokio::time::error::Elapsed, String)) -> Self { + ClientError::RuntimeError { source: error, context } + } +} + +#[derive(Debug)] +pub struct Client { + server_name: String, + transport: Arc, + timeout: u64, + server_process_id: Option, + client_info: serde_json::Value, + current_id: Arc, + pub prompt_gets: Arc>>, + pub is_prompts_out_of_date: Arc, +} + +impl Clone for Client { + fn clone(&self) -> Self { + Self { + server_name: self.server_name.clone(), + transport: self.transport.clone(), + timeout: self.timeout, + // Note that we cannot have an id for the clone because we would kill the original + // process when we drop the clone + server_process_id: None, + client_info: self.client_info.clone(), + current_id: self.current_id.clone(), + prompt_gets: self.prompt_gets.clone(), + is_prompts_out_of_date: self.is_prompts_out_of_date.clone(), + } + } +} + +impl Client { + pub fn from_config(config: ClientConfig) -> Result { + let ClientConfig { + server_name, + bin_path, + args, + timeout, + client_info, + env, + } = config; + let child = { + let mut command = tokio::process::Command::new(bin_path); + command + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .process_group(0) + .envs(std::env::vars()); + if let Some(env) = env { + for (env_name, env_value) in env { + command.env(env_name, env_value); + } + } + command.args(args).spawn()? + }; + let server_process_id = child.id().ok_or(ClientError::MissingProcessId)?; + #[allow(clippy::map_err_ignore)] + let server_process_id = Pid::from_raw( + server_process_id + .try_into() + .map_err(|_| ClientError::MissingProcessId)?, + ); + let server_process_id = Some(server_process_id); + let transport = Arc::new(transport::stdio::JsonRpcStdioTransport::client(child)?); + Ok(Self { + server_name, + transport, + timeout, + server_process_id, + client_info, + current_id: Arc::new(AtomicU64::new(0)), + prompt_gets: Arc::new(SyncRwLock::new(HashMap::new())), + is_prompts_out_of_date: Arc::new(AtomicBool::new(false)), + }) + } +} + +impl Drop for Client +where + T: Transport, +{ + // IF the servers are implemented well, they will shutdown once the pipe closes. + // This drop trait is here as a fail safe to ensure we don't leave behind any orphans. + fn drop(&mut self) { + if let Some(process_id) = self.server_process_id { + let _ = nix::sys::signal::kill(process_id, Signal::SIGTERM); + } + } +} + +impl Client +where + T: Transport, +{ + /// Exchange of information specified as per https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/lifecycle/#initialization + /// + /// Also done is the spawn of a background task that constantly listens for incoming messages + /// from the server. + pub async fn init(&self) -> Result { + let transport_ref = self.transport.clone(); + let server_name = self.server_name.clone(); + + tokio::spawn(async move { + let mut listener = transport_ref.get_listener(); + loop { + match listener.recv().await { + Ok(msg) => { + match msg { + JsonRpcMessage::Request(_req) => {}, + JsonRpcMessage::Notification(notif) => { + let JsonRpcNotification { method, params, .. } = notif; + if method.as_str() == "notifications/message" || method.as_str() == "message" { + let level = params + .as_ref() + .and_then(|p| p.get("level")) + .and_then(|v| serde_json::to_string(v).ok()); + let data = params + .as_ref() + .and_then(|p| p.get("data")) + .and_then(|v| serde_json::to_string(v).ok()); + if let (Some(level), Some(data)) = (level, data) { + match level.to_lowercase().as_str() { + "error" => { + tracing::error!(target: "mcp", "{}: {}", server_name, data); + }, + "warn" => { + tracing::warn!(target: "mcp", "{}: {}", server_name, data); + }, + "info" => { + tracing::info!(target: "mcp", "{}: {}", server_name, data); + }, + "debug" => { + tracing::debug!(target: "mcp", "{}: {}", server_name, data); + }, + "trace" => { + tracing::trace!(target: "mcp", "{}: {}", server_name, data); + }, + _ => {}, + } + } + } + }, + JsonRpcMessage::Response(_resp) => { /* noop since direct response is handled inside the request api */ + }, + } + }, + Err(e) => { + tracing::error!("Background listening thread for client {}: {:?}", server_name, e); + }, + } + } + }); + + let transport_ref = self.transport.clone(); + let server_name = self.server_name.clone(); + + // Spawning a task to listen and log stderr output + tokio::spawn(async move { + let mut log_listener = transport_ref.get_log_listener(); + loop { + match log_listener.recv().await { + Ok(msg) => { + tracing::trace!(target: "mcp", "{server_name} logged {}", msg); + }, + Err(e) => { + tracing::error!( + "Error encountered while reading from stderr for {server_name}: {:?}\nEnding stderr listening task.", + e + ); + break; + }, + } + } + }); + + let init_params = Some({ + let client_cap = ClientCapabilities::from(self.client_info.clone()); + serde_json::json!(client_cap) + }); + let server_capabilities = self.request("initialize", init_params).await?; + if let Err(e) = examine_server_capabilities(&server_capabilities) { + return Err(ClientError::NegotiationError(format!( + "Client {} has failed to negotiate server capabilities with server: {:?}", + self.server_name, e + ))); + } + self.notify("initialized", None).await?; + + // TODO: group this into examine_server_capabilities + // Prefetch prompts in the background. We should only do this after the server has been + // initialized + if let Some(res) = &server_capabilities.result { + if let Some(cap) = res.get("capabilities") { + if cap.get("prompts").is_some() { + self.is_prompts_out_of_date.store(true, Ordering::Relaxed); + let client_ref = (*self).clone(); + tokio::spawn(async move { + let Ok(resp) = client_ref.request("prompts/list", None).await else { + tracing::error!("Prompt list query failed for {0}", client_ref.server_name); + return; + }; + let Some(result) = resp.result else { + tracing::warn!("Prompt list query returned no result for {0}", client_ref.server_name); + return; + }; + let Some(prompts) = result.get("prompts") else { + tracing::warn!( + "Prompt list query result contained no field named prompts for {0}", + client_ref.server_name + ); + return; + }; + let Ok(prompts) = serde_json::from_value::>(prompts.clone()) else { + tracing::error!( + "Prompt list query deserialization failed for {0}", + client_ref.server_name + ); + return; + }; + let Ok(mut lock) = client_ref.prompt_gets.write() else { + tracing::error!( + "Failed to obtain write lock for prompt list query for {0}", + client_ref.server_name + ); + return; + }; + for prompt in prompts { + let name = prompt.name.clone(); + lock.insert(name, prompt); + } + }); + } + } + } + + Ok(serde_json::to_value(server_capabilities)?) + } + + /// Sends a request to the server associated. + /// This call will yield until a response is received. + pub async fn request( + &self, + method: &str, + params: Option, + ) -> Result { + let send_map_err = |e: Elapsed| (e, method.to_string()); + let recv_map_err = |e: Elapsed| (e, format!("recv for {method}")); + let mut id = self.get_id(); + let request = JsonRpcRequest { + jsonrpc: JsonRpcVersion::default(), + id, + method: method.to_owned(), + params, + }; + tracing::trace!(target: "mcp", "To {}:\n{:#?}", self.server_name, request); + let msg = JsonRpcMessage::Request(request); + time::timeout(Duration::from_millis(self.timeout), self.transport.send(&msg)) + .await + .map_err(send_map_err)??; + let mut listener = self.transport.get_listener(); + let mut resp = time::timeout(Duration::from_millis(self.timeout), async { + // we want to ignore all other messages sent by the server at this point and let the + // background loop handle them + loop { + if let JsonRpcMessage::Response(resp) = listener.recv().await? { + if resp.id == id { + break Ok::(resp); + } + } + } + }) + .await + .map_err(recv_map_err)??; + // Pagination support: https://spec.modelcontextprotocol.io/specification/2024-11-05/server/utilities/pagination/#pagination-model + let mut next_cursor = resp.result.as_ref().and_then(|v| v.get("nextCursor")); + if next_cursor.is_some() { + let mut current_resp = resp.clone(); + let mut results = Vec::::new(); + let pagination_supported_ops = { + let maybe_pagination_supported_op: Result = method.try_into(); + maybe_pagination_supported_op.ok() + }; + if let Some(ops) = pagination_supported_ops { + loop { + let result = current_resp.result.as_ref().cloned().unwrap(); + let mut list: Vec = match ops { + PaginationSupportedOps::Resources => { + let ResourcesListResult { resources: list, .. } = + serde_json::from_value::(result) + .map_err(ClientError::Serialization)?; + list + }, + PaginationSupportedOps::ResourceTemplates => { + let ResourceTemplatesListResult { + resource_templates: list, + .. + } = serde_json::from_value::(result) + .map_err(ClientError::Serialization)?; + list + }, + PaginationSupportedOps::Prompts => { + let PromptsListResult { prompts: list, .. } = + serde_json::from_value::(result) + .map_err(ClientError::Serialization)?; + list + }, + PaginationSupportedOps::Tools => { + let ToolsListResult { tools: list, .. } = serde_json::from_value::(result) + .map_err(ClientError::Serialization)?; + list + }, + }; + results.append(&mut list); + if next_cursor.is_none() { + break; + } + id = self.get_id(); + let next_request = JsonRpcRequest { + jsonrpc: JsonRpcVersion::default(), + id, + method: method.to_owned(), + params: Some(serde_json::json!({ + "cursor": next_cursor, + })), + }; + let msg = JsonRpcMessage::Request(next_request); + time::timeout(Duration::from_millis(self.timeout), self.transport.send(&msg)) + .await + .map_err(send_map_err)??; + let resp = time::timeout(Duration::from_millis(self.timeout), async { + // we want to ignore all other messages sent by the server at this point and let the + // background loop handle them + loop { + if let JsonRpcMessage::Response(resp) = listener.recv().await? { + if resp.id == id { + break Ok::(resp); + } + } + } + }) + .await + .map_err(recv_map_err)??; + current_resp = resp; + next_cursor = current_resp.result.as_ref().and_then(|v| v.get("nextCursor")); + } + resp.result = Some({ + let mut map = serde_json::Map::new(); + map.insert(ops.as_key().to_owned(), serde_json::to_value(results)?); + serde_json::to_value(map)? + }); + } + } + tracing::trace!(target: "mcp", "From {}:\n{:#?}", self.server_name, resp); + Ok(resp) + } + + /// Sends a notification to the server associated. + /// Notifications are requests that expect no responses. + pub async fn notify(&self, method: &str, params: Option) -> Result<(), ClientError> { + let send_map_err = |e: Elapsed| (e, method.to_string()); + let notification = JsonRpcNotification { + jsonrpc: JsonRpcVersion::default(), + method: format!("notifications/{}", method), + params, + }; + let msg = JsonRpcMessage::Notification(notification); + Ok( + time::timeout(Duration::from_millis(self.timeout), self.transport.send(&msg)) + .await + .map_err(send_map_err)??, + ) + } + + pub async fn shutdown(&self) -> Result<(), ClientError> { + Ok(self.transport.shutdown().await?) + } + + fn get_id(&self) -> u64 { + self.current_id.fetch_add(1, Ordering::SeqCst) + } +} + +fn examine_server_capabilities(ser_cap: &JsonRpcResponse) -> Result<(), ClientError> { + // Check the jrpc version. + // Currently we are only proceeding if the versions are EXACTLY the same. + let jrpc_version = ser_cap.jsonrpc.as_u32_vec(); + let client_jrpc_version = JsonRpcVersion::default().as_u32_vec(); + for (sv, cv) in jrpc_version.iter().zip(client_jrpc_version.iter()) { + if sv != cv { + return Err(ClientError::NegotiationError( + "Incompatible jrpc version between server and client".to_owned(), + )); + } + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use serde_json::Value; + + use super::*; + const TEST_BIN_OUT_DIR: &str = "target/debug"; + const TEST_SERVER_NAME: &str = "test_mcp_server"; + + fn get_workspace_root() -> PathBuf { + let output = std::process::Command::new("cargo") + .args(["metadata", "--format-version=1", "--no-deps"]) + .output() + .expect("Failed to execute cargo metadata"); + + let metadata: serde_json::Value = + serde_json::from_slice(&output.stdout).expect("Failed to parse cargo metadata"); + + let workspace_root = metadata["workspace_root"] + .as_str() + .expect("Failed to find workspace_root in metadata"); + + PathBuf::from(workspace_root) + } + + #[tokio::test(flavor = "multi_thread")] + async fn test_client_stdio() { + std::process::Command::new("cargo") + .args(["build", "--bin", TEST_SERVER_NAME]) + .status() + .expect("Failed to build binary"); + let workspace_root = get_workspace_root(); + let bin_path = workspace_root.join(TEST_BIN_OUT_DIR).join(TEST_SERVER_NAME); + println!("bin path: {}", bin_path.to_str().unwrap_or("no path found")); + + // Testing 2 concurrent sessions to make sure transport layer does not overlap. + let client_info_one = serde_json::json!({ + "name": "TestClientOne", + "version": "1.0.0" + }); + let client_config_one = ClientConfig { + server_name: "test_tool".to_owned(), + bin_path: bin_path.to_str().unwrap().to_string(), + args: ["1".to_owned()].to_vec(), + timeout: 120 * 1000, + client_info: client_info_one.clone(), + env: { + let mut map = HashMap::::new(); + map.insert("ENV_ONE".to_owned(), "1".to_owned()); + map.insert("ENV_TWO".to_owned(), "2".to_owned()); + Some(map) + }, + }; + let client_info_two = serde_json::json!({ + "name": "TestClientTwo", + "version": "1.0.0" + }); + let client_config_two = ClientConfig { + server_name: "test_tool".to_owned(), + bin_path: bin_path.to_str().unwrap().to_string(), + args: ["2".to_owned()].to_vec(), + timeout: 120 * 1000, + client_info: client_info_two.clone(), + env: { + let mut map = HashMap::::new(); + map.insert("ENV_ONE".to_owned(), "1".to_owned()); + map.insert("ENV_TWO".to_owned(), "2".to_owned()); + Some(map) + }, + }; + let mut client_one = Client::::from_config(client_config_one).expect("Failed to create client"); + let mut client_two = Client::::from_config(client_config_two).expect("Failed to create client"); + let client_one_cap = ClientCapabilities::from(client_info_one); + let client_two_cap = ClientCapabilities::from(client_info_two); + + let (res_one, res_two) = tokio::join!( + time::timeout( + time::Duration::from_secs(5), + test_client_routine(&mut client_one, serde_json::json!(client_one_cap)) + ), + time::timeout( + time::Duration::from_secs(5), + test_client_routine(&mut client_two, serde_json::json!(client_two_cap)) + ) + ); + let res_one = res_one.expect("Client one timed out"); + let res_two = res_two.expect("Client two timed out"); + assert!(res_one.is_ok()); + assert!(res_two.is_ok()); + } + + async fn test_client_routine( + client: &mut Client, + cap_sent: serde_json::Value, + ) -> Result<(), Box> { + // Test init + let _ = client.init().await.expect("Client init failed"); + tokio::time::sleep(time::Duration::from_millis(1500)).await; + let client_capabilities_sent = client + .request("verify_init_ack_sent", None) + .await + .expect("Verify init ack mock request failed"); + let has_server_recvd_init_ack = client_capabilities_sent + .result + .expect("Failed to retrieve client capabilities sent."); + assert_eq!(has_server_recvd_init_ack.to_string(), "true"); + let cap_recvd = client + .request("verify_init_params_sent", None) + .await + .expect("Verify init params mock request failed"); + let cap_recvd = cap_recvd + .result + .expect("Verify init params mock request does not contain required field (result)"); + assert!(are_json_values_equal(&cap_sent, &cap_recvd)); + + // test list tools + let fake_tool_names = ["get_weather_one", "get_weather_two", "get_weather_three"]; + let mock_result_spec = fake_tool_names.map(create_fake_tool_spec); + let mock_tool_specs_for_verify = serde_json::json!(mock_result_spec.clone()); + let mock_tool_specs_prep_param = mock_result_spec + .iter() + .zip(fake_tool_names.iter()) + .map(|(v, n)| { + serde_json::json!({ + "key": (*n).to_string(), + "value": v + }) + }) + .collect::>(); + let mock_tool_specs_prep_param = + serde_json::to_value(mock_tool_specs_prep_param).expect("Failed to create mock tool specs prep param"); + let _ = client + .request("store_mock_tool_spec", Some(mock_tool_specs_prep_param)) + .await + .expect("Mock tool spec prep failed"); + let tool_spec_recvd = client.request("tools/list", None).await.expect("List tools failed"); + assert!(are_json_values_equal( + tool_spec_recvd + .result + .as_ref() + .and_then(|v| v.get("tools")) + .expect("Failed to retrieve tool specs from result received"), + &mock_tool_specs_for_verify + )); + + // Test list prompts directly + let fake_prompt_names = ["code_review_one", "code_review_two", "code_review_three"]; + let mock_result_prompts = fake_prompt_names.map(create_fake_prompts); + let mock_prompts_for_verify = serde_json::json!(mock_result_prompts.clone()); + let mock_prompts_prep_param = mock_result_prompts + .iter() + .zip(fake_prompt_names.iter()) + .map(|(v, n)| { + serde_json::json!({ + "key": (*n).to_string(), + "value": v + }) + }) + .collect::>(); + let mock_prompts_prep_param = + serde_json::to_value(mock_prompts_prep_param).expect("Failed to create mock prompts prep param"); + let _ = client + .request("store_mock_prompts", Some(mock_prompts_prep_param)) + .await + .expect("Mock prompt prep failed"); + let prompts_recvd = client.request("prompts/list", None).await.expect("List prompts failed"); + assert!(are_json_values_equal( + prompts_recvd + .result + .as_ref() + .and_then(|v| v.get("prompts")) + .expect("Failed to retrieve prompts from results received"), + &mock_prompts_for_verify + )); + + // Test env var inclusion + let env_vars = client.request("get_env_vars", None).await.expect("Get env vars failed"); + let env_one = env_vars + .result + .as_ref() + .expect("Failed to retrieve results from env var request") + .get("ENV_ONE") + .expect("Failed to retrieve env one from env var request"); + let env_two = env_vars + .result + .as_ref() + .expect("Failed to retrieve results from env var request") + .get("ENV_TWO") + .expect("Failed to retrieve env two from env var request"); + let env_one_as_str = serde_json::to_string(env_one).expect("Failed to convert env one to string"); + let env_two_as_str = serde_json::to_string(env_two).expect("Failed to convert env two to string"); + assert_eq!(env_one_as_str, "\"1\"".to_string()); + assert_eq!(env_two_as_str, "\"2\"".to_string()); + + let shutdown_result = client.shutdown().await; + assert!(shutdown_result.is_ok()); + Ok(()) + } + + fn are_json_values_equal(a: &Value, b: &Value) -> bool { + match (a, b) { + (Value::Null, Value::Null) => true, + (Value::Bool(a_val), Value::Bool(b_val)) => a_val == b_val, + (Value::Number(a_val), Value::Number(b_val)) => a_val == b_val, + (Value::String(a_val), Value::String(b_val)) => a_val == b_val, + (Value::Array(a_arr), Value::Array(b_arr)) => { + if a_arr.len() != b_arr.len() { + return false; + } + a_arr + .iter() + .zip(b_arr.iter()) + .all(|(a_item, b_item)| are_json_values_equal(a_item, b_item)) + }, + (Value::Object(a_obj), Value::Object(b_obj)) => { + if a_obj.len() != b_obj.len() { + return false; + } + a_obj.iter().all(|(key, a_value)| match b_obj.get(key) { + Some(b_value) => are_json_values_equal(a_value, b_value), + None => false, + }) + }, + _ => false, + } + } + + fn create_fake_tool_spec(name: &str) -> serde_json::Value { + serde_json::json!({ + "name": name, + "description": "Get current weather information for a location", + "inputSchema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City name or zip code" + } + }, + "required": ["location"] + } + }) + } + + fn create_fake_prompts(name: &str) -> serde_json::Value { + serde_json::json!({ + "name": name, + "description": "Asks the LLM to analyze code quality and suggest improvements", + "arguments": [ + { + "name": "code", + "description": "The code to review", + "required": true + } + ] + }) + } +} diff --git a/crates/kiro-cli/src/mcp_client/error.rs b/crates/kiro-cli/src/mcp_client/error.rs new file mode 100644 index 0000000000..d05e7efa4d --- /dev/null +++ b/crates/kiro-cli/src/mcp_client/error.rs @@ -0,0 +1,66 @@ +/// Error codes as defined in the MCP protocol. +/// +/// These error codes are based on the JSON-RPC 2.0 specification with additional +/// MCP-specific error codes in the -32000 to -32099 range. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(i32)] +pub enum ErrorCode { + /// Invalid JSON was received by the server. + /// An error occurred on the server while parsing the JSON text. + ParseError = -32700, + + /// The JSON sent is not a valid Request object. + InvalidRequest = -32600, + + /// The method does not exist / is not available. + MethodNotFound = -32601, + + /// Invalid method parameter(s). + InvalidParams = -32602, + + /// Internal JSON-RPC error. + InternalError = -32603, + + /// Server has not been initialized. + /// This error is returned when a request is made before the server + /// has been properly initialized. + ServerNotInitialized = -32002, + + /// Unknown error code. + /// This error is returned when an error code is received that is not + /// recognized by the implementation. + UnknownErrorCode = -32001, + + /// Request failed. + /// This error is returned when a request fails for a reason not covered + /// by other error codes. + RequestFailed = -32000, +} + +impl From for ErrorCode { + fn from(code: i32) -> Self { + match code { + -32700 => ErrorCode::ParseError, + -32600 => ErrorCode::InvalidRequest, + -32601 => ErrorCode::MethodNotFound, + -32602 => ErrorCode::InvalidParams, + -32603 => ErrorCode::InternalError, + -32002 => ErrorCode::ServerNotInitialized, + -32001 => ErrorCode::UnknownErrorCode, + -32000 => ErrorCode::RequestFailed, + _ => ErrorCode::UnknownErrorCode, + } + } +} + +impl From for i32 { + fn from(code: ErrorCode) -> Self { + code as i32 + } +} + +impl std::fmt::Display for ErrorCode { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} diff --git a/crates/kiro-cli/src/mcp_client/facilitator_types.rs b/crates/kiro-cli/src/mcp_client/facilitator_types.rs new file mode 100644 index 0000000000..38d4aca280 --- /dev/null +++ b/crates/kiro-cli/src/mcp_client/facilitator_types.rs @@ -0,0 +1,229 @@ +use serde::{ + Deserialize, + Serialize, +}; +use thiserror::Error; + +/// https://spec.modelcontextprotocol.io/specification/2024-11-05/server/utilities/pagination/#operations-supporting-pagination +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PaginationSupportedOps { + Resources, + ResourceTemplates, + Prompts, + Tools, +} + +impl PaginationSupportedOps { + pub fn as_key(&self) -> &str { + match self { + PaginationSupportedOps::Resources => "resources", + PaginationSupportedOps::ResourceTemplates => "resourceTemplates", + PaginationSupportedOps::Prompts => "prompts", + PaginationSupportedOps::Tools => "tools", + } + } +} + +impl TryFrom<&str> for PaginationSupportedOps { + type Error = OpsConversionError; + + fn try_from(value: &str) -> Result { + match value { + "resources/list" => Ok(PaginationSupportedOps::Resources), + "resources/templates/list" => Ok(PaginationSupportedOps::ResourceTemplates), + "prompts/list" => Ok(PaginationSupportedOps::Prompts), + "tools/list" => Ok(PaginationSupportedOps::Tools), + _ => Err(OpsConversionError::InvalidMethod), + } + } +} + +#[derive(Error, Debug)] +pub enum OpsConversionError { + #[error("Invalid method encountered")] + InvalidMethod, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] +#[serde(rename_all = "camelCase")] +/// Role assumed for a particular message +pub enum Role { + User, + Assistant, +} + +impl std::fmt::Display for Role { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Role::User => write!(f, "user"), + Role::Assistant => write!(f, "assistant"), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +/// Result of listing resources operation +pub struct ResourcesListResult { + /// List of resources + pub resources: Vec, + /// Optional cursor for pagination + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, +} + +/// Result of listing resource templates operation +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ResourceTemplatesListResult { + /// List of resource templates + pub resource_templates: Vec, + /// Optional cursor for pagination + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +/// Result of prompt listing query +pub struct PromptsListResult { + /// List of prompts + pub prompts: Vec, + /// Optional cursor for pagination + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +/// Represents an argument to be supplied to a [PromptGet] +pub struct PromptGetArg { + /// The name identifier of the prompt + pub name: String, + /// Optional description providing context about the prompt + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Indicates whether a response to this prompt is required + /// If not specified, defaults to false + #[serde(skip_serializing_if = "Option::is_none")] + pub required: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +/// Represents a request to get a prompt from a mcp server +pub struct PromptGet { + /// Unique identifier for the prompt + pub name: String, + /// Optional description providing context about the prompt's purpose + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Optional list of arguments that define the structure of information to be collected + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +/// `result` field in [JsonRpcResponse] from a `prompts/get` request +pub struct PromptGetResult { + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub messages: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +/// Completed prompt from `prompts/get` to be returned by a mcp server +pub struct Prompt { + pub role: Role, + pub content: MessageContent, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +/// Result of listing tools operation +pub struct ToolsListResult { + /// List of tools + pub tools: Vec, + /// Optional cursor for pagination + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolCallResult { + pub content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub is_error: Option, +} + +/// Content of a message +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum MessageContent { + /// Text content + Text { + /// The text content + text: String, + }, + /// Image content + #[serde(rename_all = "camelCase")] + Image { + /// base64-encoded-data + data: String, + mime_type: String, + }, + /// Resource content + Resource { + /// The resource + resource: Resource, + }, +} + +impl From for String { + fn from(val: MessageContent) -> Self { + match val { + MessageContent::Text { text } => text, + MessageContent::Image { data, mime_type } => serde_json::json!({ + "data": data, + "mime_type": mime_type + }) + .to_string(), + MessageContent::Resource { resource } => serde_json::json!(resource).to_string(), + } + } +} + +impl std::fmt::Display for MessageContent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MessageContent::Text { text } => write!(f, "{}", text), + MessageContent::Image { data: _, mime_type } => write!(f, "Image [base64-encoded-string] ({})", mime_type), + MessageContent::Resource { resource } => write!(f, "Resource: {} ({})", resource.title, resource.uri), + } + } +} + +/// Resource contents +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum ResourceContents { + Text { text: String }, + Blob { data: Vec }, +} + +/// A resource in the system +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Resource { + /// Unique identifier for the resource + pub uri: String, + /// Human-readable title + pub title: String, + /// Optional description + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Resource contents + pub contents: ResourceContents, +} diff --git a/crates/kiro-cli/src/mcp_client/mod.rs b/crates/kiro-cli/src/mcp_client/mod.rs new file mode 100644 index 0000000000..199fb0aeea --- /dev/null +++ b/crates/kiro-cli/src/mcp_client/mod.rs @@ -0,0 +1,9 @@ +mod client; +mod error; +mod facilitator_types; +mod server; +mod transport; + +pub use client::*; +pub use facilitator_types::*; +pub use transport::*; diff --git a/crates/kiro-cli/src/mcp_client/server.rs b/crates/kiro-cli/src/mcp_client/server.rs new file mode 100644 index 0000000000..a2e4767e6d --- /dev/null +++ b/crates/kiro-cli/src/mcp_client/server.rs @@ -0,0 +1,293 @@ +use std::collections::HashMap; +use std::sync::atomic::{ + AtomicBool, + AtomicU64, + Ordering, +}; +use std::sync::{ + Arc, + Mutex, +}; + +use tokio::io::{ + Stdin, + Stdout, +}; +use tokio::task::JoinHandle; + +use crate::mcp_client::Listener as _; +use crate::mcp_client::client::StdioTransport; +use crate::mcp_client::error::ErrorCode; +use crate::mcp_client::transport::base_protocol::{ + JsonRpcError, + JsonRpcMessage, + JsonRpcNotification, + JsonRpcRequest, + JsonRpcResponse, +}; +use crate::mcp_client::transport::stdio::JsonRpcStdioTransport; +use crate::mcp_client::transport::{ + JsonRpcVersion, + Transport, + TransportError, +}; + +pub type Request = serde_json::Value; +pub type Response = Option; +pub type InitializedServer = JoinHandle>; + +pub trait PreServerRequestHandler { + fn register_pending_request_callback(&mut self, cb: impl Fn(u64) -> Option + Send + Sync + 'static); + fn register_send_request_callback( + &mut self, + cb: impl Fn(&str, Option) -> Result<(), ServerError> + Send + Sync + 'static, + ); +} + +#[async_trait::async_trait] +pub trait ServerRequestHandler: PreServerRequestHandler + Send + Sync + 'static { + async fn handle_initialize(&self, params: Option) -> Result; + async fn handle_incoming(&self, method: &str, params: Option) -> Result; + async fn handle_response(&self, resp: JsonRpcResponse) -> Result<(), ServerError>; + async fn handle_shutdown(&self) -> Result<(), ServerError>; +} + +pub struct Server { + transport: Option>, + handler: Option, + #[allow(dead_code)] + pending_requests: Arc>>, + #[allow(dead_code)] + current_id: Arc, +} + +#[derive(Debug, thiserror::Error)] +pub enum ServerError { + #[error(transparent)] + TransportError(#[from] TransportError), + #[error(transparent)] + Io(#[from] std::io::Error), + #[error(transparent)] + Serialization(#[from] serde_json::Error), + #[error("Unexpected msg type encountered")] + UnexpectedMsgType, + #[error("{0}")] + NegotiationError(String), + #[error(transparent)] + TokioJoinError(#[from] tokio::task::JoinError), + #[error("Failed to obtain mutex lock")] + MutexError, + #[error("Failed to obtain request method")] + MissingMethod, + #[error("Failed to obtain request id")] + MissingId, + #[error("Failed to initialize server. Missing transport")] + MissingTransport, + #[error("Failed to initialize server. Missing handler")] + MissingHandler, +} + +impl Server +where + H: ServerRequestHandler, +{ + pub fn new(mut handler: H, stdin: Stdin, stdout: Stdout) -> Result { + let transport = Arc::new(JsonRpcStdioTransport::server(stdin, stdout)?); + let pending_requests = Arc::new(Mutex::new(HashMap::::new())); + let pending_requests_clone_one = pending_requests.clone(); + let current_id = Arc::new(AtomicU64::new(0)); + let pending_request_getter = move |id: u64| -> Option { + match pending_requests_clone_one.lock() { + Ok(mut p) => p.remove(&id), + Err(_) => None, + } + }; + handler.register_pending_request_callback(pending_request_getter); + let transport_clone = transport.clone(); + let pending_request_clone_two = pending_requests.clone(); + let current_id_clone = current_id.clone(); + let request_sender = move |method: &str, params: Option| -> Result<(), ServerError> { + let id = current_id_clone.fetch_add(1, Ordering::SeqCst); + let request = JsonRpcRequest { + jsonrpc: JsonRpcVersion::default(), + id, + method: method.to_owned(), + params, + }; + let msg = JsonRpcMessage::Request(request.clone()); + let transport = transport_clone.clone(); + tokio::task::spawn(async move { + let _ = transport.send(&msg).await; + }); + #[allow(clippy::map_err_ignore)] + let mut pending_request = pending_request_clone_two.lock().map_err(|_| ServerError::MutexError)?; + pending_request.insert(id, request); + Ok(()) + }; + handler.register_send_request_callback(request_sender); + let server = Self { + transport: Some(transport), + handler: Some(handler), + pending_requests, + current_id, + }; + Ok(server) + } +} + +impl Server +where + T: Transport, + H: ServerRequestHandler, +{ + pub fn init(mut self) -> Result { + let transport = self.transport.take().ok_or(ServerError::MissingTransport)?; + let handler = Arc::new(self.handler.take().ok_or(ServerError::MissingHandler)?); + let has_initialized = Arc::new(AtomicBool::new(false)); + let listener = tokio::spawn(async move { + let mut listener = transport.get_listener(); + loop { + let request = listener.recv().await; + let transport_clone = transport.clone(); + let has_init_clone = has_initialized.clone(); + let handler_clone = handler.clone(); + tokio::task::spawn(async move { + process_request(has_init_clone, transport_clone, handler_clone, request).await; + }); + } + }); + Ok(listener) + } +} + +async fn process_request( + has_initialized: Arc, + transport: Arc, + handler: Arc, + request: Result, +) where + T: Transport, + H: ServerRequestHandler, +{ + match request { + Ok(msg) if msg.is_initialize() => { + let id = msg.id().unwrap_or_default(); + if has_initialized.load(Ordering::SeqCst) { + let resp = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: JsonRpcVersion::default(), + id, + error: Some(JsonRpcError { + code: ErrorCode::InvalidRequest.into(), + message: "Server has already been initialized".to_owned(), + data: None, + }), + ..Default::default() + }); + let _ = transport.send(&resp).await; + return; + } + let JsonRpcMessage::Request(req) = msg else { + let resp = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: JsonRpcVersion::default(), + id, + error: Some(JsonRpcError { + code: ErrorCode::InvalidRequest.into(), + message: "Invalid method for initialization (use request)".to_owned(), + data: None, + }), + ..Default::default() + }); + let _ = transport.send(&resp).await; + return; + }; + let JsonRpcRequest { params, .. } = req; + match handler.handle_initialize(params).await { + Ok(result) => { + let resp = JsonRpcMessage::Response(JsonRpcResponse { + id, + result, + ..Default::default() + }); + let _ = transport.send(&resp).await; + has_initialized.store(true, Ordering::SeqCst); + }, + Err(_e) => { + let resp = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: JsonRpcVersion::default(), + id, + error: Some(JsonRpcError { + code: ErrorCode::InternalError.into(), + message: "Error producing initialization response".to_owned(), + data: None, + }), + ..Default::default() + }); + let _ = transport.send(&resp).await; + }, + } + }, + Ok(msg) if msg.is_shutdown() => { + // TODO: add shutdown routine + }, + Ok(msg) if has_initialized.load(Ordering::SeqCst) => match msg { + JsonRpcMessage::Request(req) => { + let JsonRpcRequest { + id, + jsonrpc, + params, + ref method, + } = req; + let resp = handler.handle_incoming(method, params).await.map_or_else( + |error| { + let err = JsonRpcError { + code: ErrorCode::InternalError.into(), + message: error.to_string(), + data: None, + }; + let resp = JsonRpcResponse { + jsonrpc: jsonrpc.clone(), + id, + result: None, + error: Some(err), + }; + JsonRpcMessage::Response(resp) + }, + |result| { + let resp = JsonRpcResponse { + jsonrpc: jsonrpc.clone(), + id, + result, + error: None, + }; + JsonRpcMessage::Response(resp) + }, + ); + let _ = transport.send(&resp).await; + }, + JsonRpcMessage::Notification(notif) => { + let JsonRpcNotification { ref method, params, .. } = notif; + let _ = handler.handle_incoming(method, params).await; + }, + JsonRpcMessage::Response(resp) => { + let _ = handler.handle_response(resp).await; + }, + }, + Ok(msg) => { + let id = msg.id().unwrap_or_default(); + let resp = JsonRpcMessage::Response(JsonRpcResponse { + jsonrpc: JsonRpcVersion::default(), + id, + error: Some(JsonRpcError { + code: ErrorCode::ServerNotInitialized.into(), + message: "Server has not been initialized".to_owned(), + data: None, + }), + ..Default::default() + }); + let _ = transport.send(&resp).await; + }, + Err(_e) => { + // TODO: error handling + }, + } +} diff --git a/crates/kiro-cli/src/mcp_client/transport/base_protocol.rs b/crates/kiro-cli/src/mcp_client/transport/base_protocol.rs new file mode 100644 index 0000000000..b0394e6e0c --- /dev/null +++ b/crates/kiro-cli/src/mcp_client/transport/base_protocol.rs @@ -0,0 +1,108 @@ +//! Referencing https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/messages/ +//! Protocol Revision 2024-11-05 +use serde::{ + Deserialize, + Serialize, +}; + +pub type RequestId = u64; + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct JsonRpcVersion(String); + +impl Default for JsonRpcVersion { + fn default() -> Self { + JsonRpcVersion("2.0".to_owned()) + } +} + +impl JsonRpcVersion { + pub fn as_u32_vec(&self) -> Vec { + self.0 + .split(".") + .map(|n| n.parse::().unwrap()) + .collect::>() + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(untagged)] +#[serde(deny_unknown_fields)] +// DO NOT change the order of these variants. This body of json is [untagged](https://serde.rs/enum-representations.html#untagged) +// The categorization of the deserialization depends on the order in which the variants are +// declared. +pub enum JsonRpcMessage { + Response(JsonRpcResponse), + Notification(JsonRpcNotification), + Request(JsonRpcRequest), +} + +impl JsonRpcMessage { + pub fn is_initialize(&self) -> bool { + match self { + JsonRpcMessage::Request(req) => req.method == "initialize", + _ => false, + } + } + + pub fn is_shutdown(&self) -> bool { + match self { + JsonRpcMessage::Notification(notif) => notif.method == "notification/shutdown", + _ => false, + } + } + + pub fn id(&self) -> Option { + match self { + JsonRpcMessage::Request(req) => Some(req.id), + JsonRpcMessage::Response(resp) => Some(resp.id), + JsonRpcMessage::Notification(_) => None, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] +#[serde(default, deny_unknown_fields)] +pub struct JsonRpcRequest { + pub jsonrpc: JsonRpcVersion, + pub id: RequestId, + pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] +#[serde(default, deny_unknown_fields)] +pub struct JsonRpcResponse { + pub jsonrpc: JsonRpcVersion, + pub id: RequestId, + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] +#[serde(default, deny_unknown_fields)] +pub struct JsonRpcNotification { + pub jsonrpc: JsonRpcVersion, + pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] +#[serde(default, deny_unknown_fields)] +pub struct JsonRpcError { + pub code: i32, + pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq, Eq)] +pub enum TransportType { + #[default] + Stdio, + Websocket, +} diff --git a/crates/kiro-cli/src/mcp_client/transport/mod.rs b/crates/kiro-cli/src/mcp_client/transport/mod.rs new file mode 100644 index 0000000000..f86fc498f3 --- /dev/null +++ b/crates/kiro-cli/src/mcp_client/transport/mod.rs @@ -0,0 +1,56 @@ +pub mod base_protocol; +pub mod stdio; + +use std::fmt::Debug; + +pub use base_protocol::*; +pub use stdio::JsonRpcStdioTransport; +use thiserror::Error; + +#[derive(Clone, Debug, Error)] +pub enum TransportError { + #[error("Serialization error: {0}")] + Serialization(String), + #[error("IO error: {0}")] + Stdio(String), + #[error("{0}")] + Custom(String), + #[error(transparent)] + RecvError(#[from] tokio::sync::broadcast::error::RecvError), +} + +impl From for TransportError { + fn from(err: serde_json::Error) -> Self { + TransportError::Serialization(err.to_string()) + } +} + +impl From for TransportError { + fn from(err: std::io::Error) -> Self { + TransportError::Stdio(err.to_string()) + } +} + +#[async_trait::async_trait] +pub trait Transport: Send + Sync + Debug + 'static { + /// Sends a message over the transport layer. + async fn send(&self, msg: &JsonRpcMessage) -> Result<(), TransportError>; + /// Listens to awaits for a response. This is a call that should be used after `send` is called + /// to listen for a response from the message recipient. + fn get_listener(&self) -> impl Listener; + /// Gracefully terminates the transport connection, cleaning up any resources. + /// This should be called when the transport is no longer needed to ensure proper cleanup. + async fn shutdown(&self) -> Result<(), TransportError>; + /// Listener that listens for logging messages. + fn get_log_listener(&self) -> impl LogListener; +} + +#[async_trait::async_trait] +pub trait Listener: Send + Sync + 'static { + async fn recv(&mut self) -> Result; +} + +#[async_trait::async_trait] +pub trait LogListener: Send + Sync + 'static { + async fn recv(&mut self) -> Result; +} diff --git a/crates/kiro-cli/src/mcp_client/transport/stdio.rs b/crates/kiro-cli/src/mcp_client/transport/stdio.rs new file mode 100644 index 0000000000..270756f2d9 --- /dev/null +++ b/crates/kiro-cli/src/mcp_client/transport/stdio.rs @@ -0,0 +1,272 @@ +use std::sync::Arc; + +use tokio::io::{ + AsyncBufReadExt, + AsyncRead, + AsyncWriteExt as _, + BufReader, + Stdin, + Stdout, +}; +use tokio::process::{ + Child, + ChildStdin, +}; +use tokio::sync::{ + Mutex, + broadcast, +}; + +use super::base_protocol::JsonRpcMessage; +use super::{ + Listener, + LogListener, + Transport, + TransportError, +}; + +#[derive(Debug)] +pub enum JsonRpcStdioTransport { + Client { + stdin: Arc>, + receiver: broadcast::Receiver>, + log_receiver: broadcast::Receiver, + }, + Server { + stdout: Arc>, + receiver: broadcast::Receiver>, + }, +} + +impl JsonRpcStdioTransport { + fn spawn_reader( + reader: R, + tx: broadcast::Sender>, + ) { + tokio::spawn(async move { + let mut buffer = Vec::::new(); + let mut buf_reader = BufReader::new(reader); + loop { + buffer.clear(); + // Messages are delimited by newlines and assumed to contain no embedded newlines + // See https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio + match buf_reader.read_until(b'\n', &mut buffer).await { + Ok(0) => continue, + Ok(_) => match serde_json::from_slice::(buffer.as_slice()) { + Ok(msg) => { + let _ = tx.send(Ok(msg)); + }, + Err(e) => { + let _ = tx.send(Err(e.into())); + }, + }, + Err(e) => { + let _ = tx.send(Err(e.into())); + }, + } + } + }); + } + + pub fn client(child_process: Child) -> Result { + let (tx, receiver) = broadcast::channel::>(100); + let Some(stdout) = child_process.stdout else { + return Err(TransportError::Custom("No stdout found on child process".to_owned())); + }; + let Some(stdin) = child_process.stdin else { + return Err(TransportError::Custom("No stdin found on child process".to_owned())); + }; + let Some(stderr) = child_process.stderr else { + return Err(TransportError::Custom("No stderr found on child process".to_owned())); + }; + let (log_tx, log_receiver) = broadcast::channel::(100); + tokio::task::spawn(async move { + let stderr = tokio::io::BufReader::new(stderr); + let mut lines = stderr.lines(); + while let Ok(Some(line)) = lines.next_line().await { + let _ = log_tx.send(line); + } + }); + let stdin = Arc::new(Mutex::new(stdin)); + Self::spawn_reader(stdout, tx); + Ok(JsonRpcStdioTransport::Client { + stdin, + receiver, + log_receiver, + }) + } + + pub fn server(stdin: Stdin, stdout: Stdout) -> Result { + let (tx, receiver) = broadcast::channel::>(100); + Self::spawn_reader(stdin, tx); + let stdout = Arc::new(Mutex::new(stdout)); + Ok(JsonRpcStdioTransport::Server { stdout, receiver }) + } +} + +#[async_trait::async_trait] +impl Transport for JsonRpcStdioTransport { + async fn send(&self, msg: &JsonRpcMessage) -> Result<(), TransportError> { + match self { + JsonRpcStdioTransport::Client { stdin, .. } => { + let mut serialized = serde_json::to_vec(msg)?; + serialized.push(b'\n'); + let mut stdin = stdin.lock().await; + stdin + .write_all(&serialized) + .await + .map_err(|e| TransportError::Custom(format!("Error writing to server: {:?}", e)))?; + stdin + .flush() + .await + .map_err(|e| TransportError::Custom(format!("Error writing to server: {:?}", e)))?; + Ok(()) + }, + JsonRpcStdioTransport::Server { stdout, .. } => { + let mut serialized = serde_json::to_vec(msg)?; + serialized.push(b'\n'); + let mut stdout = stdout.lock().await; + stdout + .write_all(&serialized) + .await + .map_err(|e| TransportError::Custom(format!("Error writing to client: {:?}", e)))?; + stdout + .flush() + .await + .map_err(|e| TransportError::Custom(format!("Error writing to client: {:?}", e)))?; + Ok(()) + }, + } + } + + fn get_listener(&self) -> impl Listener { + match self { + JsonRpcStdioTransport::Client { receiver, .. } | JsonRpcStdioTransport::Server { receiver, .. } => { + StdioListener { + receiver: receiver.resubscribe(), + } + }, + } + } + + async fn shutdown(&self) -> Result<(), TransportError> { + match self { + JsonRpcStdioTransport::Client { stdin, .. } => { + let mut stdin = stdin.lock().await; + Ok(stdin.shutdown().await?) + }, + JsonRpcStdioTransport::Server { stdout, .. } => { + let mut stdout = stdout.lock().await; + Ok(stdout.shutdown().await?) + }, + } + } + + fn get_log_listener(&self) -> impl LogListener { + match self { + JsonRpcStdioTransport::Client { log_receiver, .. } => StdioLogListener { + receiver: log_receiver.resubscribe(), + }, + JsonRpcStdioTransport::Server { .. } => unreachable!("server does not need a log listener"), + } + } +} + +pub struct StdioListener { + pub receiver: broadcast::Receiver>, +} + +#[async_trait::async_trait] +impl Listener for StdioListener { + async fn recv(&mut self) -> Result { + self.receiver.recv().await? + } +} + +pub struct StdioLogListener { + pub receiver: broadcast::Receiver, +} + +#[async_trait::async_trait] +impl LogListener for StdioLogListener { + async fn recv(&mut self) -> Result { + Ok(self.receiver.recv().await?) + } +} + +#[cfg(test)] +mod tests { + use std::process::Stdio; + + use serde_json::{ + Value, + json, + }; + use tokio::process::Command; + + use super::*; + + // Helpers for testing + fn create_test_message() -> JsonRpcMessage { + serde_json::from_value(json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "test_method", + "params": { + "test_param": "test_value" + } + })) + .unwrap() + } + + #[tokio::test] + async fn test_client_transport() { + let mut cmd = Command::new("cat"); + cmd.stdin(Stdio::piped()).stdout(Stdio::piped()).stderr(Stdio::piped()); + + // Inject our mock transport instead + let child = cmd.spawn().expect("Failed to spawn command"); + let transport = JsonRpcStdioTransport::client(child).expect("Failed to create client transport"); + + let message = create_test_message(); + let result = transport.send(&message).await; + assert!(result.is_ok(), "Failed to send message: {:?}", result); + + let echo = transport + .get_listener() + .recv() + .await + .expect("Failed to receive message"); + let echo_value = serde_json::to_value(&echo).expect("Failed to convert echo to value"); + let message_value = serde_json::to_value(&message).expect("Failed to convert message to value"); + assert!(are_json_values_equal(&echo_value, &message_value)); + } + + fn are_json_values_equal(a: &Value, b: &Value) -> bool { + match (a, b) { + (Value::Null, Value::Null) => true, + (Value::Bool(a_val), Value::Bool(b_val)) => a_val == b_val, + (Value::Number(a_val), Value::Number(b_val)) => a_val == b_val, + (Value::String(a_val), Value::String(b_val)) => a_val == b_val, + (Value::Array(a_arr), Value::Array(b_arr)) => { + if a_arr.len() != b_arr.len() { + return false; + } + a_arr + .iter() + .zip(b_arr.iter()) + .all(|(a_item, b_item)| are_json_values_equal(a_item, b_item)) + }, + (Value::Object(a_obj), Value::Object(b_obj)) => { + if a_obj.len() != b_obj.len() { + return false; + } + a_obj.iter().all(|(key, a_value)| match b_obj.get(key) { + Some(b_value) => are_json_values_equal(a_value, b_value), + None => false, + }) + }, + _ => false, + } + } +} diff --git a/crates/kiro-cli/src/mcp_client/transport/websocket.rs b/crates/kiro-cli/src/mcp_client/transport/websocket.rs new file mode 100644 index 0000000000..e69de29bb2 diff --git a/crates/kiro-cli/src/request.rs b/crates/kiro-cli/src/request.rs new file mode 100644 index 0000000000..0da105efd1 --- /dev/null +++ b/crates/kiro-cli/src/request.rs @@ -0,0 +1,188 @@ +use std::env::current_exe; +use std::fs::File; +use std::io::BufReader; +use std::path::Path; +use std::sync::{ + Arc, + LazyLock, +}; + +use reqwest::Client; +use rustls::{ + ClientConfig, + RootCertStore, +}; +use url::ParseError; + +#[derive(Debug)] +pub enum RequestError { + Reqwest(reqwest::Error), + Serde(serde_json::Error), + Io(std::io::Error), + Dir(crate::fig_util::directories::DirectoryError), + Settings(crate::fig_settings::Error), + UrlParseError(ParseError), +} + +impl std::fmt::Display for RequestError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + RequestError::Reqwest(err) => write!(f, "Reqwest error: {err}"), + RequestError::Serde(err) => write!(f, "Serde error: {err}"), + RequestError::Io(err) => write!(f, "Io error: {err}"), + RequestError::Dir(err) => write!(f, "Dir error: {err}"), + RequestError::Settings(err) => write!(f, "Settings error: {err}"), + RequestError::UrlParseError(err) => write!(f, "Url parse error: {err}"), + } + } +} + +impl std::error::Error for RequestError {} + +impl From for RequestError { + fn from(e: reqwest::Error) -> Self { + RequestError::Reqwest(e) + } +} + +impl From for RequestError { + fn from(e: serde_json::Error) -> Self { + RequestError::Serde(e) + } +} + +impl From for RequestError { + fn from(e: std::io::Error) -> Self { + RequestError::Io(e) + } +} + +impl From for RequestError { + fn from(e: crate::fig_util::directories::DirectoryError) -> Self { + RequestError::Dir(e) + } +} + +impl From for RequestError { + fn from(e: crate::fig_settings::Error) -> Self { + RequestError::Settings(e) + } +} + +impl From for RequestError { + fn from(e: ParseError) -> Self { + RequestError::UrlParseError(e) + } +} + +pub fn client() -> Option<&'static Client> { + CLIENT_NATIVE_CERTS.as_ref() +} + +pub fn create_default_root_cert_store() -> RootCertStore { + let mut root_cert_store: RootCertStore = webpki_roots::TLS_SERVER_ROOTS.iter().cloned().collect(); + + // The errors are ignored because root certificates often include + // ancient or syntactically invalid certificates + let rustls_native_certs::CertificateResult { certs, errors: _, .. } = rustls_native_certs::load_native_certs(); + for cert in certs { + let _ = root_cert_store.add(cert); + } + + let custom_cert = std::env::var("Q_CUSTOM_CERT") + .ok() + .or_else(|| crate::fig_settings::state::get_string("Q_CUSTOM_CERT").ok().flatten()); + + if let Some(custom_cert) = custom_cert { + match File::open(Path::new(&custom_cert)) { + Ok(file) => { + let reader = &mut BufReader::new(file); + for cert in rustls_pemfile::certs(reader) { + match cert { + Ok(cert) => { + if let Err(err) = root_cert_store.add(cert) { + tracing::error!(path =% custom_cert, %err, "Failed to add custom cert"); + }; + }, + Err(err) => tracing::error!(path =% custom_cert, %err, "Failed to parse cert"), + } + } + }, + Err(err) => tracing::error!(path =% custom_cert, %err, "Failed to open cert at"), + } + } + + root_cert_store +} + +fn client_config() -> ClientConfig { + let provider = rustls::crypto::CryptoProvider::get_default() + .cloned() + .unwrap_or_else(|| Arc::new(rustls::crypto::ring::default_provider())); + + ClientConfig::builder_with_provider(provider) + .with_protocol_versions(rustls::DEFAULT_VERSIONS) + .expect("Failed to set supported TLS versions") + .with_root_certificates(create_default_root_cert_store()) + .with_no_client_auth() +} + +static CLIENT_CONFIG_NATIVE_CERTS: LazyLock> = LazyLock::new(|| Arc::new(client_config())); + +pub fn client_config_cached() -> Arc { + CLIENT_CONFIG_NATIVE_CERTS.clone() +} + +static USER_AGENT: LazyLock = LazyLock::new(|| { + let name = current_exe() + .ok() + .and_then(|exe| exe.file_stem().and_then(|name| name.to_str().map(String::from))) + .unwrap_or_else(|| "unknown-rust-client".into()); + + let os = std::env::consts::OS; + let arch = std::env::consts::ARCH; + let version = env!("CARGO_PKG_VERSION"); + + format!("{name}-{os}-{arch}-{version}") +}); + +pub static CLIENT_NATIVE_CERTS: LazyLock> = LazyLock::new(|| { + Some( + Client::builder() + .use_preconfigured_tls((*client_config_cached()).clone()) + .user_agent(USER_AGENT.chars().filter(|c| c.is_ascii_graphic()).collect::()) + .cookie_store(true) + .build() + .unwrap(), + ) +}); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn get_client() { + client().unwrap(); + } + + #[tokio::test] + async fn request_test() { + let mut server = mockito::Server::new_async().await; + let mock = server + .mock("GET", "/hello") + .with_status(200) + .with_header("content-type", "text/plain") + .with_body("world") + .create(); + let url = server.url(); + + let client = client().unwrap(); + let res = client.get(format!("{url}/hello")).send().await.unwrap(); + assert_eq!(res.status(), 200); + assert_eq!(res.headers()["content-type"], "text/plain"); + assert_eq!(res.text().await.unwrap(), "world"); + + mock.expect(1).assert(); + } +} diff --git a/crates/kiro-cli/telemetry_definitions.json b/crates/kiro-cli/telemetry_definitions.json new file mode 100644 index 0000000000..55254d38dc --- /dev/null +++ b/crates/kiro-cli/telemetry_definitions.json @@ -0,0 +1,265 @@ +{ + "types": [ + { + "name": "amazonQProfileRegion", + "type": "string", + "description": "Region of the Q Profile associated with a metric\n- \"n/a\" if metric is not associated with a profile or region.\n- \"not-set\" if metric is associated with a profile, but profile is unknown." + }, + { + "name": "ssoRegion", + "type": "string", + "description": "Region of the current SSO connection. Typically associated with credentialStartUrl\n- \"n/a\" if metric is not associated with a region.\n- \"not-set\" if metric is associated with a region, but region is unknown." + }, + { + "name": "profileCount", + "type": "int", + "description": "The number of profiles that were available to choose from" + }, + { + "name": "source", + "type": "string", + "description": "Identifies the source component where the telemetry event originated." + }, + { + "name": "amazonqConversationId", + "type": "string", + "description": "Uniquely identifies a message with which the user interacts." + }, + { + "name": "codewhispererterminal_command", + "type": "string", + "description": "The CLI tool a completion was for" + }, + { + "name": "codewhispererterminal_subcommand", + "type": "string", + "description": "A codewhisperer CLI subcommand" + }, + { + "name": "codewhispererterminal_inCloudshell", + "type": "boolean", + "description": "Whether the CLI is running in the AWS CloudShell environment" + }, + { + "name": "credentialStartUrl", + "type": "string", + "description": "The start URL of current SSO connection" + }, + { + "name": "requestId", + "type": "string", + "description": "The id assigned to an AWS request" + }, + { + "name": "oauthFlow", + "type": "string", + "description": "The oauth authentication flow executed by the user, e.g. device code or PKCE" + }, + { + "name": "result", + "type": "string", + "description": "Whether or not the operation succeeded" + }, + { + "name": "reason", + "type": "string", + "description": "Description of what caused an error, if any" + }, + { + "name": "codewhispererterminal_toolUseId", + "type": "string", + "description": "The id assigned to the client by the model representing a tool use event" + }, + { + "name": "codewhispererterminal_toolName", + "type": "string", + "description": "The name associated with a tool" + }, + { + "name": "codewhispererterminal_isToolUseAccepted", + "type": "boolean", + "description": "Denotes if a tool use event has been fulfilled" + }, + { + "name": "codewhispererterminal_toolUseIsSuccess", + "type": "boolean", + "description": "The outcome of a tool use" + }, + { + "name": "codewhispererterminal_utteranceId", + "type": "string", + "description": "Id associated with a given response from the model" + }, + { + "name": "codewhispererterminal_userInputId", + "type": "string", + "description": "Id associated with a given user input. This is used to differentiate responses to user input and that of retries from tool uses. This id is the utterance id of the first response following an user input" + }, + { + "name": "codewhispererterminal_isToolValid", + "type": "boolean", + "description": "If the use of tool as instructed by the model is valid" + }, + { + "name": "codewhispererterminal_contextFileLength", + "type": "int", + "description": "The length of the files included as part of context management" + }, + { + "name": "codewhispererterminal_mcpServerInitFailureReason", + "type": "string", + "description": "Reason for which a mcp server has failed to be initialized" + }, + { + "name": "codewhispererterminal_toolsPerMcpServer", + "type": "int", + "description": "The number of tools provided by a mcp server" + }, + { + "name": "codewhispererterminal_isCustomTool", + "type": "boolean", + "description": "Denoting whether or not the tool is a custom tool" + }, + { + "name": "codewhispererterminal_customToolInputTokenSize", + "type": "int", + "description": "Number of tokens used on invoking the custom tool" + }, + { + "name": "codewhispererterminal_customToolOutputTokenSize", + "type": "int", + "description": "Number of tokens received from invoking the custom tool" + }, + { + "name": "codewhispererterminal_customToolLatency", + "type": "int", + "description": "Custom tool call latency in seconds" + } + ], + "metrics": [ + { + "name": "amazonq_startChat", + "description": "Captures start of the conversation with amazonq /dev", + "metadata": [ + { "type": "amazonqConversationId" }, + { "type": "credentialStartUrl", "required": false }, + { "type": "codewhispererterminal_inCloudshell" } + ] + }, + { + "name": "codewhispererterminal_addChatMessage", + "description": "Captures active usage with Q Chat in shell", + "metadata": [ + { "type": "amazonqConversationId" }, + { "type": "credentialStartUrl", "required": false }, + { "type": "codewhispererterminal_inCloudshell" }, + { "type": "codewhispererterminal_contextFileLength", "required": false } + ] + }, + { + "name": "amazonq_endChat", + "description": "Captures end of the conversation with amazonq /dev", + "metadata": [ + { "type": "amazonqConversationId" }, + { "type": "credentialStartUrl", "required": false }, + { "type": "codewhispererterminal_inCloudshell" } + ] + }, + { + "name": "codewhispererterminal_userLoggedIn", + "description": "Emitted when users log in", + "passive": false, + "metadata": [ + { "type": "credentialStartUrl" }, + { "type": "codewhispererterminal_inCloudshell" } + ] + }, + { + "name": "codewhispererterminal_refreshCredentials", + "description": "Emitted when users refresh their credentials", + "passive": false, + "metadata": [ + { "type": "credentialStartUrl" }, + { "type": "requestId" }, + { "type": "oauthFlow" }, + { "type": "result" }, + { "type": "reason", "required": false }, + { "type": "codewhispererterminal_inCloudshell" } + ] + }, + { + "name": "codewhispererterminal_cliSubcommandExecuted", + "description": "Emitted on CW CLI subcommand executed", + "passive": false, + "metadata": [ + { "type": "credentialStartUrl" }, + { "type": "codewhispererterminal_subcommand" }, + { "type": "codewhispererterminal_inCloudshell" } + ] + }, + { + "name": "codewhispererterminal_toolUseSuggested", + "description": "Emitted once per tool use to report outcome of tool use suggested", + "passive": false, + "metadata": [ + { "type": "credentialStartUrl" }, + { "type": "amazonqConversationId" }, + { "type": "codewhispererterminal_utteranceId" }, + { "type": "codewhispererterminal_userInputId" }, + { "type": "codewhispererterminal_toolUseId" }, + { "type": "codewhispererterminal_toolName" }, + { "type": "codewhispererterminal_isToolUseAccepted" }, + { "type": "codewhispererterminal_isToolValid" }, + { "type": "codewhispererterminal_toolUseIsSuccess", "required": false }, + { "type": "codewhispererterminal_isCustomTool" }, + { + "type": "codewhispererterminal_customToolInputTokenSize", + "required": false + }, + { + "type": "codewhispererterminal_customToolOutputTokenSize", + "required": false + }, + { "type": "codewhispererterminal_customToolLatency", "required": false } + ] + }, + { + "name": "codewhispererterminal_mcpServerInit", + "description": "Emitted once per mcp server on start up", + "passive": false, + "metadata": [ + { "type": "amazonqConversationId" }, + { + "type": "codewhispererterminal_mcpServerInitFailureReason", + "required": false + }, + { "type": "codewhispererterminal_toolsPerMcpServer" } + ] + }, + { + "name": "amazonq_didSelectProfile", + "description": "Emitted after the user's Q Profile has been set, whether the user was prompted with a dialog, or a profile was automatically assigned after signing in.", + "metadata": [ + { "type": "source" }, + { "type": "amazonQProfileRegion" }, + { "type": "result" }, + { "type": "ssoRegion", "required": false }, + { "type": "credentialStartUrl", "required": false }, + { "type": "profileCount", "required": false } + ], + "passive": true + }, + { + "name": "amazonq_profileState", + "description": "Indicates a change in the user's Q Profile state", + "metadata": [ + { "type": "source" }, + { "type": "amazonQProfileRegion" }, + { "type": "result" }, + { "type": "ssoRegion", "required": false }, + { "type": "credentialStartUrl", "required": false } + ], + "passive": true + } + ] +} From afbf74630da8b874cec503010dda5a1a19c34d31 Mon Sep 17 00:00:00 2001 From: Chay Nabors Date: Sat, 3 May 2025 01:22:14 -0700 Subject: [PATCH 2/3] remove chat from q_cli --- Cargo.lock | 1671 ++++--- Cargo.toml | 1 - .../Cargo.toml | 28 +- .../LICENSE | 0 .../src/auth_plugin.rs | 0 .../src/client.rs | 4 +- .../src/client/customize.rs | 0 .../src/client/post_error_report.rs | 0 .../src/client/post_feedback.rs | 0 .../src/client/post_metrics.rs | 0 .../src/config.rs | 98 +- .../src/config/endpoint.rs | 0 .../src/config/interceptors.rs | 0 .../src/config/retry.rs | 0 .../src/config/timeout.rs | 0 .../src/error.rs | 0 .../src/error/sealed_unhandled.rs | 0 .../src/error_meta.rs | 0 .../src/json_errors.rs | 0 .../src/lib.rs | 2 +- .../src/meta.rs | 0 .../src/operation.rs | 0 .../src/operation/post_error_report.rs | 0 .../_post_error_report_input.rs | 0 .../_post_error_report_output.rs | 0 .../operation/post_error_report/builders.rs | 0 .../src/operation/post_feedback.rs | 0 .../post_feedback/_post_feedback_input.rs | 0 .../post_feedback/_post_feedback_output.rs | 0 .../src/operation/post_feedback/builders.rs | 0 .../src/operation/post_metrics.rs | 0 .../post_metrics/_post_metrics_input.rs | 0 .../post_metrics/_post_metrics_output.rs | 0 .../src/operation/post_metrics/builders.rs | 0 .../src/primitives.rs | 0 .../src/primitives/event_stream.rs | 0 .../src/primitives/sealed_enum_unknown.rs | 0 .../src/protocol_serde.rs | 0 .../src/protocol_serde/shape_error_details.rs | 0 .../protocol_serde/shape_metadata_entry.rs | 0 .../src/protocol_serde/shape_metric_datum.rs | 0 .../protocol_serde/shape_post_error_report.rs | 0 .../shape_post_error_report_input.rs | 0 .../src/protocol_serde/shape_post_feedback.rs | 0 .../shape_post_feedback_input.rs | 0 .../src/protocol_serde/shape_post_metrics.rs | 0 .../shape_post_metrics_input.rs | 0 .../src/protocol_serde/shape_userdata.rs | 0 .../src/serialization_settings.rs | 0 .../src/types.rs | 0 .../src/types/_aws_product.rs | 0 .../src/types/_error_details.rs | 0 .../src/types/_metadata_entry.rs | 0 .../src/types/_metric_datum.rs | 0 .../src/types/_sentiment.rs | 0 .../src/types/_unit.rs | 0 .../src/types/_userdata.rs | 0 .../src/types/builders.rs | 0 .../Cargo.toml | 2 +- .../build.rs | 18 +- .../src/lib.rs | 2 +- crates/fig_auth/src/builder_id.rs | 2 +- crates/fig_aws_common/src/lib.rs | 2 +- crates/fig_telemetry/Cargo.toml | 2 +- crates/fig_telemetry/src/cognito.rs | 4 +- crates/fig_telemetry/src/endpoint.rs | 2 +- crates/fig_telemetry/src/lib.rs | 14 +- crates/fig_telemetry_core/Cargo.toml | 2 +- crates/fig_telemetry_core/src/lib.rs | 2 +- crates/kiro-cli/Cargo.toml | 2 +- crates/q_chat/Cargo.toml | 59 - crates/q_chat/src/cli.rs | 25 - crates/q_chat/src/command.rs | 1093 ----- crates/q_chat/src/consts.rs | 19 - crates/q_chat/src/context.rs | 1016 ----- crates/q_chat/src/conversation_state.rs | 1051 ----- crates/q_chat/src/hooks.rs | 557 --- crates/q_chat/src/input_source.rs | 107 - crates/q_chat/src/lib.rs | 3848 ----------------- crates/q_chat/src/message.rs | 407 -- crates/q_chat/src/parse.rs | 762 ---- crates/q_chat/src/parser.rs | 375 -- crates/q_chat/src/prompt.rs | 364 -- crates/q_chat/src/skim_integration.rs | 378 -- crates/q_chat/src/token_counter.rs | 251 -- crates/q_chat/src/tool_manager.rs | 1019 ----- crates/q_chat/src/tools/custom_tool.rs | 241 -- crates/q_chat/src/tools/execute_bash.rs | 373 -- crates/q_chat/src/tools/fs_read.rs | 669 --- crates/q_chat/src/tools/fs_write.rs | 953 ---- crates/q_chat/src/tools/gh_issue.rs | 222 - crates/q_chat/src/tools/mod.rs | 433 -- crates/q_chat/src/tools/tool_index.json | 176 - crates/q_chat/src/tools/use_aws.rs | 315 -- crates/q_chat/src/util/issue.rs | 82 - crates/q_chat/src/util/mod.rs | 114 - crates/q_chat/src/util/shared_writer.rs | 89 - crates/q_chat/src/util/ui.rs | 212 - crates/q_cli/Cargo.toml | 1 - crates/q_cli/src/cli/issue.rs | 10 +- crates/q_cli/src/cli/mod.rs | 130 +- typos.toml | 4 +- 102 files changed, 1134 insertions(+), 16079 deletions(-) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/Cargo.toml (81%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/LICENSE (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/auth_plugin.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/client.rs (97%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/client/customize.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/client/post_error_report.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/client/post_feedback.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/client/post_metrics.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/config.rs (94%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/config/endpoint.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/config/interceptors.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/config/retry.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/config/timeout.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/error.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/error/sealed_unhandled.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/error_meta.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/json_errors.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/lib.rs (99%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/meta.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/operation.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/operation/post_error_report.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/operation/post_error_report/_post_error_report_input.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/operation/post_error_report/_post_error_report_output.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/operation/post_error_report/builders.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/operation/post_feedback.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/operation/post_feedback/_post_feedback_input.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/operation/post_feedback/_post_feedback_output.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/operation/post_feedback/builders.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/operation/post_metrics.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/operation/post_metrics/_post_metrics_input.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/operation/post_metrics/_post_metrics_output.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/operation/post_metrics/builders.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/primitives.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/primitives/event_stream.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/primitives/sealed_enum_unknown.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/protocol_serde.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/protocol_serde/shape_error_details.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/protocol_serde/shape_metadata_entry.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/protocol_serde/shape_metric_datum.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/protocol_serde/shape_post_error_report.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/protocol_serde/shape_post_error_report_input.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/protocol_serde/shape_post_feedback.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/protocol_serde/shape_post_feedback_input.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/protocol_serde/shape_post_metrics.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/protocol_serde/shape_post_metrics_input.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/protocol_serde/shape_userdata.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/serialization_settings.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/types.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/types/_aws_product.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/types/_error_details.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/types/_metadata_entry.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/types/_metric_datum.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/types/_sentiment.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/types/_unit.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/types/_userdata.rs (100%) rename crates/{amzn-toolkit-telemetry => amzn-toolkit-telemetry-client}/src/types/builders.rs (100%) delete mode 100644 crates/q_chat/Cargo.toml delete mode 100644 crates/q_chat/src/cli.rs delete mode 100644 crates/q_chat/src/command.rs delete mode 100644 crates/q_chat/src/consts.rs delete mode 100644 crates/q_chat/src/context.rs delete mode 100644 crates/q_chat/src/conversation_state.rs delete mode 100644 crates/q_chat/src/hooks.rs delete mode 100644 crates/q_chat/src/input_source.rs delete mode 100644 crates/q_chat/src/lib.rs delete mode 100644 crates/q_chat/src/message.rs delete mode 100644 crates/q_chat/src/parse.rs delete mode 100644 crates/q_chat/src/parser.rs delete mode 100644 crates/q_chat/src/prompt.rs delete mode 100644 crates/q_chat/src/skim_integration.rs delete mode 100644 crates/q_chat/src/token_counter.rs delete mode 100644 crates/q_chat/src/tool_manager.rs delete mode 100644 crates/q_chat/src/tools/custom_tool.rs delete mode 100644 crates/q_chat/src/tools/execute_bash.rs delete mode 100644 crates/q_chat/src/tools/fs_read.rs delete mode 100644 crates/q_chat/src/tools/fs_write.rs delete mode 100644 crates/q_chat/src/tools/gh_issue.rs delete mode 100644 crates/q_chat/src/tools/mod.rs delete mode 100644 crates/q_chat/src/tools/tool_index.json delete mode 100644 crates/q_chat/src/tools/use_aws.rs delete mode 100644 crates/q_chat/src/util/issue.rs delete mode 100644 crates/q_chat/src/util/mod.rs delete mode 100644 crates/q_chat/src/util/shared_writer.rs delete mode 100644 crates/q_chat/src/util/ui.rs diff --git a/Cargo.lock b/Cargo.lock index 8c6130ef4e..9383ba0d20 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,19 +24,13 @@ dependencies = [ [[package]] name = "addr2line" -version = "0.21.0" +version = "0.24.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +checksum = "dfbe277e56a376000877090da837660b4427aad530e3028d44e0bffe4f89a1c1" dependencies = [ "gimli", ] -[[package]] -name = "adler" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" - [[package]] name = "adler2" version = "2.0.0" @@ -73,7 +67,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", - "shell-color", + "shell-color 1.10.0", "tracing", "unicode-width 0.2.0", "vte 0.15.0", @@ -98,7 +92,7 @@ dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http", + "aws-smithy-http 0.60.12", "aws-smithy-json 0.61.3", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -118,7 +112,7 @@ dependencies = [ "aws-runtime", "aws-smithy-async", "aws-smithy-eventstream", - "aws-smithy-http", + "aws-smithy-http 0.60.12", "aws-smithy-json 0.61.3", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -136,7 +130,7 @@ dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http", + "aws-smithy-http 0.60.12", "aws-smithy-json 0.61.3", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -156,7 +150,7 @@ dependencies = [ "aws-runtime", "aws-smithy-async", "aws-smithy-eventstream", - "aws-smithy-http", + "aws-smithy-http 0.60.12", "aws-smithy-json 0.61.3", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -168,14 +162,14 @@ dependencies = [ ] [[package]] -name = "amzn-toolkit-telemetry" +name = "amzn-toolkit-telemetry-client" version = "1.0.0" dependencies = [ "aws-credential-types", "aws-http", "aws-runtime", "aws-smithy-async", - "aws-smithy-http", + "aws-smithy-http 0.60.12", "aws-smithy-json 0.60.7", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -248,11 +242,12 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "3.0.6" +version = "3.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +checksum = "ca3534e77181a9cc07539ad51f2141fe32f6c3ffd4df76db8ad92346b003ae4e" dependencies = [ "anstyle", + "once_cell", "windows-sys 0.59.0", ] @@ -298,9 +293,9 @@ checksum = "c1df21f715862ede32a0c525ce2ca4d52626bb0007f8c18b87a384503ac33e70" dependencies = [ "clipboard-win", "log", - "objc2 0.6.0", - "objc2-app-kit 0.3.0", - "objc2-foundation 0.3.0", + "objc2 0.6.1", + "objc2-app-kit 0.3.1", + "objc2-foundation 0.3.1", "parking_lot", "percent-encoding", "wl-clipboard-rs", @@ -326,24 +321,21 @@ checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" [[package]] name = "ashpd" -version = "0.10.2" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9c39d707614dbcc6bed00015539f488d8e3fe3e66ed60961efc0c90f4b380b3" +checksum = "6cbdf310d77fd3aaee6ea2093db7011dc2d35d2eb3481e5607f1f8d942ed99df" dependencies = [ "async-fs", "async-net", "enumflags2", "futures-channel", "futures-util", - "rand 0.8.5", + "rand 0.9.1", "raw-window-handle", "serde", "serde_repr", "url", - "wayland-backend", - "wayland-client", - "wayland-protocols", - "zbus 5.2.0", + "zbus 5.5.0", ] [[package]] @@ -369,9 +361,9 @@ dependencies = [ [[package]] name = "assert_cmd" -version = "2.0.16" +version = "2.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc1835b7f27878de8525dc71410b5a31cdcc5f230aed5ba5df968e09c201b23d" +checksum = "2bd389a4b2970a01282ee455294913c0a43724daedcd1a24c3eb0ec1c1320b66" dependencies = [ "anstyle", "bstr", @@ -409,9 +401,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.18" +version = "0.4.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df895a515f70646414f4b45c0b79082783b80552b373a68283012928df56f522" +checksum = "b37fc50485c4f3f736a4fb14199f6d5f5ba008d7f28fe710306c92780f004c07" dependencies = [ "flate2", "futures-core", @@ -543,9 +535,9 @@ checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de" [[package]] name = "async-trait" -version = "0.1.87" +version = "0.1.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d556ec1359574147ec0c4fc5eb525f3f23263a592b1a9c07e0a75b427de55c97" +checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" dependencies = [ "proc-macro2", "quote", @@ -616,18 +608,18 @@ dependencies = [ [[package]] name = "avif-serialize" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e335041290c43101ca215eed6f43ec437eb5a42125573f600fc3fa42b9bddd62" +checksum = "98922d6a4cfbcb08820c69d8eeccc05bb1f29bfa06b4f5b1dbfe9a868bd7608e" dependencies = [ "arrayvec", ] [[package]] name = "aws-config" -version = "1.5.13" +version = "1.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c03a50b30228d3af8865ce83376b4e99e1ffa34728220fe2860e4df0bb5278d6" +checksum = "b6fcc63c9860579e4cb396239570e979376e70aab79e496621748a09913f8b36" dependencies = [ "aws-credential-types", "aws-runtime", @@ -635,7 +627,7 @@ dependencies = [ "aws-sdk-ssooidc", "aws-sdk-sts", "aws-smithy-async", - "aws-smithy-http", + "aws-smithy-http 0.62.1", "aws-smithy-json 0.61.3", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -644,7 +636,7 @@ dependencies = [ "bytes", "fastrand", "hex", - "http 0.2.12", + "http 1.3.1", "ring", "time", "tokio", @@ -655,9 +647,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.2.1" +version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60e8f6b615cb5fc60a98132268508ad104310f0cfb25a1c22eee76efdf9154da" +checksum = "687bc16bc431a8533fe0097c7f0182874767f920989d7260950172ae8e3c4465" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -676,39 +668,37 @@ dependencies = [ [[package]] name = "aws-lc-rs" -version = "1.12.0" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f409eb70b561706bf8abba8ca9c112729c481595893fd06a2dd9af8ed8441148" +checksum = "19b756939cb2f8dc900aa6dcd505e6e2428e9cae7ff7b028c49e3946efa70878" dependencies = [ "aws-lc-sys", - "paste", "zeroize", ] [[package]] name = "aws-lc-sys" -version = "0.24.1" +version = "0.28.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "923ded50f602b3007e5e63e3f094c479d9c8a9b42d7f4034e4afe456aa48bfd2" +checksum = "bfa9b6986f250236c27e5a204062434a773a13243d2ffc2955f37bdba4c5c6a1" dependencies = [ "bindgen 0.69.5", "cc", "cmake", "dunce", "fs_extra", - "paste", ] [[package]] name = "aws-runtime" -version = "1.5.5" +version = "1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76dd04d39cc12844c0994f2c9c5a6f5184c22e9188ec1ff723de41910a21dcad" +checksum = "6c4063282c69991e57faab9e5cb21ae557e59f5b0fb285c196335243df8dc25c" dependencies = [ "aws-credential-types", "aws-sigv4", "aws-smithy-async", - "aws-smithy-http", + "aws-smithy-http 0.62.1", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -717,7 +707,6 @@ dependencies = [ "fastrand", "http 0.2.12", "http-body 0.4.6", - "once_cell", "percent-encoding", "pin-project-lite", "tracing", @@ -726,20 +715,21 @@ dependencies = [ [[package]] name = "aws-sdk-cognitoidentity" -version = "1.54.0" +version = "1.66.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3a63eeb333e6aac318474715bcb47130ceb02d4ce4caa4ebd632ef456ef1f7a" +checksum = "1cdb376404ce63c89ca527732904caf24cd3a97a9b54239e87974b94f4b934c8" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http", + "aws-smithy-http 0.62.1", "aws-smithy-json 0.61.3", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", "aws-types", "bytes", + "fastrand", "http 0.2.12", "once_cell", "regex-lite", @@ -748,20 +738,21 @@ dependencies = [ [[package]] name = "aws-sdk-cognitoidentityprovider" -version = "1.63.0" +version = "1.77.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cea251ab30246099e3b57406a312ac9d96a8f4cde3fce0470b7d2109ba27307" +checksum = "8054c266053cc1061f6a816fb6da6066beea601b4c3677958c9df96f5f33d9d6" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http", + "aws-smithy-http 0.62.1", "aws-smithy-json 0.61.3", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", "aws-types", "bytes", + "fastrand", "http 0.2.12", "once_cell", "regex-lite", @@ -770,20 +761,21 @@ dependencies = [ [[package]] name = "aws-sdk-sso" -version = "1.53.0" +version = "1.65.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1605dc0bf9f0a4b05b451441a17fcb0bda229db384f23bf5cead3adbab0664ac" +checksum = "8efec445fb78df585327094fcef4cad895b154b58711e504db7a93c41aa27151" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http", + "aws-smithy-http 0.62.1", "aws-smithy-json 0.61.3", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", "aws-types", "bytes", + "fastrand", "http 0.2.12", "once_cell", "regex-lite", @@ -792,20 +784,21 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.54.0" +version = "1.66.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59f3f73466ff24f6ad109095e0f3f2c830bfb4cd6c8b12f744c8e61ebf4d3ba1" +checksum = "5e49cca619c10e7b002dc8e66928ceed66ab7f56c1a3be86c5437bf2d8d89bba" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http", + "aws-smithy-http 0.62.1", "aws-smithy-json 0.61.3", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", "aws-types", "bytes", + "fastrand", "http 0.2.12", "once_cell", "regex-lite", @@ -814,14 +807,14 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.54.0" +version = "1.66.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "249b2acaa8e02fd4718705a9494e3eb633637139aa4bb09d70965b0448e865db" +checksum = "7420479eac0a53f776cc8f0d493841ffe58ad9d9783f3947be7265784471b47a" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http", + "aws-smithy-http 0.62.1", "aws-smithy-json 0.61.3", "aws-smithy-query", "aws-smithy-runtime", @@ -829,6 +822,7 @@ dependencies = [ "aws-smithy-types", "aws-smithy-xml", "aws-types", + "fastrand", "http 0.2.12", "once_cell", "regex-lite", @@ -837,12 +831,12 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.2.9" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9bfe75fad52793ce6dec0dc3d4b1f388f038b5eb866c8d4d7f3a8e21b5ea5051" +checksum = "3503af839bd8751d0bdc5a46b9cac93a003a353e635b0c12cf2376b5b53e41ea" dependencies = [ "aws-credential-types", - "aws-smithy-http", + "aws-smithy-http 0.62.1", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", @@ -850,8 +844,7 @@ dependencies = [ "hex", "hmac", "http 0.2.12", - "http 1.2.0", - "once_cell", + "http 1.3.1", "percent-encoding", "sha2", "time", @@ -901,6 +894,60 @@ dependencies = [ "tracing", ] +[[package]] +name = "aws-smithy-http" +version = "0.62.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99335bec6cdc50a346fda1437f9fefe33abf8c99060739a546a16457f2862ca9" +dependencies = [ + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "bytes-utils", + "futures-core", + "http 0.2.12", + "http 1.3.1", + "http-body 0.4.6", + "percent-encoding", + "pin-project-lite", + "pin-utils", + "tracing", +] + +[[package]] +name = "aws-smithy-http-client" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8aff1159006441d02e57204bf57a1b890ba68bedb6904ffd2873c1c4c11c546b" +dependencies = [ + "aws-smithy-async", + "aws-smithy-protocol-test", + "aws-smithy-runtime-api", + "aws-smithy-types", + "bytes", + "h2 0.4.9", + "http 0.2.12", + "http 1.3.1", + "http-body 0.4.6", + "http-body 1.0.1", + "hyper 0.14.32", + "hyper 1.6.0", + "hyper-rustls 0.24.2", + "hyper-rustls 0.27.5", + "hyper-util", + "indexmap 2.9.0", + "pin-project-lite", + "rustls 0.21.12", + "rustls 0.23.26", + "rustls-native-certs 0.8.1", + "rustls-pki-types", + "serde", + "serde_json", + "tokio", + "tower", + "tracing", +] + [[package]] name = "aws-smithy-json" version = "0.60.7" @@ -919,11 +966,20 @@ dependencies = [ "aws-smithy-types", ] +[[package]] +name = "aws-smithy-observability" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9364d5989ac4dd918e5cc4c4bdcc61c9be17dcd2586ea7f69e348fc7c6cab393" +dependencies = [ + "aws-smithy-runtime-api", +] + [[package]] name = "aws-smithy-protocol-test" -version = "0.63.0" +version = "0.63.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b92b62199921f10685c6b588fdbeb81168ae4e7950ae3e5f50145a01bb5f1ad" +checksum = "5b42f13304bed0b96d7471e4770c270bb3eb4fea277727fb03c811e84cb4bf3a" dependencies = [ "assert-json-diff 1.1.0", "aws-smithy-runtime-api", @@ -950,31 +1006,24 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.8" +version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d526a12d9ed61fadefda24abe2e682892ba288c2018bcb38b1b4c111d13f6d92" +checksum = "14302f06d1d5b7d333fd819943075b13d27c7700b414f574c3c35859bfb55d5e" dependencies = [ "aws-smithy-async", - "aws-smithy-http", - "aws-smithy-protocol-test", + "aws-smithy-http 0.62.1", + "aws-smithy-http-client", + "aws-smithy-observability", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", "fastrand", - "h2 0.3.26", "http 0.2.12", + "http 1.3.1", "http-body 0.4.6", "http-body 1.0.1", - "httparse", - "hyper 0.14.32", - "hyper-rustls 0.24.2", - "indexmap 2.9.0", - "once_cell", "pin-project-lite", "pin-utils", - "rustls 0.21.12", - "serde", - "serde_json", "tokio", "tracing", "tracing-subscriber", @@ -982,15 +1031,15 @@ dependencies = [ [[package]] name = "aws-smithy-runtime-api" -version = "1.7.3" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92165296a47a812b267b4f41032ff8069ab7ff783696d217f0994a0d7ab585cd" +checksum = "a1e5d9e3a80a18afa109391fb5ad09c3daf887b516c6fd805a157c6ea7994a57" dependencies = [ "aws-smithy-async", "aws-smithy-types", "bytes", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "pin-project-lite", "tokio", "tracing", @@ -1008,11 +1057,11 @@ dependencies = [ "bytes-utils", "futures-core", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", - "itoa 1.0.14", + "itoa 1.0.15", "num-integer", "pin-project-lite", "pin-utils", @@ -1036,7 +1085,7 @@ dependencies = [ name = "aws-toolkit-telemetry-definitions" version = "0.1.0" dependencies = [ - "amzn-toolkit-telemetry", + "amzn-toolkit-telemetry-client", "convert_case 0.8.0", "prettyplease", "quote", @@ -1047,9 +1096,9 @@ dependencies = [ [[package]] name = "aws-types" -version = "1.3.5" +version = "1.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfbd0a668309ec1f66c0f6bda4840dd6d4796ae26d699ebc266d7cc95c6d040f" +checksum = "8a322fec39e4df22777ed3ad8ea868ac2f94cd15e1a55f6ee8d8d6305057689a" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -1061,17 +1110,17 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.71" +version = "0.3.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b05800d2e817c8b3b4b54abd461726265fa9789ae34330622f2db9ee696f9d" +checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" dependencies = [ "addr2line", - "cc", "cfg-if", "libc", - "miniz_oxide 0.7.4", + "miniz_oxide", "object", "rustc-demangle", + "windows-targets 0.52.6", ] [[package]] @@ -1167,7 +1216,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "rustc-hash 2.1.0", + "rustc-hash 2.1.1", "shlex", "syn 2.0.101", ] @@ -1238,6 +1287,15 @@ dependencies = [ "objc2 0.5.2", ] +[[package]] +name = "block2" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "340d2f0bdb2a43c1d3cd40513185b2bd7def0aa1052f956455114bc98f82dcf2" +dependencies = [ + "objc2 0.6.1", +] + [[package]] name = "blocking" version = "1.6.1" @@ -1273,21 +1331,21 @@ dependencies = [ [[package]] name = "built" -version = "0.7.5" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c360505aed52b7ec96a3636c3f039d99103c37d1d9b4f7a8c743d3ea9ffcd03b" +checksum = "56ed6191a7e78c36abdb16ab65341eefd73d64d303fffccdbb00d51e4205967b" [[package]] name = "bumpalo" -version = "3.16.0" +version = "3.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" +checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" [[package]] name = "bytemuck" -version = "1.21.0" +version = "1.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef657dfab802224e671f5818e9a4935f9b1957ed18e58292690cc39e7a4092a3" +checksum = "9134a6ef01ce4b366b50689c94f82c14bc72bc5d0386829828a2e2752ef7958c" [[package]] name = "byteorder" @@ -1378,9 +1436,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.16" +version = "1.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be714c154be609ec7f5dad223a33bf1482fff90472de28f7362806e6d4832b8c" +checksum = "8691782945451c1c383942c4874dbe63814f61cb57ef773cda2972682b7bb3c0" dependencies = [ "jobserver", "libc", @@ -1437,9 +1495,9 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chrono" -version = "0.4.40" +version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a7964611d71df112cb1730f2ee67324fcf4d0fc6606acbbe9bfe06df124637c" +checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" dependencies = [ "android-tzdata", "iana-time-zone", @@ -1525,9 +1583,9 @@ dependencies = [ [[package]] name = "clap_complete" -version = "4.5.46" +version = "4.5.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5c5508ea23c5366f77e53f5a0070e5a84e51687ec3ef9e0464c86dc8d13ce98" +checksum = "be8c97f3a6f02b9e24cadc12aaba75201d18754b53ea0a9d99642f806ccdb4c9" dependencies = [ "clap", ] @@ -1573,9 +1631,9 @@ dependencies = [ [[package]] name = "cmake" -version = "0.1.52" +version = "0.1.54" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c682c223677e0e5b6b7f63a64b9351844c3f1b1678a68b7ee617e30fb082620e" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" dependencies = [ "cc", ] @@ -1591,7 +1649,7 @@ dependencies = [ "cocoa-foundation", "core-foundation 0.10.0", "core-graphics", - "foreign-types", + "foreign-types 0.5.0", "libc", "objc", ] @@ -1612,16 +1670,16 @@ dependencies = [ [[package]] name = "color-eyre" -version = "0.6.3" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55146f5e46f237f7423d74111267d4597b59b0dad0ffaf7303bce9945d843ad5" +checksum = "e6e1761c0e16f8883bbbb8ce5990867f4f06bf11a0253da6495a04ce4b6ef0ec" dependencies = [ "backtrace", "color-spantrace", "eyre", "indenter", "once_cell", - "owo-colors 3.5.0", + "owo-colors", "tracing-error", ] @@ -1648,12 +1706,12 @@ dependencies = [ [[package]] name = "color-spantrace" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd6be1b2a7e382e2b98b43b2adcca6bb0e465af0bdd38123873ae61eb17a72c2" +checksum = "2ddd8d5bfda1e11a501d0a7303f3bfed9aa632ebdb859be40d0fd70478ed70d5" dependencies = [ "once_cell", - "owo-colors 3.5.0", + "owo-colors", "tracing-core", "tracing-error", ] @@ -1672,11 +1730,10 @@ checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" [[package]] name = "colored" -version = "2.2.0" +version = "3.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c" +checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e" dependencies = [ - "lazy_static", "windows-sys 0.59.0", ] @@ -1701,9 +1758,9 @@ dependencies = [ [[package]] name = "console" -version = "0.15.10" +version = "0.15.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" dependencies = [ "encode_unicode", "libc", @@ -1727,7 +1784,7 @@ version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", "once_cell", "tiny-keccak", ] @@ -1811,7 +1868,7 @@ dependencies = [ "bitflags 2.9.0", "core-foundation 0.10.0", "core-graphics-types", - "foreign-types", + "foreign-types 0.5.0", "libc", ] @@ -1828,9 +1885,9 @@ dependencies = [ [[package]] name = "cpufeatures" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" dependencies = [ "libc", ] @@ -1964,9 +2021,9 @@ dependencies = [ [[package]] name = "crunchy" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" [[package]] name = "crypto-common" @@ -2066,9 +2123,9 @@ dependencies = [ [[package]] name = "data-encoding" -version = "2.6.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8566979429cf69b49a5c740c60791108e86440e8be149bbea4fe54d2c32d6e2" +checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" [[package]] name = "dbus" @@ -2152,9 +2209,9 @@ dependencies = [ [[package]] name = "derive_more" -version = "0.99.18" +version = "0.99.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f33878137e4dafd7fa914ad4e259e18a4e8e532b9617a2d0150262bf53abfce" +checksum = "6edb4b64a43d977b8e99788fe3a04d483834fba1215a7e02caa415b626497f7f" dependencies = [ "convert_case 0.4.0", "proc-macro2", @@ -2206,7 +2263,16 @@ version = "5.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" dependencies = [ - "dirs-sys", + "dirs-sys 0.4.1", +] + +[[package]] +name = "dirs" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3e8aa94d75141228480295a7d0e7feb620b1a5ad9f12bc40be62411e38cce4e" +dependencies = [ + "dirs-sys 0.5.0", ] [[package]] @@ -2227,10 +2293,22 @@ checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" dependencies = [ "libc", "option-ext", - "redox_users", + "redox_users 0.4.6", "windows-sys 0.48.0", ] +[[package]] +name = "dirs-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab" +dependencies = [ + "libc", + "option-ext", + "redox_users 0.5.0", + "windows-sys 0.59.0", +] + [[package]] name = "dirs-sys-next" version = "0.1.2" @@ -2238,7 +2316,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" dependencies = [ "libc", - "redox_users", + "redox_users 0.4.6", "winapi", ] @@ -2248,6 +2326,28 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd0c93bb4b0c6d9b77f4435b0ae98c24d17f1c45b2ff844c6151a07256ca923b" +[[package]] +name = "dispatch2" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a0d569e003ff27784e0e14e4a594048698e0c0f0b66cabcb51511be55a7caa0" +dependencies = [ + "bitflags 2.9.0", + "block2 0.6.1", + "libc", + "objc2 0.6.1", +] + +[[package]] +name = "dispatch2" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89a09f22a6c6069a18470eb92d2298acf25463f14256d24778e1230d789a2aec" +dependencies = [ + "bitflags 2.9.0", + "objc2 0.6.1", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -2259,15 +2359,6 @@ dependencies = [ "syn 2.0.101", ] -[[package]] -name = "dlib" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "330c60081dcc4c72131f8eb70510f1ac07223e5d4163db481a04a0befcffa412" -dependencies = [ - "libloading 0.8.6", -] - [[package]] name = "dlopen2" version = "0.7.0" @@ -2308,9 +2399,9 @@ checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" [[package]] name = "document-features" -version = "0.2.10" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb6969eaabd2421f8a2775cfd2471a2b634372b4a25d41e3bd647b79912850a0" +checksum = "95249b50c6c185bee49034bcb378a49dc2b5dff0be90ff6616d31d64febab05d" dependencies = [ "litrs", ] @@ -2323,18 +2414,18 @@ checksum = "75b325c5dbd37f80359721ad39aca5a29fb04c89279657cffdda8736d0c0b9d2" [[package]] name = "dpi" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f25c0e292a7ca6d6498557ff1df68f32c99850012b6ea401cf8daf771f22ff53" +checksum = "d8b14ccef22fc6f5a8f4d7d768562a182c04ce9a3b3157b91390b52ddfdf1a76" dependencies = [ "serde", ] [[package]] name = "dtoa" -version = "1.0.9" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcbb2bf8e87535c23f7a8a321e364ce21462d0ff10cb6407820e8e96dfff6653" +checksum = "d6add3b8cff394282be81f3fc1a0605db594ed69890078ca6e2cab1c408bcf04" [[package]] name = "dtoa-short" @@ -2353,9 +2444,9 @@ checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" [[package]] name = "either" -version = "1.13.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] name = "encode_unicode" @@ -2436,15 +2527,15 @@ dependencies = [ [[package]] name = "equivalent" -version = "1.0.1" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "erased-serde" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24e2389d65ab4fab27dc2a5de7b191e1f6617d1f1c8855c0dc569c94a4cbb18d" +checksum = "e004d887f51fcb9fef17317a2f3525c887d8aa3f4f50fed920816a688284a5b7" dependencies = [ "serde", "typeid", @@ -2452,9 +2543,9 @@ dependencies = [ [[package]] name = "errno" -version = "0.3.10" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" +checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e" dependencies = [ "libc", "windows-sys 0.59.0", @@ -2479,9 +2570,9 @@ dependencies = [ [[package]] name = "event-listener-strategy" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c3e4e0dd3673c1139bf041f3008816d9cf2946bbfac2945c09e523b8d7b05b2" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" dependencies = [ "event-listener", "pin-project-lite", @@ -2496,7 +2587,7 @@ dependencies = [ "bit_field", "half", "lebe", - "miniz_oxide 0.8.5", + "miniz_oxide", "rayon-core", "smallvec", "zune-inflate", @@ -2560,7 +2651,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ce92ff622d6dadf7349484f42c93271a0d49b7cc4d466a936405bacbe10aa78" dependencies = [ "cfg-if", - "rustix 1.0.5", + "rustix 1.0.7", "windows-sys 0.59.0", ] @@ -2617,7 +2708,7 @@ dependencies = [ "fig_request", "fig_settings", "fig_util", - "http 1.2.0", + "http 1.3.1", "regex", "serde", "serde_json", @@ -2653,7 +2744,7 @@ dependencies = [ "hyper-util", "insta", "percent-encoding", - "rand 0.9.0", + "rand 0.9.1", "reqwest", "serde", "serde_json", @@ -2675,7 +2766,7 @@ dependencies = [ "aws-smithy-types", "aws-types", "fig_request", - "http 1.2.0", + "http 1.3.1", "tracing", ] @@ -2715,7 +2806,7 @@ dependencies = [ "freedesktop-icons", "futures", "gtk", - "http 1.2.0", + "http 1.3.1", "image", "infer", "keyboard-types 0.8.0", @@ -2733,7 +2824,7 @@ dependencies = [ "parking_lot", "paste", "percent-encoding", - "rand 0.9.0", + "rand 0.9.1", "regex", "rfd", "semver", @@ -2898,7 +2989,7 @@ dependencies = [ "clap", "core-foundation 0.10.0", "dbus", - "dirs", + "dirs 5.0.1", "dispatch", "fig_os_shim", "fig_settings", @@ -2909,7 +3000,7 @@ dependencies = [ "macos-utils", "nix 0.29.0", "objc", - "owo-colors 4.2.0", + "owo-colors", "plist", "regex", "serde", @@ -2934,7 +3025,7 @@ dependencies = [ "flate2", "nix 0.29.0", "pin-project-lite", - "rand 0.9.0", + "rand 0.9.1", "tempfile", "thiserror 2.0.12", "tokio", @@ -2962,7 +3053,7 @@ name = "fig_os_shim" version = "1.10.0" dependencies = [ "cfg-if", - "dirs", + "dirs 5.0.1", "nix 0.29.0", "serde", "sysinfo", @@ -2984,7 +3075,7 @@ dependencies = [ "prost-build", "prost-reflect", "prost-reflect-build", - "rand 0.9.0", + "rand 0.9.1", "rmp-serde", "serde", "serde_json", @@ -3023,7 +3114,7 @@ dependencies = [ "mockito", "reqwest", "reqwest_cookie_store", - "rustls 0.23.23", + "rustls 0.23.26", "rustls-native-certs 0.8.1", "rustls-pemfile 2.2.0", "serde", @@ -3062,7 +3153,7 @@ name = "fig_telemetry" version = "1.10.0" dependencies = [ "amzn-codewhisperer-client", - "amzn-toolkit-telemetry", + "amzn-toolkit-telemetry-client", "anyhow", "async-trait", "aws-credential-types", @@ -3096,7 +3187,7 @@ name = "fig_telemetry_core" version = "1.10.0" dependencies = [ "amzn-codewhisperer-client", - "amzn-toolkit-telemetry", + "amzn-toolkit-telemetry-client", "async-trait", "aws-toolkit-telemetry-definitions", "fig_util", @@ -3129,7 +3220,7 @@ dependencies = [ "fig_os_shim", "fig_util", "hex", - "http 1.2.0", + "http 1.3.1", "http-body-util", "hyper 1.6.0", "hyper-util", @@ -3149,7 +3240,7 @@ dependencies = [ "cfg-if", "clap", "core-foundation 0.10.0", - "dirs", + "dirs 5.0.1", "fig_os_shim", "fig_test", "hex", @@ -3162,7 +3253,7 @@ dependencies = [ "objc2-app-kit 0.2.2", "objc2-foundation 0.2.2", "paste", - "rand 0.9.0", + "rand 0.9.1", "regex", "serde", "serde_json", @@ -3226,7 +3317,7 @@ dependencies = [ "serde", "serde_json", "shared_library", - "shell-color", + "shell-color 1.10.0", "shell-words", "shellexpand", "shlex", @@ -3288,6 +3379,12 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "flate2" version = "1.1.1" @@ -3295,7 +3392,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ced92e76e966ca2fd84c8f7aa01a4aea65b0eb6648d72f7c8f3e2764a67fece" dependencies = [ "crc32fast", - "miniz_oxide 0.8.5", + "miniz_oxide", ] [[package]] @@ -3327,9 +3424,18 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "foldhash" -version = "0.1.4" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "foreign-types" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared 0.1.1", +] [[package]] name = "foreign-types" @@ -3338,7 +3444,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" dependencies = [ "foreign-types-macros", - "foreign-types-shared", + "foreign-types-shared 0.3.1", ] [[package]] @@ -3352,6 +3458,12 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "foreign-types-shared" version = "0.3.1" @@ -3373,7 +3485,7 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8ef34245e0540c9a3ce7a28340b98d2c12b75da0d446da4e8224923fcaa0c16" dependencies = [ - "dirs", + "dirs 5.0.1", "once_cell", "rust-ini", "thiserror 1.0.69", @@ -3670,9 +3782,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", "js-sys", @@ -3683,14 +3795,16 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +checksum = "73fea8450eea4bac3940448fb7ae50d91f034f941199fcd9d909a5a07aa455f0" dependencies = [ "cfg-if", + "js-sys", "libc", - "wasi 0.13.3+wasi-0.2.2", - "windows-targets 0.52.6", + "r-efi", + "wasi 0.14.2+wasi-0.2.4", + "wasm-bindgen", ] [[package]] @@ -3705,9 +3819,9 @@ dependencies = [ [[package]] name = "gimli" -version = "0.28.1" +version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "gio" @@ -3891,16 +4005,16 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.7" +version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccae279728d634d083c00f6099cb58f01cc99c145b84b8be2f6c74618d79922e" +checksum = "75249d144030531f8dee69fe9cea04d3edf809a017ae445e2abdff6629e86633" dependencies = [ "atomic-waker", "bytes", "fnv", "futures-core", "futures-sink", - "http 1.2.0", + "http 1.3.1", "indexmap 2.9.0", "slab", "tokio", @@ -3910,9 +4024,9 @@ dependencies = [ [[package]] name = "half" -version = "2.4.1" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ "cfg-if", "crunchy", @@ -3944,9 +4058,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.2" +version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" dependencies = [ "allocator-api2", "equivalent", @@ -3991,6 +4105,12 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" +[[package]] +name = "hermit-abi" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbd780fe5cc30f81464441920d82ac8740e2e46b29a6fad543ddd075229ce37e" + [[package]] name = "hex" version = "0.4.3" @@ -4037,18 +4157,18 @@ checksum = "601cbb57e577e2f5ef5be8e7b83f0f63994f25aa94d673e54a92d5c516d101f1" dependencies = [ "bytes", "fnv", - "itoa 1.0.14", + "itoa 1.0.15", ] [[package]] name = "http" -version = "1.2.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" dependencies = [ "bytes", "fnv", - "itoa 1.0.14", + "itoa 1.0.15", ] [[package]] @@ -4069,7 +4189,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.2.0", + "http 1.3.1", ] [[package]] @@ -4080,16 +4200,16 @@ checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", "futures-core", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "pin-project-lite", ] [[package]] name = "httparse" -version = "1.9.5" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" [[package]] name = "httpdate" @@ -4112,7 +4232,7 @@ dependencies = [ "http-body 0.4.6", "httparse", "httpdate", - "itoa 1.0.14", + "itoa 1.0.15", "pin-project-lite", "socket2", "tokio", @@ -4130,12 +4250,12 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.7", - "http 1.2.0", + "h2 0.4.9", + "http 1.3.1", "http-body 1.0.1", "httparse", "httpdate", - "itoa 1.0.14", + "itoa 1.0.15", "pin-project-lite", "smallvec", "tokio", @@ -4165,18 +4285,34 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" dependencies = [ "futures-util", - "http 1.2.0", + "http 1.3.1", "hyper 1.6.0", "hyper-util", - "rustls 0.23.23", + "rustls 0.23.26", "rustls-native-certs 0.8.1", "rustls-pki-types", "tokio", - "tokio-rustls 0.26.1", + "tokio-rustls 0.26.2", "tower-service", "webpki-roots", ] +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" version = "0.1.11" @@ -4186,7 +4322,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "hyper 1.6.0", "libc", @@ -4199,16 +4335,17 @@ dependencies = [ [[package]] name = "iana-time-zone" -version = "0.1.61" +version = "0.1.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" dependencies = [ "android_system_properties", "core-foundation-sys", "iana-time-zone-haiku", "js-sys", + "log", "wasm-bindgen", - "windows-core 0.52.0", + "windows-core 0.61.0", ] [[package]] @@ -4261,9 +4398,9 @@ dependencies = [ [[package]] name = "icu_locid_transform_data" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" +checksum = "7515e6d781098bf9f7205ab3fc7e9709d34554ae0b21ddbcb5febfa4bc7df11d" [[package]] name = "icu_normalizer" @@ -4285,9 +4422,9 @@ dependencies = [ [[package]] name = "icu_normalizer_data" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" +checksum = "c5e8338228bdc8ab83303f16b797e177953730f601a96c25d10cb3ab0daa0cb7" [[package]] name = "icu_properties" @@ -4306,9 +4443,9 @@ dependencies = [ [[package]] name = "icu_properties_data" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" +checksum = "85fb8799753b75aee8d2a21d7c14d9f38921b54b3dbda10f5a3c7a7b82dba5e2" [[package]] name = "icu_provider" @@ -4390,9 +4527,9 @@ dependencies = [ [[package]] name = "image-webp" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e031e8e3d94711a9ccb5d6ea357439ef3dcbed361798bd4071dc4d9793fbe22f" +checksum = "b77d01e822461baa8409e156015a1d91735549f0f2c17691bd2d996bef238f7f" dependencies = [ "byteorder-lite", "quick-error", @@ -4427,7 +4564,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" dependencies = [ "equivalent", - "hashbrown 0.15.2", + "hashbrown 0.15.3", "serde", ] @@ -4503,9 +4640,9 @@ dependencies = [ [[package]] name = "inventory" -version = "0.3.17" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b31349d02fe60f80bbbab1a9402364cad7460626d6030494b08ac4a2075bf81" +checksum = "ab08d7cd2c5897f2c949e5383ea7c7db03fb19130ffcfbf7eda795137ae3cb83" dependencies = [ "rustversion", ] @@ -4521,19 +4658,19 @@ dependencies = [ [[package]] name = "ipnet" -version = "2.10.1" +version = "2.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" +checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] name = "is-terminal" -version = "0.4.13" +version = "0.4.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" +checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" dependencies = [ - "hermit-abi", + "hermit-abi 0.5.0", "libc", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -4575,6 +4712,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "0.4.8" @@ -4583,9 +4729,9 @@ checksum = "b71991ff56294aa922b450139ee08b3bfc70982c6b2c7562771375cf73542dd4" [[package]] name = "itoa" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "javascriptcore-rs" @@ -4612,9 +4758,9 @@ dependencies = [ [[package]] name = "jiff" -version = "0.2.8" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5ad87c89110f55e4cd4dc2893a9790820206729eaf221555f742d540b0724a0" +checksum = "27e77966151130221b079bcec80f1f34a9e414fa489d99152a201c07fd2182bc" dependencies = [ "jiff-static", "log", @@ -4625,9 +4771,9 @@ dependencies = [ [[package]] name = "jiff-static" -version = "0.2.8" +version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d076d5b64a7e2fe6f0743f02c43ca4a6725c0f904203bfe276a5b3e793103605" +checksum = "97265751f8a9a4228476f2fc17874a9e7e70e96b893368e42619880fe143b48a" dependencies = [ "proc-macro2", "quote", @@ -4658,10 +4804,11 @@ checksum = "8eaf4bc02d17cbdd7ff4c7438cafcdf7fb9a4613313ad11b4f8fefe7d3fa0130" [[package]] name = "jobserver" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" dependencies = [ + "getrandom 0.3.2", "libc", ] @@ -4703,19 +4850,135 @@ dependencies = [ ] [[package]] -name = "kqueue" -version = "1.0.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7447f1ca1b7b563588a205fe93dea8df60fd981423a768bc1c0ded35ed147d0c" +name = "kiro_cli" +version = "1.10.0" dependencies = [ - "kqueue-sys", - "libc", -] - -[[package]] -name = "kqueue-sys" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" + "amzn-codewhisperer-client", + "amzn-codewhisperer-streaming-client", + "amzn-consolas-client", + "amzn-qdeveloper-streaming-client", + "amzn-toolkit-telemetry-client", + "anstream", + "arboard", + "assert_cmd", + "async-trait", + "aws-config", + "aws-credential-types", + "aws-runtime", + "aws-sdk-cognitoidentity", + "aws-sdk-ssooidc", + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "base64 0.22.1", + "bitflags 2.9.0", + "bstr", + "bytes", + "camino", + "cfg-if", + "clap", + "clap_complete", + "clap_complete_fig", + "color-eyre", + "color-print", + "convert_case 0.8.0", + "cookie", + "criterion", + "crossterm", + "ctrlc", + "dialoguer", + "dirs 5.0.1", + "eyre", + "fd-lock", + "futures", + "glob", + "globset", + "hex", + "http 1.3.1", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "indicatif", + "indoc", + "insta", + "libc", + "mimalloc", + "mockito", + "nix 0.29.0", + "objc2 0.5.2", + "objc2-app-kit 0.2.2", + "objc2-foundation 0.2.2", + "owo-colors", + "parking_lot", + "paste", + "percent-encoding", + "predicates", + "prettyplease", + "quote", + "r2d2", + "r2d2_sqlite", + "rand 0.9.1", + "regex", + "reqwest", + "ring", + "rusqlite", + "rustls 0.23.26", + "rustls-native-certs 0.8.1", + "rustls-pemfile 2.2.0", + "rustyline", + "security-framework 3.2.0", + "self_update", + "semver", + "serde", + "serde_json", + "sha2", + "shell-color 1.0.0", + "shell-words", + "shellexpand", + "shlex", + "similar", + "skim", + "spinners", + "strip-ansi-escapes", + "strum 0.27.1", + "syn 2.0.101", + "syntect", + "sysinfo", + "tempfile", + "thiserror 2.0.12", + "time", + "tokio", + "tokio-tungstenite", + "tokio-util", + "toml", + "tracing", + "tracing-appender", + "tracing-subscriber", + "tracing-test", + "unicode-width 0.2.0", + "url", + "uuid", + "walkdir", + "webpki-roots", + "whoami", + "winnow 0.6.2", +] + +[[package]] +name = "kqueue" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7447f1ca1b7b563588a205fe93dea8df60fd981423a768bc1c0ded35ed147d0c" +dependencies = [ + "kqueue-sys", + "libc", +] + +[[package]] +name = "kqueue-sys" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed9625ffda8729b85e45cf04090035ac368927b8cebc34898e7c120f52e4838b" dependencies = [ "bitflags 1.3.2", @@ -4873,15 +5136,15 @@ checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "linux-raw-sys" -version = "0.9.2" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6db9c683daf087dc577b7506e9695b3d556a9f3849903fa28186283afd6809e9" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" [[package]] name = "litemap" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" +checksum = "23fb14cb19457329c82206317a5663005a4d404783dc74f4252769b0d5f42856" [[package]] name = "litrs" @@ -4934,7 +5197,7 @@ version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" dependencies = [ - "hashbrown 0.15.2", + "hashbrown 0.15.3", ] [[package]] @@ -4968,7 +5231,7 @@ dependencies = [ "accessibility", "accessibility-sys", "appkit-nsworkspace-bindings", - "block2", + "block2 0.5.1", "cocoa", "core-foundation 0.10.0", "core-graphics", @@ -5085,27 +5348,26 @@ dependencies = [ [[package]] name = "miette" -version = "7.5.0" +version = "7.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a955165f87b37fd1862df2a59547ac542c77ef6d17c666f619d1ad22dd89484" +checksum = "5f98efec8807c63c752b5bd61f862c165c115b0a35685bdcfd9238c7aeb592b7" dependencies = [ "cfg-if", "miette-derive", - "owo-colors 4.2.0", + "owo-colors", "supports-color", "supports-hyperlinks", "supports-unicode", "terminal_size", "textwrap", - "thiserror 1.0.69", "unicode-width 0.1.14", ] [[package]] name = "miette-derive" -version = "7.5.0" +version = "7.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf45bf44ab49be92fd1227a3be6fc6f617f1a337c06af54981048574d8783147" +checksum = "db5b29714e950dbb20d5e6f74f9dcec4edbcc1067bb7f8ed198c097b8c1a818b" dependencies = [ "proc-macro2", "quote", @@ -5135,18 +5397,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.7.4" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" -dependencies = [ - "adler", -] - -[[package]] -name = "miniz_oxide" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5" +checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" dependencies = [ "adler2", "simd-adler32", @@ -5174,13 +5427,13 @@ dependencies = [ "bytes", "colored", "futures-util", - "http 1.2.0", + "http 1.3.1", "http-body 1.0.1", "http-body-util", "hyper 1.6.0", "hyper-util", "log", - "rand 0.9.0", + "rand 0.9.1", "regex", "serde_json", "serde_urlencoded", @@ -5241,7 +5494,24 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a51313c5820b0b02bd422f4b44776fbf47961755c74ce64afc73bfad10226c3" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", +] + +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework 2.11.1", + "security-framework-sys", + "tempfile", ] [[package]] @@ -5416,7 +5686,7 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fcc7c92f190c97f79b4a332f5e81dcf68c8420af2045c936c9be0bc9de6f63b5" dependencies = [ - "proc-macro-crate 3.2.0", + "proc-macro-crate 3.3.0", "proc-macro2", "quote", "syn 1.0.109", @@ -5504,7 +5774,7 @@ version = "0.104.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "41c68c7c06898a5c4c9f10038da63759661cb8ac8f301ce7d159173a595c8258" dependencies = [ - "dirs", + "dirs 5.0.1", "omnipath", "pwd", "ref-cast", @@ -5519,8 +5789,8 @@ dependencies = [ "bytes", "chrono", "chrono-humanize", - "dirs", - "dirs-sys", + "dirs 5.0.1", + "dirs-sys 0.4.1", "fancy-regex", "heck 0.5.0", "indexmap 2.9.0", @@ -5618,7 +5888,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a652d9771a63711fd3c3deb670acfbe5c30a4072e664d7a3bf5a9e1056ac72c3" dependencies = [ "arrayvec", - "itoa 1.0.14", + "itoa 1.0.15", ] [[package]] @@ -5665,7 +5935,7 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56" dependencies = [ - "proc-macro-crate 3.2.0", + "proc-macro-crate 3.3.0", "proc-macro2", "quote", "syn 2.0.101", @@ -5716,9 +5986,9 @@ dependencies = [ [[package]] name = "objc2" -version = "0.6.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3531f65190d9cff863b77a99857e74c314dd16bf56c538c4b57c7cbc3f3a6e59" +checksum = "88c6597e14493ab2e44ce58f2fdecf095a51f12ca57bec060a11c57332520551" dependencies = [ "objc2-encode", ] @@ -5730,7 +6000,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e4e89ad9e3d7d297152b17d39ed92cd50ca8063a89a9fa569046d41568891eff" dependencies = [ "bitflags 2.9.0", - "block2", + "block2 0.5.1", "libc", "objc2 0.5.2", "objc2-core-data", @@ -5741,14 +6011,16 @@ dependencies = [ [[package]] name = "objc2-app-kit" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5906f93257178e2f7ae069efb89fbd6ee94f0592740b5f8a1512ca498814d0fb" +checksum = "e6f29f568bec459b0ddff777cec4fe3fd8666d82d5a40ebd0ff7e66134f89bcc" dependencies = [ "bitflags 2.9.0", - "objc2 0.6.0", + "block2 0.6.1", + "objc2 0.6.1", + "objc2-core-foundation", "objc2-core-graphics", - "objc2-foundation 0.3.0", + "objc2-foundation 0.3.1", ] [[package]] @@ -5758,7 +6030,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "74dd3b56391c7a0596a295029734d3c1c5e7e510a4cb30245f8221ccea96b009" dependencies = [ "bitflags 2.9.0", - "block2", + "block2 0.5.1", "objc2 0.5.2", "objc2-core-location", "objc2-foundation 0.2.2", @@ -5770,7 +6042,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a5ff520e9c33812fd374d8deecef01d4a840e7b41862d849513de77e44aa4889" dependencies = [ - "block2", + "block2 0.5.1", "objc2 0.5.2", "objc2-foundation 0.2.2", ] @@ -5782,29 +6054,31 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "617fbf49e071c178c0b24c080767db52958f716d9eabdf0890523aeae54773ef" dependencies = [ "bitflags 2.9.0", - "block2", + "block2 0.5.1", "objc2 0.5.2", "objc2-foundation 0.2.2", ] [[package]] name = "objc2-core-foundation" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daeaf60f25471d26948a1c2f840e3f7d86f4109e3af4e8e4b5cd70c39690d925" +checksum = "1c10c2894a6fed806ade6027bcd50662746363a9589d3ec9d9bef30a4e4bc166" dependencies = [ "bitflags 2.9.0", - "objc2 0.6.0", + "dispatch2 0.3.0", + "objc2 0.6.1", ] [[package]] name = "objc2-core-graphics" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8dca602628b65356b6513290a21a6405b4d4027b8b250f0b98dddbb28b7de02" +checksum = "989c6c68c13021b5c2d6b71456ebb0f9dc78d752e86a98da7c716f4f9470f5a4" dependencies = [ "bitflags 2.9.0", - "objc2 0.6.0", + "dispatch2 0.3.0", + "objc2 0.6.1", "objc2-core-foundation", "objc2-io-surface", ] @@ -5815,7 +6089,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55260963a527c99f1819c4f8e3b47fe04f9650694ef348ffd2227e8196d34c80" dependencies = [ - "block2", + "block2 0.5.1", "objc2 0.5.2", "objc2-foundation 0.2.2", "objc2-metal", @@ -5827,7 +6101,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "000cfee34e683244f284252ee206a27953279d370e309649dc3ee317b37e5781" dependencies = [ - "block2", + "block2 0.5.1", "objc2 0.5.2", "objc2-contacts", "objc2-foundation 0.2.2", @@ -5846,20 +6120,20 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ee638a5da3799329310ad4cfa62fbf045d5f56e3ef5ba4149e7452dcf89d5a8" dependencies = [ "bitflags 2.9.0", - "block2", - "dispatch", + "block2 0.5.1", "libc", "objc2 0.5.2", ] [[package]] name = "objc2-foundation" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a21c6c9014b82c39515db5b396f91645182611c97d24637cf56ac01e5f8d998" +checksum = "900831247d2fe1a09a683278e5384cfb8c80c79fe6b166f9d14bfdde0ea1b03c" dependencies = [ "bitflags 2.9.0", - "objc2 0.6.0", + "block2 0.6.1", + "objc2 0.6.1", "objc2-core-foundation", ] @@ -5876,12 +6150,12 @@ dependencies = [ [[package]] name = "objc2-io-surface" -version = "0.3.0" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "161a8b87e32610086e1a7a9e9ec39f84459db7b3a0881c1f16ca5a2605581c19" +checksum = "7282e9ac92529fa3457ce90ebb15f4ecbc383e8338060960760fa2cf75420c3c" dependencies = [ "bitflags 2.9.0", - "objc2 0.6.0", + "objc2 0.6.1", "objc2-core-foundation", ] @@ -5891,7 +6165,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1a1ae721c5e35be65f01a03b6d2ac13a54cb4fa70d8a5da293d7b0020261398" dependencies = [ - "block2", + "block2 0.5.1", "objc2 0.5.2", "objc2-app-kit 0.2.2", "objc2-foundation 0.2.2", @@ -5904,7 +6178,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd0cba1276f6023976a406a14ffa85e1fdd19df6b0f737b063b95f6c8c7aadd6" dependencies = [ "bitflags 2.9.0", - "block2", + "block2 0.5.1", "objc2 0.5.2", "objc2-foundation 0.2.2", ] @@ -5916,7 +6190,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e42bee7bff906b14b167da2bac5efe6b6a07e6f7c0a21a7308d40c960242dc7a" dependencies = [ "bitflags 2.9.0", - "block2", + "block2 0.5.1", "objc2 0.5.2", "objc2-foundation 0.2.2", "objc2-metal", @@ -5939,7 +6213,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8bb46798b20cd6b91cbd113524c490f1686f4c4e8f49502431415f3512e2b6f" dependencies = [ "bitflags 2.9.0", - "block2", + "block2 0.5.1", "objc2 0.5.2", "objc2-cloud-kit", "objc2-core-data", @@ -5959,7 +6233,7 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44fa5f9748dbfe1ca6c0b79ad20725a11eca7c2218bceb4b005cb1be26273bfe" dependencies = [ - "block2", + "block2 0.5.1", "objc2 0.5.2", "objc2-foundation 0.2.2", ] @@ -5971,7 +6245,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76cfcbf642358e8689af64cee815d139339f3ed8ad05103ed5eaf73db8d84cb3" dependencies = [ "bitflags 2.9.0", - "block2", + "block2 0.5.1", "objc2 0.5.2", "objc2-core-location", "objc2-foundation 0.2.2", @@ -5984,7 +6258,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68bc69301064cebefc6c4c90ce9cba69225239e4b8ff99d445a2b5563797da65" dependencies = [ "bitflags 2.9.0", - "block2", + "block2 0.5.1", "objc2 0.5.2", "objc2-app-kit 0.2.2", "objc2-foundation 0.2.2", @@ -5992,9 +6266,9 @@ dependencies = [ [[package]] name = "object" -version = "0.32.2" +version = "0.36.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6a622008b6e321afc04970976f62ee297fdbaa6f95318ca343e3eebb9648441" +checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" dependencies = [ "memchr", ] @@ -6007,9 +6281,9 @@ checksum = "80adb31078122c880307e9cdfd4e3361e6545c319f9b9dcafcb03acd3b51a575" [[package]] name = "once_cell" -version = "1.20.2" +version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" [[package]] name = "onig" @@ -6035,15 +6309,53 @@ dependencies = [ [[package]] name = "oorandom" -version = "11.1.4" +version = "11.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + +[[package]] +name = "openssl" +version = "0.10.72" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fedfea7d58a1f73118430a55da6a286e7b044961736ce96a16a17068ea25e5da" +dependencies = [ + "bitflags 2.9.0", + "cfg-if", + "foreign-types 0.3.2", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] [[package]] name = "openssl-probe" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e145e1651e858e820e4860f7b9c5e169bc1d8ce1c86043be79fa7b7634821847" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] [[package]] name = "option-ext" @@ -6092,9 +6404,9 @@ dependencies = [ [[package]] name = "outref" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4030760ffd992bef45b0ae3f10ce1aba99e33464c90d14dd7c039884963ddc7a" +checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" [[package]] name = "overload" @@ -6102,12 +6414,6 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" -[[package]] -name = "owo-colors" -version = "3.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f" - [[package]] name = "owo-colors" version = "4.2.0" @@ -6186,7 +6492,17 @@ version = "0.6.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ - "fixedbitset", + "fixedbitset 0.4.2", + "indexmap 2.9.0", +] + +[[package]] +name = "petgraph" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" +dependencies = [ + "fixedbitset 0.5.7", "indexmap 2.9.0", ] @@ -6250,6 +6566,16 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "phf_generator" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" +dependencies = [ + "phf_shared 0.11.3", + "rand 0.8.5", +] + [[package]] name = "phf_macros" version = "0.8.0" @@ -6270,7 +6596,7 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c00cf8b9eafe68dde5e9eaa2cef8ee84a9336a47d566ec55ca16589633b65af7" dependencies = [ - "siphasher", + "siphasher 0.3.11", ] [[package]] @@ -6279,7 +6605,16 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6796ad771acdc0123d2a88dc428b5e38ef24456743ddb1744ed628f9815c096" dependencies = [ - "siphasher", + "siphasher 0.3.11", +] + +[[package]] +name = "phf_shared" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5" +dependencies = [ + "siphasher 1.0.1", ] [[package]] @@ -6327,9 +6662,9 @@ dependencies = [ [[package]] name = "pkg-config" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" [[package]] name = "plist" @@ -6382,7 +6717,7 @@ dependencies = [ "crc32fast", "fdeflate", "flate2", - "miniz_oxide 0.8.5", + "miniz_oxide", ] [[package]] @@ -6393,7 +6728,7 @@ checksum = "a604568c3202727d1507653cb121dbd627a58684eb09a820fd746bee38b4442f" dependencies = [ "cfg-if", "concurrent-queue", - "hermit-abi", + "hermit-abi 0.4.0", "pin-project-lite", "rustix 0.38.44", "tracing", @@ -6408,9 +6743,9 @@ checksum = "2f3a9f18d041e6d0e102a0a46750538147e5e8992d3b4873aaafee2520b00ce3" [[package]] name = "portable-atomic" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" +checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" [[package]] name = "portable-atomic-util" @@ -6450,11 +6785,11 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.20" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" dependencies = [ - "zerocopy 0.7.35", + "zerocopy 0.8.25", ] [[package]] @@ -6534,11 +6869,11 @@ dependencies = [ [[package]] name = "proc-macro-crate" -version = "3.2.0" +version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" +checksum = "edce586971a4dfaa28950c6f18ed55e0406c1ab88bbce2c6f6293a7aaba73d35" dependencies = [ - "toml_edit 0.22.22", + "toml_edit 0.22.26", ] [[package]] @@ -6595,9 +6930,9 @@ checksum = "dc375e1527247fe1a97d8b7156678dfe7c1af2fc075c9a4db3690ecd2a148068" [[package]] name = "proc-macro2" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" dependencies = [ "unicode-ident", ] @@ -6663,11 +6998,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be769465445e8c1474e9c5dac2018218498557af32d9ed057325ec9a41ae81bf" dependencies = [ "heck 0.5.0", - "itertools 0.13.0", + "itertools 0.14.0", "log", "multimap", "once_cell", - "petgraph", + "petgraph 0.7.1", "prettyplease", "prost", "prost-types", @@ -6683,7 +7018,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a56d757972c98b346a9b766e3f02746cde6dd1cd1d1d563472929fdd74bec4d" dependencies = [ "anyhow", - "itertools 0.13.0", + "itertools 0.14.0", "proc-macro2", "quote", "syn 2.0.101", @@ -6776,56 +7111,6 @@ dependencies = [ "thiserror 1.0.69", ] -[[package]] -name = "q_chat" -version = "1.10.0" -dependencies = [ - "anstream", - "aws-smithy-types", - "bstr", - "clap", - "color-print", - "convert_case 0.8.0", - "crossterm", - "eyre", - "fig_api_client", - "fig_auth", - "fig_diagnostic", - "fig_install", - "fig_os_shim", - "fig_settings", - "fig_telemetry", - "fig_util", - "futures", - "glob", - "mcp_client", - "rand 0.9.0", - "regex", - "rustyline", - "semver", - "serde", - "serde_json", - "shell-color", - "shell-words", - "shellexpand", - "shlex", - "similar", - "skim", - "spinners", - "strip-ansi-escapes", - "syntect", - "tempfile", - "thiserror 2.0.12", - "time", - "tokio", - "tracing", - "tracing-subscriber", - "unicode-width 0.2.0", - "url", - "uuid", - "winnow 0.6.22", -] - [[package]] name = "q_cli" version = "1.10.0" @@ -6882,12 +7167,11 @@ dependencies = [ "objc2 0.5.2", "objc2-app-kit 0.2.2", "objc2-foundation 0.2.2", - "owo-colors 4.2.0", + "owo-colors", "parking_lot", "paste", "predicates", - "q_chat", - "rand 0.9.0", + "rand 0.9.1", "regex", "semver", "serde", @@ -6948,43 +7232,45 @@ dependencies = [ [[package]] name = "quick-xml" -version = "0.37.4" +version = "0.37.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4ce8c88de324ff838700f36fb6ab86c96df0e3c4ab6ef3a9b2044465cce1369" +checksum = "331e97a1af0bf59823e6eadffe373d7b27f485be8748f71471c662c1f269b7fb" dependencies = [ "memchr", ] [[package]] name = "quinn" -version = "0.11.6" +version = "0.11.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62e96808277ec6f97351a2380e6c25114bc9e67037775464979f3037c92d05ef" +checksum = "c3bd15a6f2967aef83887dcb9fec0014580467e33720d073560cf015a5683012" dependencies = [ "bytes", + "cfg_aliases", "pin-project-lite", "quinn-proto", "quinn-udp", - "rustc-hash 2.1.0", - "rustls 0.23.23", + "rustc-hash 2.1.1", + "rustls 0.23.26", "socket2", "thiserror 2.0.12", "tokio", "tracing", + "web-time", ] [[package]] name = "quinn-proto" -version = "0.11.9" +version = "0.11.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d" +checksum = "bcbafbbdbb0f638fe3f35f3c56739f77a8a1d070cb25603226c83339b391472b" dependencies = [ "bytes", - "getrandom 0.2.15", - "rand 0.8.5", + "getrandom 0.3.2", + "rand 0.9.1", "ring", - "rustc-hash 2.1.0", - "rustls 0.23.23", + "rustc-hash 2.1.1", + "rustls 0.23.26", "rustls-pki-types", "slab", "thiserror 2.0.12", @@ -6995,9 +7281,9 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.5.9" +version = "0.5.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c40286217b4ba3a71d644d752e6a0b71f13f1b6a2c5311acfcbe0c2418ed904" +checksum = "ee4e529991f949c5e25755532370b8af5d114acae52326361d68d47af64aa842" dependencies = [ "cfg_aliases", "libc", @@ -7016,6 +7302,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5" + [[package]] name = "r2d2" version = "0.8.10" @@ -7075,13 +7367,12 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.3", - "zerocopy 0.8.23", ] [[package]] @@ -7129,7 +7420,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", ] [[package]] @@ -7138,7 +7429,7 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom 0.3.1", + "getrandom 0.3.2", ] [[package]] @@ -7196,9 +7487,9 @@ dependencies = [ [[package]] name = "ravif" -version = "0.11.11" +version = "0.11.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2413fd96bd0ea5cdeeb37eaf446a22e6ed7b981d792828721e74ded1980a45c6" +checksum = "d6a5f31fcf7500f9401fea858ea4ab5525c99f2322cfcee732c0e6c74208c0c6" dependencies = [ "avif-serialize", "imgref", @@ -7237,9 +7528,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.8" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" +checksum = "d2f103c6d277498fbceb16e84d317e2a400f160f46904d5f5410848c829511a3" dependencies = [ "bitflags 2.9.0", ] @@ -7250,11 +7541,22 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", "libredox", "thiserror 1.0.69", ] +[[package]] +name = "redox_users" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd6f9d3d47bdd2ad6945c5015a226ec6155d0bcdfd8f7cd29f86b71f8de99d2b" +dependencies = [ + "getrandom 0.2.16", + "libredox", + "thiserror 2.0.12", +] + [[package]] name = "ref-cast" version = "1.0.24" @@ -7327,9 +7629,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "reqwest" -version = "0.12.14" +version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "989e327e510263980e231de548a33e63d34962d29ae61b467389a1a09627a254" +checksum = "d19c46a6fdd48bc4dab94b6103fccc55d34c67cc0ad04653aad4ea2a07cd7bbb" dependencies = [ "async-compression", "base64 0.22.1", @@ -7337,24 +7639,27 @@ dependencies = [ "cookie", "cookie_store", "encoding_rs", + "futures-channel", "futures-core", "futures-util", - "h2 0.4.7", - "http 1.2.0", + "h2 0.4.9", + "http 1.3.1", "http-body 1.0.1", "http-body-util", "hyper 1.6.0", "hyper-rustls 0.27.5", + "hyper-tls", "hyper-util", "ipnet", "js-sys", "log", "mime", + "native-tls", "once_cell", "percent-encoding", "pin-project-lite", "quinn", - "rustls 0.23.23", + "rustls 0.23.26", "rustls-native-certs 0.8.1", "rustls-pemfile 2.2.0", "rustls-pki-types", @@ -7363,7 +7668,8 @@ dependencies = [ "serde_urlencoded", "sync_wrapper", "tokio", - "tokio-rustls 0.26.1", + "tokio-native-tls", + "tokio-rustls 0.26.2", "tokio-socks", "tokio-util", "tower", @@ -7390,19 +7696,19 @@ dependencies = [ [[package]] name = "rfd" -version = "0.15.2" +version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a24763657bff09769a8ccf12c8b8a50416fb035fe199263b4c5071e4e3f006f" +checksum = "80c844748fdc82aae252ee4594a89b6e7ebef1063de7951545564cbc4e57075d" dependencies = [ "ashpd", - "block2", - "core-foundation 0.10.0", - "core-foundation-sys", + "block2 0.6.1", + "dispatch2 0.2.0", "js-sys", "log", - "objc2 0.5.2", - "objc2-app-kit 0.2.2", - "objc2-foundation 0.2.2", + "objc2 0.6.1", + "objc2-app-kit 0.3.1", + "objc2-core-foundation", + "objc2-foundation 0.3.1", "pollster", "raw-window-handle", "urlencoding", @@ -7426,7 +7732,7 @@ checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" dependencies = [ "cc", "cfg-if", - "getrandom 0.2.15", + "getrandom 0.2.16", "libc", "untrusted", "windows-sys 0.52.0", @@ -7502,9 +7808,9 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustc-hash" -version = "2.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustc_version" @@ -7530,14 +7836,14 @@ dependencies = [ [[package]] name = "rustix" -version = "1.0.5" +version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d97817398dd4bb2e6da002002db259209759911da105da92bec29ccb12cf58bf" +checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" dependencies = [ "bitflags 2.9.0", "errno", "libc", - "linux-raw-sys 0.9.2", + "linux-raw-sys 0.9.4", "windows-sys 0.59.0", ] @@ -7555,16 +7861,16 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.23" +version = "0.23.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47796c98c480fce5406ef69d1c76378375492c3b0a0de587be0c1d9feb12f395" +checksum = "df51b5869f3a441595eac5e8ff14d486ff285f7b8c0df8770e49c3b56351f0f0" dependencies = [ "aws-lc-rs", "log", "once_cell", "ring", "rustls-pki-types", - "rustls-webpki 0.102.8", + "rustls-webpki 0.103.1", "subtle", "zeroize", ] @@ -7613,9 +7919,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37" +checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" dependencies = [ "web-time", ] @@ -7632,9 +7938,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.102.8" +version = "0.103.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +checksum = "fef8b8769aaccf73098557a87cd1816b4f9c7c16811c9c77142aa695c16f2c03" dependencies = [ "aws-lc-rs", "ring", @@ -7644,9 +7950,9 @@ dependencies = [ [[package]] name = "rustversion" -version = "1.0.19" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" +checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" [[package]] name = "rustyline" @@ -7658,7 +7964,6 @@ dependencies = [ "cfg-if", "clipboard-win", "fd-lock", - "home", "libc", "log", "memchr", @@ -7684,9 +7989,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.18" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" [[package]] name = "same-file" @@ -7793,6 +8098,36 @@ dependencies = [ "thin-slice", ] +[[package]] +name = "self-replace" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03ec815b5eab420ab893f63393878d89c90fdd94c0bcc44c07abb8ad95552fb7" +dependencies = [ + "fastrand", + "tempfile", + "windows-sys 0.52.0", +] + +[[package]] +name = "self_update" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d832c086ece0dacc29fb2947bb4219b8f6e12fe9e40b7108f9e57c4224e47b5c" +dependencies = [ + "hyper 1.6.0", + "indicatif", + "log", + "quick-xml 0.37.5", + "regex", + "reqwest", + "self-replace", + "semver", + "serde_json", + "tempfile", + "urlencoding", +] + [[package]] name = "semver" version = "1.0.26" @@ -7845,7 +8180,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ "indexmap 2.9.0", - "itoa 1.0.14", + "itoa 1.0.15", "memchr", "ryu", "serde", @@ -7887,7 +8222,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" dependencies = [ "form_urlencoded", - "itoa 1.0.14", + "itoa 1.0.15", "ryu", "serde", ] @@ -7899,7 +8234,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47" dependencies = [ "indexmap 2.9.0", - "itoa 1.0.14", + "itoa 1.0.15", "ryu", "serde", "unsafe-libyaml", @@ -7998,6 +8333,17 @@ dependencies = [ "libc", ] +[[package]] +name = "shell-color" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fce6d5bc71503c9ec2337c80dc41f4fb2ac62fe52d6ab7500d899db19ae436f8" +dependencies = [ + "bitflags 2.9.0", + "nu-ansi-term 0.50.1", + "nu-color-config", +] + [[package]] name = "shell-color" version = "1.10.0" @@ -8029,7 +8375,7 @@ version = "3.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b1fdf65dd6331831494dd616b30351c38e96e45921a27745cf98490458b90bb" dependencies = [ - "dirs", + "dirs 6.0.0", ] [[package]] @@ -8061,9 +8407,9 @@ dependencies = [ [[package]] name = "signal-hook-registry" -version = "1.4.2" +version = "1.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410" dependencies = [ "libc", ] @@ -8095,6 +8441,12 @@ version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" +[[package]] +name = "siphasher" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d" + [[package]] name = "skim" version = "0.16.2" @@ -8113,7 +8465,7 @@ dependencies = [ "indexmap 2.9.0", "log", "nix 0.29.0", - "rand 0.9.0", + "rand 0.9.1", "rayon", "regex", "shell-quote", @@ -8137,9 +8489,9 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.13.2" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" [[package]] name = "socket2" @@ -8211,26 +8563,25 @@ checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] name = "string_cache" -version = "0.8.7" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f91138e76242f575eb1d3b38b4f1362f10d3a43f47d182a5b359af488a02293b" +checksum = "bf776ba3fa74f83bf4b63c3dcbbf82173db2632ed8452cb2d891d33f459de70f" dependencies = [ "new_debug_unreachable", - "once_cell", "parking_lot", - "phf_shared 0.10.0", + "phf_shared 0.11.3", "precomputed-hash", "serde", ] [[package]] name = "string_cache_codegen" -version = "0.5.2" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bb30289b722be4ff74a408c3cc27edeaad656e06cb1fe8fa9231fa59c728988" +checksum = "c711928715f1fe0fe509c53b43e993a9a557babc2d0a3567d0a3006f1ac931a0" dependencies = [ - "phf_generator 0.10.0", - "phf_shared 0.10.0", + "phf_generator 0.11.3", + "phf_shared 0.11.3", "proc-macro2", "quote", ] @@ -8373,9 +8724,9 @@ dependencies = [ [[package]] name = "synstructure" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", @@ -8536,15 +8887,14 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tempfile" -version = "3.18.0" +version = "3.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c317e0a526ee6120d8dabad239c8dadca62b24b6f168914bbbc8e2fb1f0e567" +checksum = "7437ac7763b9b123ccf33c338a5cc1bac6f69b45a136c19bdd8a65e3916435bf" dependencies = [ - "cfg-if", "fastrand", - "getrandom 0.3.1", + "getrandom 0.3.2", "once_cell", - "rustix 1.0.5", + "rustix 1.0.7", "windows-sys 0.59.0", ] @@ -8572,11 +8922,11 @@ dependencies = [ [[package]] name = "terminal_size" -version = "0.4.1" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5352447f921fda68cf61b4101566c0bdb5104eff6804d0678e5227580ab6a4e9" +checksum = "45c6481c4829e4cc63825e62c49186a34538b7b2750b73b266581ffb612fb5ed" dependencies = [ - "rustix 0.38.44", + "rustix 1.0.7", "windows-sys 0.59.0", ] @@ -8607,9 +8957,9 @@ dependencies = [ [[package]] name = "test-log-macros" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5999e24eaa32083191ba4e425deb75cdf25efefabe5aaccb7446dd0d4122a3f5" +checksum = "888d0c3c6db53c0fdab160d2ed5e12ba745383d3e85813f2ea0f2b1475ab553f" dependencies = [ "proc-macro2", "quote", @@ -8618,12 +8968,12 @@ dependencies = [ [[package]] name = "textwrap" -version = "0.16.1" +version = "0.16.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" +checksum = "c13547615a44dc9c452a8a534638acdf07120d4b6847c8178705da06306a3057" dependencies = [ "unicode-linebreak", - "unicode-width 0.1.14", + "unicode-width 0.2.0", ] [[package]] @@ -8700,7 +9050,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" dependencies = [ "deranged", - "itoa 1.0.14", + "itoa 1.0.15", "libc", "num-conv", "num_threads", @@ -8766,9 +9116,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8" +checksum = "09b3661f17e86524eccd4371ab0429194e0d7c008abb45f7a7495b1719463c71" dependencies = [ "tinyvec_macros", ] @@ -8809,6 +9159,16 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.24.1" @@ -8821,11 +9181,11 @@ dependencies = [ [[package]] name = "tokio-rustls" -version = "0.26.1" +version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37" +checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" dependencies = [ - "rustls 0.23.23", + "rustls 0.23.26", "tokio", ] @@ -8882,21 +9242,21 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.19" +version = "0.8.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" +checksum = "05ae329d1f08c4d17a59bed7ff5b5a769d062e64a62d34a3261b219e62cd5aae" dependencies = [ "serde", "serde_spanned", "toml_datetime", - "toml_edit 0.22.22", + "toml_edit 0.22.26", ] [[package]] name = "toml_datetime" -version = "0.6.8" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +checksum = "3da5db5a963e24bc68be8b17b6fa82814bb22ee8660f192bb182771d498f09a3" dependencies = [ "serde", ] @@ -8925,17 +9285,24 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.22.22" +version = "0.22.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" +checksum = "310068873db2c5b3e7659d2cc35d21855dbafa50d1ce336397c666e3cb08137e" dependencies = [ "indexmap 2.9.0", "serde", "serde_spanned", "toml_datetime", - "winnow 0.6.22", + "toml_write", + "winnow 0.7.9", ] +[[package]] +name = "toml_write" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfb942dfe1d8e29a7ee7fcbde5bd2b9a25fb89aa70caea2eba3bee836ff41076" + [[package]] name = "tower" version = "0.5.2" @@ -9085,21 +9452,22 @@ dependencies = [ [[package]] name = "tray-icon" -version = "0.19.2" +version = "0.19.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d48a05076dd272615d03033bf04f480199f7d1b66a8ac64d75c625fc4a70c06b" +checksum = "eadd75f5002e2513eaa19b2365f533090cc3e93abd38788452d9ea85cff7b48a" dependencies = [ - "core-graphics", "crossbeam-channel", - "dirs", + "dirs 6.0.0", "libappindicator", "muda", - "objc2 0.5.2", - "objc2-app-kit 0.2.2", - "objc2-foundation 0.2.2", + "objc2 0.6.1", + "objc2-app-kit 0.3.1", + "objc2-core-foundation", + "objc2-core-graphics", + "objc2-foundation 0.3.1", "once_cell", "png", - "thiserror 1.0.69", + "thiserror 2.0.12", "windows-sys 0.59.0", ] @@ -9113,7 +9481,7 @@ dependencies = [ "memchr", "nom", "once_cell", - "petgraph", + "petgraph 0.6.5", ] [[package]] @@ -9144,10 +9512,10 @@ checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13" dependencies = [ "bytes", "data-encoding", - "http 1.2.0", + "http 1.3.1", "httparse", "log", - "rand 0.9.0", + "rand 0.9.1", "sha1", "thiserror 2.0.12", "utf-8", @@ -9155,21 +9523,21 @@ dependencies = [ [[package]] name = "typeid" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e13db2e0ccd5e14a544e8a246ba2312cd25223f616442d7f2cb0e3db614236e" +checksum = "bc7d623258602320d5c55d1bc22793b57daff0ec7efc270ea7d55ce1d5f5471c" [[package]] name = "typenum" -version = "1.17.0" +version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" +checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" [[package]] name = "typetag" -version = "0.2.19" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "044fc3365ddd307c297fe0fe7b2e70588cdab4d0f62dc52055ca0d11b174cf0e" +checksum = "73f22b40dd7bfe8c14230cf9702081366421890435b2d625fa92b4acc4c3de6f" dependencies = [ "erased-serde", "inventory", @@ -9180,9 +9548,9 @@ dependencies = [ [[package]] name = "typetag-impl" -version = "0.2.19" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d9d30226ac9cbd2d1ff775f74e8febdab985dab14fb14aa2582c29a92d5555dc" +checksum = "35f5380909ffc31b4de4f4bdf96b877175a016aa2ca98cee39fcfd8c4d53d952" dependencies = [ "proc-macro2", "quote", @@ -9208,9 +9576,9 @@ checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" [[package]] name = "unicode-ident" -version = "1.0.14" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" [[package]] name = "unicode-linebreak" @@ -9292,12 +9660,12 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.15.1" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0f540e3240398cce6128b64ba83fdbdd86129c16a3aa1a3a252efd66eb3d587" +checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" dependencies = [ - "getrandom 0.3.1", - "rand 0.9.0", + "getrandom 0.3.2", + "rand 0.9.1", "serde", ] @@ -9314,9 +9682,9 @@ dependencies = [ [[package]] name = "valuable" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" [[package]] name = "vcpkg" @@ -9373,9 +9741,9 @@ dependencies = [ [[package]] name = "wait-timeout" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f200f5b12eb75f8c1ed65abd4b2db8a6e1b138a20de009dacee265a2498f3f6" +checksum = "09ac3b126d3914f9849036f826e054cbabdc8519970b8998ddaf3b5bd3c65f11" dependencies = [ "libc", ] @@ -9413,9 +9781,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasi" -version = "0.13.3+wasi-0.2.2" +version = "0.14.2+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" dependencies = [ "wit-bindgen-rt", ] @@ -9499,23 +9867,22 @@ dependencies = [ [[package]] name = "wayland-backend" -version = "0.3.8" +version = "0.3.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7208998eaa3870dad37ec8836979581506e0c5c64c20c9e79e9d2a10d6f47bf" +checksum = "fe770181423e5fc79d3e2a7f4410b7799d5aab1de4372853de3c6aa13ca24121" dependencies = [ "cc", "downcast-rs", "rustix 0.38.44", - "scoped-tls", "smallvec", "wayland-sys", ] [[package]] name = "wayland-client" -version = "0.31.8" +version = "0.31.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2120de3d33638aaef5b9f4472bff75f07c56379cf76ea320bd3a3d65ecaf73f" +checksum = "978fa7c67b0847dbd6a9f350ca2569174974cd4082737054dbb7fbb79d7d9a61" dependencies = [ "bitflags 2.9.0", "rustix 0.38.44", @@ -9525,9 +9892,9 @@ dependencies = [ [[package]] name = "wayland-protocols" -version = "0.32.6" +version = "0.32.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0781cf46869b37e36928f7b432273c0995aa8aed9552c556fb18754420541efc" +checksum = "779075454e1e9a521794fed15886323ea0feda3f8b0fc1390f5398141310422a" dependencies = [ "bitflags 2.9.0", "wayland-backend", @@ -9537,9 +9904,9 @@ dependencies = [ [[package]] name = "wayland-protocols-wlr" -version = "0.3.6" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "248a02e6f595aad796561fa82d25601bd2c8c3b145b1c7453fc8f94c1a58f8b2" +checksum = "1cb6cdc73399c0e06504c437fe3cf886f25568dd5454473d565085b36d6a8bbf" dependencies = [ "bitflags 2.9.0", "wayland-backend", @@ -9555,7 +9922,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "896fdafd5d28145fce7958917d69f2fd44469b1d4e861cb5961bcbeebc6d1484" dependencies = [ "proc-macro2", - "quick-xml 0.37.4", + "quick-xml 0.37.5", "quote", ] @@ -9565,8 +9932,6 @@ version = "0.31.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dbcebb399c77d5aa9fa5db874806ee7b4eba4e73650948e8f93963f128896615" dependencies = [ - "dlib", - "log", "pkg-config", ] @@ -9636,9 +10001,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.8" +version = "0.26.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2210b291f7ea53617fbafcc4939f10914214ec15aace5ba62293a668f322c5c9" +checksum = "37493cadf42a2a939ed404698ded7fb378bf301b5011f973361779a3a74f8c93" dependencies = [ "rustls-pki-types", ] @@ -9717,7 +10082,7 @@ checksum = "24d643ce3fd3e5b54854602a080f34fb10ab75e0b813ee32d00ca2b44fa74762" dependencies = [ "either", "env_home", - "rustix 1.0.5", + "rustix 1.0.7", "winsafe", ] @@ -9793,15 +10158,6 @@ dependencies = [ "windows-targets 0.52.6", ] -[[package]] -name = "windows-core" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" -dependencies = [ - "windows-targets 0.52.6", -] - [[package]] name = "windows-core" version = "0.56.0" @@ -9839,6 +10195,19 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-core" +version = "0.61.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4763c1de310c86d75a878046489e2e5ba02c649d185f21c67d4cf8a56d098980" +dependencies = [ + "windows-implement 0.60.0", + "windows-interface 0.59.1", + "windows-link", + "windows-result 0.3.2", + "windows-strings 0.4.0", +] + [[package]] name = "windows-implement" version = "0.56.0" @@ -9872,6 +10241,17 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "windows-implement" +version = "0.60.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "windows-interface" version = "0.56.0" @@ -9905,11 +10285,22 @@ dependencies = [ "syn 2.0.101", ] +[[package]] +name = "windows-interface" +version = "0.59.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "windows-link" -version = "0.1.0" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dccfd733ce2b1753b03b6d3c65edf020262ea35e20ccdf3e288043e6dd620e3" +checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" [[package]] name = "windows-registry" @@ -9917,7 +10308,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3" dependencies = [ - "windows-result 0.3.1", + "windows-result 0.3.2", "windows-strings 0.3.1", "windows-targets 0.53.0", ] @@ -9942,9 +10333,9 @@ dependencies = [ [[package]] name = "windows-result" -version = "0.3.1" +version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06374efe858fab7e4f881500e6e86ec8bc28f9462c47e5a9941a0142ad86b189" +checksum = "c64fd11a4fd95df68efcfee5f44a294fe71b8bc6a91993e2791938abcc712252" dependencies = [ "windows-link", ] @@ -9968,6 +10359,15 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-strings" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ba9642430ee452d5a7aa78d72907ebe8cfda358e8cb7918a2050581322f97" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.45.0" @@ -10068,11 +10468,11 @@ dependencies = [ [[package]] name = "windows-version" -version = "0.1.2" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c12476c23a74725c539b24eae8bfc0dac4029c39cdb561d9f23616accd4ae26d" +checksum = "e04a5c6627e310a23ad2358483286c7df260c964eb2d003d8efd6d0f4e79265c" dependencies = [ - "windows-targets 0.53.0", + "windows-link", ] [[package]] @@ -10266,9 +10666,18 @@ dependencies = [ [[package]] name = "winnow" -version = "0.6.22" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39281189af81c07ec09db316b302a3e67bf9bd7cbf6c820b50e35fee9c2fa980" +checksum = "7a4191c47f15cc3ec71fcb4913cb83d58def65dd3787610213c649283b5ce178" +dependencies = [ + "memchr", +] + +[[package]] +name = "winnow" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9fb597c990f03753e08d3c29efbfcf2019a003b4bf4ba19225c158e1549f0f3" dependencies = [ "memchr", ] @@ -10300,9 +10709,9 @@ checksum = "d135d17ab770252ad95e9a872d365cf3090e3be864a34ab46f48555993efc904" [[package]] name = "wit-bindgen-rt" -version = "0.33.0" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" dependencies = [ "bitflags 2.9.0", ] @@ -10345,7 +10754,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2e33c08b174442ff80d5c791020696f9f8b4e4a87b8cfc7494aad6167ec44e1" dependencies = [ "base64 0.22.1", - "block2", + "block2 0.5.1", "cookie", "crossbeam-channel", "dpi", @@ -10353,7 +10762,7 @@ dependencies = [ "gdkx11", "gtk", "html5ever", - "http 1.2.0", + "http 1.3.1", "javascriptcore-rs", "jni", "kuchikiki", @@ -10421,13 +10830,12 @@ checksum = "ec107c4503ea0b4a98ef47356329af139c0a4f7750e621cf2973cd3385ebcb3d" [[package]] name = "xattr" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e105d177a3871454f754b33bb0ee637ecaaac997446375fd3e5d43a2ed00c909" +checksum = "0d65cbf2f12c15564212d48f4e3dfb87923d25d611f2aed18f4cb23f0413d89e" dependencies = [ "libc", - "linux-raw-sys 0.4.15", - "rustix 0.38.44", + "rustix 1.0.7", ] [[package]] @@ -10538,9 +10946,9 @@ dependencies = [ [[package]] name = "zbus" -version = "5.2.0" +version = "5.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fb67eadba43784b6fb14857eba0d8fc518686d3ee537066eb6086dc318e2c8a1" +checksum = "59c333f648ea1b647bc95dc1d34807c8e25ed7a6feff3394034dc4776054b236" dependencies = [ "async-broadcast", "async-executor", @@ -10555,7 +10963,7 @@ dependencies = [ "enumflags2", "event-listener", "futures-core", - "futures-util", + "futures-lite", "hex", "nix 0.29.0", "ordered-stream", @@ -10565,11 +10973,11 @@ dependencies = [ "tracing", "uds_windows", "windows-sys 0.59.0", - "winnow 0.6.22", + "winnow 0.7.9", "xdg-home", - "zbus_macros 5.2.0", - "zbus_names 4.1.0", - "zvariant 5.1.0", + "zbus_macros 5.5.0", + "zbus_names 4.2.0", + "zvariant 5.4.0", ] [[package]] @@ -10578,7 +10986,7 @@ version = "4.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "267db9407081e90bbfa46d841d3cbc60f59c0351838c4bc65199ecd79ab1983e" dependencies = [ - "proc-macro-crate 3.2.0", + "proc-macro-crate 3.3.0", "proc-macro2", "quote", "syn 2.0.101", @@ -10587,17 +10995,17 @@ dependencies = [ [[package]] name = "zbus_macros" -version = "5.2.0" +version = "5.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c9d49ebc960ceb660f2abe40a5904da975de6986f2af0d7884b39eec6528c57" +checksum = "f325ad10eb0d0a3eb060203494c3b7ec3162a01a59db75d2deee100339709fc0" dependencies = [ - "proc-macro-crate 3.2.0", + "proc-macro-crate 3.3.0", "proc-macro2", "quote", "syn 2.0.101", - "zbus_names 4.1.0", - "zvariant 5.1.0", - "zvariant_utils 3.0.2", + "zbus_names 4.2.0", + "zvariant 5.4.0", + "zvariant_utils 3.2.0", ] [[package]] @@ -10622,14 +11030,14 @@ dependencies = [ [[package]] name = "zbus_names" -version = "4.1.0" +version = "4.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "856b7a38811f71846fd47856ceee8bccaec8399ff53fb370247e66081ace647b" +checksum = "7be68e64bf6ce8db94f63e72f0c7eb9a60d733f7e0499e628dfab0f84d6bcb97" dependencies = [ "serde", "static_assertions", - "winnow 0.6.22", - "zvariant 5.1.0", + "winnow 0.7.9", + "zvariant 5.4.0", ] [[package]] @@ -10651,17 +11059,16 @@ version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ - "byteorder", "zerocopy-derive 0.7.35", ] [[package]] name = "zerocopy" -version = "0.8.23" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd97444d05a4328b90e75e503a34bad781f14e28a823ad3557f0750df1ebcbc6" +checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" dependencies = [ - "zerocopy-derive 0.8.23", + "zerocopy-derive 0.8.25", ] [[package]] @@ -10677,9 +11084,9 @@ dependencies = [ [[package]] name = "zerocopy-derive" -version = "0.8.23" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6352c01d0edd5db859a63e2605f4ea3183ddbd15e2c4a9e7d32184df75e4f154" +checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" dependencies = [ "proc-macro2", "quote", @@ -10688,18 +11095,18 @@ dependencies = [ [[package]] name = "zerofrom" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" dependencies = [ "zerofrom-derive", ] [[package]] name = "zerofrom-derive" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", @@ -10746,18 +11153,18 @@ dependencies = [ [[package]] name = "zstd-safe" -version = "7.2.1" +version = "7.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" dependencies = [ "zstd-sys", ] [[package]] name = "zstd-sys" -version = "2.0.13+zstd.1.5.6" +version = "2.0.15+zstd.1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38ff0f21cfee8f97d94cef41359e0c89aa6113028ab0291aa8ca0038995a95aa" +checksum = "eb81183ddd97d0c74cedf1d50d85c8d08c1b8b68ee863bdee9e706eedba1a237" dependencies = [ "cc", "pkg-config", @@ -10807,18 +11214,18 @@ dependencies = [ [[package]] name = "zvariant" -version = "5.1.0" +version = "5.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1200ee6ac32f1e5a312e455a949a4794855515d34f9909f4a3e082d14e1a56f" +checksum = "b2df9ee044893fcffbdc25de30546edef3e32341466811ca18421e3cd6c5a3ac" dependencies = [ "endi", "enumflags2", "serde", "static_assertions", "url", - "winnow 0.6.22", - "zvariant_derive 5.1.0", - "zvariant_utils 3.0.2", + "winnow 0.7.9", + "zvariant_derive 5.4.0", + "zvariant_utils 3.2.0", ] [[package]] @@ -10827,7 +11234,7 @@ version = "4.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73e2ba546bda683a90652bac4a279bc146adad1386f25379cf73200d2002c449" dependencies = [ - "proc-macro-crate 3.2.0", + "proc-macro-crate 3.3.0", "proc-macro2", "quote", "syn 2.0.101", @@ -10836,15 +11243,15 @@ dependencies = [ [[package]] name = "zvariant_derive" -version = "5.1.0" +version = "5.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "687e3b97fae6c9104fbbd36c73d27d149abf04fb874e2efbd84838763daa8916" +checksum = "74170caa85b8b84cc4935f2d56a57c7a15ea6185ccdd7eadb57e6edd90f94b2f" dependencies = [ - "proc-macro-crate 3.2.0", + "proc-macro-crate 3.3.0", "proc-macro2", "quote", "syn 2.0.101", - "zvariant_utils 3.0.2", + "zvariant_utils 3.2.0", ] [[package]] @@ -10860,14 +11267,14 @@ dependencies = [ [[package]] name = "zvariant_utils" -version = "3.0.2" +version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20d1d011a38f12360e5fcccceeff5e2c42a8eb7f27f0dcba97a0862ede05c9c6" +checksum = "e16edfee43e5d7b553b77872d99bc36afdda75c223ca7ad5e3fbecd82ca5fc34" dependencies = [ "proc-macro2", "quote", "serde", "static_assertions", "syn 2.0.101", - "winnow 0.6.22", + "winnow 0.7.9", ] diff --git a/Cargo.toml b/Cargo.toml index 1a5369068c..90ae9c21b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -106,7 +106,6 @@ objc2-input-method-kit = "0.2.2" parking_lot = "0.12.3" percent-encoding = "2.2.0" portable-pty = "0.8.1" -q_chat = { path = "crates/q_chat" } r2d2 = "0.8.10" r2d2_sqlite = "0.25.0" rand = "0.9.0" diff --git a/crates/amzn-toolkit-telemetry/Cargo.toml b/crates/amzn-toolkit-telemetry-client/Cargo.toml similarity index 81% rename from crates/amzn-toolkit-telemetry/Cargo.toml rename to crates/amzn-toolkit-telemetry-client/Cargo.toml index 8a9ac5ca86..5ac7b22007 100644 --- a/crates/amzn-toolkit-telemetry/Cargo.toml +++ b/crates/amzn-toolkit-telemetry-client/Cargo.toml @@ -11,14 +11,10 @@ [package] edition = "2021" -name = "amzn-toolkit-telemetry" +name = "amzn-toolkit-telemetry-client" version = "1.0.0" authors = ["Grant Gurvis "] -exclude = [ - "/build", - "/Config", - "/build-tools/", -] +exclude = ["/build", "/Config", "/build-tools/"] publish = ["brazil"] description = "Rust client bindings for the toolkit-telemetry service" @@ -53,10 +49,7 @@ features = ["client"] [dependencies.aws-smithy-runtime-api] version = "1.1.3" -features = [ - "client", - "http-02x", -] +features = ["client", "http-02x"] [dependencies.aws-smithy-types] version = "1.1.3" @@ -79,16 +72,7 @@ features = ["test-util"] [features] behavior-version-latest = [] -default = [ - "rustls", - "rt-tokio", -] -rt-tokio = [ - "aws-smithy-async/rt-tokio", - "aws-smithy-types/rt-tokio", -] +default = ["rustls", "rt-tokio"] +rt-tokio = ["aws-smithy-async/rt-tokio", "aws-smithy-types/rt-tokio"] rustls = ["aws-smithy-runtime/tls-rustls"] -test-util = [ - "aws-credential-types/test-util", - "aws-smithy-runtime/test-util", -] +test-util = ["aws-credential-types/test-util", "aws-smithy-runtime/test-util"] diff --git a/crates/amzn-toolkit-telemetry/LICENSE b/crates/amzn-toolkit-telemetry-client/LICENSE similarity index 100% rename from crates/amzn-toolkit-telemetry/LICENSE rename to crates/amzn-toolkit-telemetry-client/LICENSE diff --git a/crates/amzn-toolkit-telemetry/src/auth_plugin.rs b/crates/amzn-toolkit-telemetry-client/src/auth_plugin.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/auth_plugin.rs rename to crates/amzn-toolkit-telemetry-client/src/auth_plugin.rs diff --git a/crates/amzn-toolkit-telemetry/src/client.rs b/crates/amzn-toolkit-telemetry-client/src/client.rs similarity index 97% rename from crates/amzn-toolkit-telemetry/src/client.rs rename to crates/amzn-toolkit-telemetry-client/src/client.rs index e602f0421e..08bfb13912 100644 --- a/crates/amzn-toolkit-telemetry/src/client.rs +++ b/crates/amzn-toolkit-telemetry-client/src/client.rs @@ -100,8 +100,8 @@ impl Client { /// operation call. For example, this can be used to add an additional HTTP header: /// /// ```ignore -/// # async fn wrapper() -> ::std::result::Result<(), amzn_toolkit_telemetry::Error> { -/// # let client: amzn_toolkit_telemetry::Client = unimplemented!(); +/// # async fn wrapper() -> ::std::result::Result<(), amzn_toolkit_telemetry_client::Error> { +/// # let client: amzn_toolkit_telemetry_client::Client = unimplemented!(); /// use ::http::header::{HeaderName, HeaderValue}; /// /// let result = client.post_error_report() diff --git a/crates/amzn-toolkit-telemetry/src/client/customize.rs b/crates/amzn-toolkit-telemetry-client/src/client/customize.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/client/customize.rs rename to crates/amzn-toolkit-telemetry-client/src/client/customize.rs diff --git a/crates/amzn-toolkit-telemetry/src/client/post_error_report.rs b/crates/amzn-toolkit-telemetry-client/src/client/post_error_report.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/client/post_error_report.rs rename to crates/amzn-toolkit-telemetry-client/src/client/post_error_report.rs diff --git a/crates/amzn-toolkit-telemetry/src/client/post_feedback.rs b/crates/amzn-toolkit-telemetry-client/src/client/post_feedback.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/client/post_feedback.rs rename to crates/amzn-toolkit-telemetry-client/src/client/post_feedback.rs diff --git a/crates/amzn-toolkit-telemetry/src/client/post_metrics.rs b/crates/amzn-toolkit-telemetry-client/src/client/post_metrics.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/client/post_metrics.rs rename to crates/amzn-toolkit-telemetry-client/src/client/post_metrics.rs diff --git a/crates/amzn-toolkit-telemetry/src/config.rs b/crates/amzn-toolkit-telemetry-client/src/config.rs similarity index 94% rename from crates/amzn-toolkit-telemetry/src/config.rs rename to crates/amzn-toolkit-telemetry-client/src/config.rs index 9d8629ebd1..c70286e86e 100644 --- a/crates/amzn-toolkit-telemetry/src/config.rs +++ b/crates/amzn-toolkit-telemetry-client/src/config.rs @@ -1,6 +1,6 @@ // Code generated by software.amazon.smithy.rust.codegen.smithy-rs. DO NOT EDIT. -/// Configuration for a amzn_toolkit_telemetry service client. +/// Configuration for a amzn_toolkit_telemetry_client service client. /// /// /// Service configuration allows for customization of endpoints, region, credentials providers, @@ -211,7 +211,7 @@ impl Builder { /// # fn example() { /// use std::time::Duration; /// - /// use amzn_toolkit_telemetry::config::Config; + /// use amzn_toolkit_telemetry_client::config::Config; /// use aws_smithy_runtime::client::http::hyper_014::HyperClientBuilder; /// /// let https_connector = hyper_rustls::HttpsConnectorBuilder::new() @@ -248,7 +248,7 @@ impl Builder { /// # fn example() { /// use std::time::Duration; /// - /// use amzn_toolkit_telemetry::config::{ + /// use amzn_toolkit_telemetry_client::config::{ /// Builder, /// Config, /// }; @@ -265,7 +265,7 @@ impl Builder { /// builder.set_http_client(Some(hyper_client)); /// } /// - /// let mut builder = amzn_toolkit_telemetry::Config::builder(); + /// let mut builder = amzn_toolkit_telemetry_client::Config::builder(); /// override_http_client(&mut builder); /// let config = builder.build(); /// # } @@ -288,7 +288,7 @@ impl Builder { /// # Examples /// Create a custom endpoint resolver that resolves a different endpoing per-stage, e.g. staging /// vs. production. ```no_run - /// use amzn_toolkit_telemetry::config::endpoint::{ + /// use amzn_toolkit_telemetry_client::config::endpoint::{ /// Endpoint, /// EndpointFuture, /// Params, @@ -309,10 +309,10 @@ impl Builder { /// let resolver = StageResolver { /// stage: std::env::var("STAGE").unwrap(), /// }; - /// let config = amzn_toolkit_telemetry::Config::builder() + /// let config = amzn_toolkit_telemetry_client::Config::builder() /// .endpoint_resolver(resolver) /// .build(); - /// let client = amzn_toolkit_telemetry::Client::from_conf(config); + /// let client = amzn_toolkit_telemetry_client::Client::from_conf(config); /// ``` pub fn endpoint_resolver( mut self, @@ -337,8 +337,8 @@ impl Builder { /// /// # Examples /// ```no_run - /// use amzn_toolkit_telemetry::config::Config; - /// use amzn_toolkit_telemetry::config::retry::RetryConfig; + /// use amzn_toolkit_telemetry_client::config::Config; + /// use amzn_toolkit_telemetry_client::config::retry::RetryConfig; /// /// let retry_config = RetryConfig::standard().with_max_attempts(5); /// let config = Config::builder().retry_config(retry_config).build(); @@ -352,8 +352,8 @@ impl Builder { /// /// # Examples /// ```no_run - /// use amzn_toolkit_telemetry::config::retry::RetryConfig; - /// use amzn_toolkit_telemetry::config::{ + /// use amzn_toolkit_telemetry_client::config::retry::RetryConfig; + /// use amzn_toolkit_telemetry_client::config::{ /// Builder, /// Config, /// }; @@ -380,7 +380,7 @@ impl Builder { /// # Examples /// /// ```no_run - /// use amzn_toolkit_telemetry::config::{ + /// use amzn_toolkit_telemetry_client::config::{ /// AsyncSleep, /// Config, /// SharedAsyncSleep, @@ -411,7 +411,7 @@ impl Builder { /// # Examples /// /// ```no_run - /// use amzn_toolkit_telemetry::config::{ + /// use amzn_toolkit_telemetry_client::config::{ /// AsyncSleep, /// Builder, /// Config, @@ -448,8 +448,8 @@ impl Builder { /// /// ```no_run /// # use std::time::Duration; - /// use amzn_toolkit_telemetry::config::Config; - /// use amzn_toolkit_telemetry::config::timeout::TimeoutConfig; + /// use amzn_toolkit_telemetry_client::config::Config; + /// use amzn_toolkit_telemetry_client::config::timeout::TimeoutConfig; /// /// let timeout_config = TimeoutConfig::builder() /// .operation_attempt_timeout(Duration::from_secs(1)) @@ -467,8 +467,8 @@ impl Builder { /// /// ```no_run /// # use std::time::Duration; - /// use amzn_toolkit_telemetry::config::timeout::TimeoutConfig; - /// use amzn_toolkit_telemetry::config::{ + /// use amzn_toolkit_telemetry_client::config::timeout::TimeoutConfig; + /// use amzn_toolkit_telemetry_client::config::{ /// Builder, /// Config, /// }; @@ -527,22 +527,22 @@ impl Builder { /// /// Disabling identity caching: /// ```no_run - /// use amzn_toolkit_telemetry::config::IdentityCache; + /// use amzn_toolkit_telemetry_client::config::IdentityCache; /// - /// let config = amzn_toolkit_telemetry::Config::builder() + /// let config = amzn_toolkit_telemetry_client::Config::builder() /// .identity_cache(IdentityCache::no_cache()) /// // ... /// .build(); - /// let client = amzn_toolkit_telemetry::Client::from_conf(config); + /// let client = amzn_toolkit_telemetry_client::Client::from_conf(config); /// ``` /// /// Customizing lazy caching: /// ```no_run /// use std::time::Duration; /// - /// use amzn_toolkit_telemetry::config::IdentityCache; + /// use amzn_toolkit_telemetry_client::config::IdentityCache; /// - /// let config = amzn_toolkit_telemetry::Config::builder() + /// let config = amzn_toolkit_telemetry_client::Config::builder() /// .identity_cache( /// IdentityCache::lazy() /// // change the load timeout to 10 seconds @@ -551,7 +551,7 @@ impl Builder { /// ) /// // ... /// .build(); - /// let client = amzn_toolkit_telemetry::Client::from_conf(config); + /// let client = amzn_toolkit_telemetry_client::Client::from_conf(config); /// ``` pub fn identity_cache(mut self, identity_cache: impl crate::config::ResolveCachedIdentity + 'static) -> Self { self.set_identity_cache(identity_cache); @@ -574,22 +574,22 @@ impl Builder { /// /// Disabling identity caching: /// ```no_run - /// use amzn_toolkit_telemetry::config::IdentityCache; + /// use amzn_toolkit_telemetry_client::config::IdentityCache; /// - /// let config = amzn_toolkit_telemetry::Config::builder() + /// let config = amzn_toolkit_telemetry_client::Config::builder() /// .identity_cache(IdentityCache::no_cache()) /// // ... /// .build(); - /// let client = amzn_toolkit_telemetry::Client::from_conf(config); + /// let client = amzn_toolkit_telemetry_client::Client::from_conf(config); /// ``` /// /// Customizing lazy caching: /// ```no_run /// use std::time::Duration; /// - /// use amzn_toolkit_telemetry::config::IdentityCache; + /// use amzn_toolkit_telemetry_client::config::IdentityCache; /// - /// let config = amzn_toolkit_telemetry::Config::builder() + /// let config = amzn_toolkit_telemetry_client::Config::builder() /// .identity_cache( /// IdentityCache::lazy() /// // change the load timeout to 10 seconds @@ -598,7 +598,7 @@ impl Builder { /// ) /// // ... /// .build(); - /// let client = amzn_toolkit_telemetry::Client::from_conf(config); + /// let client = amzn_toolkit_telemetry_client::Client::from_conf(config); /// ``` pub fn set_identity_cache( &mut self, @@ -622,7 +622,7 @@ impl Builder { /// # mod tests { /// # #[test] /// # fn example() { - /// use amzn_toolkit_telemetry::config::Config; + /// use amzn_toolkit_telemetry_client::config::Config; /// use aws_smithy_runtime_api::client::interceptors::context::phase::BeforeTransmit; /// use aws_smithy_runtime_api::client::interceptors::{ /// Interceptor, @@ -675,7 +675,7 @@ impl Builder { /// # mod tests { /// # #[test] /// # fn example() { - /// use amzn_toolkit_telemetry::config::{ + /// use amzn_toolkit_telemetry_client::config::{ /// Builder, /// Config, /// }; @@ -772,7 +772,7 @@ impl Builder { /// use aws_smithy_types::retry::ErrorKind; /// use std::error::Error as StdError; /// use std::marker::PhantomData; - /// use amzn_toolkit_telemetry::config::Config; + /// use amzn_toolkit_telemetry_client::config::Config; /// # struct SomeOperationError {} /// /// const RETRYABLE_ERROR_CODES: &[&str] = [ @@ -865,7 +865,7 @@ impl Builder { /// use aws_smithy_types::retry::ErrorKind; /// use std::error::Error as StdError; /// use std::marker::PhantomData; - /// use amzn_toolkit_telemetry::config::{Builder, Config}; + /// use amzn_toolkit_telemetry_client::config::{Builder, Config}; /// # struct SomeOperationError {} /// /// const RETRYABLE_ERROR_CODES: &[&str] = [ @@ -1001,13 +1001,13 @@ impl Builder { /// /// # Examples /// ```no_run - /// use amzn_toolkit_telemetry::config::{ + /// use amzn_toolkit_telemetry_client::config::{ /// Builder, /// Config, /// }; /// use aws_types::region::Region; /// - /// let config = amzn_toolkit_telemetry::Config::builder() + /// let config = amzn_toolkit_telemetry_client::Config::builder() /// .region(Region::new("us-east-1")) /// .build(); /// ``` @@ -1058,25 +1058,25 @@ impl Builder { /// `behavior-version-latest` cargo feature. /// /// ```no_run - /// use amzn_toolkit_telemetry::config::BehaviorVersion; + /// use amzn_toolkit_telemetry_client::config::BehaviorVersion; /// - /// let config = amzn_toolkit_telemetry::Config::builder() + /// let config = amzn_toolkit_telemetry_client::Config::builder() /// .behavior_version(BehaviorVersion::latest()) /// // ... /// .build(); - /// let client = amzn_toolkit_telemetry::Client::from_conf(config); + /// let client = amzn_toolkit_telemetry_client::Client::from_conf(config); /// ``` /// /// Customizing behavior major version: /// /// ```no_run - /// use amzn_toolkit_telemetry::config::BehaviorVersion; + /// use amzn_toolkit_telemetry_client::config::BehaviorVersion; /// - /// let config = amzn_toolkit_telemetry::Config::builder() + /// let config = amzn_toolkit_telemetry_client::Config::builder() /// .behavior_version(BehaviorVersion::v2023_11_09()) /// // ... /// .build(); - /// let client = amzn_toolkit_telemetry::Client::from_conf(config); + /// let client = amzn_toolkit_telemetry_client::Client::from_conf(config); /// ``` pub fn behavior_version(mut self, behavior_version: crate::config::BehaviorVersion) -> Self { self.set_behavior_version(Some(behavior_version)); @@ -1096,25 +1096,25 @@ impl Builder { /// `behavior-version-latest` cargo feature. /// /// ```no_run - /// use amzn_toolkit_telemetry::config::BehaviorVersion; + /// use amzn_toolkit_telemetry_client::config::BehaviorVersion; /// - /// let config = amzn_toolkit_telemetry::Config::builder() + /// let config = amzn_toolkit_telemetry_client::Config::builder() /// .behavior_version(BehaviorVersion::latest()) /// // ... /// .build(); - /// let client = amzn_toolkit_telemetry::Client::from_conf(config); + /// let client = amzn_toolkit_telemetry_client::Client::from_conf(config); /// ``` /// /// Customizing behavior major version: /// /// ```no_run - /// use amzn_toolkit_telemetry::config::BehaviorVersion; + /// use amzn_toolkit_telemetry_client::config::BehaviorVersion; /// - /// let config = amzn_toolkit_telemetry::Config::builder() + /// let config = amzn_toolkit_telemetry_client::Config::builder() /// .behavior_version(BehaviorVersion::v2023_11_09()) /// // ... /// .build(); - /// let client = amzn_toolkit_telemetry::Client::from_conf(config); + /// let client = amzn_toolkit_telemetry_client::Client::from_conf(config); /// ``` pub fn set_behavior_version(&mut self, behavior_version: Option) -> &mut Self { self.behavior_version = behavior_version; @@ -1182,7 +1182,7 @@ impl Builder { .map(|r| layer.store_put(::aws_types::region::SigningRegion::from(r))); Config { config: crate::config::Layer::from(layer.clone()) - .with_name("amzn_toolkit_telemetry::config::Config") + .with_name("amzn_toolkit_telemetry_client::config::Config") .freeze(), cloneable: layer, runtime_components: self.runtime_components, @@ -1281,7 +1281,7 @@ impl ConfigOverrideRuntimePlugin { let _ = resolver; Self { config: ::aws_smithy_types::config_bag::Layer::from(layer) - .with_name("amzn_toolkit_telemetry::config::ConfigOverrideRuntimePlugin") + .with_name("amzn_toolkit_telemetry_client::config::ConfigOverrideRuntimePlugin") .freeze(), components, } diff --git a/crates/amzn-toolkit-telemetry/src/config/endpoint.rs b/crates/amzn-toolkit-telemetry-client/src/config/endpoint.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/config/endpoint.rs rename to crates/amzn-toolkit-telemetry-client/src/config/endpoint.rs diff --git a/crates/amzn-toolkit-telemetry/src/config/interceptors.rs b/crates/amzn-toolkit-telemetry-client/src/config/interceptors.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/config/interceptors.rs rename to crates/amzn-toolkit-telemetry-client/src/config/interceptors.rs diff --git a/crates/amzn-toolkit-telemetry/src/config/retry.rs b/crates/amzn-toolkit-telemetry-client/src/config/retry.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/config/retry.rs rename to crates/amzn-toolkit-telemetry-client/src/config/retry.rs diff --git a/crates/amzn-toolkit-telemetry/src/config/timeout.rs b/crates/amzn-toolkit-telemetry-client/src/config/timeout.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/config/timeout.rs rename to crates/amzn-toolkit-telemetry-client/src/config/timeout.rs diff --git a/crates/amzn-toolkit-telemetry/src/error.rs b/crates/amzn-toolkit-telemetry-client/src/error.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/error.rs rename to crates/amzn-toolkit-telemetry-client/src/error.rs diff --git a/crates/amzn-toolkit-telemetry/src/error/sealed_unhandled.rs b/crates/amzn-toolkit-telemetry-client/src/error/sealed_unhandled.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/error/sealed_unhandled.rs rename to crates/amzn-toolkit-telemetry-client/src/error/sealed_unhandled.rs diff --git a/crates/amzn-toolkit-telemetry/src/error_meta.rs b/crates/amzn-toolkit-telemetry-client/src/error_meta.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/error_meta.rs rename to crates/amzn-toolkit-telemetry-client/src/error_meta.rs diff --git a/crates/amzn-toolkit-telemetry/src/json_errors.rs b/crates/amzn-toolkit-telemetry-client/src/json_errors.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/json_errors.rs rename to crates/amzn-toolkit-telemetry-client/src/json_errors.rs diff --git a/crates/amzn-toolkit-telemetry/src/lib.rs b/crates/amzn-toolkit-telemetry-client/src/lib.rs similarity index 99% rename from crates/amzn-toolkit-telemetry/src/lib.rs rename to crates/amzn-toolkit-telemetry-client/src/lib.rs index 0d30f01da8..56b7d2192d 100644 --- a/crates/amzn-toolkit-telemetry/src/lib.rs +++ b/crates/amzn-toolkit-telemetry-client/src/lib.rs @@ -15,7 +15,7 @@ #![allow(rustdoc::redundant_explicit_links)] #![warn(missing_docs)] #![cfg_attr(docsrs, feature(doc_auto_cfg))] -//! amzn-toolkit-telemetry +//! amzn-toolkit-telemetry-client //! //! # Crate Organization //! diff --git a/crates/amzn-toolkit-telemetry/src/meta.rs b/crates/amzn-toolkit-telemetry-client/src/meta.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/meta.rs rename to crates/amzn-toolkit-telemetry-client/src/meta.rs diff --git a/crates/amzn-toolkit-telemetry/src/operation.rs b/crates/amzn-toolkit-telemetry-client/src/operation.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/operation.rs rename to crates/amzn-toolkit-telemetry-client/src/operation.rs diff --git a/crates/amzn-toolkit-telemetry/src/operation/post_error_report.rs b/crates/amzn-toolkit-telemetry-client/src/operation/post_error_report.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/operation/post_error_report.rs rename to crates/amzn-toolkit-telemetry-client/src/operation/post_error_report.rs diff --git a/crates/amzn-toolkit-telemetry/src/operation/post_error_report/_post_error_report_input.rs b/crates/amzn-toolkit-telemetry-client/src/operation/post_error_report/_post_error_report_input.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/operation/post_error_report/_post_error_report_input.rs rename to crates/amzn-toolkit-telemetry-client/src/operation/post_error_report/_post_error_report_input.rs diff --git a/crates/amzn-toolkit-telemetry/src/operation/post_error_report/_post_error_report_output.rs b/crates/amzn-toolkit-telemetry-client/src/operation/post_error_report/_post_error_report_output.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/operation/post_error_report/_post_error_report_output.rs rename to crates/amzn-toolkit-telemetry-client/src/operation/post_error_report/_post_error_report_output.rs diff --git a/crates/amzn-toolkit-telemetry/src/operation/post_error_report/builders.rs b/crates/amzn-toolkit-telemetry-client/src/operation/post_error_report/builders.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/operation/post_error_report/builders.rs rename to crates/amzn-toolkit-telemetry-client/src/operation/post_error_report/builders.rs diff --git a/crates/amzn-toolkit-telemetry/src/operation/post_feedback.rs b/crates/amzn-toolkit-telemetry-client/src/operation/post_feedback.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/operation/post_feedback.rs rename to crates/amzn-toolkit-telemetry-client/src/operation/post_feedback.rs diff --git a/crates/amzn-toolkit-telemetry/src/operation/post_feedback/_post_feedback_input.rs b/crates/amzn-toolkit-telemetry-client/src/operation/post_feedback/_post_feedback_input.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/operation/post_feedback/_post_feedback_input.rs rename to crates/amzn-toolkit-telemetry-client/src/operation/post_feedback/_post_feedback_input.rs diff --git a/crates/amzn-toolkit-telemetry/src/operation/post_feedback/_post_feedback_output.rs b/crates/amzn-toolkit-telemetry-client/src/operation/post_feedback/_post_feedback_output.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/operation/post_feedback/_post_feedback_output.rs rename to crates/amzn-toolkit-telemetry-client/src/operation/post_feedback/_post_feedback_output.rs diff --git a/crates/amzn-toolkit-telemetry/src/operation/post_feedback/builders.rs b/crates/amzn-toolkit-telemetry-client/src/operation/post_feedback/builders.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/operation/post_feedback/builders.rs rename to crates/amzn-toolkit-telemetry-client/src/operation/post_feedback/builders.rs diff --git a/crates/amzn-toolkit-telemetry/src/operation/post_metrics.rs b/crates/amzn-toolkit-telemetry-client/src/operation/post_metrics.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/operation/post_metrics.rs rename to crates/amzn-toolkit-telemetry-client/src/operation/post_metrics.rs diff --git a/crates/amzn-toolkit-telemetry/src/operation/post_metrics/_post_metrics_input.rs b/crates/amzn-toolkit-telemetry-client/src/operation/post_metrics/_post_metrics_input.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/operation/post_metrics/_post_metrics_input.rs rename to crates/amzn-toolkit-telemetry-client/src/operation/post_metrics/_post_metrics_input.rs diff --git a/crates/amzn-toolkit-telemetry/src/operation/post_metrics/_post_metrics_output.rs b/crates/amzn-toolkit-telemetry-client/src/operation/post_metrics/_post_metrics_output.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/operation/post_metrics/_post_metrics_output.rs rename to crates/amzn-toolkit-telemetry-client/src/operation/post_metrics/_post_metrics_output.rs diff --git a/crates/amzn-toolkit-telemetry/src/operation/post_metrics/builders.rs b/crates/amzn-toolkit-telemetry-client/src/operation/post_metrics/builders.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/operation/post_metrics/builders.rs rename to crates/amzn-toolkit-telemetry-client/src/operation/post_metrics/builders.rs diff --git a/crates/amzn-toolkit-telemetry/src/primitives.rs b/crates/amzn-toolkit-telemetry-client/src/primitives.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/primitives.rs rename to crates/amzn-toolkit-telemetry-client/src/primitives.rs diff --git a/crates/amzn-toolkit-telemetry/src/primitives/event_stream.rs b/crates/amzn-toolkit-telemetry-client/src/primitives/event_stream.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/primitives/event_stream.rs rename to crates/amzn-toolkit-telemetry-client/src/primitives/event_stream.rs diff --git a/crates/amzn-toolkit-telemetry/src/primitives/sealed_enum_unknown.rs b/crates/amzn-toolkit-telemetry-client/src/primitives/sealed_enum_unknown.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/primitives/sealed_enum_unknown.rs rename to crates/amzn-toolkit-telemetry-client/src/primitives/sealed_enum_unknown.rs diff --git a/crates/amzn-toolkit-telemetry/src/protocol_serde.rs b/crates/amzn-toolkit-telemetry-client/src/protocol_serde.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/protocol_serde.rs rename to crates/amzn-toolkit-telemetry-client/src/protocol_serde.rs diff --git a/crates/amzn-toolkit-telemetry/src/protocol_serde/shape_error_details.rs b/crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_error_details.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/protocol_serde/shape_error_details.rs rename to crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_error_details.rs diff --git a/crates/amzn-toolkit-telemetry/src/protocol_serde/shape_metadata_entry.rs b/crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_metadata_entry.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/protocol_serde/shape_metadata_entry.rs rename to crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_metadata_entry.rs diff --git a/crates/amzn-toolkit-telemetry/src/protocol_serde/shape_metric_datum.rs b/crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_metric_datum.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/protocol_serde/shape_metric_datum.rs rename to crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_metric_datum.rs diff --git a/crates/amzn-toolkit-telemetry/src/protocol_serde/shape_post_error_report.rs b/crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_post_error_report.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/protocol_serde/shape_post_error_report.rs rename to crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_post_error_report.rs diff --git a/crates/amzn-toolkit-telemetry/src/protocol_serde/shape_post_error_report_input.rs b/crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_post_error_report_input.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/protocol_serde/shape_post_error_report_input.rs rename to crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_post_error_report_input.rs diff --git a/crates/amzn-toolkit-telemetry/src/protocol_serde/shape_post_feedback.rs b/crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_post_feedback.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/protocol_serde/shape_post_feedback.rs rename to crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_post_feedback.rs diff --git a/crates/amzn-toolkit-telemetry/src/protocol_serde/shape_post_feedback_input.rs b/crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_post_feedback_input.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/protocol_serde/shape_post_feedback_input.rs rename to crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_post_feedback_input.rs diff --git a/crates/amzn-toolkit-telemetry/src/protocol_serde/shape_post_metrics.rs b/crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_post_metrics.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/protocol_serde/shape_post_metrics.rs rename to crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_post_metrics.rs diff --git a/crates/amzn-toolkit-telemetry/src/protocol_serde/shape_post_metrics_input.rs b/crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_post_metrics_input.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/protocol_serde/shape_post_metrics_input.rs rename to crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_post_metrics_input.rs diff --git a/crates/amzn-toolkit-telemetry/src/protocol_serde/shape_userdata.rs b/crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_userdata.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/protocol_serde/shape_userdata.rs rename to crates/amzn-toolkit-telemetry-client/src/protocol_serde/shape_userdata.rs diff --git a/crates/amzn-toolkit-telemetry/src/serialization_settings.rs b/crates/amzn-toolkit-telemetry-client/src/serialization_settings.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/serialization_settings.rs rename to crates/amzn-toolkit-telemetry-client/src/serialization_settings.rs diff --git a/crates/amzn-toolkit-telemetry/src/types.rs b/crates/amzn-toolkit-telemetry-client/src/types.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/types.rs rename to crates/amzn-toolkit-telemetry-client/src/types.rs diff --git a/crates/amzn-toolkit-telemetry/src/types/_aws_product.rs b/crates/amzn-toolkit-telemetry-client/src/types/_aws_product.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/types/_aws_product.rs rename to crates/amzn-toolkit-telemetry-client/src/types/_aws_product.rs diff --git a/crates/amzn-toolkit-telemetry/src/types/_error_details.rs b/crates/amzn-toolkit-telemetry-client/src/types/_error_details.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/types/_error_details.rs rename to crates/amzn-toolkit-telemetry-client/src/types/_error_details.rs diff --git a/crates/amzn-toolkit-telemetry/src/types/_metadata_entry.rs b/crates/amzn-toolkit-telemetry-client/src/types/_metadata_entry.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/types/_metadata_entry.rs rename to crates/amzn-toolkit-telemetry-client/src/types/_metadata_entry.rs diff --git a/crates/amzn-toolkit-telemetry/src/types/_metric_datum.rs b/crates/amzn-toolkit-telemetry-client/src/types/_metric_datum.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/types/_metric_datum.rs rename to crates/amzn-toolkit-telemetry-client/src/types/_metric_datum.rs diff --git a/crates/amzn-toolkit-telemetry/src/types/_sentiment.rs b/crates/amzn-toolkit-telemetry-client/src/types/_sentiment.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/types/_sentiment.rs rename to crates/amzn-toolkit-telemetry-client/src/types/_sentiment.rs diff --git a/crates/amzn-toolkit-telemetry/src/types/_unit.rs b/crates/amzn-toolkit-telemetry-client/src/types/_unit.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/types/_unit.rs rename to crates/amzn-toolkit-telemetry-client/src/types/_unit.rs diff --git a/crates/amzn-toolkit-telemetry/src/types/_userdata.rs b/crates/amzn-toolkit-telemetry-client/src/types/_userdata.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/types/_userdata.rs rename to crates/amzn-toolkit-telemetry-client/src/types/_userdata.rs diff --git a/crates/amzn-toolkit-telemetry/src/types/builders.rs b/crates/amzn-toolkit-telemetry-client/src/types/builders.rs similarity index 100% rename from crates/amzn-toolkit-telemetry/src/types/builders.rs rename to crates/amzn-toolkit-telemetry-client/src/types/builders.rs diff --git a/crates/aws-toolkit-telemetry-definitions/Cargo.toml b/crates/aws-toolkit-telemetry-definitions/Cargo.toml index 353222f519..d29de564c8 100644 --- a/crates/aws-toolkit-telemetry-definitions/Cargo.toml +++ b/crates/aws-toolkit-telemetry-definitions/Cargo.toml @@ -16,7 +16,7 @@ serde_json.workspace = true syn = "2.0.101" [dependencies] -amzn-toolkit-telemetry = { path = "../amzn-toolkit-telemetry" } +amzn-toolkit-telemetry-client = { path = "../amzn-toolkit-telemetry-client" } serde.workspace = true [dev-dependencies] diff --git a/crates/aws-toolkit-telemetry-definitions/build.rs b/crates/aws-toolkit-telemetry-definitions/build.rs index 9628f288f0..d7d286fddc 100644 --- a/crates/aws-toolkit-telemetry-definitions/build.rs +++ b/crates/aws-toolkit-telemetry-definitions/build.rs @@ -162,11 +162,11 @@ fn main() { let passive = m.passive.unwrap_or_default(); let unit = match m.unit.map(|u| u.to_lowercase()).as_deref() { - Some("bytes") => quote!(::amzn_toolkit_telemetry::types::Unit::Bytes), - Some("count") => quote!(::amzn_toolkit_telemetry::types::Unit::Count), - Some("milliseconds") => quote!(::amzn_toolkit_telemetry::types::Unit::Milliseconds), - Some("percent") => quote!(::amzn_toolkit_telemetry::types::Unit::Percent), - Some("none") | None => quote!(::amzn_toolkit_telemetry::types::Unit::None), + Some("bytes") => quote!(::amzn_toolkit_telemetry_client::types::Unit::Bytes), + Some("count") => quote!(::amzn_toolkit_telemetry_client::types::Unit::Count), + Some("milliseconds") => quote!(::amzn_toolkit_telemetry_client::types::Unit::Milliseconds), + Some("percent") => quote!(::amzn_toolkit_telemetry_client::types::Unit::Percent), + Some("none") | None => quote!(::amzn_toolkit_telemetry_client::types::Unit::None), Some(unknown) => { panic!("unknown unit: {:?}", unknown); }, @@ -200,7 +200,7 @@ fn main() { }; quote!( - ::amzn_toolkit_telemetry::types::MetadataEntry::builder() + ::amzn_toolkit_telemetry_client::types::MetadataEntry::builder() .key(#raw_name) #value .build() @@ -222,11 +222,11 @@ fn main() { impl #name { const NAME: &'static ::std::primitive::str = #raw_name; const PASSIVE: ::std::primitive::bool = #passive; - const UNIT: ::amzn_toolkit_telemetry::types::Unit = #unit; + const UNIT: ::amzn_toolkit_telemetry_client::types::Unit = #unit; } impl crate::IntoMetricDatum for #name { - fn into_metric_datum(self) -> ::amzn_toolkit_telemetry::types::MetricDatum { + fn into_metric_datum(self) -> ::amzn_toolkit_telemetry_client::types::MetricDatum { let metadata_entries = vec![ #( #metadata_entries, @@ -239,7 +239,7 @@ fn main() { |t| t.duration_since(std::time::UNIX_EPOCH).unwrap().as_millis() as ::std::primitive::i64 ); - ::amzn_toolkit_telemetry::types::MetricDatum::builder() + ::amzn_toolkit_telemetry_client::types::MetricDatum::builder() .metric_name(#name::NAME) .passive(#name::PASSIVE) .unit(#name::UNIT) diff --git a/crates/aws-toolkit-telemetry-definitions/src/lib.rs b/crates/aws-toolkit-telemetry-definitions/src/lib.rs index c1e2f2c8ca..cd3bee5889 100644 --- a/crates/aws-toolkit-telemetry-definitions/src/lib.rs +++ b/crates/aws-toolkit-telemetry-definitions/src/lib.rs @@ -1,5 +1,5 @@ pub trait IntoMetricDatum: Send { - fn into_metric_datum(self) -> amzn_toolkit_telemetry::types::MetricDatum; + fn into_metric_datum(self) -> amzn_toolkit_telemetry_client::types::MetricDatum; } include!(concat!(env!("OUT_DIR"), "/mod.rs")); diff --git a/crates/fig_auth/src/builder_id.rs b/crates/fig_auth/src/builder_id.rs index a7fe7f0b7c..e99a13f3dd 100644 --- a/crates/fig_auth/src/builder_id.rs +++ b/crates/fig_auth/src/builder_id.rs @@ -96,7 +96,7 @@ pub(crate) fn client(region: Region) -> Client { let retry_config = RetryConfig::standard().with_max_attempts(3); let sdk_config = aws_types::SdkConfig::builder() .http_client(fig_aws_common::http_client::client()) - .behavior_version(BehaviorVersion::v2024_03_28()) + .behavior_version(BehaviorVersion::v2025_01_17()) .endpoint_url(oidc_url(®ion)) .region(region) .retry_config(retry_config) diff --git a/crates/fig_aws_common/src/lib.rs b/crates/fig_aws_common/src/lib.rs index caa4dbc106..b9739f9109 100644 --- a/crates/fig_aws_common/src/lib.rs +++ b/crates/fig_aws_common/src/lib.rs @@ -17,7 +17,7 @@ pub fn app_name() -> AppName { } pub fn behavior_version() -> BehaviorVersion { - BehaviorVersion::v2024_03_28() + BehaviorVersion::v2025_01_17() } #[cfg(test)] diff --git a/crates/fig_telemetry/Cargo.toml b/crates/fig_telemetry/Cargo.toml index 735e350470..df3ad45e7d 100644 --- a/crates/fig_telemetry/Cargo.toml +++ b/crates/fig_telemetry/Cargo.toml @@ -12,7 +12,7 @@ workspace = true [dependencies] amzn-codewhisperer-client = { path = "../amzn-codewhisperer-client" } -amzn-toolkit-telemetry = { path = "../amzn-toolkit-telemetry" } +amzn-toolkit-telemetry-client = { path = "../amzn-toolkit-telemetry-client" } anyhow.workspace = true async-trait.workspace = true aws-credential-types = "1.1.6" diff --git a/crates/fig_telemetry/src/cognito.rs b/crates/fig_telemetry/src/cognito.rs index 4c26c0de54..f50b6125cb 100644 --- a/crates/fig_telemetry/src/cognito.rs +++ b/crates/fig_telemetry/src/cognito.rs @@ -1,4 +1,4 @@ -use amzn_toolkit_telemetry::config::BehaviorVersion; +use amzn_toolkit_telemetry_client::config::BehaviorVersion; use aws_credential_types::provider::error::CredentialsError; use aws_credential_types::{ Credentials, @@ -28,7 +28,7 @@ pub(crate) async fn get_cognito_credentials_send( telemetry_stage: &TelemetryStage, ) -> Result { let conf = aws_sdk_cognitoidentity::Config::builder() - .behavior_version(BehaviorVersion::v2024_03_28()) + .behavior_version(BehaviorVersion::v2025_01_17()) .region(telemetry_stage.region.clone()) .app_name(app_name()) .build(); diff --git a/crates/fig_telemetry/src/endpoint.rs b/crates/fig_telemetry/src/endpoint.rs index 0612322712..681d19af76 100644 --- a/crates/fig_telemetry/src/endpoint.rs +++ b/crates/fig_telemetry/src/endpoint.rs @@ -1,4 +1,4 @@ -use amzn_toolkit_telemetry::config::endpoint::{ +use amzn_toolkit_telemetry_client::config::endpoint::{ Endpoint, EndpointFuture, Params, diff --git a/crates/fig_telemetry/src/lib.rs b/crates/fig_telemetry/src/lib.rs index ef73492d07..aabd80cd1f 100644 --- a/crates/fig_telemetry/src/lib.rs +++ b/crates/fig_telemetry/src/lib.rs @@ -25,13 +25,13 @@ use amzn_codewhisperer_client::types::{ UserContext, UserTriggerDecisionEvent, }; -use amzn_toolkit_telemetry::config::{ +use amzn_toolkit_telemetry_client::config::{ BehaviorVersion, Region, }; -use amzn_toolkit_telemetry::error::DisplayErrorContext; -use amzn_toolkit_telemetry::types::AwsProduct; -use amzn_toolkit_telemetry::{ +use amzn_toolkit_telemetry_client::error::DisplayErrorContext; +use amzn_toolkit_telemetry_client::types::AwsProduct; +use amzn_toolkit_telemetry_client::{ Client as ToolkitTelemetryClient, Config, }; @@ -92,7 +92,7 @@ pub enum Error { #[error("Telemetry is disabled")] TelemetryDisabled, #[error(transparent)] - ClientError(#[from] amzn_toolkit_telemetry::operation::post_metrics::PostMetricsError), + ClientError(#[from] amzn_toolkit_telemetry_client::operation::post_metrics::PostMetricsError), } const PRODUCT: &str = "CodeWhisperer"; @@ -205,10 +205,10 @@ pub struct Client { impl Client { pub async fn new(telemetry_stage: TelemetryStage) -> Self { let client_id = util::get_client_id(); - let toolkit_telemetry_client = Some(amzn_toolkit_telemetry::Client::from_conf( + let toolkit_telemetry_client = Some(amzn_toolkit_telemetry_client::Client::from_conf( Config::builder() .http_client(fig_aws_common::http_client::client()) - .behavior_version(BehaviorVersion::v2024_03_28()) + .behavior_version(BehaviorVersion::v2025_01_17()) .endpoint_resolver(StaticEndpoint(telemetry_stage.endpoint)) .app_name(app_name()) .region(telemetry_stage.region.clone()) diff --git a/crates/fig_telemetry_core/Cargo.toml b/crates/fig_telemetry_core/Cargo.toml index f4a2704750..10cf49ef69 100644 --- a/crates/fig_telemetry_core/Cargo.toml +++ b/crates/fig_telemetry_core/Cargo.toml @@ -9,7 +9,7 @@ license.workspace = true [dependencies] amzn-codewhisperer-client = { path = "../amzn-codewhisperer-client" } -amzn-toolkit-telemetry = { path = "../amzn-toolkit-telemetry" } +amzn-toolkit-telemetry-client = { path = "../amzn-toolkit-telemetry-client" } async-trait.workspace = true aws-toolkit-telemetry-definitions = { path = "../aws-toolkit-telemetry-definitions" } fig_util.workspace = true diff --git a/crates/fig_telemetry_core/src/lib.rs b/crates/fig_telemetry_core/src/lib.rs index 85aa5d5a2e..26d6bd8412 100644 --- a/crates/fig_telemetry_core/src/lib.rs +++ b/crates/fig_telemetry_core/src/lib.rs @@ -5,7 +5,7 @@ use std::time::{ SystemTime, }; -pub use amzn_toolkit_telemetry::types::MetricDatum; +pub use amzn_toolkit_telemetry_client::types::MetricDatum; use aws_toolkit_telemetry_definitions::IntoMetricDatum; use aws_toolkit_telemetry_definitions::metrics::{ AmazonqDidSelectProfile, diff --git a/crates/kiro-cli/Cargo.toml b/crates/kiro-cli/Cargo.toml index 1554937843..cc0e39aadc 100644 --- a/crates/kiro-cli/Cargo.toml +++ b/crates/kiro-cli/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "q_cli" +name = "kiro_cli" authors.workspace = true edition.workspace = true homepage.workspace = true diff --git a/crates/q_chat/Cargo.toml b/crates/q_chat/Cargo.toml deleted file mode 100644 index 8997a1b37f..0000000000 --- a/crates/q_chat/Cargo.toml +++ /dev/null @@ -1,59 +0,0 @@ -[package] -name = "q_chat" -authors.workspace = true -edition.workspace = true -homepage.workspace = true -publish.workspace = true -version.workspace = true -license.workspace = true - -[dependencies] -anstream.workspace = true -aws-smithy-types = "1.2.10" -bstr.workspace = true -clap.workspace = true -color-print.workspace = true -convert_case.workspace = true -crossterm.workspace = true -semver.workspace = true -eyre.workspace = true -fig_api_client.workspace = true -fig_auth.workspace = true -fig_diagnostic.workspace = true -fig_os_shim.workspace = true -fig_install.workspace = true -fig_settings.workspace = true -fig_telemetry.workspace = true -fig_util.workspace = true -futures.workspace = true -glob.workspace = true -mcp_client.workspace = true -rand.workspace = true -regex.workspace = true -rustyline = { version = "15.0.0", features = ["derive", "custom-bindings"] } -serde.workspace = true -serde_json.workspace = true -shell-color.workspace = true -shell-words = "1.1" -shellexpand.workspace = true -shlex.workspace = true -similar.workspace = true -skim = "0.16.2" -spinners.workspace = true -syntect = { version = "5.2.0", features = [ "default-syntaxes", "default-themes" ]} -tempfile.workspace = true -thiserror.workspace = true -time.workspace = true -tokio.workspace = true -tracing.workspace = true -unicode-width.workspace = true -url.workspace = true -uuid.workspace = true -winnow.workspace = true -strip-ansi-escapes = "0.2.1" - -[dev-dependencies] -tracing-subscriber.workspace = true - -[lints] -workspace = true diff --git a/crates/q_chat/src/cli.rs b/crates/q_chat/src/cli.rs deleted file mode 100644 index a441887b9b..0000000000 --- a/crates/q_chat/src/cli.rs +++ /dev/null @@ -1,25 +0,0 @@ -use clap::Parser; - -#[derive(Debug, Clone, PartialEq, Eq, Default, Parser)] -pub struct Chat { - /// (Deprecated, use --trust-all-tools) Enabling this flag allows the model to execute - /// all commands without first accepting them. - #[arg(short, long, hide = true)] - pub accept_all: bool, - /// Print the first response to STDOUT without interactive mode. This will fail if the - /// prompt requests permissions to use a tool, unless --trust-all-tools is also used. - #[arg(long)] - pub no_interactive: bool, - /// The first question to ask - pub input: Option, - /// Context profile to use - #[arg(long = "profile")] - pub profile: Option, - /// Allows the model to use any tool to run commands without asking for confirmation. - #[arg(long)] - pub trust_all_tools: bool, - /// Trust only this set of tools. Example: trust some tools: - /// '--trust-tools=fs_read,fs_write', trust no tools: '--trust-tools=' - #[arg(long, value_delimiter = ',', value_name = "TOOL_NAMES")] - pub trust_tools: Option>, -} diff --git a/crates/q_chat/src/command.rs b/crates/q_chat/src/command.rs deleted file mode 100644 index 43d07f1169..0000000000 --- a/crates/q_chat/src/command.rs +++ /dev/null @@ -1,1093 +0,0 @@ -use std::collections::HashSet; -use std::io::Write; - -use clap::{ - Parser, - Subcommand, -}; -use crossterm::style::Color; -use crossterm::{ - queue, - style, -}; -use eyre::Result; -use serde::{ - Deserialize, - Serialize, -}; - -#[derive(Debug, PartialEq, Eq)] -pub enum Command { - Ask { - prompt: String, - }, - Execute { - command: String, - }, - Clear, - Help, - Issue { - prompt: Option, - }, - Quit, - Profile { - subcommand: ProfileSubcommand, - }, - Context { - subcommand: ContextSubcommand, - }, - PromptEditor { - initial_text: Option, - }, - Compact { - prompt: Option, - show_summary: bool, - help: bool, - }, - Tools { - subcommand: Option, - }, - Prompts { - subcommand: Option, - }, - Usage, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ProfileSubcommand { - List, - Create { name: String }, - Delete { name: String }, - Set { name: String }, - Rename { old_name: String, new_name: String }, - Help, -} - -impl ProfileSubcommand { - const AVAILABLE_COMMANDS: &str = color_print::cstr! {"Available commands - help Show an explanation for the profile command - list List all available profiles - create <> Create a new profile with the specified name - delete <> Delete the specified profile - set <> Switch to the specified profile - rename <> <> Rename a profile"}; - const CREATE_USAGE: &str = "/profile create "; - const DELETE_USAGE: &str = "/profile delete "; - const RENAME_USAGE: &str = "/profile rename "; - const SET_USAGE: &str = "/profile set "; - - fn usage_msg(header: impl AsRef) -> String { - format!("{}\n\n{}", header.as_ref(), Self::AVAILABLE_COMMANDS) - } - - pub fn help_text() -> String { - color_print::cformat!( - r#" -(Beta) Profile Management - -Profiles allow you to organize and manage different sets of context files for different projects or tasks. - -{} - -Notes -• The "global" profile contains context files that are available in all profiles -• The "default" profile is used when no profile is specified -• You can switch between profiles to work on different projects -• Each profile maintains its own set of context files -"#, - Self::AVAILABLE_COMMANDS - ) - } -} - -#[derive(Parser, Debug, Clone)] -#[command(name = "hooks", disable_help_flag = true, disable_help_subcommand = true)] -struct HooksCommand { - #[command(subcommand)] - command: HooksSubcommand, -} - -#[derive(Subcommand, Debug, Clone, Eq, PartialEq)] -pub enum HooksSubcommand { - Add { - name: String, - - #[arg(long, value_parser = ["per_prompt", "conversation_start"])] - trigger: String, - - #[arg(long, value_parser = clap::value_parser!(String))] - command: String, - - #[arg(long)] - global: bool, - }, - #[command(name = "rm")] - Remove { - name: String, - - #[arg(long)] - global: bool, - }, - Enable { - name: String, - - #[arg(long)] - global: bool, - }, - Disable { - name: String, - - #[arg(long)] - global: bool, - }, - EnableAll { - #[arg(long)] - global: bool, - }, - DisableAll { - #[arg(long)] - global: bool, - }, - Help, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ContextSubcommand { - Show { - expand: bool, - }, - Add { - global: bool, - force: bool, - paths: Vec, - }, - Remove { - global: bool, - paths: Vec, - }, - Clear { - global: bool, - }, - Hooks { - subcommand: Option, - }, - Help, -} - -impl ContextSubcommand { - const ADD_USAGE: &str = "/context add [--global] [--force] [path2...]"; - const AVAILABLE_COMMANDS: &str = color_print::cstr! {"Available commands - help Show an explanation for the context command - - show [--expand] Display the context rule configuration and matched files - --expand: Print out each matched file's content - - add [--global] [--force] <> - Add context rules (filenames or glob patterns) - --global: Add to global rules (available in all profiles) - --force: Include even if matched files exceed size limits - - rm [--global] <> Remove specified rules from current profile - --global: Remove specified rules globally - - clear [--global] Remove all rules from current profile - --global: Remove global rules - - hooks View and manage context hooks"}; - const CLEAR_USAGE: &str = "/context clear [--global]"; - const HOOKS_AVAILABLE_COMMANDS: &str = color_print::cstr! {"Available subcommands - hooks help Show an explanation for context hooks commands - - hooks add [--global] <> Add a new command context hook - --global: Add to global hooks - --trigger <> When to trigger the hook, valid options: `per_prompt` or `conversation_start` - --command <> Shell command to execute - - hooks rm [--global] <> Remove an existing context hook - --global: Remove from global hooks - - hooks enable [--global] <> Enable an existing context hook - --global: Enable in global hooks - - hooks disable [--global] <> Disable an existing context hook - --global: Disable in global hooks - - hooks enable-all [--global] Enable all existing context hooks - --global: Enable all in global hooks - - hooks disable-all [--global] Disable all existing context hooks - --global: Disable all in global hooks"}; - const REMOVE_USAGE: &str = "/context rm [--global] [path2...]"; - const SHOW_USAGE: &str = "/context show [--expand]"; - - fn usage_msg(header: impl AsRef) -> String { - format!("{}\n\n{}", header.as_ref(), Self::AVAILABLE_COMMANDS) - } - - fn hooks_usage_msg(header: impl AsRef) -> String { - format!("{}\n\n{}", header.as_ref(), Self::HOOKS_AVAILABLE_COMMANDS) - } - - pub fn help_text() -> String { - color_print::cformat!( - r#" -(Beta) Context Rule Management - -Context rules determine which files are included in your Amazon Q session. -The files matched by these rules provide Amazon Q with additional information -about your project or environment. Adding relevant files helps Q generate -more accurate and helpful responses. - -In addition to files, you can specify hooks that will run commands and return -the output as context to Amazon Q. - -{} - -Notes -• You can add specific files or use glob patterns (e.g., "*.py", "src/**/*.js") -• Profile rules apply only to the current profile -• Global rules apply across all profiles -• Context is preserved between chat sessions -"#, - Self::AVAILABLE_COMMANDS - ) - } - - pub fn hooks_help_text() -> String { - color_print::cformat!( - r#" -(Beta) Context Hooks - -Use context hooks to specify shell commands to run. The output from these -commands will be appended to the prompt to Amazon Q. Hooks can be defined -in global or local profiles. - -Usage: /context hooks [SUBCOMMAND] - -Description - Show existing global or profile-specific hooks. - Alternatively, specify a subcommand to modify the hooks. - -{} - -Notes -• Hooks are executed in parallel -• 'conversation_start' hooks run on the first user prompt and are attached once to the conversation history sent to Amazon Q -• 'per_prompt' hooks run on each user prompt and are attached to the prompt, but are not stored in conversation history -"#, - Self::HOOKS_AVAILABLE_COMMANDS - ) - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ToolsSubcommand { - Schema, - Trust { tool_names: HashSet }, - Untrust { tool_names: HashSet }, - TrustAll, - Reset, - ResetSingle { tool_name: String }, - Help, -} - -impl ToolsSubcommand { - const AVAILABLE_COMMANDS: &str = color_print::cstr! {"Available subcommands - help Show an explanation for the tools command - schema Show the input schema for all available tools - trust <> Trust a specific tool or tools for the session - untrust <> Revert a tool or tools to per-request confirmation - trustall Trust all tools (equivalent to deprecated /acceptall) - reset Reset all tools to default permission levels - reset <> Reset a single tool to default permission level"}; - const BASE_COMMAND: &str = color_print::cstr! {"Usage: /tools [SUBCOMMAND] - -Description - Show the current set of tools and their permission setting. - The permission setting states when user confirmation is required. Trusted tools never require confirmation. - Alternatively, specify a subcommand to modify the tool permissions."}; - - fn usage_msg(header: impl AsRef) -> String { - format!( - "{}\n\n{}\n\n{}", - header.as_ref(), - Self::BASE_COMMAND, - Self::AVAILABLE_COMMANDS - ) - } - - pub fn help_text() -> String { - color_print::cformat!( - r#" -Tool Permissions - -By default, Amazon Q will ask for your permission to use certain tools. You can control which tools you -trust so that no confirmation is required. These settings will last only for this session. - -{} - -{}"#, - Self::BASE_COMMAND, - Self::AVAILABLE_COMMANDS - ) - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum PromptsSubcommand { - List { search_word: Option }, - Get { get_command: PromptsGetCommand }, - Help, -} - -impl PromptsSubcommand { - const AVAILABLE_COMMANDS: &str = color_print::cstr! {"Available subcommands - help Show an explanation for the prompts command - list [search word] List available prompts from a tool or show all available prompts"}; - const BASE_COMMAND: &str = color_print::cstr! {"Usage: /prompts [SUBCOMMAND] - -Description - Show the current set of reusuable prompts from the current fleet of mcp servers."}; - - fn usage_msg(header: impl AsRef) -> String { - format!( - "{}\n\n{}\n\n{}", - header.as_ref(), - Self::BASE_COMMAND, - Self::AVAILABLE_COMMANDS - ) - } - - pub fn help_text() -> String { - color_print::cformat!( - r#" -Prompts - -Prompts are reusable templates that help you quickly access common workflows and tasks. -These templates are provided by the mcp servers you have installed and configured. - -To actually retrieve a prompt, directly start with the following command (without prepending /prompt get): - @<> [arg] Retrieve prompt specified -Or if you prefer the long way: - /prompts get <> [arg] Retrieve prompt specified - -{} - -{}"#, - Self::BASE_COMMAND, - Self::AVAILABLE_COMMANDS - ) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct PromptsGetCommand { - pub orig_input: Option, - pub params: PromptsGetParam, -} - -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct PromptsGetParam { - pub name: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub arguments: Option>, -} - -impl Command { - // Check if input is a common single-word command that should use slash prefix - fn check_common_command(input: &str) -> Option { - let input_lower = input.trim().to_lowercase(); - match input_lower.as_str() { - "exit" | "quit" | "q" | "exit()" => { - Some("Did you mean to use the command '/quit' to exit? Type '/quit' to exit.".to_string()) - }, - "clear" | "cls" => Some( - "Did you mean to use the command '/clear' to clear the conversation? Type '/clear' to clear." - .to_string(), - ), - "help" | "?" => Some( - "Did you mean to use the command '/help' for help? Type '/help' to see available commands.".to_string(), - ), - _ => None, - } - } - - pub fn parse(input: &str, output: &mut impl Write) -> Result { - let input = input.trim(); - - // Check for common single-word commands without slash prefix - if let Some(suggestion) = Self::check_common_command(input) { - return Err(suggestion); - } - - // Check if the input starts with a literal backslash followed by a slash - // This allows users to escape the slash if they actually want to start with one - if input.starts_with("\\/") { - return Ok(Self::Ask { - prompt: input[1..].to_string(), // Remove the backslash but keep the slash - }); - } - - if let Some(command) = input.strip_prefix("/") { - let parts: Vec<&str> = command.split_whitespace().collect(); - - if parts.is_empty() { - return Err("Empty command".to_string()); - } - - return Ok(match parts[0].to_lowercase().as_str() { - "clear" => Self::Clear, - "help" => Self::Help, - "compact" => { - let mut prompt = None; - let show_summary = true; - let mut help = false; - - // Check if "help" is the first subcommand - if parts.len() > 1 && parts[1].to_lowercase() == "help" { - help = true; - } else { - let mut remaining_parts = Vec::new(); - - remaining_parts.extend_from_slice(&parts[1..]); - - // If we have remaining parts after parsing flags, join them as the prompt - if !remaining_parts.is_empty() { - prompt = Some(remaining_parts.join(" ")); - } - } - - Self::Compact { - prompt, - show_summary, - help, - } - }, - "acceptall" => { - let _ = queue!( - output, - style::SetForegroundColor(Color::Yellow), - style::Print("\n/acceptall is deprecated. Use /tools instead.\n\n"), - style::SetForegroundColor(Color::Reset) - ); - - Self::Tools { - subcommand: Some(ToolsSubcommand::TrustAll), - } - }, - "editor" => { - if parts.len() > 1 { - Self::PromptEditor { - initial_text: Some(parts[1..].join(" ")), - } - } else { - Self::PromptEditor { initial_text: None } - } - }, - "issue" => { - if parts.len() > 1 { - Self::Issue { - prompt: Some(parts[1..].join(" ")), - } - } else { - Self::Issue { prompt: None } - } - }, - "q" | "exit" | "quit" => Self::Quit, - "profile" => { - if parts.len() < 2 { - return Ok(Self::Profile { - subcommand: ProfileSubcommand::Help, - }); - } - - macro_rules! usage_err { - ($usage_str:expr) => { - return Err(format!( - "Invalid /profile arguments.\n\nUsage:\n {}", - $usage_str - )) - }; - } - - match parts[1].to_lowercase().as_str() { - "list" => Self::Profile { - subcommand: ProfileSubcommand::List, - }, - "create" => { - let name = parts.get(2); - match name { - Some(name) => Self::Profile { - subcommand: ProfileSubcommand::Create { - name: (*name).to_string(), - }, - }, - None => usage_err!(ProfileSubcommand::CREATE_USAGE), - } - }, - "delete" => { - let name = parts.get(2); - match name { - Some(name) => Self::Profile { - subcommand: ProfileSubcommand::Delete { - name: (*name).to_string(), - }, - }, - None => usage_err!(ProfileSubcommand::DELETE_USAGE), - } - }, - "rename" => { - let old_name = parts.get(2); - let new_name = parts.get(3); - match (old_name, new_name) { - (Some(old), Some(new)) => Self::Profile { - subcommand: ProfileSubcommand::Rename { - old_name: (*old).to_string(), - new_name: (*new).to_string(), - }, - }, - _ => usage_err!(ProfileSubcommand::RENAME_USAGE), - } - }, - "set" => { - let name = parts.get(2); - match name { - Some(name) => Self::Profile { - subcommand: ProfileSubcommand::Set { - name: (*name).to_string(), - }, - }, - None => usage_err!(ProfileSubcommand::SET_USAGE), - } - }, - "help" => Self::Profile { - subcommand: ProfileSubcommand::Help, - }, - other => { - return Err(ProfileSubcommand::usage_msg(format!("Unknown subcommand '{}'.", other))); - }, - } - }, - "context" => { - if parts.len() < 2 { - return Ok(Self::Context { - subcommand: ContextSubcommand::Help, - }); - } - - macro_rules! usage_err { - ($usage_str:expr) => { - return Err(format!( - "Invalid /context arguments.\n\nUsage:\n {}", - $usage_str - )) - }; - } - - match parts[1].to_lowercase().as_str() { - "show" => { - let mut expand = false; - for part in &parts[2..] { - if *part == "--expand" { - expand = true; - } else { - usage_err!(ContextSubcommand::SHOW_USAGE); - } - } - Self::Context { - subcommand: ContextSubcommand::Show { expand }, - } - }, - "add" => { - // Parse add command with paths and flags - let mut global = false; - let mut force = false; - let mut paths = Vec::new(); - - let args = match shlex::split(&parts[2..].join(" ")) { - Some(args) => args, - None => return Err("Failed to parse quoted arguments".to_string()), - }; - - for arg in &args { - if arg == "--global" { - global = true; - } else if arg == "--force" || arg == "-f" { - force = true; - } else { - paths.push(arg.to_string()); - } - } - - if paths.is_empty() { - usage_err!(ContextSubcommand::ADD_USAGE); - } - - Self::Context { - subcommand: ContextSubcommand::Add { global, force, paths }, - } - }, - "rm" => { - // Parse rm command with paths and --global flag - let mut global = false; - let mut paths = Vec::new(); - let args = match shlex::split(&parts[2..].join(" ")) { - Some(args) => args, - None => return Err("Failed to parse quoted arguments".to_string()), - }; - - for arg in &args { - if arg == "--global" { - global = true; - } else { - paths.push(arg.to_string()); - } - } - - if paths.is_empty() { - usage_err!(ContextSubcommand::REMOVE_USAGE); - } - - Self::Context { - subcommand: ContextSubcommand::Remove { global, paths }, - } - }, - "clear" => { - // Parse clear command with optional --global flag - let mut global = false; - - for part in &parts[2..] { - if *part == "--global" { - global = true; - } else { - usage_err!(ContextSubcommand::CLEAR_USAGE); - } - } - - Self::Context { - subcommand: ContextSubcommand::Clear { global }, - } - }, - "help" => Self::Context { - subcommand: ContextSubcommand::Help, - }, - "hooks" => { - if parts.get(2).is_none() { - return Ok(Self::Context { - subcommand: ContextSubcommand::Hooks { subcommand: None }, - }); - }; - - match Self::parse_hooks(&parts) { - Ok(command) => command, - Err(err) => return Err(ContextSubcommand::hooks_usage_msg(err)), - } - }, - other => { - return Err(ContextSubcommand::usage_msg(format!("Unknown subcommand '{}'.", other))); - }, - } - }, - "tools" => { - if parts.len() < 2 { - return Ok(Self::Tools { subcommand: None }); - } - - match parts[1].to_lowercase().as_str() { - "schema" => Self::Tools { - subcommand: Some(ToolsSubcommand::Schema), - }, - "trust" => { - let mut tool_names = HashSet::new(); - for part in &parts[2..] { - tool_names.insert((*part).to_string()); - } - - if tool_names.is_empty() { - let _ = queue!( - output, - style::SetForegroundColor(Color::DarkGrey), - style::Print("\nPlease use"), - style::SetForegroundColor(Color::DarkGreen), - style::Print(" /tools trust "), - style::SetForegroundColor(Color::DarkGrey), - style::Print(" to trust tools.\n\n"), - style::Print("Use "), - style::SetForegroundColor(Color::DarkGreen), - style::Print("/tools"), - style::SetForegroundColor(Color::DarkGrey), - style::Print(" to see all available tools.\n\n"), - style::SetForegroundColor(Color::Reset), - ); - } - - Self::Tools { - subcommand: Some(ToolsSubcommand::Trust { tool_names }), - } - }, - "untrust" => { - let mut tool_names = HashSet::new(); - for part in &parts[2..] { - tool_names.insert((*part).to_string()); - } - - if tool_names.is_empty() { - let _ = queue!( - output, - style::SetForegroundColor(Color::DarkGrey), - style::Print("\nPlease use"), - style::SetForegroundColor(Color::DarkGreen), - style::Print(" /tools untrust "), - style::SetForegroundColor(Color::DarkGrey), - style::Print(" to untrust tools.\n\n"), - style::Print("Use "), - style::SetForegroundColor(Color::DarkGreen), - style::Print("/tools"), - style::SetForegroundColor(Color::DarkGrey), - style::Print(" to see all available tools.\n\n"), - style::SetForegroundColor(Color::Reset), - ); - } - - Self::Tools { - subcommand: Some(ToolsSubcommand::Untrust { tool_names }), - } - }, - "trustall" => Self::Tools { - subcommand: Some(ToolsSubcommand::TrustAll), - }, - "reset" => { - let tool_name = parts.get(2); - match tool_name { - Some(tool_name) => Self::Tools { - subcommand: Some(ToolsSubcommand::ResetSingle { - tool_name: (*tool_name).to_string(), - }), - }, - None => Self::Tools { - subcommand: Some(ToolsSubcommand::Reset), - }, - } - }, - "help" => Self::Tools { - subcommand: Some(ToolsSubcommand::Help), - }, - other => { - return Err(ToolsSubcommand::usage_msg(format!("Unknown subcommand '{}'.", other))); - }, - } - }, - "prompts" => { - let subcommand = parts.get(1); - match subcommand { - Some(c) if c.to_lowercase() == "list" => Self::Prompts { - subcommand: Some(PromptsSubcommand::List { - search_word: parts.get(2).map(|v| (*v).to_string()), - }), - }, - Some(c) if c.to_lowercase() == "help" => Self::Prompts { - subcommand: Some(PromptsSubcommand::Help), - }, - Some(c) if c.to_lowercase() == "get" => { - // Need to reconstruct the input because simple splitting of - // white space might not be sufficient - let command = parts[2..].join(" "); - let get_command = parse_input_to_prompts_get_command(command.as_str())?; - let subcommand = Some(PromptsSubcommand::Get { get_command }); - Self::Prompts { subcommand } - }, - Some(other) => { - return Err(PromptsSubcommand::usage_msg(format!( - "Unknown subcommand '{}'\n", - other - ))); - }, - None => Self::Prompts { - subcommand: Some(PromptsSubcommand::List { - search_word: parts.get(2).map(|v| (*v).to_string()), - }), - }, - } - }, - "usage" => Self::Usage, - unknown_command => { - // If the command starts with a slash but isn't recognized, - // return an error instead of treating it as a prompt - return Err(format!( - "Unknown command: '/{}'. Type '/help' to see available commands.\nTo use a literal slash at the beginning of your message, escape it with a backslash (e.g., '\\//hey' for '/hey').", - unknown_command - )); - }, - }); - } - - if let Some(command) = input.strip_prefix('@') { - let get_command = parse_input_to_prompts_get_command(command)?; - let subcommand = Some(PromptsSubcommand::Get { get_command }); - return Ok(Self::Prompts { subcommand }); - } - - if let Some(command) = input.strip_prefix("!") { - return Ok(Self::Execute { - command: command.to_string(), - }); - } - - Ok(Self::Ask { - prompt: input.to_string(), - }) - } - - // NOTE: Here we use clap to parse the hooks subcommand instead of parsing manually - // like the rest of the file. - // Since the hooks subcommand has a lot of options, this makes more sense. - // Ideally, we parse everything with clap instead of trying to do it manually. - fn parse_hooks(parts: &[&str]) -> Result { - // Skip the first two parts ("/context" and "hooks") - let args = match shlex::split(&parts[1..].join(" ")) { - Some(args) => args, - None => return Err("Failed to parse arguments".to_string()), - }; - - // Parse with Clap - HooksCommand::try_parse_from(args) - .map(|hooks_command| Self::Context { - subcommand: ContextSubcommand::Hooks { - subcommand: Some(hooks_command.command), - }, - }) - .map_err(|e| e.to_string()) - } -} - -fn parse_input_to_prompts_get_command(command: &str) -> Result { - let input = shell_words::split(command).map_err(|e| format!("Error splitting command for prompts: {:?}", e))?; - let mut iter = input.into_iter(); - let prompt_name = iter.next().ok_or("Prompt name needs to be specified")?; - let args = iter.collect::>(); - let params = PromptsGetParam { - name: prompt_name, - arguments: { if args.is_empty() { None } else { Some(args) } }, - }; - let orig_input = Some(command.to_string()); - Ok(PromptsGetCommand { orig_input, params }) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_command_parse() { - let mut stdout = std::io::stdout(); - - macro_rules! profile { - ($subcommand:expr) => { - Command::Profile { - subcommand: $subcommand, - } - }; - } - macro_rules! context { - ($subcommand:expr) => { - Command::Context { - subcommand: $subcommand, - } - }; - } - macro_rules! compact { - ($prompt:expr, $show_summary:expr) => { - Command::Compact { - prompt: $prompt, - show_summary: $show_summary, - help: false, - } - }; - } - let tests = &[ - ("/compact", compact!(None, true)), - ( - "/compact custom prompt", - compact!(Some("custom prompt".to_string()), true), - ), - ("/profile list", profile!(ProfileSubcommand::List)), - ( - "/profile create new_profile", - profile!(ProfileSubcommand::Create { - name: "new_profile".to_string(), - }), - ), - ( - "/profile delete p", - profile!(ProfileSubcommand::Delete { name: "p".to_string() }), - ), - ( - "/profile rename old new", - profile!(ProfileSubcommand::Rename { - old_name: "old".to_string(), - new_name: "new".to_string(), - }), - ), - ( - "/profile set p", - profile!(ProfileSubcommand::Set { name: "p".to_string() }), - ), - ( - "/profile set p", - profile!(ProfileSubcommand::Set { name: "p".to_string() }), - ), - ("/context show", context!(ContextSubcommand::Show { expand: false })), - ( - "/context show --expand", - context!(ContextSubcommand::Show { expand: true }), - ), - ( - "/context add p1 p2", - context!(ContextSubcommand::Add { - global: false, - force: false, - paths: vec!["p1".into(), "p2".into()] - }), - ), - ( - "/context add --global --force p1 p2", - context!(ContextSubcommand::Add { - global: true, - force: true, - paths: vec!["p1".into(), "p2".into()] - }), - ), - ( - "/context rm p1 p2", - context!(ContextSubcommand::Remove { - global: false, - paths: vec!["p1".into(), "p2".into()] - }), - ), - ( - "/context rm --global p1 p2", - context!(ContextSubcommand::Remove { - global: true, - paths: vec!["p1".into(), "p2".into()] - }), - ), - ("/context clear", context!(ContextSubcommand::Clear { global: false })), - ( - "/context clear --global", - context!(ContextSubcommand::Clear { global: true }), - ), - ("/issue", Command::Issue { prompt: None }), - ("/issue there was an error in the chat", Command::Issue { - prompt: Some("there was an error in the chat".to_string()), - }), - ("/issue \"there was an error in the chat\"", Command::Issue { - prompt: Some("\"there was an error in the chat\"".to_string()), - }), - ( - "/context hooks", - context!(ContextSubcommand::Hooks { subcommand: None }), - ), - ( - "/context hooks add test --trigger per_prompt --command 'echo 1' --global", - context!(ContextSubcommand::Hooks { - subcommand: Some(HooksSubcommand::Add { - name: "test".to_string(), - global: true, - trigger: "per_prompt".to_string(), - command: "echo 1".to_string() - }) - }), - ), - ( - "/context hooks rm test --global", - context!(ContextSubcommand::Hooks { - subcommand: Some(HooksSubcommand::Remove { - name: "test".to_string(), - global: true - }) - }), - ), - ( - "/context hooks enable test --global", - context!(ContextSubcommand::Hooks { - subcommand: Some(HooksSubcommand::Enable { - name: "test".to_string(), - global: true - }) - }), - ), - ( - "/context hooks disable test", - context!(ContextSubcommand::Hooks { - subcommand: Some(HooksSubcommand::Disable { - name: "test".to_string(), - global: false - }) - }), - ), - ( - "/context hooks enable-all --global", - context!(ContextSubcommand::Hooks { - subcommand: Some(HooksSubcommand::EnableAll { global: true }) - }), - ), - ( - "/context hooks disable-all", - context!(ContextSubcommand::Hooks { - subcommand: Some(HooksSubcommand::DisableAll { global: false }) - }), - ), - ( - "/context hooks help", - context!(ContextSubcommand::Hooks { - subcommand: Some(HooksSubcommand::Help) - }), - ), - ]; - - for (input, parsed) in tests { - assert_eq!(&Command::parse(input, &mut stdout).unwrap(), parsed, "{}", input); - } - } - - #[test] - fn test_common_command_suggestions() { - let mut stdout = std::io::stdout(); - let test_cases = vec![ - ( - "exit", - "Did you mean to use the command '/quit' to exit? Type '/quit' to exit.", - ), - ( - "quit", - "Did you mean to use the command '/quit' to exit? Type '/quit' to exit.", - ), - ( - "q", - "Did you mean to use the command '/quit' to exit? Type '/quit' to exit.", - ), - ( - "clear", - "Did you mean to use the command '/clear' to clear the conversation? Type '/clear' to clear.", - ), - ( - "cls", - "Did you mean to use the command '/clear' to clear the conversation? Type '/clear' to clear.", - ), - ( - "help", - "Did you mean to use the command '/help' for help? Type '/help' to see available commands.", - ), - ( - "?", - "Did you mean to use the command '/help' for help? Type '/help' to see available commands.", - ), - ]; - - for (input, expected_message) in test_cases { - let result = Command::parse(input, &mut stdout); - assert!(result.is_err(), "Expected error for input: {}", input); - assert_eq!(result.unwrap_err(), expected_message); - } - } -} diff --git a/crates/q_chat/src/consts.rs b/crates/q_chat/src/consts.rs deleted file mode 100644 index 6850f7efab..0000000000 --- a/crates/q_chat/src/consts.rs +++ /dev/null @@ -1,19 +0,0 @@ -use super::token_counter::TokenCounter; - -// These limits are the internal undocumented values from the service for each item - -pub const MAX_CURRENT_WORKING_DIRECTORY_LEN: usize = 256; - -/// Limit to send the number of messages as part of chat. -pub const MAX_CONVERSATION_STATE_HISTORY_LEN: usize = 250; - -pub const MAX_TOOL_RESPONSE_SIZE: usize = 800_000; - -/// TODO: Use this to gracefully handle user message sizes. -#[allow(dead_code)] -pub const MAX_USER_MESSAGE_SIZE: usize = 600_000; - -/// In tokens -pub const CONTEXT_WINDOW_SIZE: usize = 200_000; - -pub const MAX_CHARS: usize = TokenCounter::token_to_chars(CONTEXT_WINDOW_SIZE); // Character-based warning threshold diff --git a/crates/q_chat/src/context.rs b/crates/q_chat/src/context.rs deleted file mode 100644 index 6be6aeb7fe..0000000000 --- a/crates/q_chat/src/context.rs +++ /dev/null @@ -1,1016 +0,0 @@ -use std::collections::HashMap; -use std::io::Write; -use std::path::{ - Path, - PathBuf, -}; -use std::sync::Arc; - -use eyre::{ - Result, - eyre, -}; -use fig_os_shim::Context; -use fig_util::directories; -use glob::glob; -use regex::Regex; -use serde::{ - Deserialize, - Serialize, -}; -use tracing::debug; - -use super::hooks::{ - Hook, - HookExecutor, -}; - -pub const AMAZONQ_FILENAME: &str = "AmazonQ.md"; - -/// Configuration for context files, containing paths to include in the context. -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -#[serde(default)] -pub struct ContextConfig { - /// List of file paths or glob patterns to include in the context. - pub paths: Vec, - - /// Map of Hook Name to [`Hook`]. The hook name serves as the hook's ID. - pub hooks: HashMap, -} - -#[allow(dead_code)] -/// Manager for context files and profiles. -#[derive(Debug, Clone)] -pub struct ContextManager { - ctx: Arc, - - /// Global context configuration that applies to all profiles. - pub global_config: ContextConfig, - - /// Name of the current active profile. - pub current_profile: String, - - /// Context configuration for the current profile. - pub profile_config: ContextConfig, - - pub hook_executor: HookExecutor, -} - -#[allow(dead_code)] -impl ContextManager { - /// Create a new ContextManager with default settings. - /// - /// This will: - /// 1. Create the necessary directories if they don't exist - /// 2. Load the global configuration - /// 3. Load the default profile configuration - /// - /// # Returns - /// A Result containing the new ContextManager or an error - pub async fn new(ctx: Arc) -> Result { - let profiles_dir = directories::chat_profiles_dir(&ctx)?; - - ctx.fs().create_dir_all(&profiles_dir).await?; - - let global_config = load_global_config(&ctx).await?; - let current_profile = "default".to_string(); - let profile_config = load_profile_config(&ctx, ¤t_profile).await?; - - Ok(Self { - ctx, - global_config, - current_profile, - profile_config, - hook_executor: HookExecutor::new(), - }) - } - - /// Save the current configuration to disk. - /// - /// # Arguments - /// * `global` - If true, save the global configuration; otherwise, save the current profile - /// configuration - /// - /// # Returns - /// A Result indicating success or an error - async fn save_config(&self, global: bool) -> Result<()> { - if global { - let global_path = directories::chat_global_context_path(&self.ctx)?; - let contents = serde_json::to_string_pretty(&self.global_config) - .map_err(|e| eyre!("Failed to serialize global configuration: {}", e))?; - - self.ctx.fs().write(&global_path, contents).await?; - } else { - let profile_path = profile_context_path(&self.ctx, &self.current_profile)?; - if let Some(parent) = profile_path.parent() { - self.ctx.fs().create_dir_all(parent).await?; - } - let contents = serde_json::to_string_pretty(&self.profile_config) - .map_err(|e| eyre!("Failed to serialize profile configuration: {}", e))?; - - self.ctx.fs().write(&profile_path, contents).await?; - } - - Ok(()) - } - - /// Add paths to the context configuration. - /// - /// # Arguments - /// * `paths` - List of paths to add - /// * `global` - If true, add to global configuration; otherwise, add to current profile - /// configuration - /// * `force` - If true, skip validation that the path exists - /// - /// # Returns - /// A Result indicating success or an error - pub async fn add_paths(&mut self, paths: Vec, global: bool, force: bool) -> Result<()> { - let mut all_paths = self.global_config.paths.clone(); - all_paths.append(&mut self.profile_config.paths.clone()); - - // Validate paths exist before adding them - if !force { - let mut context_files = Vec::new(); - - // Check each path to make sure it exists or matches at least one file - for path in &paths { - // We're using a temporary context_files vector just for validation - // Pass is_validation=true to ensure we error if glob patterns don't match any files - match process_path(&self.ctx, path, &mut context_files, false, true).await { - Ok(_) => {}, // Path is valid - Err(e) => return Err(eyre!("Invalid path '{}': {}. Use --force to add anyway.", path, e)), - } - } - } - - // Add each path, checking for duplicates - for path in paths { - if all_paths.contains(&path) { - return Err(eyre!("Rule '{}' already exists.", path)); - } - if global { - self.global_config.paths.push(path); - } else { - self.profile_config.paths.push(path); - } - } - - // Save the updated configuration - self.save_config(global).await?; - - Ok(()) - } - - /// Remove paths from the context configuration. - /// - /// # Arguments - /// * `paths` - List of paths to remove - /// * `global` - If true, remove from global configuration; otherwise, remove from current - /// profile configuration - /// - /// # Returns - /// A Result indicating success or an error - pub async fn remove_paths(&mut self, paths: Vec, global: bool) -> Result<()> { - // Get reference to the appropriate config - let config = self.get_config_mut(global); - - // Track if any paths were removed - let mut removed_any = false; - - // Remove each path if it exists - for path in paths { - let original_len = config.paths.len(); - config.paths.retain(|p| p != &path); - - if config.paths.len() < original_len { - removed_any = true; - } - } - - if !removed_any { - return Err(eyre!("None of the specified paths were found in the context")); - } - - // Save the updated configuration - self.save_config(global).await?; - - Ok(()) - } - - /// List all available profiles. - /// - /// # Returns - /// A Result containing a vector of profile names, with "default" always first - pub async fn list_profiles(&self) -> Result> { - let mut profiles = Vec::new(); - - // Always include default profile - profiles.push("default".to_string()); - - // Read profile directory and extract profile names - let profiles_dir = directories::chat_profiles_dir(&self.ctx)?; - if profiles_dir.exists() { - let mut read_dir = self.ctx.fs().read_dir(&profiles_dir).await?; - while let Some(entry) = read_dir.next_entry().await? { - let path = entry.path(); - if let (true, Some(name)) = (path.is_dir(), path.file_name()) { - if name != "default" { - profiles.push(name.to_string_lossy().to_string()); - } - } - } - } - - // Sort non-default profiles alphabetically - if profiles.len() > 1 { - profiles[1..].sort(); - } - - Ok(profiles) - } - - /// List all available profiles using blocking operations. - /// - /// Similar to list_profiles but uses synchronous filesystem operations. - /// - /// # Returns - /// A Result containing a vector of profile names, with "default" always first - pub fn list_profiles_blocking(&self) -> Result> { - let mut profiles = Vec::new(); - - // Always include default profile - profiles.push("default".to_string()); - - // Read profile directory and extract profile names - let profiles_dir = directories::chat_profiles_dir(&self.ctx)?; - if profiles_dir.exists() { - for entry in std::fs::read_dir(profiles_dir)? { - let entry = entry?; - let path = entry.path(); - if let (true, Some(name)) = (path.is_dir(), path.file_name()) { - if name != "default" { - profiles.push(name.to_string_lossy().to_string()); - } - } - } - } - - // Sort non-default profiles alphabetically - if profiles.len() > 1 { - profiles[1..].sort(); - } - - Ok(profiles) - } - - /// Clear all paths from the context configuration. - /// - /// # Arguments - /// * `global` - If true, clear global configuration; otherwise, clear current profile - /// configuration - /// - /// # Returns - /// A Result indicating success or an error - pub async fn clear(&mut self, global: bool) -> Result<()> { - // Clear the appropriate config - if global { - self.global_config.paths.clear(); - } else { - self.profile_config.paths.clear(); - } - - // Save the updated configuration - self.save_config(global).await?; - - Ok(()) - } - - /// Create a new profile. - /// - /// # Arguments - /// * `name` - Name of the profile to create - /// - /// # Returns - /// A Result indicating success or an error - pub async fn create_profile(&self, name: &str) -> Result<()> { - validate_profile_name(name)?; - - // Check if profile already exists - let profile_path = profile_context_path(&self.ctx, name)?; - if profile_path.exists() { - return Err(eyre!("Profile '{}' already exists", name)); - } - - // Create empty profile configuration - let config = ContextConfig::default(); - let contents = serde_json::to_string_pretty(&config) - .map_err(|e| eyre!("Failed to serialize profile configuration: {}", e))?; - - // Create the file - if let Some(parent) = profile_path.parent() { - self.ctx.fs().create_dir_all(parent).await?; - } - self.ctx.fs().write(&profile_path, contents).await?; - - Ok(()) - } - - /// Delete a profile. - /// - /// # Arguments - /// * `name` - Name of the profile to delete - /// - /// # Returns - /// A Result indicating success or an error - pub async fn delete_profile(&self, name: &str) -> Result<()> { - if name == "default" { - return Err(eyre!("Cannot delete the default profile")); - } else if name == self.current_profile { - return Err(eyre!( - "Cannot delete the active profile. Switch to another profile first" - )); - } - - let profile_path = profile_dir_path(&self.ctx, name)?; - if !profile_path.exists() { - return Err(eyre!("Profile '{}' does not exist", name)); - } - - self.ctx.fs().remove_dir_all(&profile_path).await?; - - Ok(()) - } - - /// Rename a profile. - /// - /// # Arguments - /// * `old_name` - Current name of the profile - /// * `new_name` - New name for the profile - /// - /// # Returns - /// A Result indicating success or an error - pub async fn rename_profile(&mut self, old_name: &str, new_name: &str) -> Result<()> { - // Validate profile names - if old_name == "default" { - return Err(eyre!("Cannot rename the default profile")); - } - if new_name == "default" { - return Err(eyre!("Cannot rename to 'default' as it's a reserved profile name")); - } - - validate_profile_name(new_name)?; - - let old_profile_path = profile_dir_path(&self.ctx, old_name)?; - if !old_profile_path.exists() { - return Err(eyre!("Profile '{}' not found", old_name)); - } - - let new_profile_path = profile_dir_path(&self.ctx, new_name)?; - if new_profile_path.exists() { - return Err(eyre!("Profile '{}' already exists", new_name)); - } - - self.ctx.fs().rename(&old_profile_path, &new_profile_path).await?; - - // If the current profile is being renamed, update the current_profile field - if self.current_profile == old_name { - self.current_profile = new_name.to_string(); - self.profile_config = load_profile_config(&self.ctx, new_name).await?; - } - - Ok(()) - } - - /// Switch to a different profile. - /// - /// # Arguments - /// * `name` - Name of the profile to switch to - /// - /// # Returns - /// A Result indicating success or an error - pub async fn switch_profile(&mut self, name: &str) -> Result<()> { - validate_profile_name(name)?; - self.hook_executor.profile_cache.clear(); - - // Special handling for default profile - it always exists - if name == "default" { - // Load the default profile configuration - let profile_config = load_profile_config(&self.ctx, name).await?; - - // Update the current profile - self.current_profile = name.to_string(); - self.profile_config = profile_config; - - return Ok(()); - } - - // Check if profile exists - let profile_path = profile_context_path(&self.ctx, name)?; - if !profile_path.exists() { - return Err(eyre!("Profile '{}' does not exist. Use 'create' to create it", name)); - } - - // Update the current profile - self.current_profile = name.to_string(); - self.profile_config = load_profile_config(&self.ctx, name).await?; - - Ok(()) - } - - /// Get all context files (global + profile-specific). - /// - /// This method: - /// 1. Processes all paths in the global and profile configurations - /// 2. Expands glob patterns to include matching files - /// 3. Reads the content of each file - /// 4. Returns a vector of (filename, content) pairs - /// - /// # Arguments - /// * `force` - If true, include paths that don't exist yet - /// - /// # Returns - /// A Result containing a vector of (filename, content) pairs or an error - pub async fn get_context_files(&self, force: bool) -> Result> { - let mut context_files = Vec::new(); - - self.collect_context_files(&self.global_config.paths, &mut context_files, force) - .await?; - self.collect_context_files(&self.profile_config.paths, &mut context_files, force) - .await?; - - context_files.sort_by(|a, b| a.0.cmp(&b.0)); - context_files.dedup_by(|a, b| a.0 == b.0); - - Ok(context_files) - } - - pub async fn get_context_files_by_path(&self, force: bool, path: &str) -> Result> { - let mut context_files = Vec::new(); - process_path(&self.ctx, path, &mut context_files, force, true).await?; - Ok(context_files) - } - - /// Get all context files from the global configuration. - pub async fn get_global_context_files(&self, force: bool) -> Result> { - let mut context_files = Vec::new(); - - self.collect_context_files(&self.global_config.paths, &mut context_files, force) - .await?; - - Ok(context_files) - } - - /// Get all context files from the current profile configuration. - pub async fn get_current_profile_context_files(&self, force: bool) -> Result> { - let mut context_files = Vec::new(); - - self.collect_context_files(&self.profile_config.paths, &mut context_files, force) - .await?; - - Ok(context_files) - } - - async fn collect_context_files( - &self, - paths: &[String], - context_files: &mut Vec<(String, String)>, - force: bool, - ) -> Result<()> { - for path in paths { - // Use is_validation=false to handle non-matching globs gracefully - process_path(&self.ctx, path, context_files, force, false).await?; - } - Ok(()) - } - - fn get_config_mut(&mut self, global: bool) -> &mut ContextConfig { - if global { - &mut self.global_config - } else { - &mut self.profile_config - } - } - - /// Add hooks to the context config. If another hook with the same name already exists, throw an - /// error. - /// - /// # Arguments - /// * `hook` - name of the hook to delete - /// * `global` - If true, the add to the global config. If false, add to the current profile - /// config. - /// * `conversation_start` - If true, add the hook to conversation_start. Otherwise, it will be - /// added to per_prompt. - pub async fn add_hook(&mut self, name: String, hook: Hook, global: bool) -> Result<()> { - let config = self.get_config_mut(global); - - if config.hooks.contains_key(&name) { - return Err(eyre!("name already exists.")); - } - - config.hooks.insert(name, hook); - self.save_config(global).await - } - - /// Delete hook(s) by name - /// # Arguments - /// * `name` - name of the hook to delete - /// * `global` - If true, the delete from the global config. If false, delete from the current - /// profile config - pub async fn remove_hook(&mut self, name: &str, global: bool) -> Result<()> { - let config = self.get_config_mut(global); - - if !config.hooks.contains_key(name) { - return Err(eyre!("does not exist.")); - } - - config.hooks.remove(name); - - self.save_config(global).await - } - - /// Sets the "disabled" field on any [`Hook`] with the given name - /// # Arguments - /// * `disable` - Set "disabled" field to this value - pub async fn set_hook_disabled(&mut self, name: &str, global: bool, disable: bool) -> Result<()> { - let config = self.get_config_mut(global); - - if !config.hooks.contains_key(name) { - return Err(eyre!("does not exist.")); - } - - if let Some(hook) = config.hooks.get_mut(name) { - hook.disabled = disable; - } - - self.save_config(global).await - } - - /// Sets the "disabled" field on all [`Hook`]s - /// # Arguments - /// * `disable` - Set all "disabled" fields to this value - pub async fn set_all_hooks_disabled(&mut self, global: bool, disable: bool) -> Result<()> { - let config = self.get_config_mut(global); - - config.hooks.iter_mut().for_each(|(_, h)| h.disabled = disable); - - self.save_config(global).await - } - - /// Run all the currently enabled hooks from both the global and profile contexts. - /// Skipped hooks (disabled) will not appear in the output. - /// # Arguments - /// * `updates` - output stream to write hook run status to if Some, else do nothing if None - /// # Returns - /// A vector containing pairs of a [`Hook`] definition and its execution output - pub async fn run_hooks(&mut self, updates: Option<&mut impl Write>) -> Vec<(Hook, String)> { - let mut hooks: Vec<&Hook> = Vec::new(); - - // Set internal hook states - let configs = [ - (&mut self.global_config.hooks, true), - (&mut self.profile_config.hooks, false), - ]; - - for (hook_list, is_global) in configs { - hooks.extend(hook_list.iter_mut().map(|(name, h)| { - h.name = name.to_string(); - h.is_global = is_global; - &*h - })); - } - - self.hook_executor.run_hooks(hooks, updates).await - } -} - -fn profile_dir_path(ctx: &Context, profile_name: &str) -> Result { - Ok(directories::chat_profiles_dir(ctx)?.join(profile_name)) -} - -/// Path to the context config file for `profile_name`. -pub fn profile_context_path(ctx: &Context, profile_name: &str) -> Result { - Ok(directories::chat_profiles_dir(ctx)? - .join(profile_name) - .join("context.json")) -} - -/// Load the global context configuration. -/// -/// If the global configuration file doesn't exist, returns a default configuration. -async fn load_global_config(ctx: &Context) -> Result { - let global_path = directories::chat_global_context_path(&ctx)?; - debug!(?global_path, "loading profile config"); - if ctx.fs().exists(&global_path) { - let contents = ctx.fs().read_to_string(&global_path).await?; - let config: ContextConfig = - serde_json::from_str(&contents).map_err(|e| eyre!("Failed to parse global configuration: {}", e))?; - Ok(config) - } else { - // Return default global configuration with predefined paths - Ok(ContextConfig { - paths: vec![ - ".amazonq/rules/**/*.md".to_string(), - "README.md".to_string(), - AMAZONQ_FILENAME.to_string(), - ], - hooks: HashMap::new(), - }) - } -} - -/// Load a profile's context configuration. -/// -/// If the profile configuration file doesn't exist, creates a default configuration. -async fn load_profile_config(ctx: &Context, profile_name: &str) -> Result { - let profile_path = profile_context_path(ctx, profile_name)?; - debug!(?profile_path, "loading profile config"); - if ctx.fs().exists(&profile_path) { - let contents = ctx.fs().read_to_string(&profile_path).await?; - let config: ContextConfig = - serde_json::from_str(&contents).map_err(|e| eyre!("Failed to parse profile configuration: {}", e))?; - Ok(config) - } else { - // Return empty configuration for new profiles - Ok(ContextConfig::default()) - } -} - -/// Process a path, handling glob patterns and file types. -/// -/// This method: -/// 1. Expands the path (handling ~ for home directory) -/// 2. If the path contains glob patterns, expands them -/// 3. For each resulting path, adds the file to the context collection -/// 4. Handles directories by including all files in the directory (non-recursive) -/// 5. With force=true, includes paths that don't exist yet -/// -/// # Arguments -/// * `path` - The path to process -/// * `context_files` - The collection to add files to -/// * `force` - If true, include paths that don't exist yet -/// * `is_validation` - If true, error when glob patterns don't match; if false, silently skip -/// -/// # Returns -/// A Result indicating success or an error -async fn process_path( - ctx: &Context, - path: &str, - context_files: &mut Vec<(String, String)>, - force: bool, - is_validation: bool, -) -> Result<()> { - // Expand ~ to home directory - let expanded_path = if path.starts_with('~') { - if let Some(home_dir) = ctx.env().home() { - home_dir.join(&path[2..]).to_string_lossy().to_string() - } else { - return Err(eyre!("Could not determine home directory")); - } - } else { - path.to_string() - }; - - // Handle absolute, relative paths, and glob patterns - let full_path = if expanded_path.starts_with('/') { - expanded_path - } else { - ctx.env() - .current_dir()? - .join(&expanded_path) - .to_string_lossy() - .to_string() - }; - - // Required in chroot testing scenarios so that we can use `Path::exists`. - let full_path = ctx.fs().chroot_path_str(full_path); - - // Check if the path contains glob patterns - if full_path.contains('*') || full_path.contains('?') || full_path.contains('[') { - // Expand glob pattern - match glob(&full_path) { - Ok(entries) => { - let mut found_any = false; - - for entry in entries { - match entry { - Ok(path) => { - if path.is_file() { - add_file_to_context(ctx, &path, context_files).await?; - found_any = true; - } - }, - Err(e) => return Err(eyre!("Glob error: {}", e)), - } - } - - if !found_any && !force && is_validation { - // When validating paths (e.g., for /context add), error if no files match - return Err(eyre!("No files found matching glob pattern '{}'", full_path)); - } - // When just showing expanded files (e.g., for /context show --expand), - // silently skip non-matching patterns (don't add anything to context_files) - }, - Err(e) => return Err(eyre!("Invalid glob pattern '{}': {}", full_path, e)), - } - } else { - // Regular path - let path = Path::new(&full_path); - if path.exists() { - if path.is_file() { - add_file_to_context(ctx, path, context_files).await?; - } else if path.is_dir() { - // For directories, add all files in the directory (non-recursive) - let mut read_dir = ctx.fs().read_dir(path).await?; - while let Some(entry) = read_dir.next_entry().await? { - let path = entry.path(); - if path.is_file() { - add_file_to_context(ctx, &path, context_files).await?; - } - } - } - } else if !force && is_validation { - // When validating paths (e.g., for /context add), error if the path doesn't exist - return Err(eyre!("Path '{}' does not exist", full_path)); - } else if force { - // When using --force, we'll add the path even though it doesn't exist - // This allows users to add paths that will exist in the future - context_files.push((full_path.clone(), format!("(Path '{}' does not exist yet)", full_path))); - } - // When just showing expanded files (e.g., for /context show --expand), - // silently skip non-existent paths if is_validation is false - } - - Ok(()) -} - -/// Add a file to the context collection. -/// -/// This method: -/// 1. Reads the content of the file -/// 2. Adds the (filename, content) pair to the context collection -/// -/// # Arguments -/// * `path` - The path to the file -/// * `context_files` - The collection to add the file to -/// -/// # Returns -/// A Result indicating success or an error -async fn add_file_to_context(ctx: &Context, path: &Path, context_files: &mut Vec<(String, String)>) -> Result<()> { - let filename = path.to_string_lossy().to_string(); - let content = ctx.fs().read_to_string(path).await?; - context_files.push((filename, content)); - Ok(()) -} - -/// Validate a profile name. -/// -/// Profile names can only contain alphanumeric characters, hyphens, and underscores. -/// -/// # Arguments -/// * `name` - Name to validate -/// -/// # Returns -/// A Result indicating if the name is valid -fn validate_profile_name(name: &str) -> Result<()> { - // Check if name is empty - if name.is_empty() { - return Err(eyre!("Profile name cannot be empty")); - } - - // Check if name contains only allowed characters and starts with an alphanumeric character - let re = Regex::new(r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$").unwrap(); - if !re.is_match(name) { - return Err(eyre!( - "Profile name must start with an alphanumeric character and can only contain alphanumeric characters, hyphens, and underscores" - )); - } - - Ok(()) -} - -#[cfg(test)] -mod tests { - use std::io::Stdout; - - use super::super::hooks::HookTrigger; - use super::*; - - // Helper function to create a test ContextManager with Context - pub async fn create_test_context_manager() -> Result { - let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); - let manager = ContextManager::new(ctx).await?; - Ok(manager) - } - - #[tokio::test] - async fn test_validate_profile_name() { - // Test valid names - assert!(validate_profile_name("valid").is_ok()); - assert!(validate_profile_name("valid-name").is_ok()); - assert!(validate_profile_name("valid_name").is_ok()); - assert!(validate_profile_name("valid123").is_ok()); - assert!(validate_profile_name("1valid").is_ok()); - assert!(validate_profile_name("9test").is_ok()); - - // Test invalid names - assert!(validate_profile_name("").is_err()); - assert!(validate_profile_name("invalid/name").is_err()); - assert!(validate_profile_name("invalid.name").is_err()); - assert!(validate_profile_name("invalid name").is_err()); - assert!(validate_profile_name("_invalid").is_err()); - assert!(validate_profile_name("-invalid").is_err()); - } - - #[tokio::test] - async fn test_profile_ops() -> Result<()> { - let mut manager = create_test_context_manager().await?; - let ctx = Arc::clone(&manager.ctx); - - assert_eq!(manager.current_profile, "default"); - - // Create ops - manager.create_profile("test_profile").await?; - assert!(profile_context_path(&ctx, "test_profile")?.exists()); - assert!(manager.create_profile("test_profile").await.is_err()); - manager.create_profile("alt").await?; - - // Listing - let profiles = manager.list_profiles().await?; - assert!(profiles.contains(&"default".to_string())); - assert!(profiles.contains(&"test_profile".to_string())); - assert!(profiles.contains(&"alt".to_string())); - - // Switching - manager.switch_profile("test_profile").await?; - assert!(manager.switch_profile("notexists").await.is_err()); - - // Renaming - manager.rename_profile("alt", "renamed").await?; - assert!(!profile_context_path(&ctx, "alt")?.exists()); - assert!(profile_context_path(&ctx, "renamed")?.exists()); - - // Delete ops - assert!(manager.delete_profile("test_profile").await.is_err()); - manager.switch_profile("default").await?; - manager.delete_profile("test_profile").await?; - assert!(!profile_context_path(&ctx, "test_profile")?.exists()); - assert!(manager.delete_profile("test_profile").await.is_err()); - assert!(manager.delete_profile("default").await.is_err()); - - Ok(()) - } - - #[tokio::test] - async fn test_path_ops() -> Result<()> { - let mut manager = create_test_context_manager().await?; - let ctx = Arc::clone(&manager.ctx); - - // Create some test files for matching. - ctx.fs().create_dir_all("test").await?; - ctx.fs().write("test/p1.md", "p1").await?; - ctx.fs().write("test/p2.md", "p2").await?; - - assert!( - manager.get_context_files(false).await?.is_empty(), - "no files should be returned for an empty profile when force is false" - ); - assert_eq!( - manager.get_context_files(true).await?.len(), - 2, - "default non-glob global files should be included when force is true" - ); - - manager.add_paths(vec!["test/*.md".to_string()], false, false).await?; - let files = manager.get_context_files(false).await?; - assert!(files[0].0.ends_with("p1.md")); - assert_eq!(files[0].1, "p1"); - assert!(files[1].0.ends_with("p2.md")); - assert_eq!(files[1].1, "p2"); - - assert!( - manager - .add_paths(vec!["test/*.txt".to_string()], false, false) - .await - .is_err(), - "adding a glob with no matching and without force should fail" - ); - - Ok(()) - } - - #[tokio::test] - async fn test_add_hook() -> Result<()> { - let mut manager = create_test_context_manager().await?; - let hook = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - - // Test adding hook to profile config - manager.add_hook("test_hook".to_string(), hook.clone(), false).await?; - assert!(manager.profile_config.hooks.contains_key("test_hook")); - - // Test adding hook to global config - manager.add_hook("global_hook".to_string(), hook.clone(), true).await?; - assert!(manager.global_config.hooks.contains_key("global_hook")); - - // Test adding duplicate hook name - assert!(manager.add_hook("test_hook".to_string(), hook, false).await.is_err()); - - Ok(()) - } - - #[tokio::test] - async fn test_remove_hook() -> Result<()> { - let mut manager = create_test_context_manager().await?; - let hook = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - - manager.add_hook("test_hook".to_string(), hook, false).await?; - - // Test removing existing hook - manager.remove_hook("test_hook", false).await?; - assert!(!manager.profile_config.hooks.contains_key("test_hook")); - - // Test removing non-existent hook - assert!(manager.remove_hook("test_hook", false).await.is_err()); - - Ok(()) - } - - #[tokio::test] - async fn test_set_hook_disabled() -> Result<()> { - let mut manager = create_test_context_manager().await?; - let hook = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - - manager.add_hook("test_hook".to_string(), hook, false).await?; - - // Test disabling hook - manager.set_hook_disabled("test_hook", false, true).await?; - assert!(manager.profile_config.hooks.get("test_hook").unwrap().disabled); - - // Test enabling hook - manager.set_hook_disabled("test_hook", false, false).await?; - assert!(!manager.profile_config.hooks.get("test_hook").unwrap().disabled); - - // Test with non-existent hook - assert!(manager.set_hook_disabled("nonexistent", false, true).await.is_err()); - - Ok(()) - } - - #[tokio::test] - async fn test_set_all_hooks_disabled() -> Result<()> { - let mut manager = create_test_context_manager().await?; - let hook1 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - let hook2 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - - manager.add_hook("hook1".to_string(), hook1, false).await?; - manager.add_hook("hook2".to_string(), hook2, false).await?; - - // Test disabling all hooks - manager.set_all_hooks_disabled(false, true).await?; - assert!(manager.profile_config.hooks.values().all(|h| h.disabled)); - - // Test enabling all hooks - manager.set_all_hooks_disabled(false, false).await?; - assert!(manager.profile_config.hooks.values().all(|h| !h.disabled)); - - Ok(()) - } - - #[tokio::test] - async fn test_run_hooks() -> Result<()> { - let mut manager = create_test_context_manager().await?; - let hook1 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - let hook2 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - - manager.add_hook("hook1".to_string(), hook1, false).await?; - manager.add_hook("hook2".to_string(), hook2, false).await?; - - // Run the hooks - let results = manager.run_hooks(None::<&mut Stdout>).await; - assert_eq!(results.len(), 2); // Should include both hooks - - Ok(()) - } - - #[tokio::test] - async fn test_hooks_across_profiles() -> Result<()> { - let mut manager = create_test_context_manager().await?; - let hook1 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - let hook2 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo test".to_string()); - - manager.add_hook("profile_hook".to_string(), hook1, false).await?; - manager.add_hook("global_hook".to_string(), hook2, true).await?; - - let results = manager.run_hooks(None::<&mut Stdout>).await; - assert_eq!(results.len(), 2); // Should include both hooks - - // Create and switch to a new profile - manager.create_profile("test_profile").await?; - manager.switch_profile("test_profile").await?; - - let results = manager.run_hooks(None::<&mut Stdout>).await; - assert_eq!(results.len(), 1); // Should include global hook - assert_eq!(results[0].0.name, "global_hook"); - - Ok(()) - } -} diff --git a/crates/q_chat/src/conversation_state.rs b/crates/q_chat/src/conversation_state.rs deleted file mode 100644 index 75a636c6d8..0000000000 --- a/crates/q_chat/src/conversation_state.rs +++ /dev/null @@ -1,1051 +0,0 @@ -use std::collections::{ - HashMap, - VecDeque, -}; -use std::sync::Arc; - -use fig_api_client::model::{ - AssistantResponseMessage, - ChatMessage, - ConversationState as FigConversationState, - Tool, - ToolInputSchema, - ToolResult, - ToolResultContentBlock, - ToolSpecification, - ToolUse, - UserInputMessage, - UserInputMessageContext, -}; -use fig_os_shim::Context; -use mcp_client::Prompt; -use tracing::{ - debug, - error, - warn, -}; - -use super::consts::{ - MAX_CHARS, - MAX_CONVERSATION_STATE_HISTORY_LEN, -}; -use super::context::ContextManager; -use super::hooks::{ - Hook, - HookTrigger, -}; -use super::message::{ - AssistantMessage, - ToolUseResult, - ToolUseResultBlock, - UserMessage, - UserMessageContent, - build_env_state, -}; -use super::token_counter::{ - CharCount, - CharCounter, -}; -use super::tools::{ - InputSchema, - QueuedTool, - ToolOrigin, - ToolSpec, - serde_value_to_document, -}; - -const CONTEXT_ENTRY_START_HEADER: &str = "--- CONTEXT ENTRY BEGIN ---\n"; -const CONTEXT_ENTRY_END_HEADER: &str = "--- CONTEXT ENTRY END ---\n\n"; - -use crate::util::shared_writer::SharedWriter; -/// Tracks state related to an ongoing conversation. -#[derive(Debug, Clone)] -pub struct ConversationState { - /// Randomly generated on creation. - conversation_id: String, - /// The next user message to be sent as part of the conversation. Required to be [Some] before - /// calling [Self::as_sendable_conversation_state]. - next_message: Option, - history: VecDeque<(UserMessage, AssistantMessage)>, - /// The range in the history sendable to the backend (start inclusive, end exclusive). - valid_history_range: (usize, usize), - /// Similar to history in that stores user and assistant responses, except that it is not used - /// in message requests. Instead, the responses are expected to be in human-readable format, - /// e.g user messages prefixed with '> '. Should also be used to store errors posted in the - /// chat. - pub transcript: VecDeque, - pub tools: HashMap>, - /// Context manager for handling sticky context files - pub context_manager: Option, - /// Cached value representing the length of the user context message. - context_message_length: Option, - /// Stores the latest conversation summary created by /compact - latest_summary: Option, - updates: Option, -} - -impl ConversationState { - pub async fn new( - ctx: Arc, - conversation_id: &str, - tool_config: HashMap, - profile: Option, - updates: Option, - ) -> Self { - // Initialize context manager - let context_manager = match ContextManager::new(ctx).await { - Ok(mut manager) => { - // Switch to specified profile if provided - if let Some(profile_name) = profile { - if let Err(e) = manager.switch_profile(&profile_name).await { - warn!("Failed to switch to profile {}: {}", profile_name, e); - } - } - Some(manager) - }, - Err(e) => { - warn!("Failed to initialize context manager: {}", e); - None - }, - }; - - Self { - conversation_id: conversation_id.to_string(), - next_message: None, - history: VecDeque::new(), - valid_history_range: Default::default(), - transcript: VecDeque::with_capacity(MAX_CONVERSATION_STATE_HISTORY_LEN), - tools: tool_config - .into_values() - .fold(HashMap::>::new(), |mut acc, v| { - let tool = Tool::ToolSpecification(ToolSpecification { - name: v.name, - description: v.description, - input_schema: v.input_schema.into(), - }); - acc.entry(v.tool_origin) - .and_modify(|tools| tools.push(tool.clone())) - .or_insert(vec![tool]); - acc - }), - context_manager, - context_message_length: None, - latest_summary: None, - updates, - } - } - - pub fn history(&self) -> &VecDeque<(UserMessage, AssistantMessage)> { - &self.history - } - - /// Clears the conversation history and optionally the summary. - pub fn clear(&mut self, preserve_summary: bool) { - self.next_message = None; - self.history.clear(); - if !preserve_summary { - self.latest_summary = None; - } - } - - /// Appends a collection prompts into history and returns the last message in the collection. - /// It asserts that the collection ends with a prompt that assumes the role of user. - pub fn append_prompts(&mut self, mut prompts: VecDeque) -> Option { - debug_assert!(self.next_message.is_none(), "next_message should not exist"); - debug_assert!(prompts.back().is_some_and(|p| p.role == mcp_client::Role::User)); - let last_msg = prompts.pop_back()?; - let (mut candidate_user, mut candidate_asst) = (None::, None::); - while let Some(prompt) = prompts.pop_front() { - let Prompt { role, content } = prompt; - match role { - mcp_client::Role::User => { - let user_msg = UserMessage::new_prompt(content.to_string()); - candidate_user.replace(user_msg); - }, - mcp_client::Role::Assistant => { - let assistant_msg = AssistantMessage::new_response(None, content.into()); - candidate_asst.replace(assistant_msg); - }, - } - if candidate_asst.is_some() && candidate_user.is_some() { - let asst = candidate_asst.take().unwrap(); - let user = candidate_user.take().unwrap(); - self.append_assistant_transcript(&asst); - self.history.push_back((user, asst)); - } - } - Some(last_msg.content.to_string()) - } - - pub fn next_user_message(&self) -> Option<&UserMessage> { - self.next_message.as_ref() - } - - pub fn reset_next_user_message(&mut self) { - self.next_message = None; - } - - pub async fn set_next_user_message(&mut self, input: String) { - debug_assert!(self.next_message.is_none(), "next_message should not exist"); - if let Some(next_message) = self.next_message.as_ref() { - warn!(?next_message, "next_message should not exist"); - } - - let input = if input.is_empty() { - warn!("input must not be empty when adding new messages"); - "Empty prompt".to_string() - } else { - input - }; - - let msg = UserMessage::new_prompt(input); - self.next_message = Some(msg); - } - - /// Sets the response message according to the currently set [Self::next_message]. - pub fn push_assistant_message(&mut self, message: AssistantMessage) { - debug_assert!(self.next_message.is_some(), "next_message should exist"); - let next_user_message = self.next_message.take().expect("next user message should exist"); - - self.append_assistant_transcript(&message); - self.history.push_back((next_user_message, message)); - } - - /// Returns the conversation id. - pub fn conversation_id(&self) -> &str { - self.conversation_id.as_ref() - } - - /// Returns the message id associated with the last assistant message, if present. - /// - /// This is equivalent to `utterance_id` in the Q API. - pub fn message_id(&self) -> Option<&str> { - self.history.back().and_then(|(_, msg)| msg.message_id()) - } - - /// Updates the history so that, when non-empty, the following invariants are in place: - /// 1. The history length is `<= MAX_CONVERSATION_STATE_HISTORY_LEN`. Oldest messages are - /// dropped. - /// 2. The first message is from the user, and does not contain tool results. Oldest messages - /// are dropped. - /// 3. If the last message from the assistant contains tool results, and a next user message is - /// set without tool results, then the user message will have "cancelled" tool results. - pub fn enforce_conversation_invariants(&mut self) { - // First set the valid range as the entire history - this will be truncated as necessary - // later below. - self.valid_history_range = (0, self.history.len()); - - // Trim the conversation history by finding the second oldest message from the user without - // tool results - this will be the new oldest message in the history. - // - // Note that we reserve extra slots for [ConversationState::context_messages]. - if (self.history.len() * 2) > MAX_CONVERSATION_STATE_HISTORY_LEN - 6 { - match self - .history - .iter() - .enumerate() - .skip(1) - .find(|(_, (m, _))| -> bool { !m.has_tool_use_results() }) - .map(|v| v.0) - { - Some(i) => { - debug!("removing the first {i} user/assistant response pairs in the history"); - self.valid_history_range.0 = i; - }, - None => { - debug!("no valid starting user message found in the history, clearing"); - self.valid_history_range = (0, 0); - // Edge case: if the next message contains tool results, then we have to just - // abandon them. - if self.next_message.as_ref().is_some_and(|m| m.has_tool_use_results()) { - debug!("abandoning tool results"); - self.next_message = Some(UserMessage::new_prompt( - "The conversation history has overflowed, clearing state".to_string(), - )); - } - }, - } - } - - // If the last message from the assistant contains tool uses AND next_message is set, we need to - // ensure that next_message contains tool results. - if let (Some((_, AssistantMessage::ToolUse { tool_uses, .. })), Some(user_msg)) = ( - self.history - .range(self.valid_history_range.0..self.valid_history_range.1) - .last(), - &mut self.next_message, - ) { - if !user_msg.has_tool_use_results() { - debug!( - "last assistant message contains tool uses, but next message is set and does not contain tool results. setting tool results as cancelled" - ); - *user_msg = UserMessage::new_cancelled_tool_uses( - user_msg.prompt().map(|p| p.to_string()), - tool_uses.iter().map(|t| t.id.as_str()), - ); - } - } - } - - pub fn add_tool_results(&mut self, tool_results: Vec) { - debug_assert!(self.next_message.is_none()); - self.next_message = Some(UserMessage::new_tool_use_results(tool_results)); - } - - /// Sets the next user message with "cancelled" tool results. - pub fn abandon_tool_use(&mut self, tools_to_be_abandoned: Vec, deny_input: String) { - self.next_message = Some(UserMessage::new_cancelled_tool_uses( - Some(deny_input), - tools_to_be_abandoned.iter().map(|t| t.id.as_str()), - )); - } - - /// Returns a [FigConversationState] capable of being sent by [fig_api_client::StreamingClient]. - /// - /// Params: - /// - `run_hooks` - whether hooks should be executed and included as context - pub async fn as_sendable_conversation_state(&mut self, run_hooks: bool) -> FigConversationState { - debug_assert!(self.next_message.is_some()); - self.enforce_conversation_invariants(); - self.history.drain(self.valid_history_range.1..); - self.history.drain(..self.valid_history_range.0); - - self.backend_conversation_state(run_hooks, false) - .await - .into_fig_conversation_state() - .expect("unable to construct conversation state") - } - - /// Returns a conversation state representation which reflects the exact conversation to send - /// back to the model. - pub async fn backend_conversation_state(&mut self, run_hooks: bool, quiet: bool) -> BackendConversationState<'_> { - self.enforce_conversation_invariants(); - - // Run hooks and add to conversation start and next user message. - let mut conversation_start_context = None; - if let (true, Some(cm)) = (run_hooks, self.context_manager.as_mut()) { - let mut null_writer = SharedWriter::null(); - let updates = if quiet { - None - } else { - Some(self.updates.as_mut().unwrap_or(&mut null_writer)) - }; - - let hook_results = cm.run_hooks(updates).await; - conversation_start_context = Some(format_hook_context(hook_results.iter(), HookTrigger::ConversationStart)); - - // add per prompt content to next_user_message if available - if let Some(next_message) = self.next_message.as_mut() { - next_message.additional_context = format_hook_context(hook_results.iter(), HookTrigger::PerPrompt); - } - } - - let context_messages = self.context_messages(conversation_start_context).await; - - BackendConversationState { - conversation_id: self.conversation_id.as_str(), - next_user_message: self.next_message.as_ref(), - history: self - .history - .range(self.valid_history_range.0..self.valid_history_range.1), - context_messages, - tools: &self.tools, - } - } - - /// Returns a [FigConversationState] capable of replacing the history of the current - /// conversation with a summary generated by the model. - pub async fn create_summary_request(&mut self, custom_prompt: Option>) -> FigConversationState { - let summary_content = match custom_prompt { - Some(custom_prompt) => { - // Make the custom instructions much more prominent and directive - format!( - "[SYSTEM NOTE: This is an automated summarization request, not from the user]\n\n\ - FORMAT REQUIREMENTS: Create a structured, concise summary in bullet-point format. DO NOT respond conversationally. DO NOT address the user directly.\n\n\ - IMPORTANT CUSTOM INSTRUCTION: {}\n\n\ - Your task is to create a structured summary document containing:\n\ - 1) A bullet-point list of key topics/questions covered\n\ - 2) Bullet points for all significant tools executed and their results\n\ - 3) Bullet points for any code or technical information shared\n\ - 4) A section of key insights gained\n\n\ - FORMAT THE SUMMARY IN THIRD PERSON, NOT AS A DIRECT RESPONSE. Example format:\n\n\ - ## CONVERSATION SUMMARY\n\ - * Topic 1: Key information\n\ - * Topic 2: Key information\n\n\ - ## TOOLS EXECUTED\n\ - * Tool X: Result Y\n\n\ - Remember this is a DOCUMENT not a chat response. The custom instruction above modifies what to prioritize.\n\ - FILTER OUT CHAT CONVENTIONS (greetings, offers to help, etc).", - custom_prompt.as_ref() - ) - }, - None => { - // Default prompt - "[SYSTEM NOTE: This is an automated summarization request, not from the user]\n\n\ - FORMAT REQUIREMENTS: Create a structured, concise summary in bullet-point format. DO NOT respond conversationally. DO NOT address the user directly.\n\n\ - Your task is to create a structured summary document containing:\n\ - 1) A bullet-point list of key topics/questions covered\n\ - 2) Bullet points for all significant tools executed and their results\n\ - 3) Bullet points for any code or technical information shared\n\ - 4) A section of key insights gained\n\n\ - FORMAT THE SUMMARY IN THIRD PERSON, NOT AS A DIRECT RESPONSE. Example format:\n\n\ - ## CONVERSATION SUMMARY\n\ - * Topic 1: Key information\n\ - * Topic 2: Key information\n\n\ - ## TOOLS EXECUTED\n\ - * Tool X: Result Y\n\n\ - Remember this is a DOCUMENT not a chat response.\n\ - FILTER OUT CHAT CONVENTIONS (greetings, offers to help, etc).".to_string() - }, - }; - - let conv_state = self.backend_conversation_state(false, true).await; - - // Include everything but the last message in the history. - let history_len = conv_state.history.len(); - let history = if history_len < 2 { - vec![] - } else { - flatten_history(conv_state.history.take(history_len.saturating_sub(1))) - }; - - let mut summary_message = UserInputMessage { - content: summary_content, - user_input_message_context: None, - user_intent: None, - }; - - // If the last message contains tool uses, then add cancelled tool results to the summary - // message. - if let Some(ChatMessage::AssistantResponseMessage(AssistantResponseMessage { - tool_uses: Some(tool_uses), - .. - })) = history.last() - { - self.set_cancelled_tool_results(&mut summary_message, tool_uses); - } - - FigConversationState { - conversation_id: Some(self.conversation_id.clone()), - user_input_message: summary_message, - history: Some(history), - } - } - - pub fn replace_history_with_summary(&mut self, summary: String) { - self.history.drain(..(self.history.len().saturating_sub(1))); - self.latest_summary = Some(summary); - // If the last message contains tool results, then we add the results to the content field - // instead. This is required to avoid validation errors. - // TODO: this can break since the max user content size is less than the max tool response - // size! Alternative could be to set the last tool use as part of the context messages. - if let Some((user, _)) = self.history.back_mut() { - if let Some(tool_results) = user.tool_use_results() { - let tool_content: Vec = tool_results - .iter() - .flat_map(|tr| { - tr.content.iter().map(|c| match c { - ToolUseResultBlock::Json(document) => serde_json::to_string(&document) - .map_err(|err| error!(?err, "failed to serialize tool result")) - .unwrap_or_default(), - ToolUseResultBlock::Text(s) => s.clone(), - }) - }) - .collect::<_>(); - let mut tool_content = tool_content.join(" "); - if tool_content.is_empty() { - // To avoid validation errors with empty content, we need to make sure - // something is set. - tool_content.push_str(""); - } - user.content = UserMessageContent::Prompt { prompt: tool_content }; - } - } - } - - pub fn current_profile(&self) -> Option<&str> { - if let Some(cm) = self.context_manager.as_ref() { - Some(cm.current_profile.as_str()) - } else { - None - } - } - - /// Returns pairs of user and assistant messages to include as context in the message history - /// including both summaries and context files if available. - /// - /// TODO: - /// - Either add support for multiple context messages if the context is too large to fit inside - /// a single user message, or handle this case more gracefully. For now, always return 2 - /// messages. - /// - Cache this return for some period of time. - async fn context_messages( - &mut self, - conversation_start_context: Option, - ) -> Option> { - let mut context_content = String::new(); - - if let Some(summary) = &self.latest_summary { - context_content.push_str(CONTEXT_ENTRY_START_HEADER); - context_content.push_str("This summary contains ALL relevant information from our previous conversation including tool uses, results, code analysis, and file operations. YOU MUST reference this information when answering questions and explicitly acknowledge specific details from the summary when they're relevant to the current question.\n\n"); - context_content.push_str("SUMMARY CONTENT:\n"); - context_content.push_str(summary); - context_content.push('\n'); - context_content.push_str(CONTEXT_ENTRY_END_HEADER); - } - - // Add context files if available - if let Some(context_manager) = self.context_manager.as_mut() { - match context_manager.get_context_files(true).await { - Ok(files) => { - if !files.is_empty() { - context_content.push_str(CONTEXT_ENTRY_START_HEADER); - for (filename, content) in files { - context_content.push_str(&format!("[{}]\n{}\n", filename, content)); - } - context_content.push_str(CONTEXT_ENTRY_END_HEADER); - } - }, - Err(e) => { - warn!("Failed to get context files: {}", e); - }, - } - } - - if let Some(context) = conversation_start_context { - context_content.push_str(&context); - } - - if !context_content.is_empty() { - self.context_message_length = Some(context_content.len()); - let user_msg = UserMessage::new_prompt(context_content); - let assistant_msg = AssistantMessage::new_response(None, "I will fully incorporate this information when generating my responses, and explicitly acknowledge relevant parts of the summary when answering questions.".into()); - Some(vec![(user_msg, assistant_msg)]) - } else { - None - } - } - - /// The length of the user message used as context, if any. - pub fn context_message_length(&self) -> Option { - self.context_message_length - } - - /// Calculate the total character count in the conversation - pub async fn calculate_char_count(&mut self) -> CharCount { - self.backend_conversation_state(false, true).await.char_count() - } - - /// Get the current token warning level - pub async fn get_token_warning_level(&mut self) -> TokenWarningLevel { - let total_chars = self.calculate_char_count().await; - - if *total_chars >= MAX_CHARS { - TokenWarningLevel::Critical - } else { - TokenWarningLevel::None - } - } - - pub fn append_user_transcript(&mut self, message: &str) { - self.append_transcript(format!("> {}", message.replace("\n", "> \n"))); - } - - pub fn append_assistant_transcript(&mut self, message: &AssistantMessage) { - let tool_uses = message.tool_uses().map_or("none".to_string(), |tools| { - tools.iter().map(|tool| tool.name.clone()).collect::>().join(",") - }); - self.append_transcript(format!("{}\n[Tool uses: {tool_uses}]", message.content())); - } - - pub fn append_transcript(&mut self, message: String) { - if self.transcript.len() >= MAX_CONVERSATION_STATE_HISTORY_LEN { - self.transcript.pop_front(); - } - self.transcript.push_back(message); - } - - /// Mutates `msg` so that it will contain an appropriate [UserInputMessageContext] that - /// contains "cancelled" tool results for `tool_uses`. - fn set_cancelled_tool_results(&self, msg: &mut UserInputMessage, tool_uses: &[ToolUse]) { - match msg.user_input_message_context.as_mut() { - Some(ctx) => { - if ctx.tool_results.as_ref().is_none_or(|r| r.is_empty()) { - debug!( - "last assistant message contains tool uses, but next message is set and does not contain tool results. setting tool results as cancelled" - ); - ctx.tool_results = Some( - tool_uses - .iter() - .map(|tool_use| ToolResult { - tool_use_id: tool_use.tool_use_id.clone(), - content: vec![ToolResultContentBlock::Text( - "Tool use was cancelled by the user".to_string(), - )], - status: fig_api_client::model::ToolResultStatus::Error, - }) - .collect::>(), - ); - } - }, - None => { - debug!( - "last assistant message contains tool uses, but next message is set and does not contain tool results. setting tool results as cancelled" - ); - let tool_results = tool_uses - .iter() - .map(|tool_use| ToolResult { - tool_use_id: tool_use.tool_use_id.clone(), - content: vec![ToolResultContentBlock::Text( - "Tool use was cancelled by the user".to_string(), - )], - status: fig_api_client::model::ToolResultStatus::Error, - }) - .collect::>(); - let user_input_message_context = UserInputMessageContext { - shell_state: None, - env_state: Some(build_env_state()), - tool_results: Some(tool_results), - tools: if self.tools.is_empty() { - None - } else { - Some(self.tools.values().flatten().cloned().collect::>()) - }, - ..Default::default() - }; - msg.user_input_message_context = Some(user_input_message_context); - }, - } - } -} - -/// Represents a conversation state that can be converted into a [FigConversationState] (the type -/// used by the API client). Represents borrowed data, and reflects an exact [FigConversationState] -/// that can be generated from [ConversationState] at any point in time. -/// -/// This is intended to provide us ways to accurately assess the exact state that is sent to the -/// model without having to needlessly clone and mutate [ConversationState] in strange ways. -pub type BackendConversationState<'a> = BackendConversationStateImpl< - 'a, - std::collections::vec_deque::Iter<'a, (UserMessage, AssistantMessage)>, - Option>, ->; - -/// See [BackendConversationState] -#[derive(Debug, Clone)] -pub struct BackendConversationStateImpl<'a, T, U> { - pub conversation_id: &'a str, - pub next_user_message: Option<&'a UserMessage>, - pub history: T, - pub context_messages: U, - pub tools: &'a HashMap>, -} - -impl - BackendConversationStateImpl< - '_, - std::collections::vec_deque::Iter<'_, (UserMessage, AssistantMessage)>, - Option>, - > -{ - fn into_fig_conversation_state(self) -> eyre::Result { - let history = flatten_history(self.context_messages.unwrap_or_default().iter().chain(self.history)); - let mut user_input_message: UserInputMessage = self - .next_user_message - .cloned() - .map(UserMessage::into_user_input_message) - .ok_or(eyre::eyre!("next user message is not set"))?; - if let Some(ctx) = user_input_message.user_input_message_context.as_mut() { - ctx.tools = Some(self.tools.values().flatten().cloned().collect::>()); - } - - Ok(FigConversationState { - conversation_id: Some(self.conversation_id.to_string()), - user_input_message, - history: Some(history), - }) - } - - pub fn calculate_conversation_size(&self) -> ConversationSize { - let mut user_chars = 0; - let mut assistant_chars = 0; - let mut context_chars = 0; - - // Count the chars used by the messages in the history. - // this clone is cheap - let history = self.history.clone(); - for (user, assistant) in history { - user_chars += *user.char_count(); - assistant_chars += *assistant.char_count(); - } - - // Add any chars from context messages, if available. - context_chars += self - .context_messages - .as_ref() - .map(|v| { - v.iter().fold(0, |acc, (user, assistant)| { - acc + *user.char_count() + *assistant.char_count() - }) - }) - .unwrap_or_default(); - - ConversationSize { - context_messages: context_chars.into(), - user_messages: user_chars.into(), - assistant_messages: assistant_chars.into(), - } - } -} - -/// Reflects a detailed accounting of the context window utilization for a given conversation. -#[derive(Debug, Clone, Copy)] -pub struct ConversationSize { - pub context_messages: CharCount, - pub user_messages: CharCount, - pub assistant_messages: CharCount, -} - -/// Converts a list of user/assistant message pairs into a flattened list of ChatMessage. -fn flatten_history<'a, T>(history: T) -> Vec -where - T: Iterator, -{ - history.fold(Vec::new(), |mut acc, (user, assistant)| { - acc.push(ChatMessage::UserInputMessage(user.clone().into_history_entry())); - acc.push(ChatMessage::AssistantResponseMessage(assistant.clone().into())); - acc - }) -} - -/// Character count warning levels for conversation size -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum TokenWarningLevel { - /// No warning, conversation is within normal limits - None, - /// Critical level - at single warning threshold (600K characters) - Critical, -} - -impl From for ToolInputSchema { - fn from(value: InputSchema) -> Self { - Self { - json: Some(serde_value_to_document(value.0)), - } - } -} - -fn format_hook_context<'a>(hook_results: impl IntoIterator, trigger: HookTrigger) -> String { - let mut context_content = String::new(); - - context_content.push_str(CONTEXT_ENTRY_START_HEADER); - context_content.push_str("This section (like others) contains important information that I want you to use in your responses. I have gathered this context from valuable programmatic script hooks. You must follow any requests and consider all of the information in this section"); - if trigger == HookTrigger::ConversationStart { - context_content.push_str(" for the entire conversation"); - } - context_content.push_str("\n\n"); - - for (hook, output) in hook_results.into_iter().filter(|(h, _)| h.trigger == trigger) { - context_content.push_str(&format!("'{}': {output}\n\n", &hook.name)); - } - context_content.push_str(CONTEXT_ENTRY_END_HEADER); - context_content -} - -#[cfg(test)] -mod tests { - use fig_api_client::model::{ - AssistantResponseMessage, - ToolResultStatus, - }; - - use super::super::context::{ - AMAZONQ_FILENAME, - profile_context_path, - }; - use super::super::message::AssistantToolUse; - use super::*; - use crate::tool_manager::ToolManager; - - fn assert_conversation_state_invariants(state: FigConversationState, assertion_iteration: usize) { - if let Some(Some(msg)) = state.history.as_ref().map(|h| h.first()) { - assert!( - matches!(msg, ChatMessage::UserInputMessage(_)), - "{assertion_iteration}: First message in the history must be from the user, instead found: {:?}", - msg - ); - } - if let Some(Some(msg)) = state.history.as_ref().map(|h| h.last()) { - assert!( - matches!(msg, ChatMessage::AssistantResponseMessage(_)), - "{assertion_iteration}: Last message in the history must be from the assistant, instead found: {:?}", - msg - ); - // If the last message from the assistant contains tool uses, then the next user - // message must contain tool results. - match (state.user_input_message.user_input_message_context.as_ref(), msg) { - ( - Some(ctx), - ChatMessage::AssistantResponseMessage(AssistantResponseMessage { - tool_uses: Some(tool_uses), - .. - }), - ) if !tool_uses.is_empty() => { - assert!( - ctx.tool_results.as_ref().is_some_and(|r| !r.is_empty()), - "The user input message must contain tool results when the last assistant message contains tool uses" - ); - }, - _ => {}, - } - } - - if let Some(history) = state.history.as_ref() { - for (i, msg) in history.iter().enumerate() { - // User message checks. - if let ChatMessage::UserInputMessage(user) = msg { - assert!( - user.user_input_message_context - .as_ref() - .is_none_or(|ctx| ctx.tools.is_none()), - "the tool specification should be empty for all user messages in the history" - ); - - // Check that messages with tool results are immediately preceded by an - // assistant message with tool uses. - if user - .user_input_message_context - .as_ref() - .is_some_and(|ctx| ctx.tool_results.as_ref().is_some_and(|r| !r.is_empty())) - { - match history.get(i.checked_sub(1).unwrap_or_else(|| { - panic!( - "{assertion_iteration}: first message in the history should not contain tool results" - ) - })) { - Some(ChatMessage::AssistantResponseMessage(assistant)) => { - assert!(assistant.tool_uses.is_some()); - }, - _ => panic!( - "expected an assistant response message with tool uses at index: {}", - i - 1 - ), - } - } - } - } - } - - let actual_history_len = state.history.unwrap_or_default().len(); - assert!( - actual_history_len <= MAX_CONVERSATION_STATE_HISTORY_LEN, - "history should not extend past the max limit of {}, instead found length {}", - MAX_CONVERSATION_STATE_HISTORY_LEN, - actual_history_len - ); - - let ctx = state - .user_input_message - .user_input_message_context - .as_ref() - .expect("user input message context must exist"); - assert!( - ctx.tools.is_some(), - "Currently, the tool spec must be included in the next user message" - ); - } - - #[tokio::test] - async fn test_conversation_state_history_handling_truncation() { - let mut tool_manager = ToolManager::default(); - let mut conversation_state = ConversationState::new( - Context::new_fake(), - "fake_conv_id", - tool_manager.load_tools().await.unwrap(), - None, - None, - ) - .await; - - // First, build a large conversation history. We need to ensure that the order is always - // User -> Assistant -> User -> Assistant ...and so on. - conversation_state.set_next_user_message("start".to_string()).await; - for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { - let s = conversation_state.as_sendable_conversation_state(true).await; - assert_conversation_state_invariants(s, i); - conversation_state.push_assistant_message(AssistantMessage::new_response(None, i.to_string())); - conversation_state.set_next_user_message(i.to_string()).await; - } - } - - #[tokio::test] - async fn test_conversation_state_history_handling_with_tool_results() { - // Build a long conversation history of tool use results. - let mut tool_manager = ToolManager::default(); - let mut conversation_state = ConversationState::new( - Context::new_fake(), - "fake_conv_id", - tool_manager.load_tools().await.unwrap(), - None, - None, - ) - .await; - conversation_state.set_next_user_message("start".to_string()).await; - for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { - let s = conversation_state.as_sendable_conversation_state(true).await; - assert_conversation_state_invariants(s, i); - - conversation_state.push_assistant_message(AssistantMessage::new_tool_use(None, i.to_string(), vec![ - AssistantToolUse { - id: "tool_id".to_string(), - name: "tool name".to_string(), - args: serde_json::Value::Null, - }, - ])); - conversation_state.add_tool_results(vec![ToolUseResult { - tool_use_id: "tool_id".to_string(), - content: vec![], - status: ToolResultStatus::Success, - }]); - } - - // Build a long conversation history of user messages mixed in with tool results. - let mut conversation_state = ConversationState::new( - Context::new_fake(), - "fake_conv_id", - tool_manager.load_tools().await.unwrap(), - None, - None, - ) - .await; - conversation_state.set_next_user_message("start".to_string()).await; - for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { - let s = conversation_state.as_sendable_conversation_state(true).await; - assert_conversation_state_invariants(s, i); - if i % 3 == 0 { - conversation_state.push_assistant_message(AssistantMessage::new_tool_use(None, i.to_string(), vec![ - AssistantToolUse { - id: "tool_id".to_string(), - name: "tool name".to_string(), - args: serde_json::Value::Null, - }, - ])); - conversation_state.add_tool_results(vec![ToolUseResult { - tool_use_id: "tool_id".to_string(), - content: vec![], - status: ToolResultStatus::Success, - }]); - } else { - conversation_state.push_assistant_message(AssistantMessage::new_response(None, i.to_string())); - conversation_state.set_next_user_message(i.to_string()).await; - } - } - } - - #[tokio::test] - async fn test_conversation_state_with_context_files() { - let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); - ctx.fs().write(AMAZONQ_FILENAME, "test context").await.unwrap(); - - let mut tool_manager = ToolManager::default(); - let mut conversation_state = ConversationState::new( - ctx, - "fake_conv_id", - tool_manager.load_tools().await.unwrap(), - None, - None, - ) - .await; - - // First, build a large conversation history. We need to ensure that the order is always - // User -> Assistant -> User -> Assistant ...and so on. - conversation_state.set_next_user_message("start".to_string()).await; - for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { - let s = conversation_state.as_sendable_conversation_state(true).await; - - // Ensure that the first two messages are the fake context messages. - let hist = s.history.as_ref().unwrap(); - let user = &hist[0]; - let assistant = &hist[1]; - match (user, assistant) { - (ChatMessage::UserInputMessage(user), ChatMessage::AssistantResponseMessage(_)) => { - assert!( - user.content.contains("test context"), - "expected context message to contain context file, instead found: {}", - user.content - ); - }, - _ => panic!("Expected the first two messages to be from the user and the assistant"), - } - - assert_conversation_state_invariants(s, i); - - conversation_state.push_assistant_message(AssistantMessage::new_response(None, i.to_string())); - conversation_state.set_next_user_message(i.to_string()).await; - } - } - - #[tokio::test] - async fn test_conversation_state_additional_context() { - tracing_subscriber::fmt::try_init().ok(); - - let mut tool_manager = ToolManager::default(); - let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); - let conversation_start_context = "conversation start context"; - let prompt_context = "prompt context"; - let config = serde_json::json!({ - "hooks": { - "test_per_prompt": { - "trigger": "per_prompt", - "type": "inline", - "command": format!("echo {}", prompt_context) - }, - "test_conversation_start": { - "trigger": "conversation_start", - "type": "inline", - "command": format!("echo {}", conversation_start_context) - } - } - }); - let config_path = profile_context_path(&ctx, "default").unwrap(); - ctx.fs().create_dir_all(config_path.parent().unwrap()).await.unwrap(); - ctx.fs() - .write(&config_path, serde_json::to_string(&config).unwrap()) - .await - .unwrap(); - let mut conversation_state = ConversationState::new( - ctx, - "fake_conv_id", - tool_manager.load_tools().await.unwrap(), - None, - Some(SharedWriter::stdout()), - ) - .await; - - // Simulate conversation flow - conversation_state.set_next_user_message("start".to_string()).await; - for i in 0..=5 { - let s = conversation_state.as_sendable_conversation_state(true).await; - let hist = s.history.as_ref().unwrap(); - #[allow(clippy::match_wildcard_for_single_variants)] - match &hist[0] { - ChatMessage::UserInputMessage(user) => { - assert!( - user.content.contains(conversation_start_context), - "expected to contain '{conversation_start_context}', instead found: {}", - user.content - ); - }, - _ => panic!("Expected user message."), - } - assert!( - s.user_input_message.content.contains(prompt_context), - "expected to contain '{prompt_context}', instead found: {}", - s.user_input_message.content - ); - - conversation_state.push_assistant_message(AssistantMessage::new_response(None, i.to_string())); - conversation_state.set_next_user_message(i.to_string()).await; - } - } -} diff --git a/crates/q_chat/src/hooks.rs b/crates/q_chat/src/hooks.rs deleted file mode 100644 index 036195ba17..0000000000 --- a/crates/q_chat/src/hooks.rs +++ /dev/null @@ -1,557 +0,0 @@ -use std::collections::HashMap; -use std::io::Write; -use std::process::Stdio; -use std::time::{ - Duration, - Instant, -}; - -use bstr::ByteSlice; -use crossterm::style::{ - Color, - Stylize, -}; -use crossterm::{ - cursor, - execute, - queue, - style, - terminal, -}; -use eyre::{ - Result, - eyre, -}; -use futures::stream::{ - FuturesUnordered, - StreamExt, -}; -use serde::{ - Deserialize, - Serialize, -}; -use spinners::{ - Spinner, - Spinners, -}; - -use super::util::truncate_safe; - -const DEFAULT_TIMEOUT_MS: u64 = 30_000; -const DEFAULT_MAX_OUTPUT_SIZE: usize = 1024 * 10; -const DEFAULT_CACHE_TTL_SECONDS: u64 = 0; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Hook { - pub trigger: HookTrigger, - - pub r#type: HookType, - - #[serde(default = "Hook::default_disabled")] - pub disabled: bool, - - /// Max time the hook can run before it throws a timeout error - #[serde(default = "Hook::default_timeout_ms")] - pub timeout_ms: u64, - - /// Max output size of the hook before it is truncated - #[serde(default = "Hook::default_max_output_size")] - pub max_output_size: usize, - - /// How long the hook output is cached before it will be executed again - #[serde(default = "Hook::default_cache_ttl_seconds")] - pub cache_ttl_seconds: u64, - - // Type-specific fields - /// The bash command to execute - pub command: Option, // For inline hooks - - // Internal data - #[serde(skip)] - pub name: String, - #[serde(skip)] - pub is_global: bool, -} - -impl Hook { - pub fn new_inline_hook(trigger: HookTrigger, command: String) -> Self { - Self { - trigger, - r#type: HookType::Inline, - disabled: Self::default_disabled(), - timeout_ms: Self::default_timeout_ms(), - max_output_size: Self::default_max_output_size(), - cache_ttl_seconds: Self::default_cache_ttl_seconds(), - command: Some(command), - is_global: false, - name: "new hook".to_string(), - } - } - - fn default_disabled() -> bool { - false - } - - fn default_timeout_ms() -> u64 { - DEFAULT_TIMEOUT_MS - } - - fn default_max_output_size() -> usize { - DEFAULT_MAX_OUTPUT_SIZE - } - - fn default_cache_ttl_seconds() -> u64 { - DEFAULT_CACHE_TTL_SECONDS - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] -#[serde(rename_all = "lowercase")] -pub enum HookType { - // Execute an inline shell command - Inline, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] -#[serde(rename_all = "snake_case")] -pub enum HookTrigger { - ConversationStart, - PerPrompt, -} - -#[derive(Debug, Clone)] -pub struct CachedHook { - output: String, - expiry: Option, -} - -/// Maps a hook name to a [`CachedHook`] -#[derive(Debug, Clone, Default)] -pub struct HookExecutor { - pub global_cache: HashMap, - pub profile_cache: HashMap, -} - -impl HookExecutor { - pub fn new() -> Self { - Self { - global_cache: HashMap::new(), - profile_cache: HashMap::new(), - } - } - - /// Run and cache [`Hook`]s. Any hooks that are already cached will be returned without - /// executing. Hooks that fail to execute will not be returned. - /// - /// If `updates` is `Some`, progress on hook execution will be written to it. - /// Errors encountered with write operations to `updates` are ignored. - /// - /// Note: [`HookTrigger::ConversationStart`] hooks never leave the cache. - pub async fn run_hooks(&mut self, hooks: Vec<&Hook>, mut updates: Option<&mut impl Write>) -> Vec<(Hook, String)> { - let mut results = Vec::with_capacity(hooks.len()); - let mut futures = FuturesUnordered::new(); - - // Start all hook future OR fetch from cache if available - // Why enumerate? We want to return the hook results in the order of hooks that we received, - // however, for output display we want to process hooks as they complete rather than the - // order they were started in. The index will be used later to sort them back to output order. - for (index, hook) in hooks.into_iter().enumerate() { - if hook.disabled { - continue; - } - - if let Some(cached) = self.get_cache(hook) { - results.push((index, (hook.clone(), cached.clone()))); - continue; - } - let future = self.execute_hook(hook); - futures.push(async move { (index, future.await) }); - } - - // Start caching the results added after whats already their (they are from the cache already) - let start_cache_index = results.len(); - - let mut succeeded = 0; - let total = futures.len(); - - let mut spinner = None; - let spinner_text = |complete: usize, total: usize| { - format!( - "{} of {} hooks finished", - complete.to_string().blue(), - total.to_string().blue(), - ) - }; - if total != 0 && updates.is_some() { - spinner = Some(Spinner::new(Spinners::Dots12, spinner_text(succeeded, total))); - } - - // Process results as they complete - let start_time = Instant::now(); - while let Some((index, (hook, result, duration))) = futures.next().await { - // If output is enabled, handle that first - if let Some(updates) = updates.as_deref_mut() { - if let Some(spinner) = spinner.as_mut() { - spinner.stop(); - - // Erase the spinner - let _ = execute!( - updates, - cursor::MoveToColumn(0), - terminal::Clear(terminal::ClearType::CurrentLine), - cursor::Hide, - ); - } - match &result { - Ok(_) => { - let _ = queue!( - updates, - style::SetForegroundColor(style::Color::Green), - style::Print("✓ "), - style::SetForegroundColor(style::Color::Blue), - style::Print(&hook.name), - style::ResetColor, - style::Print(" finished in "), - style::SetForegroundColor(style::Color::Yellow), - style::Print(format!("{:.2} s\n", duration.as_secs_f32())), - style::ResetColor, - ); - }, - Err(e) => { - let _ = queue!( - updates, - style::SetForegroundColor(style::Color::Red), - style::Print("✗ "), - style::SetForegroundColor(style::Color::Blue), - style::Print(&hook.name), - style::ResetColor, - style::Print(" failed after "), - style::SetForegroundColor(style::Color::Yellow), - style::Print(format!("{:.2} s", duration.as_secs_f32())), - style::ResetColor, - style::Print(format!(": {}\n", e)), - ); - }, - } - } - - // Process results regardless of output enabled - if let Ok(output) = result { - succeeded += 1; - results.push((index, (hook.clone(), output))); - } - - // Display ending summary or add a new spinner - if let Some(updates) = updates.as_deref_mut() { - // The futures set size decreases each time we process one - if futures.is_empty() { - let symbol = if total == succeeded { - "✓".to_string().green() - } else { - "✗".to_string().red() - }; - - let _ = queue!( - updates, - style::SetForegroundColor(Color::Blue), - style::Print(format!("{symbol} {} in ", spinner_text(succeeded, total))), - style::SetForegroundColor(style::Color::Yellow), - style::Print(format!("{:.2} s\n", start_time.elapsed().as_secs_f32())), - style::ResetColor, - ); - } else { - spinner = Some(Spinner::new(Spinners::Dots, spinner_text(succeeded, total))); - } - } - } - drop(futures); - - // Fill cache with executed results, skipping what was already from cache - results.iter().skip(start_cache_index).for_each(|(_, (hook, output))| { - let expiry = match hook.trigger { - HookTrigger::ConversationStart => None, - HookTrigger::PerPrompt => Some(Instant::now() + Duration::from_secs(hook.cache_ttl_seconds)), - }; - self.insert_cache(hook, CachedHook { - output: output.clone(), - expiry, - }); - }); - - // Return back to order at request start - results.sort_by_key(|(idx, _)| *idx); - results.into_iter().map(|(_, r)| r).collect() - } - - async fn execute_hook<'a>(&self, hook: &'a Hook) -> (&'a Hook, Result, Duration) { - let start_time = Instant::now(); - let result = match hook.r#type { - HookType::Inline => self.execute_inline_hook(hook).await, - }; - - (hook, result, start_time.elapsed()) - } - - async fn execute_inline_hook(&self, hook: &Hook) -> Result { - let command = hook.command.as_ref().ok_or_else(|| eyre!("no command specified"))?; - - let command_future = tokio::process::Command::new("bash") - .arg("-c") - .arg(command) - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .output(); - let timeout = Duration::from_millis(hook.timeout_ms); - - // Run with timeout - match tokio::time::timeout(timeout, command_future).await { - Ok(result) => { - let result = result?; - if result.status.success() { - let stdout = result.stdout.to_str_lossy(); - let stdout = format!( - "{}{}", - truncate_safe(&stdout, hook.max_output_size), - if stdout.len() > hook.max_output_size { - " ... truncated" - } else { - "" - } - ); - Ok(stdout) - } else { - Err(eyre!("command returned non-zero exit code: {}", result.status)) - } - }, - Err(_) => Err(eyre!("command timed out after {} ms", timeout.as_millis())), - } - } - - /// Will return a cached hook's output if it exists and isn't expired. - fn get_cache(&self, hook: &Hook) -> Option { - let cache = if hook.is_global { - &self.global_cache - } else { - &self.profile_cache - }; - - cache.get(&hook.name).and_then(|o| { - if let Some(expiry) = o.expiry { - if Instant::now() < expiry { - Some(o.output.clone()) - } else { - None - } - } else { - Some(o.output.clone()) - } - }) - } - - fn insert_cache(&mut self, hook: &Hook, hook_output: CachedHook) { - let cache = if hook.is_global { - &mut self.global_cache - } else { - &mut self.profile_cache - }; - - cache.insert(hook.name.clone(), hook_output); - } -} - -#[cfg(test)] -mod tests { - use std::io::Stdout; - use std::time::Duration; - - use tokio::time::sleep; - - use super::*; - - #[test] - fn test_hook_creation() { - let command = "echo 'hello'"; - let hook = Hook::new_inline_hook(HookTrigger::PerPrompt, command.to_string()); - - assert_eq!(hook.r#type, HookType::Inline); - assert!(!hook.disabled); - assert_eq!(hook.timeout_ms, DEFAULT_TIMEOUT_MS); - assert_eq!(hook.max_output_size, DEFAULT_MAX_OUTPUT_SIZE); - assert_eq!(hook.cache_ttl_seconds, DEFAULT_CACHE_TTL_SECONDS); - assert_eq!(hook.command, Some(command.to_string())); - assert_eq!(hook.trigger, HookTrigger::PerPrompt); - assert!(!hook.is_global); - } - - #[tokio::test] - async fn test_hook_executor_cached_conversation_start() { - let mut executor = HookExecutor::new(); - let mut hook1 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo 'test1'".to_string()); - hook1.is_global = true; - - let mut hook2 = Hook::new_inline_hook(HookTrigger::ConversationStart, "echo 'test2'".to_string()); - hook2.is_global = false; - - // First execution should run the command - let mut output = Vec::new(); - let results = executor.run_hooks(vec![&hook1, &hook2], Some(&mut output)).await; - - assert_eq!(results.len(), 2); - assert!(results[0].1.contains("test1")); - assert!(results[1].1.contains("test2")); - assert!(!output.is_empty()); - - // Second execution should use cache - let mut output = Vec::new(); - let results = executor.run_hooks(vec![&hook1, &hook2], Some(&mut output)).await; - - assert_eq!(results.len(), 2); - assert!(results[0].1.contains("test1")); - assert!(results[1].1.contains("test2")); - assert!(output.is_empty()); // Should not have run the hook, so no output. - } - - #[tokio::test] - async fn test_hook_executor_cached_per_prompt() { - let mut executor = HookExecutor::new(); - let mut hook1 = Hook::new_inline_hook(HookTrigger::PerPrompt, "echo 'test1'".to_string()); - hook1.is_global = true; - hook1.cache_ttl_seconds = 60; - - let mut hook2 = Hook::new_inline_hook(HookTrigger::PerPrompt, "echo 'test2'".to_string()); - hook2.is_global = false; - hook2.cache_ttl_seconds = 60; - - // First execution should run the command - let mut output = Vec::new(); - let results = executor.run_hooks(vec![&hook1, &hook2], Some(&mut output)).await; - - assert_eq!(results.len(), 2); - assert!(results[0].1.contains("test1")); - assert!(results[1].1.contains("test2")); - assert!(!output.is_empty()); - - // Second execution should use cache - let mut output = Vec::new(); - let results = executor.run_hooks(vec![&hook1, &hook2], Some(&mut output)).await; - - assert_eq!(results.len(), 2); - assert!(results[0].1.contains("test1")); - assert!(results[1].1.contains("test2")); - assert!(output.is_empty()); // Should not have run the hook, so no output. - } - - #[tokio::test] - async fn test_hook_executor_not_cached_per_prompt() { - let mut executor = HookExecutor::new(); - let mut hook1 = Hook::new_inline_hook(HookTrigger::PerPrompt, "echo 'test1'".to_string()); - hook1.is_global = true; - - let mut hook2 = Hook::new_inline_hook(HookTrigger::PerPrompt, "echo 'test2'".to_string()); - hook2.is_global = false; - - // First execution should run the command - let mut output = Vec::new(); - let results = executor.run_hooks(vec![&hook1, &hook2], Some(&mut output)).await; - - assert_eq!(results.len(), 2); - assert!(results[0].1.contains("test1")); - assert!(results[1].1.contains("test2")); - assert!(!output.is_empty()); - - // Second execution should use cache - let mut output = Vec::new(); - let results = executor.run_hooks(vec![&hook1, &hook2], Some(&mut output)).await; - - assert_eq!(results.len(), 2); - assert!(results[0].1.contains("test1")); - assert!(results[1].1.contains("test2")); - assert!(!output.is_empty()); - } - - #[tokio::test] - async fn test_hook_timeout() { - let mut executor = HookExecutor::new(); - let mut hook = Hook::new_inline_hook(HookTrigger::PerPrompt, "sleep 2".to_string()); - hook.timeout_ms = 100; // Set very short timeout - - let results = executor.run_hooks(vec![&hook], None::<&mut Stdout>).await; - - assert_eq!(results.len(), 0); // Should fail due to timeout - } - - #[tokio::test] - async fn test_disabled_hook() { - let mut executor = HookExecutor::new(); - let mut hook = Hook::new_inline_hook(HookTrigger::PerPrompt, "echo 'test'".to_string()); - hook.disabled = true; - - let results = executor.run_hooks(vec![&hook], None::<&mut Stdout>).await; - - assert_eq!(results.len(), 0); // Disabled hook should not run - } - - #[tokio::test] - async fn test_cache_expiration() { - let mut executor = HookExecutor::new(); - let mut hook = Hook::new_inline_hook(HookTrigger::PerPrompt, "echo 'test'".to_string()); - hook.cache_ttl_seconds = 1; - - // First execution - let results1 = executor.run_hooks(vec![&hook], None::<&mut Stdout>).await; - assert_eq!(results1.len(), 1); - - // Wait for cache to expire - sleep(Duration::from_millis(1001)).await; - - // Second execution should run command again - let results2 = executor.run_hooks(vec![&hook], None::<&mut Stdout>).await; - assert_eq!(results2.len(), 1); - } - - #[test] - fn test_hook_cache_storage() { - let mut executor: HookExecutor = HookExecutor::new(); - let hook = Hook::new_inline_hook(HookTrigger::PerPrompt, "".to_string()); - - let cached_hook = CachedHook { - output: "test output".to_string(), - expiry: None, - }; - - executor.insert_cache(&hook, cached_hook.clone()); - - assert_eq!(executor.get_cache(&hook), Some("test output".to_string())); - } - - #[test] - fn test_hook_cache_storage_expired() { - let mut executor: HookExecutor = HookExecutor::new(); - let hook = Hook::new_inline_hook(HookTrigger::PerPrompt, "".to_string()); - - let cached_hook = CachedHook { - output: "test output".to_string(), - expiry: Some(Instant::now()), - }; - - executor.insert_cache(&hook, cached_hook.clone()); - - // Item should not return since it is expired - assert_eq!(executor.get_cache(&hook), None); - } - - #[tokio::test] - async fn test_max_output_size() { - let mut executor = HookExecutor::new(); - let mut hook = Hook::new_inline_hook( - HookTrigger::PerPrompt, - "for i in {1..1000}; do echo $i; done".to_string(), - ); - hook.max_output_size = 100; - - let results = executor.run_hooks(vec![&hook], None::<&mut Stdout>).await; - - assert!(results[0].1.len() <= hook.max_output_size + " ... truncated".len()); - } -} diff --git a/crates/q_chat/src/input_source.rs b/crates/q_chat/src/input_source.rs deleted file mode 100644 index d7bd446a77..0000000000 --- a/crates/q_chat/src/input_source.rs +++ /dev/null @@ -1,107 +0,0 @@ -use std::sync::Arc; - -use eyre::Result; -use rustyline::error::ReadlineError; -use rustyline::{ - EventHandler, - KeyEvent, -}; - -use super::context::ContextManager; -use super::prompt::rl; -use super::skim_integration::SkimCommandSelector; - -#[derive(Debug)] -pub struct InputSource(inner::Inner); - -mod inner { - use rustyline::Editor; - use rustyline::history::FileHistory; - - use super::super::prompt::ChatHelper; - - #[derive(Debug)] - pub enum Inner { - Readline(Editor), - #[allow(dead_code)] - Mock { - index: usize, - lines: Vec, - }, - } -} - -impl InputSource { - pub fn new( - sender: std::sync::mpsc::Sender>, - receiver: std::sync::mpsc::Receiver>, - ) -> Result { - Ok(Self(inner::Inner::Readline(rl(sender, receiver)?))) - } - - pub fn put_skim_command_selector(&mut self, context_manager: Arc, tool_names: Vec) { - if let inner::Inner::Readline(rl) = &mut self.0 { - let key_char = match fig_settings::settings::get_string_opt("chat.skimCommandKey").as_deref() { - Some(key) if key.len() == 1 => key.chars().next().unwrap_or('s'), - _ => 's', // Default to 's' if setting is missing or invalid - }; - rl.bind_sequence( - KeyEvent::ctrl(key_char), - EventHandler::Conditional(Box::new(SkimCommandSelector::new(context_manager, tool_names))), - ); - } - } - - #[allow(dead_code)] - pub fn new_mock(lines: Vec) -> Self { - Self(inner::Inner::Mock { index: 0, lines }) - } - - pub fn read_line(&mut self, prompt: Option<&str>) -> Result, ReadlineError> { - match &mut self.0 { - inner::Inner::Readline(rl) => { - let prompt = prompt.unwrap_or_default(); - let curr_line = rl.readline(prompt); - match curr_line { - Ok(line) => { - let _ = rl.add_history_entry(line.as_str()); - Ok(Some(line)) - }, - Err(ReadlineError::Interrupted | ReadlineError::Eof) => Ok(None), - Err(err) => Err(err), - } - }, - inner::Inner::Mock { index, lines } => { - *index += 1; - Ok(lines.get(*index - 1).cloned()) - }, - } - } - - // We're keeping this method for potential future use - #[allow(dead_code)] - pub fn set_buffer(&mut self, content: &str) { - if let inner::Inner::Readline(rl) = &mut self.0 { - // Add to history so user can access it with up arrow - let _ = rl.add_history_entry(content); - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_mock_input_source() { - let l1 = "Hello,".to_string(); - let l2 = "Line 2".to_string(); - let l3 = "World!".to_string(); - let mut input = InputSource::new_mock(vec![l1.clone(), l2.clone(), l3.clone()]); - - assert_eq!(input.read_line(None).unwrap().unwrap(), l1); - assert_eq!(input.read_line(None).unwrap().unwrap(), l2); - assert_eq!(input.read_line(None).unwrap().unwrap(), l3); - assert!(input.read_line(None).unwrap().is_none()); - } -} diff --git a/crates/q_chat/src/lib.rs b/crates/q_chat/src/lib.rs deleted file mode 100644 index bf19615045..0000000000 --- a/crates/q_chat/src/lib.rs +++ /dev/null @@ -1,3848 +0,0 @@ -pub mod cli; -mod command; -mod consts; -mod context; -mod conversation_state; -mod hooks; -mod input_source; -mod message; -mod parse; -mod parser; -mod prompt; -mod skim_integration; -mod token_counter; -mod tool_manager; -mod tools; -pub mod util; - -use std::borrow::Cow; -use std::collections::{ - HashMap, - HashSet, - VecDeque, -}; -use std::io::{ - IsTerminal, - Read, - Write, -}; -use std::process::{ - Command as ProcessCommand, - ExitCode, -}; -use std::sync::Arc; -use std::time::Duration; -use std::{ - env, - fs, -}; - -use command::{ - Command, - PromptsSubcommand, - ToolsSubcommand, -}; -use consts::CONTEXT_WINDOW_SIZE; -use context::ContextManager; -use conversation_state::{ - ConversationState, - TokenWarningLevel, -}; -use crossterm::style::{ - Attribute, - Color, - Stylize, -}; -use crossterm::{ - cursor, - execute, - queue, - style, - terminal, -}; -use eyre::{ - ErrReport, - Result, - bail, -}; -use fig_api_client::StreamingClient; -use fig_api_client::clients::SendMessageOutput; -use fig_api_client::model::{ - ChatResponseStream, - Tool as FigTool, - ToolResultStatus, -}; -use fig_os_shim::Context; -use fig_settings::keys::UPDATE_AVAILABLE_KEY; -use fig_settings::{ - Settings, - State, -}; -use fig_util::CLI_BINARY_NAME; -use hooks::{ - Hook, - HookTrigger, -}; -use message::{ - AssistantMessage, - AssistantToolUse, - ToolUseResult, - ToolUseResultBlock, -}; -use rand::distr::{ - Alphanumeric, - SampleString, -}; -use semver::Version; - -/// Help text for the compact command -fn compact_help_text() -> String { - color_print::cformat!( - r#" -Conversation Compaction - -The /compact command summarizes the conversation history to free up context space -while preserving essential information. This is useful for long-running conversations -that may eventually reach memory constraints. - -Usage - /compact Summarize the conversation and clear history - /compact [prompt] Provide custom guidance for summarization - -When to use -• When you see the memory constraint warning message -• When a conversation has been running for a long time -• Before starting a new topic within the same session -• After completing complex tool operations - -How it works -• Creates an AI-generated summary of your conversation -• Retains key information, code, and tool executions in the summary -• Clears the conversation history to free up space -• The assistant will reference the summary context in future responses -"# - ) -} -use input_source::InputSource; -use mcp_client::{ - Prompt, - PromptGetResult, -}; -use parse::{ - ParseState, - interpret_markdown, -}; -use parser::{ - RecvErrorKind, - ResponseParser, -}; -use regex::Regex; -use serde_json::Map; -use spinners::{ - Spinner, - Spinners, -}; -use thiserror::Error; -use token_counter::{ - TokenCount, - TokenCounter, -}; -use tokio::signal::unix::{ - SignalKind, - signal, -}; -use tool_manager::{ - GetPromptError, - McpServerConfig, - PromptBundle, - ToolManager, - ToolManagerBuilder, -}; -use tools::gh_issue::GhIssueContext; -use tools::{ - QueuedTool, - Tool, - ToolPermissions, - ToolSpec, -}; -use tracing::{ - debug, - error, - info, - trace, - warn, -}; -use unicode_width::UnicodeWidthStr; -use util::{ - animate_output, - play_notification_bell, - region_check, -}; -use uuid::Uuid; -use winnow::Partial; -use winnow::stream::Offset; - -use crate::util::shared_writer::SharedWriter; -use crate::util::ui::draw_box; - -const WELCOME_TEXT: &str = color_print::cstr! {" -Welcome to - - █████╗ ███╗ ███╗ █████╗ ███████╗ ██████╗ ███╗ ██╗ ██████╗ -██╔══██╗████╗ ████║██╔══██╗╚══███╔╝██╔═══██╗████╗ ██║ ██╔═══██╗ -███████║██╔████╔██║███████║ ███╔╝ ██║ ██║██╔██╗ ██║ ██║ ██║ -██╔══██║██║╚██╔╝██║██╔══██║ ███╔╝ ██║ ██║██║╚██╗██║ ██║▄▄ ██║ -██║ ██║██║ ╚═╝ ██║██║ ██║███████╗╚██████╔╝██║ ╚████║ ╚██████╔╝ -╚═╝ ╚═╝╚═╝ ╚═╝╚═╝ ╚═╝╚══════╝ ╚═════╝ ╚═╝ ╚═══╝ ╚══▀▀═╝ - -"}; - -const SMALL_SCREEN_WECLOME_TEXT: &str = color_print::cstr! {" -Welcome to Amazon Q! -"}; - -const ROTATING_TIPS: [&str; 9] = [ - color_print::cstr! {"Get notified whenever Q CLI finishes responding. Just run q settings chat.enableNotifications true"}, - color_print::cstr! {"You can use /editor to edit your prompt with a vim-like experience"}, - color_print::cstr! {"You can execute bash commands by typing ! followed by the command"}, - color_print::cstr! {"Q can use tools without asking for confirmation every time. Give /tools trust a try"}, - color_print::cstr! {"You can programmatically inject context to your prompts by using hooks. Check out /context hooks help"}, - color_print::cstr! {"You can use /compact to replace the conversation history with its summary to free up the context space"}, - color_print::cstr! {"/usage shows you a visual breakdown of your current context window usage"}, - color_print::cstr! {"If you want to file an issue to the Q CLI team, just tell me, or run q issue"}, - color_print::cstr! {"You can enable custom tools with MCP servers. Learn more with /help"}, -]; - -const GREETING_BREAK_POINT: usize = 67; - -const POPULAR_SHORTCUTS: &str = color_print::cstr! {" - -/help all commands ctrl + j new lines ctrl + s fuzzy search -"}; - -const SMALL_SCREEN_POPULAR_SHORTCUTS: &str = color_print::cstr! {" - -/help all commands -ctrl + j new lines -ctrl + s fuzzy search - -"}; -const HELP_TEXT: &str = color_print::cstr! {" - -q (Amazon Q Chat) - -Commands: -/clear Clear the conversation history -/issue Report an issue or make a feature request -/editor Open $EDITOR (defaults to vi) to compose a prompt -/help Show this help dialogue -/quit Quit the application -/compact Summarize the conversation to free up context space - help Show help for the compact command - [prompt] Optional custom prompt to guide summarization -/tools View and manage tools and permissions - help Show an explanation for the trust command - trust Trust a specific tool or tools for the session - untrust Revert a tool or tools to per-request confirmation - trustall Trust all tools (equivalent to deprecated /acceptall) - reset Reset all tools to default permission levels -/profile Manage profiles - help Show profile help - list List profiles - set Set the current profile - create Create a new profile - delete Delete a profile - rename Rename a profile -/prompts View and retrieve prompts - help Show prompts help - list List or search available prompts - get Retrieve and send a prompt -/context Manage context files and hooks for the chat session - help Show context help - show Display current context rules configuration [--expand] - add Add file(s) to context [--global] [--force] - rm Remove file(s) from context [--global] - clear Clear all files from current context [--global] - hooks View and manage context hooks -/usage Show current session's context window usage - -MCP: -You can now configure the Amazon Q CLI to use MCP servers. \nLearn how: https://docs.aws.amazon.com/en_us/amazonq/latest/qdeveloper-ug/command-line-mcp.html - -Tips: -!{command} Quickly execute a command in your current session -Ctrl(^) + j Insert new-line to provide multi-line prompt. Alternatively, [Alt(⌥) + Enter(⏎)] -Ctrl(^) + s Fuzzy search commands and context files. Use Tab to select multiple items. - Change the keybind to ctrl+x with: q settings chat.skimCommandKey x (where x is any key) - -"}; - -const RESPONSE_TIMEOUT_CONTENT: &str = "Response timed out - message took too long to generate"; -const TRUST_ALL_TEXT: &str = color_print::cstr! {"All tools are now trusted (!). Amazon Q will execute tools without asking for confirmation.\ -\nAgents can sometimes do unexpected things so understand the risks. -\nLearn more at https://docs.aws.amazon.com/amazonq/latest/qdeveloper-ug/command-line-chat-security.html#command-line-chat-trustall-safety"}; - -const TOOL_BULLET: &str = " ● "; -const CONTINUATION_LINE: &str = " ⋮ "; - -pub async fn launch_chat(args: cli::Chat) -> Result { - let trust_tools = args.trust_tools.map(|mut tools| { - if tools.len() == 1 && tools[0].is_empty() { - tools.pop(); - } - tools - }); - chat( - args.input, - args.no_interactive, - args.accept_all, - args.profile, - args.trust_all_tools, - trust_tools, - ) - .await -} - -pub async fn chat( - input: Option, - no_interactive: bool, - accept_all: bool, - profile: Option, - trust_all_tools: bool, - trust_tools: Option>, -) -> Result { - if !fig_util::system_info::in_cloudshell() && !fig_auth::is_logged_in().await { - bail!( - "You are not logged in, please log in with {}", - format!("{CLI_BINARY_NAME} login",).bold() - ); - } - - region_check("chat")?; - - let ctx = Context::new(); - - let stdin = std::io::stdin(); - // no_interactive flag or part of a pipe - let interactive = !no_interactive && stdin.is_terminal(); - let input = if !interactive && !stdin.is_terminal() { - // append to input string any extra info that was provided, e.g. via pipe - let mut input = input.unwrap_or_default(); - stdin.lock().read_to_string(&mut input)?; - Some(input) - } else { - input - }; - - let mut output = match interactive { - true => SharedWriter::stderr(), - false => SharedWriter::stdout(), - }; - - let client = match ctx.env().get("Q_MOCK_CHAT_RESPONSE") { - Ok(json) => create_stream(serde_json::from_str(std::fs::read_to_string(json)?.as_str())?), - _ => StreamingClient::new().await?, - }; - - let mcp_server_configs = match McpServerConfig::load_config(&mut output).await { - Ok(config) => { - execute!( - output, - style::Print( - "To learn more about MCP safety, see https://docs.aws.amazon.com/amazonq/latest/qdeveloper-ug/command-line-mcp-security.html\n" - ) - )?; - config - }, - Err(e) => { - warn!("No mcp server config loaded: {}", e); - McpServerConfig::default() - }, - }; - - // If profile is specified, verify it exists before starting the chat - if let Some(ref profile_name) = profile { - // Create a temporary context manager to check if the profile exists - match ContextManager::new(Arc::clone(&ctx)).await { - Ok(context_manager) => { - let profiles = context_manager.list_profiles().await?; - if !profiles.contains(profile_name) { - bail!( - "Profile '{}' does not exist. Available profiles: {}", - profile_name, - profiles.join(", ") - ); - } - }, - Err(e) => { - warn!("Failed to initialize context manager to verify profile: {}", e); - // Continue without verification if context manager can't be initialized - }, - } - } - - let conversation_id = Alphanumeric.sample_string(&mut rand::rng(), 9); - info!(?conversation_id, "Generated new conversation id"); - let (prompt_request_sender, prompt_request_receiver) = std::sync::mpsc::channel::>(); - let (prompt_response_sender, prompt_response_receiver) = std::sync::mpsc::channel::>(); - let mut tool_manager = ToolManagerBuilder::default() - .mcp_server_config(mcp_server_configs) - .prompt_list_sender(prompt_response_sender) - .prompt_list_receiver(prompt_request_receiver) - .conversation_id(&conversation_id) - .build()?; - let tool_config = tool_manager.load_tools().await?; - let mut tool_permissions = ToolPermissions::new(tool_config.len()); - if accept_all || trust_all_tools { - for tool in tool_config.values() { - tool_permissions.trust_tool(&tool.name); - } - - // Deprecation notice for --accept-all users - if accept_all && interactive { - queue!( - output, - style::SetForegroundColor(Color::Yellow), - style::Print("\n--accept-all, -a is deprecated. Use --trust-all-tools instead."), - style::SetForegroundColor(Color::Reset), - )?; - } - } else if let Some(trusted) = trust_tools.map(|vec| vec.into_iter().collect::>()) { - // --trust-all-tools takes precedence over --trust-tools=... - for tool in tool_config.values() { - if trusted.contains(&tool.name) { - tool_permissions.trust_tool(&tool.name); - } else { - tool_permissions.untrust_tool(&tool.name); - } - } - } - - let mut chat = ChatContext::new( - ctx, - &conversation_id, - Settings::new(), - State::new(), - output, - input, - InputSource::new(prompt_request_sender, prompt_response_receiver)?, - interactive, - client, - || terminal::window_size().map(|s| s.columns.into()).ok(), - tool_manager, - profile, - tool_config, - tool_permissions, - ) - .await?; - - let result = chat.try_chat().await.map(|_| ExitCode::SUCCESS); - drop(chat); // Explicit drop for clarity - - result -} - -/// Enum used to denote the origin of a tool use event -enum ToolUseStatus { - /// Variant denotes that the tool use event associated with chat context is a direct result of - /// a user request - Idle, - /// Variant denotes that the tool use event associated with the chat context is a result of a - /// retry for one or more previously attempted tool use. The tuple is the utterance id - /// associated with the original user request that necessitated the tool use - RetryInProgress(String), -} - -#[derive(Debug, Error)] -pub enum ChatError { - #[error("{0}")] - Client(#[from] fig_api_client::Error), - #[error("{0}")] - ResponseStream(#[from] parser::RecvError), - #[error("{0}")] - Std(#[from] std::io::Error), - #[error("{0}")] - Readline(#[from] rustyline::error::ReadlineError), - #[error("{0}")] - Custom(Cow<'static, str>), - #[error("interrupted")] - Interrupted { tool_uses: Option> }, - #[error( - "Tool approval required but --no-interactive was specified. Use --trust-all-tools to automatically approve tools." - )] - NonInteractiveToolApproval, - #[error(transparent)] - GetPromptError(#[from] GetPromptError), -} - -pub struct ChatContext { - ctx: Arc, - settings: Settings, - /// The [State] to use for the chat context. - state: State, - /// The [Write] destination for printing conversation text. - output: SharedWriter, - initial_input: Option, - input_source: InputSource, - interactive: bool, - /// The client to use to interact with the model. - client: StreamingClient, - /// Width of the terminal, required for [ParseState]. - terminal_width_provider: fn() -> Option, - spinner: Option, - /// [ConversationState]. - conversation_state: ConversationState, - /// State to track tools that need confirmation. - tool_permissions: ToolPermissions, - /// Telemetry events to be sent as part of the conversation. - tool_use_telemetry_events: HashMap, - /// State used to keep track of tool use relation - tool_use_status: ToolUseStatus, - /// Abstraction that consolidates custom tools with native ones - tool_manager: ToolManager, - /// Any failed requests that could be useful for error report/debugging - failed_request_ids: Vec, - /// Pending prompts to be sent - pending_prompts: VecDeque, -} - -impl ChatContext { - #[allow(clippy::too_many_arguments)] - pub async fn new( - ctx: Arc, - conversation_id: &str, - settings: Settings, - state: State, - output: SharedWriter, - input: Option, - input_source: InputSource, - interactive: bool, - client: StreamingClient, - terminal_width_provider: fn() -> Option, - tool_manager: ToolManager, - profile: Option, - tool_config: HashMap, - tool_permissions: ToolPermissions, - ) -> Result { - let ctx_clone = Arc::clone(&ctx); - let output_clone = output.clone(); - let conversation_state = - ConversationState::new(ctx_clone, conversation_id, tool_config, profile, Some(output_clone)).await; - Ok(Self { - ctx, - settings, - state, - output, - initial_input: input, - input_source, - interactive, - client, - terminal_width_provider, - spinner: None, - tool_permissions, - conversation_state, - tool_use_telemetry_events: HashMap::new(), - tool_use_status: ToolUseStatus::Idle, - tool_manager, - failed_request_ids: Vec::new(), - pending_prompts: VecDeque::new(), - }) - } -} - -impl Drop for ChatContext { - fn drop(&mut self) { - if let Some(spinner) = &mut self.spinner { - spinner.stop(); - } - - if self.interactive { - queue!( - self.output, - cursor::MoveToColumn(0), - style::SetAttribute(Attribute::Reset), - style::ResetColor, - cursor::Show - ) - .ok(); - } - - self.output.flush().ok(); - } -} - -/// The chat execution state. -/// -/// Intended to provide more robust handling around state transitions while dealing with, e.g., -/// tool validation, execution, response stream handling, etc. -#[derive(Debug)] -enum ChatState { - /// Prompt the user with `tool_uses`, if available. - PromptUser { - /// Tool uses to present to the user. - tool_uses: Option>, - /// Tracks the next tool in tool_uses that needs user acceptance. - pending_tool_index: Option, - /// Used to avoid displaying the tool info at inappropriate times, e.g. after clear or help - /// commands. - skip_printing_tools: bool, - }, - /// Handle the user input, depending on if any tools require execution. - HandleInput { - input: String, - tool_uses: Option>, - pending_tool_index: Option, - }, - /// Validate the list of tool uses provided by the model. - ValidateTools(Vec), - /// Execute the list of tools. - ExecuteTools(Vec), - /// Consume the response stream and display to the user. - HandleResponseStream(SendMessageOutput), - /// Compact the chat history. - CompactHistory { - tool_uses: Option>, - pending_tool_index: Option, - /// Custom prompt to include as part of history compaction. - prompt: Option, - /// Whether or not the summary should be shown on compact success. - show_summary: bool, - /// Whether or not to show the /compact help text. - help: bool, - }, - /// Exit the chat. - Exit, -} - -impl Default for ChatState { - fn default() -> Self { - Self::PromptUser { - tool_uses: None, - pending_tool_index: None, - skip_printing_tools: false, - } - } -} - -impl ChatContext { - /// Opens the user's preferred editor to compose a prompt - fn open_editor(initial_text: Option) -> Result { - // Create a temporary file with a unique name - let temp_dir = std::env::temp_dir(); - let file_name = format!("q_prompt_{}.md", Uuid::new_v4()); - let temp_file_path = temp_dir.join(file_name); - - // Get the editor from environment variable or use a default - let editor_cmd = env::var("EDITOR").unwrap_or_else(|_| "vi".to_string()); - - // Parse the editor command to handle arguments - let mut parts = - shlex::split(&editor_cmd).ok_or_else(|| ChatError::Custom("Failed to parse EDITOR command".into()))?; - - if parts.is_empty() { - return Err(ChatError::Custom("EDITOR environment variable is empty".into())); - } - - let editor_bin = parts.remove(0); - - // Write initial content to the file if provided - let initial_content = initial_text.unwrap_or_default(); - fs::write(&temp_file_path, &initial_content) - .map_err(|e| ChatError::Custom(format!("Failed to create temporary file: {}", e).into()))?; - - // Open the editor with the parsed command and arguments - let mut cmd = ProcessCommand::new(editor_bin); - // Add any arguments that were part of the EDITOR variable - for arg in parts { - cmd.arg(arg); - } - // Add the file path as the last argument - let status = cmd - .arg(&temp_file_path) - .status() - .map_err(|e| ChatError::Custom(format!("Failed to open editor: {}", e).into()))?; - - if !status.success() { - return Err(ChatError::Custom("Editor exited with non-zero status".into())); - } - - // Read the content back - let content = fs::read_to_string(&temp_file_path) - .map_err(|e| ChatError::Custom(format!("Failed to read temporary file: {}", e).into()))?; - - // Clean up the temporary file - let _ = fs::remove_file(&temp_file_path); - - Ok(content.trim().to_string()) - } - - fn check_for_updates(&mut self) { - let exe_path = match std::env::current_exe().and_then(|p| p.canonicalize()) { - Ok(path) => path, - Err(_) => return, // Early return if we can't get the executable path - }; - - if let Some(exe_parent) = exe_path.parent() { - let local_bin = match fig_util::directories::home_local_bin().map(|p| p.canonicalize()) { - Ok(path) => path, - Err(_) => return, - }; - - if let Ok(local_bin) = local_bin { - if exe_parent != local_bin { - let _ = self.state.remove_value(UPDATE_AVAILABLE_KEY); - return; - } - } - } - - tokio::spawn(async { - let result = - tokio::time::timeout(std::time::Duration::from_secs(3), fig_install::check_for_updates(false)).await; - - match result { - Ok(Ok(Some(new_package))) => { - if let Err(err) = - fig_settings::state::set_value(UPDATE_AVAILABLE_KEY, new_package.version.to_string()) - { - warn!(?err, "Error setting {UPDATE_AVAILABLE_KEY}: {err}"); - } - }, - Ok(Ok(None)) => {}, - Ok(Err(err)) => { - warn!(?err, "Error checking for updates: {err}"); - }, - Err(_) => { - warn!("Update check timed out"); - }, - } - }); - } - - async fn try_chat(&mut self) -> Result<()> { - self.check_for_updates(); - - let is_small_screen = self.terminal_width() < GREETING_BREAK_POINT; - if self.interactive && self.settings.get_bool_or("chat.greeting.enabled", true) { - execute!( - self.output, - style::Print(if is_small_screen { - SMALL_SCREEN_WECLOME_TEXT - } else { - WELCOME_TEXT - }), - style::Print("\n\n"), - )?; - - let current_tip_index = - (self.state.get_int_or("chat.greeting.rotating_tips_current_index", 0) as usize) % ROTATING_TIPS.len(); - - let tip = ROTATING_TIPS[current_tip_index]; - if is_small_screen { - // If the screen is small, print the tip in a single line - execute!( - self.output, - style::Print("💡 ".to_string()), - style::Print(tip), - style::Print("\n") - )?; - } else { - draw_box( - self.output.clone(), - "Did you know?", - tip, - GREETING_BREAK_POINT, - Color::DarkGrey, - )?; - } - - execute!( - self.output, - style::Print(if is_small_screen { - SMALL_SCREEN_POPULAR_SHORTCUTS - } else { - POPULAR_SHORTCUTS - }), - style::Print( - "━" - .repeat(if is_small_screen { 0 } else { GREETING_BREAK_POINT }) - .dark_grey() - ) - )?; - execute!(self.output, style::Print("\n"), style::SetForegroundColor(Color::Reset))?; - - // update the current tip index - let next_tip_index = (current_tip_index + 1) % ROTATING_TIPS.len(); - self.state - .set_value("chat.greeting.rotating_tips_current_index", next_tip_index)?; - } - - match self.state.get_string(UPDATE_AVAILABLE_KEY) { - Ok(Some(version)) => match Version::parse(&version) { - Ok(version) => { - let current_version = Version::parse(env!("CARGO_PKG_VERSION")).unwrap(); - if version > current_version { - execute!(self.output, style::Print("\n"), style::SetForegroundColor(Color::Reset))?; - let content = format!("Run {} to update to the latest version", "q update".dark_green().bold()); - - if is_small_screen { - queue!( - self.output, - style::Print("🎉 New Update: "), - style::Print(content), - style::Print("\n") - )?; - } else { - draw_box( - self.output.clone(), - "New Update!", - &content, - GREETING_BREAK_POINT, - Color::DarkYellow, - )?; - } - execute!(self.output, style::Print("\n"), style::SetForegroundColor(Color::Reset))?; - } - }, - Err(err) => { - warn!(?err, "Error parsing {UPDATE_AVAILABLE_KEY}: {err}"); - let _ = fig_settings::state::remove_value(UPDATE_AVAILABLE_KEY); - }, - }, - Ok(None) => {}, - Err(err) => { - warn!(?err, "Error getting {UPDATE_AVAILABLE_KEY}: {err}"); - let _ = fig_settings::state::remove_value(UPDATE_AVAILABLE_KEY); - }, - } - - if self.interactive && self.all_tools_trusted() { - queue!( - self.output, - style::Print(format!( - "{}{TRUST_ALL_TEXT}\n\n", - if !is_small_screen { "\n" } else { "" } - )) - )?; - } - self.output.flush()?; - - let mut ctrl_c_stream = signal(SignalKind::interrupt())?; - - let mut next_state = Some(ChatState::PromptUser { - tool_uses: None, - pending_tool_index: None, - skip_printing_tools: true, - }); - - if let Some(user_input) = self.initial_input.take() { - if self.interactive { - execute!( - self.output, - style::SetAttribute(Attribute::Reset), - style::SetForegroundColor(Color::Magenta), - style::Print("> "), - style::SetAttribute(Attribute::Reset), - style::Print(&user_input), - style::Print("\n") - )?; - } - next_state = Some(ChatState::HandleInput { - input: user_input, - tool_uses: None, - pending_tool_index: None, - }); - } - - loop { - debug_assert!(next_state.is_some()); - let chat_state = next_state.take().unwrap_or_default(); - debug!(?chat_state, "changing to state"); - - let result = match chat_state { - ChatState::PromptUser { - tool_uses, - pending_tool_index, - skip_printing_tools, - } => { - // Cannot prompt in non-interactive mode no matter what. - if !self.interactive { - return Ok(()); - } - self.prompt_user(tool_uses, pending_tool_index, skip_printing_tools) - .await - }, - ChatState::HandleInput { - input, - tool_uses, - pending_tool_index, - } => { - let tool_uses_clone = tool_uses.clone(); - tokio::select! { - res = self.handle_input(input, tool_uses, pending_tool_index) => res, - Some(_) = ctrl_c_stream.recv() => Err(ChatError::Interrupted { tool_uses: tool_uses_clone }) - } - }, - ChatState::CompactHistory { - tool_uses, - pending_tool_index, - prompt, - show_summary, - help, - } => { - let tool_uses_clone = tool_uses.clone(); - tokio::select! { - res = self.compact_history(tool_uses, pending_tool_index, prompt, show_summary, help) => res, - Some(_) = ctrl_c_stream.recv() => Err(ChatError::Interrupted { tool_uses: tool_uses_clone }) - } - }, - ChatState::ExecuteTools(tool_uses) => { - let tool_uses_clone = tool_uses.clone(); - tokio::select! { - res = self.tool_use_execute(tool_uses) => res, - Some(_) = ctrl_c_stream.recv() => Err(ChatError::Interrupted { tool_uses: Some(tool_uses_clone) }) - } - }, - ChatState::ValidateTools(tool_uses) => { - tokio::select! { - res = self.validate_tools(tool_uses) => res, - Some(_) = ctrl_c_stream.recv() => Err(ChatError::Interrupted { tool_uses: None }) - } - }, - ChatState::HandleResponseStream(response) => tokio::select! { - res = self.handle_response(response) => res, - Some(_) = ctrl_c_stream.recv() => Err(ChatError::Interrupted { tool_uses: None }) - }, - ChatState::Exit => return Ok(()), - }; - - next_state = Some(self.handle_state_execution_result(result).await?); - } - } - - /// Handles the result of processing a [ChatState], returning the next [ChatState] to change - /// to. - async fn handle_state_execution_result( - &mut self, - result: Result, - ) -> Result { - // Remove non-ASCII and ANSI characters. - let re = Regex::new(r"((\x9B|\x1B\[)[0-?]*[ -\/]*[@-~])|([^\x00-\x7F]+)").unwrap(); - match result { - Ok(state) => Ok(state), - Err(e) => { - macro_rules! print_err { - ($prepend_msg:expr, $err:expr) => {{ - queue!( - self.output, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Red), - )?; - - let report = eyre::Report::from($err); - - let text = re - .replace_all(&format!("{}: {:?}\n", $prepend_msg, report), "") - .into_owned(); - - queue!(self.output, style::Print(&text),)?; - self.conversation_state.append_transcript(text); - - execute!( - self.output, - style::SetAttribute(Attribute::Reset), - style::SetForegroundColor(Color::Reset), - )?; - }}; - } - - macro_rules! print_default_error { - ($err:expr) => { - print_err!("Amazon Q is having trouble responding right now", $err); - }; - } - - error!(?e, "An error occurred processing the current state"); - if self.interactive && self.spinner.is_some() { - drop(self.spinner.take()); - queue!( - self.output, - terminal::Clear(terminal::ClearType::CurrentLine), - cursor::MoveToColumn(0), - )?; - } - match e { - ChatError::Interrupted { tool_uses: inter } => { - execute!(self.output, style::Print("\n\n"))?; - // If there was an interrupt during tool execution, then we add fake - // messages to "reset" the chat state. - match inter { - Some(tool_uses) if !tool_uses.is_empty() => { - self.conversation_state.abandon_tool_use( - tool_uses, - "The user interrupted the tool execution.".to_string(), - ); - let _ = self.conversation_state.as_sendable_conversation_state(false).await; - self.conversation_state - .push_assistant_message(AssistantMessage::new_response( - None, - "Tool uses were interrupted, waiting for the next user prompt".to_string(), - )); - }, - _ => (), - } - }, - ChatError::Client(err) => match err { - // Errors from attempting to send too large of a conversation history. In - // this case, attempt to automatically compact the history for the user. - fig_api_client::Error::ContextWindowOverflow => { - let history_too_small = self - .conversation_state - .backend_conversation_state(false, true) - .await - .history - .len() - < 2; - if history_too_small { - print_err!( - "Your conversation is too large - try reducing the size of - the context being passed", - err - ); - return Ok(ChatState::PromptUser { - tool_uses: None, - pending_tool_index: None, - skip_printing_tools: false, - }); - } - - return Ok(ChatState::CompactHistory { - tool_uses: None, - pending_tool_index: None, - prompt: None, - show_summary: false, - help: false, - }); - }, - fig_api_client::Error::QuotaBreach(msg) => { - print_err!(msg, err); - }, - _ => { - print_default_error!(err); - }, - }, - _ => { - print_default_error!(e); - }, - } - self.conversation_state.enforce_conversation_invariants(); - self.conversation_state.reset_next_user_message(); - Ok(ChatState::PromptUser { - tool_uses: None, - pending_tool_index: None, - skip_printing_tools: false, - }) - }, - } - } - - /// Compacts the conversation history, replacing the history with a summary generated by the - /// model. - /// - /// The last two user messages in the history are not included in the compaction process. - async fn compact_history( - &mut self, - tool_uses: Option>, - pending_tool_index: Option, - custom_prompt: Option, - show_summary: bool, - help: bool, - ) -> Result { - let hist = self.conversation_state.history(); - debug!(?hist, "compacting history"); - - // If help flag is set, show compact command help - if help { - execute!( - self.output, - style::Print("\n"), - style::Print(compact_help_text()), - style::Print("\n") - )?; - - return Ok(ChatState::PromptUser { - tool_uses, - pending_tool_index, - skip_printing_tools: true, - }); - } - - if self.conversation_state.history().len() < 2 { - execute!( - self.output, - style::SetForegroundColor(Color::Yellow), - style::Print("\nConversation too short to compact.\n\n"), - style::SetForegroundColor(Color::Reset) - )?; - - return Ok(ChatState::PromptUser { - tool_uses, - pending_tool_index, - skip_printing_tools: true, - }); - } - - // Send a request for summarizing the history. - let summary_state = self - .conversation_state - .create_summary_request(custom_prompt.as_ref()) - .await; - if self.interactive { - execute!(self.output, cursor::Hide, style::Print("\n"))?; - self.spinner = Some(Spinner::new(Spinners::Dots, "Creating summary...".to_string())); - } - let response = self.client.send_message(summary_state).await; - - // TODO(brandonskiser): This is a temporary hotfix for failing compaction. We should instead - // retry except with less context included. - let response = match response { - Ok(res) => res, - Err(e) => match e { - fig_api_client::Error::ContextWindowOverflow => { - self.conversation_state.clear(true); - if self.interactive { - self.spinner.take(); - execute!( - self.output, - terminal::Clear(terminal::ClearType::CurrentLine), - cursor::MoveToColumn(0), - style::SetForegroundColor(Color::Yellow), - style::Print( - "The context window usage has overflowed. Clearing the conversation history.\n\n" - ), - style::SetAttribute(Attribute::Reset) - )?; - } - return Ok(ChatState::PromptUser { - tool_uses, - pending_tool_index, - skip_printing_tools: true, - }); - }, - e => return Err(e.into()), - }, - }; - - let summary = { - let mut parser = ResponseParser::new(response); - loop { - match parser.recv().await { - Ok(parser::ResponseEvent::EndStream { message }) => { - break message.content().to_string(); - }, - Ok(_) => (), - Err(err) => { - if let Some(request_id) = &err.request_id { - self.failed_request_ids.push(request_id.clone()); - }; - return Err(err.into()); - }, - } - } - }; - - if self.interactive && self.spinner.is_some() { - drop(self.spinner.take()); - queue!( - self.output, - terminal::Clear(terminal::ClearType::CurrentLine), - cursor::MoveToColumn(0), - cursor::Show - )?; - } - - if let Some(message_id) = self.conversation_state.message_id() { - fig_telemetry::send_chat_added_message( - self.conversation_state.conversation_id().to_owned(), - message_id.to_owned(), - self.conversation_state.context_message_length(), - ) - .await; - } - - self.conversation_state.replace_history_with_summary(summary.clone()); - - // Print output to the user. - { - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print("✔ Conversation history has been compacted successfully!\n\n"), - style::SetForegroundColor(Color::DarkGrey) - )?; - - let mut output = Vec::new(); - if let Some(custom_prompt) = &custom_prompt { - execute!( - output, - style::Print(format!("• Custom prompt applied: {}\n", custom_prompt)) - )?; - } - animate_output(&mut self.output, &output)?; - - // Display the summary if the show_summary flag is set - if show_summary { - // Add a border around the summary for better visual separation - let terminal_width = self.terminal_width(); - let border = "═".repeat(terminal_width.min(80)); - execute!( - self.output, - style::Print("\n"), - style::SetForegroundColor(Color::Cyan), - style::Print(&border), - style::Print("\n"), - style::SetAttribute(Attribute::Bold), - style::Print(" CONVERSATION SUMMARY"), - style::Print("\n"), - style::Print(&border), - style::SetAttribute(Attribute::Reset), - style::Print("\n\n"), - )?; - - execute!( - output, - style::Print(&summary), - style::Print("\n\n"), - style::SetForegroundColor(Color::Cyan), - style::Print("The conversation history has been replaced with this summary.\n"), - style::Print("It contains all important details from previous interactions.\n"), - )?; - animate_output(&mut self.output, &output)?; - - execute!( - self.output, - style::Print(&border), - style::Print("\n\n"), - style::SetForegroundColor(Color::Reset) - )?; - } - } - - // If a next message is set, then retry the request. - if self.conversation_state.next_user_message().is_some() { - Ok(ChatState::HandleResponseStream( - self.client - .send_message(self.conversation_state.as_sendable_conversation_state(false).await) - .await?, - )) - } else { - // Otherwise, return back to the prompt for any pending tool uses. - Ok(ChatState::PromptUser { - tool_uses, - pending_tool_index, - skip_printing_tools: true, - }) - } - } - - /// Read input from the user. - async fn prompt_user( - &mut self, - mut tool_uses: Option>, - pending_tool_index: Option, - skip_printing_tools: bool, - ) -> Result { - execute!(self.output, cursor::Show)?; - let tool_uses = tool_uses.take().unwrap_or_default(); - - // Check token usage and display warnings if needed - if pending_tool_index.is_none() { - // Only display warnings when not waiting for tool approval - if let Err(e) = self.display_char_warnings().await { - warn!("Failed to display character limit warnings: {}", e); - } - } - - let show_tool_use_confirmation_dialog = !skip_printing_tools && pending_tool_index.is_some(); - if show_tool_use_confirmation_dialog { - execute!( - self.output, - style::SetForegroundColor(Color::DarkGrey), - style::Print("\nAllow this action? Use '"), - style::SetForegroundColor(Color::Green), - style::Print("t"), - style::SetForegroundColor(Color::DarkGrey), - style::Print("' to trust (always allow) this tool for the session. ["), - style::SetForegroundColor(Color::Green), - style::Print("y"), - style::SetForegroundColor(Color::DarkGrey), - style::Print("/"), - style::SetForegroundColor(Color::Green), - style::Print("n"), - style::SetForegroundColor(Color::DarkGrey), - style::Print("/"), - style::SetForegroundColor(Color::Green), - style::Print("t"), - style::SetForegroundColor(Color::DarkGrey), - style::Print("]:\n\n"), - style::SetForegroundColor(Color::Reset), - )?; - } - - // Do this here so that the skim integration sees an updated view of the context *during the current - // q session*. (e.g., if I add files to context, that won't show up for skim for the current - // q session unless we do this in prompt_user... unless you can find a better way) - if let Some(ref context_manager) = self.conversation_state.context_manager { - let tool_names = self.tool_manager.tn_map.keys().cloned().collect::>(); - self.input_source - .put_skim_command_selector(Arc::new(context_manager.clone()), tool_names); - } - execute!( - self.output, - style::SetForegroundColor(Color::Reset), - style::SetAttribute(Attribute::Reset) - )?; - let user_input = match self.read_user_input(&self.generate_tool_trust_prompt(), false) { - Some(input) => input, - None => return Ok(ChatState::Exit), - }; - - self.conversation_state.append_user_transcript(&user_input); - Ok(ChatState::HandleInput { - input: user_input, - tool_uses: Some(tool_uses), - pending_tool_index, - }) - } - - async fn handle_input( - &mut self, - mut user_input: String, - tool_uses: Option>, - pending_tool_index: Option, - ) -> Result { - let command_result = Command::parse(&user_input, &mut self.output); - - if let Err(error_message) = &command_result { - // Display error message for command parsing errors - execute!( - self.output, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nError: {}\n\n", error_message)), - style::SetForegroundColor(Color::Reset) - )?; - - return Ok(ChatState::PromptUser { - tool_uses, - pending_tool_index, - skip_printing_tools: true, - }); - } - - let command = command_result.unwrap(); - let mut tool_uses: Vec = tool_uses.unwrap_or_default(); - - Ok(match command { - Command::Ask { prompt } => { - // Check for a pending tool approval - if let Some(index) = pending_tool_index { - let tool_use = &mut tool_uses[index]; - - let is_trust = ["t", "T"].contains(&prompt.as_str()); - if ["y", "Y"].contains(&prompt.as_str()) || is_trust { - if is_trust { - self.tool_permissions.trust_tool(&tool_use.name); - } - tool_use.accepted = true; - - return Ok(ChatState::ExecuteTools(tool_uses)); - } - } else if !self.pending_prompts.is_empty() { - let prompts = self.pending_prompts.drain(0..).collect(); - user_input = self - .conversation_state - .append_prompts(prompts) - .ok_or(ChatError::Custom("Prompt append failed".into()))?; - } - - // Otherwise continue with normal chat on 'n' or other responses - self.tool_use_status = ToolUseStatus::Idle; - - if pending_tool_index.is_some() { - self.conversation_state.abandon_tool_use(tool_uses, user_input); - } else { - self.conversation_state.set_next_user_message(user_input).await; - } - - let conv_state = self.conversation_state.as_sendable_conversation_state(true).await; - - if self.interactive { - queue!(self.output, style::SetForegroundColor(Color::Magenta))?; - queue!(self.output, style::SetForegroundColor(Color::Reset))?; - queue!(self.output, cursor::Hide)?; - execute!(self.output, style::Print("\n"))?; - self.spinner = Some(Spinner::new(Spinners::Dots, "Thinking...".to_owned())); - } - - self.send_tool_use_telemetry().await; - - ChatState::HandleResponseStream(self.client.send_message(conv_state).await?) - }, - Command::Execute { command } => { - queue!(self.output, style::Print('\n'))?; - std::process::Command::new("bash").args(["-c", &command]).status().ok(); - queue!(self.output, style::Print('\n'))?; - ChatState::PromptUser { - tool_uses: None, - pending_tool_index: None, - skip_printing_tools: false, - } - }, - Command::Clear => { - execute!(self.output, cursor::Show)?; - execute!( - self.output, - style::SetForegroundColor(Color::DarkGrey), - style::Print( - "\nAre you sure? This will erase the conversation history and context from hooks for the current session. " - ), - style::Print("["), - style::SetForegroundColor(Color::Green), - style::Print("y"), - style::SetForegroundColor(Color::DarkGrey), - style::Print("/"), - style::SetForegroundColor(Color::Green), - style::Print("n"), - style::SetForegroundColor(Color::DarkGrey), - style::Print("]:\n\n"), - style::SetForegroundColor(Color::Reset), - )?; - - // Setting `exit_on_single_ctrl_c` for better ux: exit the confirmation dialog rather than the CLI - let user_input = match self.read_user_input("> ".yellow().to_string().as_str(), true) { - Some(input) => input, - None => "".to_string(), - }; - - if ["y", "Y"].contains(&user_input.as_str()) { - self.conversation_state.clear(true); - if let Some(cm) = self.conversation_state.context_manager.as_mut() { - cm.hook_executor.global_cache.clear(); - cm.hook_executor.profile_cache.clear(); - } - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print("\nConversation history cleared.\n\n"), - style::SetForegroundColor(Color::Reset) - )?; - } - - ChatState::PromptUser { - tool_uses: None, - pending_tool_index: None, - skip_printing_tools: true, - } - }, - Command::Compact { - prompt, - show_summary, - help, - } => { - self.compact_history(Some(tool_uses), pending_tool_index, prompt, show_summary, help) - .await? - }, - Command::Help => { - execute!(self.output, style::Print(HELP_TEXT))?; - ChatState::PromptUser { - tool_uses: Some(tool_uses), - pending_tool_index, - skip_printing_tools: true, - } - }, - Command::Issue { prompt } => { - let input = "I would like to report an issue or make a feature request"; - ChatState::HandleInput { - input: if let Some(prompt) = prompt { - format!("{input}: {prompt}") - } else { - input.to_string() - }, - tool_uses: Some(tool_uses), - pending_tool_index, - } - }, - Command::PromptEditor { initial_text } => { - match Self::open_editor(initial_text) { - Ok(content) => { - if content.trim().is_empty() { - execute!( - self.output, - style::SetForegroundColor(Color::Yellow), - style::Print("\nEmpty content from editor, not submitting.\n\n"), - style::SetForegroundColor(Color::Reset) - )?; - - ChatState::PromptUser { - tool_uses: Some(tool_uses), - pending_tool_index, - skip_printing_tools: true, - } - } else { - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print("\nContent loaded from editor. Submitting prompt...\n\n"), - style::SetForegroundColor(Color::Reset) - )?; - - // Display the content as if the user typed it - execute!( - self.output, - style::SetAttribute(Attribute::Reset), - style::SetForegroundColor(Color::Magenta), - style::Print("> "), - style::SetAttribute(Attribute::Reset), - style::Print(&content), - style::Print("\n") - )?; - - // Process the content as user input - ChatState::HandleInput { - input: content, - tool_uses: Some(tool_uses), - pending_tool_index, - } - } - }, - Err(e) => { - execute!( - self.output, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nError opening editor: {}\n\n", e)), - style::SetForegroundColor(Color::Reset) - )?; - - ChatState::PromptUser { - tool_uses: Some(tool_uses), - pending_tool_index, - skip_printing_tools: true, - } - }, - } - }, - Command::Quit => ChatState::Exit, - Command::Profile { subcommand } => { - if let Some(context_manager) = &mut self.conversation_state.context_manager { - macro_rules! print_err { - ($err:expr) => { - execute!( - self.output, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nError: {}\n\n", $err)), - style::SetForegroundColor(Color::Reset) - )? - }; - } - - match subcommand { - command::ProfileSubcommand::List => { - let profiles = match context_manager.list_profiles().await { - Ok(profiles) => profiles, - Err(e) => { - execute!( - self.output, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nError listing profiles: {}\n\n", e)), - style::SetForegroundColor(Color::Reset) - )?; - vec![] - }, - }; - - execute!(self.output, style::Print("\n"))?; - for profile in profiles { - if profile == context_manager.current_profile { - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print("* "), - style::Print(&profile), - style::SetForegroundColor(Color::Reset), - style::Print("\n") - )?; - } else { - execute!( - self.output, - style::Print(" "), - style::Print(&profile), - style::Print("\n") - )?; - } - } - execute!(self.output, style::Print("\n"))?; - }, - command::ProfileSubcommand::Create { name } => { - match context_manager.create_profile(&name).await { - Ok(_) => { - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nCreated profile: {}\n\n", name)), - style::SetForegroundColor(Color::Reset) - )?; - context_manager - .switch_profile(&name) - .await - .map_err(|e| warn!(?e, "failed to switch to newly created profile")) - .ok(); - }, - Err(e) => print_err!(e), - } - }, - command::ProfileSubcommand::Delete { name } => { - match context_manager.delete_profile(&name).await { - Ok(_) => { - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nDeleted profile: {}\n\n", name)), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => print_err!(e), - } - }, - command::ProfileSubcommand::Set { name } => match context_manager.switch_profile(&name).await { - Ok(_) => { - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nSwitched to profile: {}\n\n", name)), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => print_err!(e), - }, - command::ProfileSubcommand::Rename { old_name, new_name } => { - match context_manager.rename_profile(&old_name, &new_name).await { - Ok(_) => { - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nRenamed profile: {} -> {}\n\n", old_name, new_name)), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => print_err!(e), - } - }, - command::ProfileSubcommand::Help => { - execute!( - self.output, - style::Print("\n"), - style::Print(command::ProfileSubcommand::help_text()), - style::Print("\n") - )?; - }, - } - } - ChatState::PromptUser { - tool_uses: Some(tool_uses), - pending_tool_index, - skip_printing_tools: true, - } - }, - Command::Context { subcommand } => { - if let Some(context_manager) = &mut self.conversation_state.context_manager { - match subcommand { - command::ContextSubcommand::Show { expand } => { - // Display global context - execute!( - self.output, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Magenta), - style::Print("\n🌍 global:\n"), - style::SetAttribute(Attribute::Reset), - )?; - let mut global_context_files = HashSet::new(); - let mut profile_context_files = HashSet::new(); - if context_manager.global_config.paths.is_empty() { - execute!( - self.output, - style::SetForegroundColor(Color::DarkGrey), - style::Print(" \n"), - style::SetForegroundColor(Color::Reset) - )?; - } else { - for path in &context_manager.global_config.paths { - execute!(self.output, style::Print(format!(" {} ", path)))?; - if let Ok(context_files) = - context_manager.get_context_files_by_path(false, path).await - { - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!( - "({} match{})", - context_files.len(), - if context_files.len() == 1 { "" } else { "es" } - )), - style::SetForegroundColor(Color::Reset) - )?; - global_context_files.extend(context_files); - } - execute!(self.output, style::Print("\n"))?; - } - } - - // Display profile context - execute!( - self.output, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Magenta), - style::Print(format!("\n👤 profile ({}):\n", context_manager.current_profile)), - style::SetAttribute(Attribute::Reset), - )?; - - if context_manager.profile_config.paths.is_empty() { - execute!( - self.output, - style::SetForegroundColor(Color::DarkGrey), - style::Print(" \n\n"), - style::SetForegroundColor(Color::Reset) - )?; - } else { - for path in &context_manager.profile_config.paths { - execute!(self.output, style::Print(format!(" {} ", path)))?; - if let Ok(context_files) = - context_manager.get_context_files_by_path(false, path).await - { - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!( - "({} match{})", - context_files.len(), - if context_files.len() == 1 { "" } else { "es" } - )), - style::SetForegroundColor(Color::Reset) - )?; - profile_context_files.extend(context_files); - } - execute!(self.output, style::Print("\n"))?; - } - execute!(self.output, style::Print("\n"))?; - } - - if global_context_files.is_empty() && profile_context_files.is_empty() { - execute!( - self.output, - style::SetForegroundColor(Color::DarkGrey), - style::Print("No files in the current directory matched the rules above.\n\n"), - style::SetForegroundColor(Color::Reset) - )?; - } else { - let total = global_context_files.len() + profile_context_files.len(); - let total_tokens = global_context_files - .iter() - .map(|(_, content)| TokenCounter::count_tokens(content)) - .sum::() - + profile_context_files - .iter() - .map(|(_, content)| TokenCounter::count_tokens(content)) - .sum::(); - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::SetAttribute(Attribute::Bold), - style::Print(format!( - "{} matched file{} in use:\n", - total, - if total == 1 { "" } else { "s" } - )), - style::SetForegroundColor(Color::Reset), - style::SetAttribute(Attribute::Reset) - )?; - - for (filename, content) in global_context_files { - let est_tokens = TokenCounter::count_tokens(&content); - execute!( - self.output, - style::Print(format!("🌍 {} ", filename)), - style::SetForegroundColor(Color::DarkGrey), - style::Print(format!("(~{} tkns)\n", est_tokens)), - style::SetForegroundColor(Color::Reset), - )?; - if expand { - execute!( - self.output, - style::SetForegroundColor(Color::DarkGrey), - style::Print(format!("{}\n\n", content)), - style::SetForegroundColor(Color::Reset) - )?; - } - } - - for (filename, content) in profile_context_files { - let est_tokens = TokenCounter::count_tokens(&content); - execute!( - self.output, - style::Print(format!("👤 {} ", filename)), - style::SetForegroundColor(Color::DarkGrey), - style::Print(format!("(~{} tkns)\n", est_tokens)), - style::SetForegroundColor(Color::Reset), - )?; - if expand { - execute!( - self.output, - style::SetForegroundColor(Color::DarkGrey), - style::Print(format!("{}\n\n", content)), - style::SetForegroundColor(Color::Reset) - )?; - } - } - - if expand { - execute!(self.output, style::Print(format!("{}\n\n", "▔".repeat(3))),)?; - } - - execute!( - self.output, - style::Print(format!("\nTotal: ~{} tokens\n\n", total_tokens)), - )?; - - execute!(self.output, style::Print("\n"))?; - } - }, - command::ContextSubcommand::Add { global, force, paths } => { - match context_manager.add_paths(paths.clone(), global, force).await { - Ok(_) => { - let target = if global { "global" } else { "profile" }; - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!( - "\nAdded {} path(s) to {} context.\n\n", - paths.len(), - target - )), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => { - execute!( - self.output, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nError: {}\n\n", e)), - style::SetForegroundColor(Color::Reset) - )?; - }, - } - }, - command::ContextSubcommand::Remove { global, paths } => { - match context_manager.remove_paths(paths.clone(), global).await { - Ok(_) => { - let target = if global { "global" } else { "profile" }; - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!( - "\nRemoved {} path(s) from {} context.\n\n", - paths.len(), - target - )), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => { - execute!( - self.output, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nError: {}\n\n", e)), - style::SetForegroundColor(Color::Reset) - )?; - }, - } - }, - command::ContextSubcommand::Clear { global } => match context_manager.clear(global).await { - Ok(_) => { - let target = if global { - "global".to_string() - } else { - format!("profile '{}'", context_manager.current_profile) - }; - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nCleared context for {}\n\n", target)), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => { - execute!( - self.output, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nError: {}\n\n", e)), - style::SetForegroundColor(Color::Reset) - )?; - }, - }, - command::ContextSubcommand::Help => { - execute!( - self.output, - style::Print("\n"), - style::Print(command::ContextSubcommand::help_text()), - style::Print("\n") - )?; - }, - command::ContextSubcommand::Hooks { subcommand } => { - fn map_chat_error(e: ErrReport) -> ChatError { - ChatError::Custom(e.to_string().into()) - } - - let scope = |g: bool| if g { "global" } else { "profile" }; - if let Some(subcommand) = subcommand { - match subcommand { - command::HooksSubcommand::Add { - name, - trigger, - command, - global, - } => { - let trigger = if trigger == "conversation_start" { - HookTrigger::ConversationStart - } else { - HookTrigger::PerPrompt - }; - - let result = context_manager - .add_hook(name.clone(), Hook::new_inline_hook(trigger, command), global) - .await; - match result { - Ok(_) => { - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!( - "\nAdded {} hook '{name}'.\n\n", - scope(global) - )), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => { - execute!( - self.output, - style::SetForegroundColor(Color::Red), - style::Print(format!( - "\nCannot add {} hook '{name}': {}\n\n", - scope(global), - e - )), - style::SetForegroundColor(Color::Reset) - )?; - }, - } - }, - command::HooksSubcommand::Remove { name, global } => { - let result = context_manager.remove_hook(&name, global).await; - match result { - Ok(_) => { - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!( - "\nRemoved {} hook '{name}'.\n\n", - scope(global) - )), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => { - execute!( - self.output, - style::SetForegroundColor(Color::Red), - style::Print(format!( - "\nCannot remove {} hook '{name}': {}\n\n", - scope(global), - e - )), - style::SetForegroundColor(Color::Reset) - )?; - }, - } - }, - command::HooksSubcommand::Enable { name, global } => { - let result = context_manager.set_hook_disabled(&name, global, false).await; - match result { - Ok(_) => { - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!( - "\nEnabled {} hook '{name}'.\n\n", - scope(global) - )), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => { - execute!( - self.output, - style::SetForegroundColor(Color::Red), - style::Print(format!( - "\nCannot enable {} hook '{name}': {}\n\n", - scope(global), - e - )), - style::SetForegroundColor(Color::Reset) - )?; - }, - } - }, - command::HooksSubcommand::Disable { name, global } => { - let result = context_manager.set_hook_disabled(&name, global, true).await; - match result { - Ok(_) => { - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!( - "\nDisabled {} hook '{name}'.\n\n", - scope(global) - )), - style::SetForegroundColor(Color::Reset) - )?; - }, - Err(e) => { - execute!( - self.output, - style::SetForegroundColor(Color::Red), - style::Print(format!( - "\nCannot disable {} hook '{name}': {}\n\n", - scope(global), - e - )), - style::SetForegroundColor(Color::Reset) - )?; - }, - } - }, - command::HooksSubcommand::EnableAll { global } => { - context_manager - .set_all_hooks_disabled(global, false) - .await - .map_err(map_chat_error)?; - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nEnabled all {} hooks.\n\n", scope(global))), - style::SetForegroundColor(Color::Reset) - )?; - }, - command::HooksSubcommand::DisableAll { global } => { - context_manager - .set_all_hooks_disabled(global, true) - .await - .map_err(map_chat_error)?; - execute!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nDisabled all {} hooks.\n\n", scope(global))), - style::SetForegroundColor(Color::Reset) - )?; - }, - command::HooksSubcommand::Help => { - execute!( - self.output, - style::Print("\n"), - style::Print(command::ContextSubcommand::hooks_help_text()), - style::Print("\n") - )?; - }, - } - } else { - fn print_hook_section( - output: &mut impl Write, - hooks: &HashMap, - trigger: HookTrigger, - ) -> Result<()> { - let section = match trigger { - HookTrigger::ConversationStart => "Conversation Start", - HookTrigger::PerPrompt => "Per Prompt", - }; - let hooks: Vec<(&String, &Hook)> = - hooks.iter().filter(|(_, h)| h.trigger == trigger).collect(); - - queue!( - output, - style::SetForegroundColor(Color::Cyan), - style::Print(format!(" {section}:\n")), - style::SetForegroundColor(Color::Reset), - )?; - - if hooks.is_empty() { - queue!( - output, - style::SetForegroundColor(Color::DarkGrey), - style::Print(" \n"), - style::SetForegroundColor(Color::Reset) - )?; - } else { - for (name, hook) in hooks { - if hook.disabled { - queue!( - output, - style::SetForegroundColor(Color::DarkGrey), - style::Print(format!(" {} (disabled)\n", name)), - style::SetForegroundColor(Color::Reset) - )?; - } else { - queue!(output, style::Print(format!(" {}\n", name)),)?; - } - } - } - Ok(()) - } - queue!( - self.output, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Magenta), - style::Print("\n🌍 global:\n"), - style::SetAttribute(Attribute::Reset), - )?; - - print_hook_section( - &mut self.output, - &context_manager.global_config.hooks, - HookTrigger::ConversationStart, - ) - .map_err(map_chat_error)?; - print_hook_section( - &mut self.output, - &context_manager.global_config.hooks, - HookTrigger::PerPrompt, - ) - .map_err(map_chat_error)?; - - queue!( - self.output, - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Magenta), - style::Print(format!("\n👤 profile ({}):\n", &context_manager.current_profile)), - style::SetAttribute(Attribute::Reset), - )?; - - print_hook_section( - &mut self.output, - &context_manager.profile_config.hooks, - HookTrigger::ConversationStart, - ) - .map_err(map_chat_error)?; - print_hook_section( - &mut self.output, - &context_manager.profile_config.hooks, - HookTrigger::PerPrompt, - ) - .map_err(map_chat_error)?; - - execute!( - self.output, - style::Print(format!( - "\nUse {} to manage hooks.\n\n", - "/context hooks help".to_string().dark_green() - )), - )?; - } - }, - } - // fig_telemetry::send_context_command_executed - } else { - execute!( - self.output, - style::SetForegroundColor(Color::Red), - style::Print("\nContext management is not available.\n\n"), - style::SetForegroundColor(Color::Reset) - )?; - } - - ChatState::PromptUser { - tool_uses: Some(tool_uses), - pending_tool_index, - skip_printing_tools: true, - } - }, - Command::Tools { subcommand } => { - let existing_tools: HashSet<&String> = self - .conversation_state - .tools - .values() - .flatten() - .map(|FigTool::ToolSpecification(spec)| &spec.name) - .collect(); - - match subcommand { - Some(ToolsSubcommand::Schema) => { - let schema_json = serde_json::to_string_pretty(&self.tool_manager.schema).map_err(|e| { - ChatError::Custom(format!("Error converting tool schema to string: {e}").into()) - })?; - queue!(self.output, style::Print(schema_json), style::Print("\n"))?; - }, - Some(ToolsSubcommand::Trust { tool_names }) => { - let (valid_tools, invalid_tools): (Vec, Vec) = tool_names - .into_iter() - .partition(|tool_name| existing_tools.contains(tool_name)); - - if !invalid_tools.is_empty() { - queue!( - self.output, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nCannot trust '{}', ", invalid_tools.join("', '"))), - if invalid_tools.len() > 1 { - style::Print("they do not exist.") - } else { - style::Print("it does not exist.") - }, - style::SetForegroundColor(Color::Reset), - )?; - } - if !valid_tools.is_empty() { - valid_tools.iter().for_each(|t| self.tool_permissions.trust_tool(t)); - queue!( - self.output, - style::SetForegroundColor(Color::Green), - if valid_tools.len() > 1 { - style::Print(format!("\nTools '{}' are ", valid_tools.join("', '"))) - } else { - style::Print(format!("\nTool '{}' is ", valid_tools[0])) - }, - style::Print("now trusted. I will "), - style::SetAttribute(Attribute::Bold), - style::Print("not"), - style::SetAttribute(Attribute::Reset), - style::SetForegroundColor(Color::Green), - style::Print(format!( - " ask for confirmation before running {}.", - if valid_tools.len() > 1 { - "these tools" - } else { - "this tool" - } - )), - style::SetForegroundColor(Color::Reset), - )?; - } - }, - Some(ToolsSubcommand::Untrust { tool_names }) => { - let (valid_tools, invalid_tools): (Vec, Vec) = tool_names - .into_iter() - .partition(|tool_name| existing_tools.contains(tool_name)); - - if !invalid_tools.is_empty() { - queue!( - self.output, - style::SetForegroundColor(Color::Red), - style::Print(format!("\nCannot untrust '{}', ", invalid_tools.join("', '"))), - if invalid_tools.len() > 1 { - style::Print("they do not exist.") - } else { - style::Print("it does not exist.") - }, - style::SetForegroundColor(Color::Reset), - )?; - } - if !valid_tools.is_empty() { - valid_tools.iter().for_each(|t| self.tool_permissions.untrust_tool(t)); - queue!( - self.output, - style::SetForegroundColor(Color::Green), - if valid_tools.len() > 1 { - style::Print(format!("\nTools '{}' are ", valid_tools.join("', '"))) - } else { - style::Print(format!("\nTool '{}' is ", valid_tools[0])) - }, - style::Print("set to per-request confirmation."), - style::SetForegroundColor(Color::Reset), - )?; - } - }, - Some(ToolsSubcommand::TrustAll) => { - self.conversation_state.tools.values().flatten().for_each( - |FigTool::ToolSpecification(spec)| { - self.tool_permissions.trust_tool(spec.name.as_str()); - }, - ); - queue!(self.output, style::Print(TRUST_ALL_TEXT),)?; - }, - Some(ToolsSubcommand::Reset) => { - self.tool_permissions.reset(); - queue!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print("\nReset all tools to the default permission levels."), - style::SetForegroundColor(Color::Reset), - )?; - }, - Some(ToolsSubcommand::ResetSingle { tool_name }) => { - if self.tool_permissions.has(&tool_name) { - self.tool_permissions.reset_tool(&tool_name); - queue!( - self.output, - style::SetForegroundColor(Color::Green), - style::Print(format!("\nReset tool '{}' to the default permission level.", tool_name)), - style::SetForegroundColor(Color::Reset), - )?; - } else { - queue!( - self.output, - style::SetForegroundColor(Color::Red), - style::Print(format!( - "\nTool '{}' does not exist or is already in default settings.", - tool_name - )), - style::SetForegroundColor(Color::Reset), - )?; - } - }, - Some(ToolsSubcommand::Help) => { - queue!( - self.output, - style::Print("\n"), - style::Print(command::ToolsSubcommand::help_text()), - )?; - }, - None => { - // No subcommand - print the current tools and their permissions. - // Determine how to format the output nicely. - let terminal_width = self.terminal_width(); - let longest = self - .conversation_state - .tools - .values() - .flatten() - .map(|FigTool::ToolSpecification(spec)| spec.name.len()) - .max() - .unwrap_or(0); - - queue!( - self.output, - style::Print("\n"), - style::SetAttribute(Attribute::Bold), - style::Print({ - // Adding 2 because of "- " preceding every tool name - let width = longest + 2 - "Tool".len() + 4; - format!("Tool{:>width$}Permission", "", width = width) - }), - style::SetAttribute(Attribute::Reset), - style::Print("\n"), - style::Print("▔".repeat(terminal_width)), - )?; - - self.conversation_state.tools.iter().for_each(|(origin, tools)| { - let to_display = - tools - .iter() - .fold(String::new(), |mut acc, FigTool::ToolSpecification(spec)| { - let width = longest - spec.name.len() + 4; - acc.push_str( - format!( - "- {}{:>width$}{}\n", - spec.name, - "", - self.tool_permissions.display_label(&spec.name), - width = width - ) - .as_str(), - ); - acc - }); - let _ = queue!( - self.output, - style::SetAttribute(Attribute::Bold), - style::Print(format!("{}:\n", origin)), - style::SetAttribute(Attribute::Reset), - style::Print(to_display), - style::Print("\n") - ); - }); - - queue!( - self.output, - style::Print("\nTrusted tools can be run without confirmation\n"), - style::SetForegroundColor(Color::DarkGrey), - style::Print(format!("\n{}\n", "* Default settings")), - style::Print("\n💡 Use "), - style::SetForegroundColor(Color::Green), - style::Print("/tools help"), - style::SetForegroundColor(Color::Reset), - style::SetForegroundColor(Color::DarkGrey), - style::Print(" to edit permissions."), - style::SetForegroundColor(Color::Reset), - )?; - }, - }; - - // Put spacing between previous output as to not be overwritten by - // during PromptUser. - self.output.flush()?; - - ChatState::PromptUser { - tool_uses: Some(tool_uses), - pending_tool_index, - skip_printing_tools: true, - } - }, - Command::Prompts { subcommand } => { - match subcommand { - Some(PromptsSubcommand::Help) => { - queue!(self.output, style::Print(command::PromptsSubcommand::help_text()))?; - }, - Some(PromptsSubcommand::Get { mut get_command }) => { - let orig_input = get_command.orig_input.take(); - let prompts = match self.tool_manager.get_prompt(get_command).await { - Ok(resp) => resp, - Err(e) => { - match e { - GetPromptError::AmbiguousPrompt(prompt_name, alt_msg) => { - queue!( - self.output, - style::Print("\n"), - style::SetForegroundColor(Color::Yellow), - style::Print("Prompt "), - style::SetForegroundColor(Color::Cyan), - style::Print(prompt_name), - style::SetForegroundColor(Color::Yellow), - style::Print(" is ambiguous. Use one of the following "), - style::SetForegroundColor(Color::Cyan), - style::Print(alt_msg), - style::SetForegroundColor(Color::Reset), - )?; - }, - GetPromptError::PromptNotFound(prompt_name) => { - queue!( - self.output, - style::Print("\n"), - style::SetForegroundColor(Color::Yellow), - style::Print("Prompt "), - style::SetForegroundColor(Color::Cyan), - style::Print(prompt_name), - style::SetForegroundColor(Color::Yellow), - style::Print(" not found. Use "), - style::SetForegroundColor(Color::Cyan), - style::Print("/prompts list"), - style::SetForegroundColor(Color::Yellow), - style::Print(" to see available prompts.\n"), - style::SetForegroundColor(Color::Reset), - )?; - }, - _ => return Err(ChatError::Custom(e.to_string().into())), - } - execute!(self.output, style::Print("\n"))?; - return Ok(ChatState::PromptUser { - tool_uses: Some(tool_uses), - pending_tool_index, - skip_printing_tools: true, - }); - }, - }; - if let Some(err) = prompts.error { - // If we are running into error we should just display the error - // and abort. - let to_display = serde_json::json!(err); - queue!( - self.output, - style::Print("\n"), - style::SetAttribute(Attribute::Bold), - style::Print("Error encountered while retrieving prompt:"), - style::SetAttribute(Attribute::Reset), - style::Print("\n"), - style::SetForegroundColor(Color::Red), - style::Print( - serde_json::to_string_pretty(&to_display) - .unwrap_or_else(|_| format!("{:?}", &to_display)) - ), - style::SetForegroundColor(Color::Reset), - style::Print("\n"), - )?; - } else { - let prompts = prompts - .result - .ok_or(ChatError::Custom("Result field missing from prompt/get request".into()))?; - let prompts = serde_json::from_value::(prompts).map_err(|e| { - ChatError::Custom(format!("Failed to deserialize prompt/get result: {:?}", e).into()) - })?; - self.pending_prompts.clear(); - self.pending_prompts.append(&mut VecDeque::from(prompts.messages)); - return Ok(ChatState::HandleInput { - input: orig_input.unwrap_or_default(), - tool_uses: Some(tool_uses), - pending_tool_index, - }); - } - }, - subcommand => { - let search_word = match subcommand { - Some(PromptsSubcommand::List { search_word }) => search_word, - _ => None, - }; - let terminal_width = self.terminal_width(); - let mut prompts_wl = self.tool_manager.prompts.write().map_err(|e| { - ChatError::Custom( - format!("Poison error encountered while retrieving prompts: {}", e).into(), - ) - })?; - self.tool_manager.refresh_prompts(&mut prompts_wl)?; - let mut longest_name = ""; - let arg_pos = { - let optimal_case = UnicodeWidthStr::width(longest_name) + terminal_width / 4; - if optimal_case > terminal_width { - terminal_width / 3 - } else { - optimal_case - } - }; - queue!( - self.output, - style::Print("\n"), - style::SetAttribute(Attribute::Bold), - style::Print("Prompt"), - style::SetAttribute(Attribute::Reset), - style::Print({ - let name_width = UnicodeWidthStr::width("Prompt"); - let padding = arg_pos.saturating_sub(name_width); - " ".repeat(padding) - }), - style::SetAttribute(Attribute::Bold), - style::Print("Arguments (* = required)"), - style::SetAttribute(Attribute::Reset), - style::Print("\n"), - style::Print(format!("{}\n", "▔".repeat(terminal_width))), - )?; - let prompts_by_server = prompts_wl.iter().fold( - HashMap::<&String, Vec<&PromptBundle>>::new(), - |mut acc, (prompt_name, bundles)| { - if prompt_name.contains(search_word.as_deref().unwrap_or("")) { - if prompt_name.len() > longest_name.len() { - longest_name = prompt_name.as_str(); - } - for bundle in bundles { - acc.entry(&bundle.server_name) - .and_modify(|b| b.push(bundle)) - .or_insert(vec![bundle]); - } - } - acc - }, - ); - for (i, (server_name, bundles)) in prompts_by_server.iter().enumerate() { - if i > 0 { - queue!(self.output, style::Print("\n"))?; - } - queue!( - self.output, - style::SetAttribute(Attribute::Bold), - style::Print(server_name), - style::Print(" (MCP):"), - style::SetAttribute(Attribute::Reset), - style::Print("\n"), - )?; - for bundle in bundles { - queue!( - self.output, - style::Print("- "), - style::Print(&bundle.prompt_get.name), - style::Print({ - if bundle - .prompt_get - .arguments - .as_ref() - .is_some_and(|args| !args.is_empty()) - { - let name_width = UnicodeWidthStr::width(bundle.prompt_get.name.as_str()); - let padding = - arg_pos.saturating_sub(name_width) - UnicodeWidthStr::width("- "); - " ".repeat(padding) - } else { - "\n".to_owned() - } - }) - )?; - if let Some(args) = bundle.prompt_get.arguments.as_ref() { - for (i, arg) in args.iter().enumerate() { - queue!( - self.output, - style::SetForegroundColor(Color::DarkGrey), - style::Print(match arg.required { - Some(true) => format!("{}*", arg.name), - _ => arg.name.clone(), - }), - style::SetForegroundColor(Color::Reset), - style::Print(if i < args.len() - 1 { ", " } else { "\n" }), - )?; - } - } - } - } - }, - } - execute!(self.output, style::Print("\n"))?; - ChatState::PromptUser { - tool_uses: Some(tool_uses), - pending_tool_index, - skip_printing_tools: true, - } - }, - Command::Usage => { - let state = self.conversation_state.backend_conversation_state(true, true).await; - let data = state.calculate_conversation_size(); - - let context_token_count: TokenCount = data.context_messages.into(); - let assistant_token_count: TokenCount = data.assistant_messages.into(); - let user_token_count: TokenCount = data.user_messages.into(); - let total_token_used: TokenCount = - (data.context_messages + data.user_messages + data.assistant_messages).into(); - - let window_width = self.terminal_width(); - // set a max width for the progress bar for better aesthetic - let progress_bar_width = std::cmp::min(window_width, 80); - - let context_width = ((context_token_count.value() as f64 / CONTEXT_WINDOW_SIZE as f64) - * progress_bar_width as f64) as usize; - let assistant_width = ((assistant_token_count.value() as f64 / CONTEXT_WINDOW_SIZE as f64) - * progress_bar_width as f64) as usize; - let user_width = ((user_token_count.value() as f64 / CONTEXT_WINDOW_SIZE as f64) - * progress_bar_width as f64) as usize; - - let left_over_width = progress_bar_width - - std::cmp::min(context_width + assistant_width + user_width, progress_bar_width); - - queue!( - self.output, - style::Print(format!( - "\nCurrent context window ({} of {}k tokens used)\n", - total_token_used, - CONTEXT_WINDOW_SIZE / 1000 - )), - style::SetForegroundColor(Color::DarkCyan), - // add a nice visual to mimic "tiny" progress, so the overral progress bar doesn't look too - // empty - style::Print("|".repeat(if context_width == 0 && *context_token_count > 0 { - 1 - } else { - 0 - })), - style::Print("█".repeat(context_width)), - style::SetForegroundColor(Color::Blue), - style::Print("|".repeat(if assistant_width == 0 && *assistant_token_count > 0 { - 1 - } else { - 0 - })), - style::Print("█".repeat(assistant_width)), - style::SetForegroundColor(Color::Magenta), - style::Print("|".repeat(if user_width == 0 && *user_token_count > 0 { 1 } else { 0 })), - style::Print("█".repeat(user_width)), - style::SetForegroundColor(Color::DarkGrey), - style::Print("█".repeat(left_over_width)), - style::Print(" "), - style::SetForegroundColor(Color::Reset), - style::Print(format!( - "{:.2}%", - (total_token_used.value() as f32 / CONTEXT_WINDOW_SIZE as f32) * 100.0 - )), - )?; - - queue!(self.output, style::Print("\n\n"))?; - self.output.flush()?; - - queue!( - self.output, - style::SetForegroundColor(Color::DarkCyan), - style::Print("█ Context files: "), - style::SetForegroundColor(Color::Reset), - style::Print(format!( - "~{} tokens ({:.2}%)\n", - context_token_count, - (context_token_count.value() as f32 / CONTEXT_WINDOW_SIZE as f32) * 100.0 - )), - style::SetForegroundColor(Color::Blue), - style::Print("█ Q responses: "), - style::SetForegroundColor(Color::Reset), - style::Print(format!( - " ~{} tokens ({:.2}%)\n", - assistant_token_count, - (assistant_token_count.value() as f32 / CONTEXT_WINDOW_SIZE as f32) * 100.0 - )), - style::SetForegroundColor(Color::Magenta), - style::Print("█ Your prompts: "), - style::SetForegroundColor(Color::Reset), - style::Print(format!( - " ~{} tokens ({:.2}%)\n\n", - user_token_count, - (user_token_count.value() as f32 / CONTEXT_WINDOW_SIZE as f32) * 100.0 - )), - )?; - - queue!( - self.output, - style::SetAttribute(Attribute::Bold), - style::Print("\n💡 Pro Tips:\n"), - style::SetAttribute(Attribute::Reset), - style::SetForegroundColor(Color::DarkGrey), - style::Print("Run "), - style::SetForegroundColor(Color::DarkGreen), - style::Print("/compact"), - style::SetForegroundColor(Color::DarkGrey), - style::Print(" to replace the conversation history with its summary\n"), - style::Print("Run "), - style::SetForegroundColor(Color::DarkGreen), - style::Print("/clear"), - style::SetForegroundColor(Color::DarkGrey), - style::Print(" to erase the entire chat history\n"), - style::Print("Run "), - style::SetForegroundColor(Color::DarkGreen), - style::Print("/context show"), - style::SetForegroundColor(Color::DarkGrey), - style::Print(" to see tokens per context file\n\n"), - style::SetForegroundColor(Color::Reset), - )?; - - ChatState::PromptUser { - tool_uses: Some(tool_uses), - pending_tool_index, - skip_printing_tools: true, - } - }, - }) - } - - async fn tool_use_execute(&mut self, mut tool_uses: Vec) -> Result { - // Verify tools have permissions. - for (index, tool) in tool_uses.iter_mut().enumerate() { - // Manually accepted by the user or otherwise verified already. - if tool.accepted { - continue; - } - - // If there is an override, we will use it. Otherwise fall back to Tool's default. - let allowed = if self.tool_permissions.has(&tool.name) { - self.tool_permissions.is_trusted(&tool.name) - } else { - !tool.tool.requires_acceptance(&self.ctx) - }; - - if self.settings.get_bool_or("chat.enableNotifications", false) { - play_notification_bell(!allowed); - } - - self.print_tool_descriptions(tool, allowed).await?; - - if allowed { - tool.accepted = true; - continue; - } - - let pending_tool_index = Some(index); - if !self.interactive { - // Cannot request in non-interactive, so fail. - return Err(ChatError::NonInteractiveToolApproval); - } - - return Ok(ChatState::PromptUser { - tool_uses: Some(tool_uses), - pending_tool_index, - skip_printing_tools: false, - }); - } - - // Execute the requested tools. - let mut tool_results = vec![]; - - for tool in tool_uses { - let mut tool_telemetry = self.tool_use_telemetry_events.entry(tool.id.clone()); - tool_telemetry = tool_telemetry.and_modify(|ev| ev.is_accepted = true); - - let tool_start = std::time::Instant::now(); - let invoke_result = tool.tool.invoke(&self.ctx, &mut self.output).await; - - if self.interactive && self.spinner.is_some() { - queue!( - self.output, - terminal::Clear(terminal::ClearType::CurrentLine), - cursor::MoveToColumn(0), - cursor::Show - )?; - } - execute!(self.output, style::Print("\n"))?; - - let tool_time = std::time::Instant::now().duration_since(tool_start); - if let Tool::Custom(ct) = &tool.tool { - tool_telemetry = tool_telemetry.and_modify(|ev| { - ev.custom_tool_call_latency = Some(tool_time.as_secs() as usize); - ev.input_token_size = Some(ct.get_input_token_size()); - ev.is_custom_tool = true; - }); - } - let tool_time = format!("{}.{}", tool_time.as_secs(), tool_time.subsec_millis()); - - match invoke_result { - Ok(result) => { - debug!("tool result output: {:#?}", result); - execute!( - self.output, - style::Print(CONTINUATION_LINE), - style::Print("\n"), - style::SetForegroundColor(Color::Green), - style::SetAttribute(Attribute::Bold), - style::Print(format!(" ● Completed in {}s", tool_time)), - style::SetForegroundColor(Color::Reset), - style::Print("\n"), - )?; - - tool_telemetry = tool_telemetry.and_modify(|ev| ev.is_success = Some(true)); - if let Tool::Custom(_) = &tool.tool { - tool_telemetry - .and_modify(|ev| ev.output_token_size = Some(TokenCounter::count_tokens(result.as_str()))); - } - tool_results.push(ToolUseResult { - tool_use_id: tool.id, - content: vec![result.into()], - status: ToolResultStatus::Success, - }); - }, - Err(err) => { - error!(?err, "An error occurred processing the tool"); - execute!( - self.output, - style::Print(CONTINUATION_LINE), - style::Print("\n"), - style::SetAttribute(Attribute::Bold), - style::SetForegroundColor(Color::Red), - style::Print(format!(" ● Execution failed after {}s:\n", tool_time)), - style::SetAttribute(Attribute::Reset), - style::SetForegroundColor(Color::Red), - style::Print(&err), - style::SetAttribute(Attribute::Reset), - style::Print("\n\n"), - )?; - - tool_telemetry.and_modify(|ev| ev.is_success = Some(false)); - tool_results.push(ToolUseResult { - tool_use_id: tool.id, - content: vec![ToolUseResultBlock::Text(format!( - "An error occurred processing the tool: \n{}", - &err - ))], - status: ToolResultStatus::Error, - }); - if let ToolUseStatus::Idle = self.tool_use_status { - self.tool_use_status = ToolUseStatus::RetryInProgress( - self.conversation_state - .message_id() - .map_or("No utterance id found".to_string(), |v| v.to_string()), - ); - } - }, - } - } - - self.conversation_state.add_tool_results(tool_results); - - self.send_tool_use_telemetry().await; - return Ok(ChatState::HandleResponseStream( - self.client - .send_message(self.conversation_state.as_sendable_conversation_state(false).await) - .await?, - )); - } - - async fn handle_response(&mut self, response: SendMessageOutput) -> Result { - let request_id = response.request_id().map(|s| s.to_string()); - let mut buf = String::new(); - let mut offset = 0; - let mut ended = false; - let mut parser = ResponseParser::new(response); - let mut state = ParseState::new(Some(self.terminal_width())); - - let mut tool_uses = Vec::new(); - let mut tool_name_being_recvd: Option = None; - - loop { - match parser.recv().await { - Ok(msg_event) => { - trace!("Consumed: {:?}", msg_event); - match msg_event { - parser::ResponseEvent::ToolUseStart { name } => { - // We need to flush the buffer here, otherwise text will not be - // printed while we are receiving tool use events. - buf.push('\n'); - tool_name_being_recvd = Some(name); - }, - parser::ResponseEvent::AssistantText(text) => { - buf.push_str(&text); - }, - parser::ResponseEvent::ToolUse(tool_use) => { - if self.interactive && self.spinner.is_some() { - drop(self.spinner.take()); - queue!( - self.output, - terminal::Clear(terminal::ClearType::CurrentLine), - cursor::MoveToColumn(0), - cursor::Show - )?; - } - tool_uses.push(tool_use); - tool_name_being_recvd = None; - }, - parser::ResponseEvent::EndStream { message } => { - // This log is attempting to help debug instances where users encounter - // the response timeout message. - if message.content() == RESPONSE_TIMEOUT_CONTENT { - error!(?request_id, ?message, "Encountered an unexpected model response"); - } - self.conversation_state.push_assistant_message(message); - ended = true; - }, - } - }, - Err(recv_error) => { - if let Some(request_id) = &recv_error.request_id { - self.failed_request_ids.push(request_id.clone()); - }; - - match recv_error.source { - RecvErrorKind::StreamTimeout { source, duration } => { - error!( - recv_error.request_id, - ?source, - "Encountered a stream timeout after waiting for {}s", - duration.as_secs() - ); - if self.interactive { - execute!(self.output, cursor::Hide)?; - self.spinner = - Some(Spinner::new(Spinners::Dots, "Dividing up the work...".to_string())); - } - // For stream timeouts, we'll tell the model to try and split its response into - // smaller chunks. - self.conversation_state - .push_assistant_message(AssistantMessage::new_response( - None, - RESPONSE_TIMEOUT_CONTENT.to_string(), - )); - self.conversation_state - .set_next_user_message( - "You took too long to respond - try to split up the work into smaller steps." - .to_string(), - ) - .await; - self.send_tool_use_telemetry().await; - return Ok(ChatState::HandleResponseStream( - self.client - .send_message(self.conversation_state.as_sendable_conversation_state(false).await) - .await?, - )); - }, - RecvErrorKind::UnexpectedToolUseEos { - tool_use_id, - name, - message, - time_elapsed, - } => { - error!( - recv_error.request_id, - tool_use_id, name, "The response stream ended before the entire tool use was received" - ); - if self.interactive { - drop(self.spinner.take()); - queue!( - self.output, - terminal::Clear(terminal::ClearType::CurrentLine), - cursor::MoveToColumn(0), - style::SetForegroundColor(Color::Yellow), - style::SetAttribute(Attribute::Bold), - style::Print(format!( - "Warning: received an unexpected error from the model after {:.2}s", - time_elapsed.as_secs_f64() - )), - )?; - if let Some(request_id) = recv_error.request_id { - queue!( - self.output, - style::Print(format!("\n request_id: {}", request_id)) - )?; - } - execute!(self.output, style::Print("\n\n"), style::SetAttribute(Attribute::Reset))?; - self.spinner = Some(Spinner::new( - Spinners::Dots, - "Trying to divide up the work...".to_string(), - )); - } - - self.conversation_state.push_assistant_message(*message); - let tool_results = vec![ToolUseResult { - tool_use_id, - content: vec![ToolUseResultBlock::Text( - "The generated tool was too large, try again but this time split up the work between multiple tool uses".to_string(), - )], - status: ToolResultStatus::Error, - }]; - self.conversation_state.add_tool_results(tool_results); - self.send_tool_use_telemetry().await; - return Ok(ChatState::HandleResponseStream( - self.client - .send_message(self.conversation_state.as_sendable_conversation_state(false).await) - .await?, - )); - }, - _ => return Err(recv_error.into()), - } - }, - } - - // Fix for the markdown parser copied over from q chat: - // this is a hack since otherwise the parser might report Incomplete with useful data - // still left in the buffer. I'm not sure how this is intended to be handled. - if ended { - buf.push('\n'); - } - - if tool_name_being_recvd.is_none() && !buf.is_empty() && self.interactive && self.spinner.is_some() { - drop(self.spinner.take()); - queue!( - self.output, - terminal::Clear(terminal::ClearType::CurrentLine), - cursor::MoveToColumn(0), - cursor::Show - )?; - } - - // Print the response for normal cases - loop { - let input = Partial::new(&buf[offset..]); - match interpret_markdown(input, &mut self.output, &mut state) { - Ok(parsed) => { - offset += parsed.offset_from(&input); - self.output.flush()?; - state.newline = state.set_newline; - state.set_newline = false; - }, - Err(err) => match err.into_inner() { - Some(err) => return Err(ChatError::Custom(err.to_string().into())), - None => break, // Data was incomplete - }, - } - - // TODO: We should buffer output based on how much we have to parse, not as a constant - // Do not remove unless you are nabochay :) - std::thread::sleep(Duration::from_millis(8)); - } - - // Set spinner after showing all of the assistant text content so far. - if let (Some(_name), true) = (&tool_name_being_recvd, self.interactive) { - queue!(self.output, cursor::Hide)?; - self.spinner = Some(Spinner::new(Spinners::Dots, "Thinking...".to_string())); - } - - if ended { - if let Some(message_id) = self.conversation_state.message_id() { - fig_telemetry::send_chat_added_message( - self.conversation_state.conversation_id().to_owned(), - message_id.to_owned(), - self.conversation_state.context_message_length(), - ) - .await; - } - - if self.interactive && self.settings.get_bool_or("chat.enableNotifications", false) { - // For final responses (no tools suggested), always play the bell - play_notification_bell(tool_uses.is_empty()); - } - - if self.interactive { - queue!(self.output, style::ResetColor, style::SetAttribute(Attribute::Reset))?; - execute!(self.output, style::Print("\n"))?; - - for (i, citation) in &state.citations { - queue!( - self.output, - style::Print("\n"), - style::SetForegroundColor(Color::Blue), - style::Print(format!("[^{i}]: ")), - style::SetForegroundColor(Color::DarkGrey), - style::Print(format!("{citation}\n")), - style::SetForegroundColor(Color::Reset) - )?; - } - } - - break; - } - } - - if !tool_uses.is_empty() { - Ok(ChatState::ValidateTools(tool_uses)) - } else { - Ok(ChatState::PromptUser { - tool_uses: None, - pending_tool_index: None, - skip_printing_tools: false, - }) - } - } - - async fn validate_tools(&mut self, tool_uses: Vec) -> Result { - let conv_id = self.conversation_state.conversation_id().to_owned(); - debug!(?tool_uses, "Validating tool uses"); - let mut queued_tools: Vec = Vec::new(); - let mut tool_results: Vec = Vec::new(); - - for tool_use in tool_uses { - let tool_use_id = tool_use.id.clone(); - let tool_use_name = tool_use.name.clone(); - let mut tool_telemetry = ToolUseEventBuilder::new(conv_id.clone(), tool_use.id.clone()) - .set_tool_use_id(tool_use_id.clone()) - .set_tool_name(tool_use.name.clone()) - .utterance_id(self.conversation_state.message_id().map(|s| s.to_string())); - match self.tool_manager.get_tool_from_tool_use(tool_use) { - Ok(mut tool) => { - // Apply non-Q-generated context to tools - self.contextualize_tool(&mut tool); - - match tool.validate(&self.ctx).await { - Ok(()) => { - tool_telemetry.is_valid = Some(true); - queued_tools.push(QueuedTool { - id: tool_use_id.clone(), - name: tool_use_name, - tool, - accepted: false, - }); - }, - Err(err) => { - tool_telemetry.is_valid = Some(false); - tool_results.push(ToolUseResult { - tool_use_id: tool_use_id.clone(), - content: vec![ToolUseResultBlock::Text(format!( - "Failed to validate tool parameters: {err}" - ))], - status: ToolResultStatus::Error, - }); - }, - }; - }, - Err(err) => { - tool_telemetry.is_valid = Some(false); - tool_results.push(err.into()); - }, - } - self.tool_use_telemetry_events.insert(tool_use_id, tool_telemetry); - } - - // If we have any validation errors, then return them immediately to the model. - if !tool_results.is_empty() { - debug!(?tool_results, "Error found in the model tools"); - queue!( - self.output, - style::SetAttribute(Attribute::Bold), - style::Print("Tool validation failed: "), - style::SetAttribute(Attribute::Reset), - )?; - for tool_result in &tool_results { - for block in &tool_result.content { - let content: Option> = match block { - ToolUseResultBlock::Text(t) => Some(t.as_str().into()), - ToolUseResultBlock::Json(d) => serde_json::to_string(d) - .map_err(|err| error!(?err, "failed to serialize tool result content")) - .map(Into::into) - .ok(), - }; - if let Some(content) = content { - queue!( - self.output, - style::Print("\n"), - style::SetForegroundColor(Color::Red), - style::Print(format!("{}\n", content)), - style::SetForegroundColor(Color::Reset), - )?; - } - } - } - self.conversation_state.add_tool_results(tool_results); - self.send_tool_use_telemetry().await; - if let ToolUseStatus::Idle = self.tool_use_status { - self.tool_use_status = ToolUseStatus::RetryInProgress( - self.conversation_state - .message_id() - .map_or("No utterance id found".to_string(), |v| v.to_string()), - ); - } - - let response = self - .client - .send_message(self.conversation_state.as_sendable_conversation_state(false).await) - .await?; - return Ok(ChatState::HandleResponseStream(response)); - } - - Ok(ChatState::ExecuteTools(queued_tools)) - } - - /// Apply program context to tools that Q may not have. - // We cannot attach this any other way because Tools are constructed by deserializing - // output from Amazon Q. - // TODO: Is there a better way? - fn contextualize_tool(&self, tool: &mut Tool) { - #[allow(clippy::single_match)] - match tool { - Tool::GhIssue(gh_issue) => { - gh_issue.set_context(GhIssueContext { - // Ideally we avoid cloning, but this function is not called very often. - // Using references with lifetimes requires a large refactor, and Arc> - // seems like overkill and may incur some performance cost anyway. - context_manager: self.conversation_state.context_manager.clone(), - transcript: self.conversation_state.transcript.clone(), - failed_request_ids: self.failed_request_ids.clone(), - tool_permissions: self.tool_permissions.permissions.clone(), - interactive: self.interactive, - }); - }, - _ => (), - }; - } - - async fn print_tool_descriptions(&mut self, tool_use: &QueuedTool, trusted: bool) -> Result<(), ChatError> { - queue!( - self.output, - style::SetForegroundColor(Color::Magenta), - style::Print(format!( - "🛠️ Using tool: {}{}", - tool_use.tool.display_name(), - if trusted { " (trusted)".dark_green() } else { "".reset() } - )), - style::SetForegroundColor(Color::Reset) - )?; - if let Tool::Custom(ref tool) = tool_use.tool { - queue!( - self.output, - style::SetForegroundColor(Color::Reset), - style::Print(" from mcp server "), - style::SetForegroundColor(Color::Magenta), - style::Print(tool.client.get_server_name()), - style::SetForegroundColor(Color::Reset), - )?; - } - queue!(self.output, style::Print("\n"), style::Print(CONTINUATION_LINE))?; - queue!(self.output, style::Print("\n"))?; - queue!(self.output, style::Print(TOOL_BULLET))?; - - self.output.flush()?; - - tool_use - .tool - .queue_description(&self.ctx, &mut self.output) - .await - .map_err(|e| ChatError::Custom(format!("failed to print tool, `{}`: {}", tool_use.name, e).into()))?; - - Ok(()) - } - - /// Helper function to read user input with a prompt and Ctrl+C handling - fn read_user_input(&mut self, prompt: &str, exit_on_single_ctrl_c: bool) -> Option { - let mut ctrl_c = false; - loop { - match (self.input_source.read_line(Some(prompt)), ctrl_c) { - (Ok(Some(line)), _) => { - if line.trim().is_empty() { - continue; // Reprompt if the input is empty - } - return Some(line); - }, - (Ok(None), false) => { - if exit_on_single_ctrl_c { - return None; - } - execute!( - self.output, - style::Print(format!( - "\n(To exit the CLI, press Ctrl+C or Ctrl+D again or type {})\n\n", - "/quit".green() - )) - ) - .unwrap_or_default(); - ctrl_c = true; - }, - (Ok(None), true) => return None, // Exit if Ctrl+C was pressed twice - (Err(_), _) => return None, - } - } - } - - /// Helper function to generate a prompt based on the current context - fn generate_tool_trust_prompt(&self) -> String { - prompt::generate_prompt(self.conversation_state.current_profile(), self.all_tools_trusted()) - } - - async fn send_tool_use_telemetry(&mut self) { - for (_, mut event) in self.tool_use_telemetry_events.drain() { - event.user_input_id = match self.tool_use_status { - ToolUseStatus::Idle => self.conversation_state.message_id(), - ToolUseStatus::RetryInProgress(ref id) => Some(id.as_str()), - } - .map(|v| v.to_string()); - let event: fig_telemetry::EventType = event.into(); - let app_event = fig_telemetry::AppTelemetryEvent::new(event).await; - fig_telemetry::dispatch_or_send_event(app_event).await; - } - } - - fn terminal_width(&self) -> usize { - (self.terminal_width_provider)().unwrap_or(80) - } - - fn all_tools_trusted(&self) -> bool { - self.conversation_state.tools.values().flatten().all(|t| match t { - FigTool::ToolSpecification(t) => self.tool_permissions.is_trusted(&t.name), - }) - } - - /// Display character limit warnings based on current conversation size - async fn display_char_warnings(&mut self) -> Result<(), std::io::Error> { - let warning_level = self.conversation_state.get_token_warning_level().await; - - match warning_level { - TokenWarningLevel::Critical => { - // Memory constraint warning with gentler wording - execute!( - self.output, - style::SetForegroundColor(Color::Yellow), - style::SetAttribute(Attribute::Bold), - style::Print("\n⚠️ This conversation is getting lengthy.\n"), - style::SetAttribute(Attribute::Reset), - style::Print( - "To ensure continued smooth operation, please use /compact to summarize the conversation.\n\n" - ), - style::SetForegroundColor(Color::Reset) - )?; - }, - TokenWarningLevel::None => { - // No warning needed - }, - } - - Ok(()) - } -} - -#[derive(Debug)] -struct ToolUseEventBuilder { - pub conversation_id: String, - pub utterance_id: Option, - pub user_input_id: Option, - pub tool_use_id: Option, - pub tool_name: Option, - pub is_accepted: bool, - pub is_success: Option, - pub is_valid: Option, - pub is_custom_tool: bool, - pub input_token_size: Option, - pub output_token_size: Option, - pub custom_tool_call_latency: Option, -} - -impl ToolUseEventBuilder { - pub fn new(conv_id: String, tool_use_id: String) -> Self { - Self { - conversation_id: conv_id, - utterance_id: None, - user_input_id: None, - tool_use_id: Some(tool_use_id), - tool_name: None, - is_accepted: false, - is_success: None, - is_valid: None, - is_custom_tool: false, - input_token_size: None, - output_token_size: None, - custom_tool_call_latency: None, - } - } - - pub fn utterance_id(mut self, id: Option) -> Self { - self.utterance_id = id; - self - } - - pub fn set_tool_use_id(mut self, id: String) -> Self { - self.tool_use_id.replace(id); - self - } - - pub fn set_tool_name(mut self, name: String) -> Self { - self.tool_name.replace(name); - self - } -} - -impl From for fig_telemetry::EventType { - fn from(val: ToolUseEventBuilder) -> Self { - fig_telemetry::EventType::ToolUseSuggested { - conversation_id: val.conversation_id, - utterance_id: val.utterance_id, - user_input_id: val.user_input_id, - tool_use_id: val.tool_use_id, - tool_name: val.tool_name, - is_accepted: val.is_accepted, - is_success: val.is_success, - is_valid: val.is_valid, - is_custom_tool: val.is_custom_tool, - input_token_size: val.input_token_size, - output_token_size: val.output_token_size, - custom_tool_call_latency: val.custom_tool_call_latency, - } - } -} - -/// Testing helper -fn split_tool_use_event(value: &Map) -> Vec { - let tool_use_id = value.get("tool_use_id").unwrap().as_str().unwrap().to_string(); - let name = value.get("name").unwrap().as_str().unwrap().to_string(); - let args_str = value.get("args").unwrap().to_string(); - let split_point = args_str.len() / 2; - vec![ - ChatResponseStream::ToolUseEvent { - tool_use_id: tool_use_id.clone(), - name: name.clone(), - input: None, - stop: None, - }, - ChatResponseStream::ToolUseEvent { - tool_use_id: tool_use_id.clone(), - name: name.clone(), - input: Some(args_str.split_at(split_point).0.to_string()), - stop: None, - }, - ChatResponseStream::ToolUseEvent { - tool_use_id: tool_use_id.clone(), - name: name.clone(), - input: Some(args_str.split_at(split_point).1.to_string()), - stop: None, - }, - ChatResponseStream::ToolUseEvent { - tool_use_id: tool_use_id.clone(), - name: name.clone(), - input: None, - stop: Some(true), - }, - ] -} - -/// Testing helper -fn create_stream(model_responses: serde_json::Value) -> StreamingClient { - let mut mock = Vec::new(); - for response in model_responses.as_array().unwrap() { - let mut stream = Vec::new(); - for event in response.as_array().unwrap() { - match event { - serde_json::Value::String(assistant_text) => { - stream.push(ChatResponseStream::AssistantResponseEvent { - content: assistant_text.to_string(), - }); - }, - serde_json::Value::Object(tool_use) => { - stream.append(&mut split_tool_use_event(tool_use)); - }, - other => panic!("Unexpected value: {:?}", other), - } - } - mock.push(stream); - } - StreamingClient::mock(mock) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_flow() { - let _ = tracing_subscriber::fmt::try_init(); - let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); - let test_client = create_stream(serde_json::json!([ - [ - "Sure, I'll create a file for you", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file.txt", - } - } - ], - [ - "Hope that looks good to you!", - ], - ])); - - let tool_manager = ToolManager::default(); - let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) - .expect("Tools failed to load"); - ChatContext::new( - Arc::clone(&ctx), - "fake_conv_id", - Settings::new_fake(), - State::new_fake(), - SharedWriter::stdout(), - None, - InputSource::new_mock(vec![ - "create a new file".to_string(), - "y".to_string(), - "exit".to_string(), - ]), - true, - test_client, - || Some(80), - tool_manager, - None, - tool_config, - ToolPermissions::new(0), - ) - .await - .unwrap() - .try_chat() - .await - .unwrap(); - - assert_eq!(ctx.fs().read_to_string("/file.txt").await.unwrap(), "Hello, world!\n"); - } - - #[tokio::test] - async fn test_flow_tool_permissions() { - let _ = tracing_subscriber::fmt::try_init(); - let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); - let test_client = create_stream(serde_json::json!([ - [ - "Ok", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file1.txt", - } - } - ], - [ - "Done", - ], - [ - "Ok", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file2.txt", - } - } - ], - [ - "Done", - ], - [ - "Ok", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file3.txt", - } - } - ], - [ - "Done", - ], - [ - "Ok", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file4.txt", - } - } - ], - [ - "Ok, I won't make it.", - ], - [ - "Ok", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file5.txt", - } - } - ], - [ - "Done", - ], - [ - "Ok", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file6.txt", - } - } - ], - [ - "Ok, I won't make it.", - ], - ])); - - let tool_manager = ToolManager::default(); - let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) - .expect("Tools failed to load"); - ChatContext::new( - Arc::clone(&ctx), - "fake_conv_id", - Settings::new_fake(), - State::new_fake(), - SharedWriter::stdout(), - None, - InputSource::new_mock(vec![ - "/tools".to_string(), - "/tools help".to_string(), - "create a new file".to_string(), - "y".to_string(), - "create a new file".to_string(), - "t".to_string(), - "create a new file".to_string(), // should make without prompting due to 't' - "/tools untrust fs_write".to_string(), - "create a file".to_string(), // prompt again due to untrust - "n".to_string(), // cancel - "/tools trust fs_write".to_string(), - "create a file".to_string(), // again without prompting due to '/tools trust' - "/tools reset".to_string(), - "create a file".to_string(), // prompt again due to reset - "n".to_string(), // cancel - "exit".to_string(), - ]), - true, - test_client, - || Some(80), - tool_manager, - None, - tool_config, - ToolPermissions::new(0), - ) - .await - .unwrap() - .try_chat() - .await - .unwrap(); - - assert_eq!(ctx.fs().read_to_string("/file2.txt").await.unwrap(), "Hello, world!\n"); - assert_eq!(ctx.fs().read_to_string("/file3.txt").await.unwrap(), "Hello, world!\n"); - assert!(!ctx.fs().exists("/file4.txt")); - assert_eq!(ctx.fs().read_to_string("/file5.txt").await.unwrap(), "Hello, world!\n"); - assert!(!ctx.fs().exists("/file6.txt")); - } - - #[tokio::test] - async fn test_flow_multiple_tools() { - let _ = tracing_subscriber::fmt::try_init(); - let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); - let test_client = create_stream(serde_json::json!([ - [ - "Sure, I'll create a file for you", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file1.txt", - } - }, - { - "tool_use_id": "2", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file2.txt", - } - } - ], - [ - "Done", - ], - [ - "Sure, I'll create a file for you", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file3.txt", - } - }, - { - "tool_use_id": "2", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file4.txt", - } - } - ], - [ - "Done", - ], - ])); - - let tool_manager = ToolManager::default(); - let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) - .expect("Tools failed to load"); - ChatContext::new( - Arc::clone(&ctx), - "fake_conv_id", - Settings::new_fake(), - State::new_fake(), - SharedWriter::stdout(), - None, - InputSource::new_mock(vec![ - "create 2 new files parallel".to_string(), - "t".to_string(), - "/tools reset".to_string(), - "create 2 new files parallel".to_string(), - "y".to_string(), - "y".to_string(), - "exit".to_string(), - ]), - true, - test_client, - || Some(80), - tool_manager, - None, - tool_config, - ToolPermissions::new(0), - ) - .await - .unwrap() - .try_chat() - .await - .unwrap(); - - assert_eq!(ctx.fs().read_to_string("/file1.txt").await.unwrap(), "Hello, world!\n"); - assert_eq!(ctx.fs().read_to_string("/file2.txt").await.unwrap(), "Hello, world!\n"); - assert_eq!(ctx.fs().read_to_string("/file3.txt").await.unwrap(), "Hello, world!\n"); - assert_eq!(ctx.fs().read_to_string("/file4.txt").await.unwrap(), "Hello, world!\n"); - } - - #[tokio::test] - async fn test_flow_tools_trust_all() { - let _ = tracing_subscriber::fmt::try_init(); - let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); - let test_client = create_stream(serde_json::json!([ - [ - "Sure, I'll create a file for you", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file1.txt", - } - } - ], - [ - "Done", - ], - [ - "Sure, I'll create a file for you", - { - "tool_use_id": "1", - "name": "fs_write", - "args": { - "command": "create", - "file_text": "Hello, world!", - "path": "/file3.txt", - } - } - ], - [ - "Ok I won't.", - ], - ])); - - let tool_manager = ToolManager::default(); - let tool_config = serde_json::from_str::>(include_str!("tools/tool_index.json")) - .expect("Tools failed to load"); - ChatContext::new( - Arc::clone(&ctx), - "fake_conv_id", - Settings::new_fake(), - State::new_fake(), - SharedWriter::stdout(), - None, - InputSource::new_mock(vec![ - "/tools trustall".to_string(), - "create a new file".to_string(), - "/tools reset".to_string(), - "create a new file".to_string(), - "exit".to_string(), - ]), - true, - test_client, - || Some(80), - tool_manager, - None, - tool_config, - ToolPermissions::new(0), - ) - .await - .unwrap() - .try_chat() - .await - .unwrap(); - - assert_eq!(ctx.fs().read_to_string("/file1.txt").await.unwrap(), "Hello, world!\n"); - assert!(!ctx.fs().exists("/file2.txt")); - } - - #[test] - fn test_editor_content_processing() { - // Since we no longer have template replacement, this test is simplified - let cases = vec![ - ("My content", "My content"), - ("My content with newline\n", "My content with newline"), - ("", ""), - ]; - - for (input, expected) in cases { - let processed = input.trim().to_string(); - assert_eq!(processed, expected.trim().to_string(), "Failed for input: {}", input); - } - } -} diff --git a/crates/q_chat/src/message.rs b/crates/q_chat/src/message.rs deleted file mode 100644 index 3a91d361f6..0000000000 --- a/crates/q_chat/src/message.rs +++ /dev/null @@ -1,407 +0,0 @@ -use std::env; - -use fig_api_client::model::{ - AssistantResponseMessage, - EnvState, - ShellState, - ToolResult, - ToolResultContentBlock, - ToolResultStatus, - ToolUse, - UserInputMessage, - UserInputMessageContext, -}; -use fig_util::Shell; -use serde::{ - Deserialize, - Serialize, -}; -use tracing::error; - -use super::consts::MAX_CURRENT_WORKING_DIRECTORY_LEN; -use super::tools::{ - InvokeOutput, - OutputKind, - document_to_serde_value, - serde_value_to_document, -}; -use super::util::truncate_safe; - -const USER_ENTRY_START_HEADER: &str = "--- USER MESSAGE BEGIN ---\n"; -const USER_ENTRY_END_HEADER: &str = "--- USER MESSAGE END ---\n\n"; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct UserMessage { - pub additional_context: String, - pub env_context: UserEnvContext, - pub content: UserMessageContent, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum UserMessageContent { - Prompt { - /// The original prompt as input by the user. - prompt: String, - }, - CancelledToolUses { - /// The original prompt as input by the user, if any. - prompt: Option, - tool_use_results: Vec, - }, - ToolUseResults { - tool_use_results: Vec, - }, -} - -impl UserMessage { - /// Creates a new [UserMessage::Prompt], automatically detecting and adding the user's - /// environment [UserEnvContext]. - pub fn new_prompt(prompt: String) -> Self { - Self { - additional_context: String::new(), - env_context: UserEnvContext::generate_new(), - content: UserMessageContent::Prompt { prompt }, - } - } - - pub fn new_cancelled_tool_uses<'a>(prompt: Option, tool_use_ids: impl Iterator) -> Self { - Self { - additional_context: String::new(), - env_context: UserEnvContext::generate_new(), - content: UserMessageContent::CancelledToolUses { - prompt, - tool_use_results: tool_use_ids - .map(|id| ToolUseResult { - tool_use_id: id.to_string(), - content: vec![ToolUseResultBlock::Text( - "Tool use was cancelled by the user".to_string(), - )], - status: ToolResultStatus::Error, - }) - .collect(), - }, - } - } - - pub fn new_tool_use_results(results: Vec) -> Self { - Self { - additional_context: String::new(), - env_context: UserEnvContext::generate_new(), - content: UserMessageContent::ToolUseResults { - tool_use_results: results, - }, - } - } - - /// Converts this message into a [UserInputMessage] to be stored in the history of - /// [fig_api_client::model::ConversationState]. - pub fn into_history_entry(self) -> UserInputMessage { - UserInputMessage { - content: self.prompt().unwrap_or_default().to_string(), - user_input_message_context: Some(UserInputMessageContext { - shell_state: self.env_context.shell_state, - env_state: self.env_context.env_state, - tool_results: match self.content { - UserMessageContent::CancelledToolUses { tool_use_results, .. } - | UserMessageContent::ToolUseResults { tool_use_results } => { - Some(tool_use_results.into_iter().map(Into::into).collect()) - }, - UserMessageContent::Prompt { .. } => None, - }, - tools: None, - ..Default::default() - }), - user_intent: None, - } - } - - /// Converts this message into a [UserInputMessage] to be sent as - /// [FigConversationState::user_input_message]. - pub fn into_user_input_message(self) -> UserInputMessage { - let formatted_prompt = match self.prompt() { - Some(prompt) if !prompt.is_empty() => { - format!("{}{}{}", USER_ENTRY_START_HEADER, prompt, USER_ENTRY_END_HEADER) - }, - _ => String::new(), - }; - UserInputMessage { - content: format!("{} {}", self.additional_context, formatted_prompt) - .trim() - .to_string(), - user_input_message_context: Some(UserInputMessageContext { - shell_state: self.env_context.shell_state, - env_state: self.env_context.env_state, - tool_results: match self.content { - UserMessageContent::CancelledToolUses { tool_use_results, .. } - | UserMessageContent::ToolUseResults { tool_use_results } => { - Some(tool_use_results.into_iter().map(Into::into).collect()) - }, - UserMessageContent::Prompt { .. } => None, - }, - tools: None, - ..Default::default() - }), - user_intent: None, - } - } - - pub fn has_tool_use_results(&self) -> bool { - match self.content() { - UserMessageContent::CancelledToolUses { .. } | UserMessageContent::ToolUseResults { .. } => true, - UserMessageContent::Prompt { .. } => false, - } - } - - pub fn tool_use_results(&self) -> Option<&[ToolUseResult]> { - match self.content() { - UserMessageContent::Prompt { .. } => None, - UserMessageContent::CancelledToolUses { tool_use_results, .. } => Some(tool_use_results.as_slice()), - UserMessageContent::ToolUseResults { tool_use_results } => Some(tool_use_results.as_slice()), - } - } - - pub fn additional_context(&self) -> &str { - &self.additional_context - } - - pub fn content(&self) -> &UserMessageContent { - &self.content - } - - pub fn prompt(&self) -> Option<&str> { - match self.content() { - UserMessageContent::Prompt { prompt } => Some(prompt.as_str()), - UserMessageContent::CancelledToolUses { prompt, .. } => prompt.as_ref().map(|s| s.as_str()), - UserMessageContent::ToolUseResults { .. } => None, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolUseResult { - /// The ID for the tool request. - pub tool_use_id: String, - /// Content of the tool result. - pub content: Vec, - /// Status of the tool result. - pub status: ToolResultStatus, -} - -impl From for ToolUseResult { - fn from(value: ToolResult) -> Self { - Self { - tool_use_id: value.tool_use_id, - content: value.content.into_iter().map(Into::into).collect(), - status: value.status, - } - } -} - -impl From for ToolResult { - fn from(value: ToolUseResult) -> Self { - Self { - tool_use_id: value.tool_use_id, - content: value.content.into_iter().map(Into::into).collect(), - status: value.status, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum ToolUseResultBlock { - Json(serde_json::Value), - Text(String), -} - -impl From for ToolResultContentBlock { - fn from(value: ToolUseResultBlock) -> Self { - match value { - ToolUseResultBlock::Json(v) => Self::Json(serde_value_to_document(v)), - ToolUseResultBlock::Text(s) => Self::Text(s), - } - } -} - -impl From for ToolUseResultBlock { - fn from(value: ToolResultContentBlock) -> Self { - match value { - ToolResultContentBlock::Json(v) => Self::Json(document_to_serde_value(v)), - ToolResultContentBlock::Text(s) => Self::Text(s), - } - } -} - -impl From for ToolUseResultBlock { - fn from(value: InvokeOutput) -> Self { - match value.output { - OutputKind::Text(text) => Self::Text(text), - OutputKind::Json(value) => Self::Json(value), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct UserEnvContext { - shell_state: Option, - env_state: Option, -} - -impl UserEnvContext { - pub fn generate_new() -> Self { - Self { - shell_state: Some(build_shell_state()), - env_state: Some(build_env_state()), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum AssistantMessage { - /// Normal response containing no tool uses. - Response { - message_id: Option, - content: String, - }, - /// An assistant message containing tool uses. - ToolUse { - message_id: Option, - content: String, - tool_uses: Vec, - }, -} - -impl AssistantMessage { - pub fn new_response(message_id: Option, content: String) -> Self { - Self::Response { message_id, content } - } - - pub fn new_tool_use(message_id: Option, content: String, tool_uses: Vec) -> Self { - Self::ToolUse { - message_id, - content, - tool_uses, - } - } - - pub fn message_id(&self) -> Option<&str> { - match self { - AssistantMessage::Response { message_id, .. } => message_id.as_ref().map(|s| s.as_str()), - AssistantMessage::ToolUse { message_id, .. } => message_id.as_ref().map(|s| s.as_str()), - } - } - - pub fn content(&self) -> &str { - match self { - AssistantMessage::Response { content, .. } => content.as_str(), - AssistantMessage::ToolUse { content, .. } => content.as_str(), - } - } - - pub fn tool_uses(&self) -> Option<&[AssistantToolUse]> { - match self { - AssistantMessage::ToolUse { tool_uses, .. } => Some(tool_uses.as_slice()), - AssistantMessage::Response { .. } => None, - } - } -} - -impl From for AssistantResponseMessage { - fn from(value: AssistantMessage) -> Self { - let (message_id, content, tool_uses) = match value { - AssistantMessage::Response { message_id, content } => (message_id, content, None), - AssistantMessage::ToolUse { - message_id, - content, - tool_uses, - } => ( - message_id, - content, - Some(tool_uses.into_iter().map(Into::into).collect()), - ), - }; - Self { - message_id, - content, - tool_uses, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AssistantToolUse { - /// The ID for the tool request. - pub id: String, - /// The name for the tool. - pub name: String, - /// The input to pass to the tool. - pub args: serde_json::Value, -} - -impl From for ToolUse { - fn from(value: AssistantToolUse) -> Self { - Self { - tool_use_id: value.id, - name: value.name, - input: serde_value_to_document(value.args), - } - } -} - -impl From for AssistantToolUse { - fn from(value: ToolUse) -> Self { - Self { - id: value.tool_use_id, - name: value.name, - args: document_to_serde_value(value.input), - } - } -} - -pub fn build_env_state() -> EnvState { - let mut env_state = EnvState { - operating_system: Some(env::consts::OS.into()), - ..Default::default() - }; - - match env::current_dir() { - Ok(current_dir) => { - env_state.current_working_directory = - Some(truncate_safe(¤t_dir.to_string_lossy(), MAX_CURRENT_WORKING_DIRECTORY_LEN).into()); - }, - Err(err) => { - error!(?err, "Attempted to fetch the CWD but it did not exist."); - }, - } - - env_state -} - -fn build_shell_state() -> ShellState { - // Try to grab the shell from the parent process via the `Shell::current_shell`, - // then try the `SHELL` env, finally just report bash - let shell_name = Shell::current_shell() - .or_else(|| { - let shell_name = env::var("SHELL").ok()?; - Shell::try_find_shell(shell_name) - }) - .unwrap_or(Shell::Bash) - .to_string(); - - ShellState { - shell_name, - shell_history: None, - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_env_state() { - let env_state = build_env_state(); - assert!(env_state.current_working_directory.is_some()); - assert!(env_state.operating_system.as_ref().is_some_and(|os| !os.is_empty())); - println!("{env_state:?}"); - } -} diff --git a/crates/q_chat/src/parse.rs b/crates/q_chat/src/parse.rs deleted file mode 100644 index db3f0cf382..0000000000 --- a/crates/q_chat/src/parse.rs +++ /dev/null @@ -1,762 +0,0 @@ -use std::io::Write; - -use crossterm::style::{ - Attribute, - Color, - Stylize, -}; -use crossterm::{ - Command, - style, -}; -use unicode_width::{ - UnicodeWidthChar, - UnicodeWidthStr, -}; -use winnow::Partial; -use winnow::ascii::{ - self, - digit1, - space0, - space1, - till_line_ending, -}; -use winnow::combinator::{ - alt, - delimited, - preceded, - repeat, - terminated, -}; -use winnow::error::{ - ErrMode, - ErrorKind, - ParserError, -}; -use winnow::prelude::*; -use winnow::stream::{ - AsChar, - Stream, -}; -use winnow::token::{ - any, - take_till, - take_until, - take_while, -}; - -const CODE_COLOR: Color = Color::Green; -const HEADING_COLOR: Color = Color::Magenta; -const BLOCKQUOTE_COLOR: Color = Color::DarkGrey; -const URL_TEXT_COLOR: Color = Color::Blue; -const URL_LINK_COLOR: Color = Color::DarkGrey; - -const DEFAULT_RULE_WIDTH: usize = 40; - -#[derive(Debug, thiserror::Error)] -pub enum Error<'a> { - #[error(transparent)] - Stdio(#[from] std::io::Error), - #[error("parse error {1}, input {0}")] - Winnow(Partial<&'a str>, ErrorKind), -} - -impl<'a> ParserError> for Error<'a> { - fn from_error_kind(input: &Partial<&'a str>, kind: ErrorKind) -> Self { - Self::Winnow(*input, kind) - } - - fn append( - self, - _input: &Partial<&'a str>, - _checkpoint: &winnow::stream::Checkpoint< - winnow::stream::Checkpoint<&'a str, &'a str>, - winnow::Partial<&'a str>, - >, - _kind: ErrorKind, - ) -> Self { - self - } -} - -#[derive(Debug)] -pub struct ParseState { - pub terminal_width: Option, - pub column: usize, - pub in_codeblock: bool, - pub bold: bool, - pub italic: bool, - pub strikethrough: bool, - pub set_newline: bool, - pub newline: bool, - pub citations: Vec<(String, String)>, -} - -impl ParseState { - pub fn new(terminal_width: Option) -> Self { - Self { - terminal_width, - column: 0, - in_codeblock: false, - bold: false, - italic: false, - strikethrough: false, - set_newline: false, - newline: true, - citations: vec![], - } - } -} - -pub fn interpret_markdown<'a, 'b>( - mut i: Partial<&'a str>, - mut o: impl Write + 'b, - state: &mut ParseState, -) -> PResult, Error<'a>> { - let mut error: Option> = None; - let start = i.checkpoint(); - - macro_rules! stateful_alt { - ($($fns:ident),*) => { - $({ - i.reset(&start); - match $fns(&mut o, state).parse_next(&mut i) { - Err(ErrMode::Backtrack(e)) => { - error = match error { - Some(error) => Some(error.or(e)), - None => Some(e), - }; - }, - res => { - return res.map(|_| i); - } - } - })* - }; - } - - match state.in_codeblock { - false => { - stateful_alt!( - // This pattern acts as a short circuit for alphanumeric plaintext - // More importantly, it's needed to support manual wordwrapping - text, - // multiline patterns - blockquote, - // linted_codeblock, - codeblock_begin, - // single line patterns - horizontal_rule, - heading, - bulleted_item, - numbered_item, - // inline patterns - code, - citation, - url, - bold, - italic, - strikethrough, - // symbols - less_than, - greater_than, - ampersand, - quot, - line_ending, - // fallback - fallback - ); - }, - true => { - stateful_alt!( - codeblock_less_than, - codeblock_greater_than, - codeblock_ampersand, - codeblock_quot, - codeblock_end, - codeblock_line_ending, - codeblock_fallback - ); - }, - } - - match error { - Some(e) => Err(ErrMode::Backtrack(e.append(&i, &start, ErrorKind::Alt))), - None => Err(ErrMode::assert(&i, "no parsers")), - } -} - -fn text<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - let content = take_while(1.., |t| AsChar::is_alphanum(t) || "+,.!?\"".contains(t)).parse_next(i)?; - queue_newline_or_advance(&mut o, state, content.width())?; - queue(&mut o, style::Print(content))?; - Ok(()) - } -} - -fn heading<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - if !state.newline { - return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); - } - - let level = terminated(take_while(1.., |c| c == '#'), space1).parse_next(i)?; - let print = format!("{level} "); - - queue_newline_or_advance(&mut o, state, print.width())?; - queue(&mut o, style::SetForegroundColor(HEADING_COLOR))?; - queue(&mut o, style::SetAttribute(Attribute::Bold))?; - queue(&mut o, style::Print(print)) - } -} - -fn bulleted_item<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - if !state.newline { - return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); - } - - let ws = (space0, alt(("-", "*")), space1).parse_next(i)?.0; - let print = format!("{ws}• "); - - queue_newline_or_advance(&mut o, state, print.width())?; - queue(&mut o, style::Print(print)) - } -} - -fn numbered_item<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - if !state.newline { - return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); - } - - let (ws, digits, _, _) = (space0, digit1, ".", space1).parse_next(i)?; - let print = format!("{ws}{digits}. "); - - queue_newline_or_advance(&mut o, state, print.width())?; - queue(&mut o, style::Print(print)) - } -} - -fn horizontal_rule<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - if !state.newline { - return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); - } - - ( - space0, - alt((take_while(3.., '-'), take_while(3.., '*'), take_while(3.., '_'))), - ) - .parse_next(i)?; - - state.column = 0; - state.set_newline = true; - - let rule_width = state.terminal_width.unwrap_or(DEFAULT_RULE_WIDTH); - queue(&mut o, style::Print(format!("{}\n", "━".repeat(rule_width)))) - } -} - -fn code<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - "`".parse_next(i)?; - let code = terminated(take_until(0.., "`"), "`").parse_next(i)?; - let out = code.replace("&", "&").replace(">", ">").replace("<", "<"); - - queue_newline_or_advance(&mut o, state, out.width())?; - queue(&mut o, style::SetForegroundColor(Color::Green))?; - queue(&mut o, style::Print(out))?; - queue(&mut o, style::ResetColor) - } -} - -fn blockquote<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - if !state.newline { - return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); - } - - let level = repeat::<_, _, Vec<&'_ str>, _, _>(1.., terminated(">", space0)) - .parse_next(i)? - .len(); - let print = "│ ".repeat(level); - - queue(&mut o, style::SetForegroundColor(BLOCKQUOTE_COLOR))?; - queue_newline_or_advance(&mut o, state, print.width())?; - queue(&mut o, style::Print(print)) - } -} - -fn bold<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - match state.newline { - true => { - alt(("**", "__")).parse_next(i)?; - queue(&mut o, style::SetAttribute(Attribute::Bold))?; - }, - false => match state.bold { - true => { - alt(("**", "__")).parse_next(i)?; - queue(&mut o, style::SetAttribute(Attribute::NormalIntensity))?; - }, - false => { - preceded(space1, alt(("**", "__"))).parse_next(i)?; - queue(&mut o, style::Print(' '))?; - queue(&mut o, style::SetAttribute(Attribute::Bold))?; - }, - }, - }; - - state.bold = !state.bold; - - Ok(()) - } -} - -fn italic<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - match state.newline { - true => { - alt(("*", "_")).parse_next(i)?; - queue(&mut o, style::SetAttribute(Attribute::Italic))?; - }, - false => match state.italic { - true => { - alt(("*", "_")).parse_next(i)?; - queue(&mut o, style::SetAttribute(Attribute::NoItalic))?; - }, - false => { - preceded(space1, alt(("*", "_"))).parse_next(i)?; - queue(&mut o, style::Print(' '))?; - queue(&mut o, style::SetAttribute(Attribute::Italic))?; - }, - }, - }; - - state.italic = !state.italic; - - Ok(()) - } -} - -fn strikethrough<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - "~~".parse_next(i)?; - state.strikethrough = !state.strikethrough; - match state.strikethrough { - true => queue(&mut o, style::SetAttribute(Attribute::CrossedOut)), - false => queue(&mut o, style::SetAttribute(Attribute::NotCrossedOut)), - } - } -} - -fn citation<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - let num = delimited("[[", digit1, "]]").parse_next(i)?; - let link = delimited("(", take_till(0.., ')'), ")").parse_next(i)?; - - state.citations.push((num.to_owned(), link.to_owned())); - - queue_newline_or_advance(&mut o, state, num.width() + 1)?; - queue(&mut o, style::SetForegroundColor(URL_TEXT_COLOR))?; - queue(&mut o, style::Print(format!("[^{num}]")))?; - queue(&mut o, style::ResetColor) - } -} - -fn url<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - // Save the current input position - let start = i.checkpoint(); - - // Try to match the first part of URL pattern "[text]" - let display = match delimited::<_, _, _, _, Error<'a>, _, _, _>("[", take_until(1.., "]("), "]").parse_next(i) { - Ok(display) => display, - Err(_) => { - // If it doesn't match, reset position and fail - i.reset(&start); - return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); - }, - }; - - // Try to match the second part of URL pattern "(url)" - let link = match delimited::<_, _, _, _, Error<'a>, _, _, _>("(", take_till(0.., ')'), ")").parse_next(i) { - Ok(link) => link, - Err(_) => { - // If it doesn't match, reset position and fail - i.reset(&start); - return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); - }, - }; - - // Only generate output if the complete URL pattern matches - queue_newline_or_advance(&mut o, state, display.width() + 1)?; - queue(&mut o, style::SetForegroundColor(URL_TEXT_COLOR))?; - queue(&mut o, style::Print(format!("{display} ")))?; - queue(&mut o, style::SetForegroundColor(URL_LINK_COLOR))?; - state.column += link.width(); - queue(&mut o, style::Print(link))?; - queue(&mut o, style::ResetColor) - } -} - -fn less_than<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - "<".parse_next(i)?; - queue_newline_or_advance(&mut o, state, 1)?; - queue(&mut o, style::Print('<')) - } -} - -fn greater_than<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - ">".parse_next(i)?; - queue_newline_or_advance(&mut o, state, 1)?; - queue(&mut o, style::Print('>')) - } -} - -fn ampersand<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - "&".parse_next(i)?; - queue_newline_or_advance(&mut o, state, 1)?; - queue(&mut o, style::Print('&')) - } -} - -fn quot<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - """.parse_next(i)?; - queue_newline_or_advance(&mut o, state, 1)?; - queue(&mut o, style::Print('"')) - } -} - -fn line_ending<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - ascii::line_ending.parse_next(i)?; - - state.column = 0; - state.set_newline = true; - - queue(&mut o, style::ResetColor)?; - queue(&mut o, style::SetAttribute(style::Attribute::Reset))?; - queue(&mut o, style::Print("\n")) - } -} - -fn fallback<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - let fallback = any.parse_next(i)?; - if let Some(width) = fallback.width() { - queue_newline_or_advance(&mut o, state, width)?; - if fallback != ' ' || state.column != 1 { - queue(&mut o, style::Print(fallback))?; - } - } - - Ok(()) - } -} - -fn queue_newline_or_advance<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, - width: usize, -) -> Result<(), ErrMode>> { - if let Some(terminal_width) = state.terminal_width { - if state.column > 0 && state.column + width > terminal_width { - state.column = width; - queue(&mut o, style::Print('\n'))?; - return Ok(()); - } - } - - // else - state.column += width; - - Ok(()) -} - -fn queue<'a>(o: &mut impl Write, command: impl Command) -> Result<(), ErrMode>> { - use crossterm::QueueableCommand; - o.queue(command).map_err(|err| ErrMode::Cut(Error::Stdio(err)))?; - Ok(()) -} - -fn codeblock_begin<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - if !state.newline { - return Err(ErrMode::from_error_kind(i, ErrorKind::Fail)); - } - - // We don't want to do anything special to text inside codeblocks so we wait for all of it - // The alternative is to switch between parse rules at the top level but that's slightly involved - let language = preceded("```", till_line_ending).parse_next(i)?; - ascii::line_ending.parse_next(i)?; - - state.in_codeblock = true; - - if !language.is_empty() { - queue(&mut o, style::Print(format!("{}\n", language).bold()))?; - } - - queue(&mut o, style::SetForegroundColor(CODE_COLOR))?; - - Ok(()) - } -} - -fn codeblock_end<'a, 'b>( - mut o: impl Write + 'b, - state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - "```".parse_next(i)?; - state.in_codeblock = false; - queue(&mut o, style::ResetColor) - } -} - -fn codeblock_less_than<'a, 'b>( - mut o: impl Write + 'b, - _state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - "<".parse_next(i)?; - queue(&mut o, style::Print('<')) - } -} - -fn codeblock_greater_than<'a, 'b>( - mut o: impl Write + 'b, - _state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - ">".parse_next(i)?; - queue(&mut o, style::Print('>')) - } -} - -fn codeblock_ampersand<'a, 'b>( - mut o: impl Write + 'b, - _state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - "&".parse_next(i)?; - queue(&mut o, style::Print('&')) - } -} - -fn codeblock_quot<'a, 'b>( - mut o: impl Write + 'b, - _state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - """.parse_next(i)?; - queue(&mut o, style::Print('"')) - } -} - -fn codeblock_line_ending<'a, 'b>( - mut o: impl Write + 'b, - _state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - ascii::line_ending.parse_next(i)?; - queue(&mut o, style::Print("\n")) - } -} - -fn codeblock_fallback<'a, 'b>( - mut o: impl Write + 'b, - _state: &'b mut ParseState, -) -> impl FnMut(&mut Partial<&'a str>) -> PResult<(), Error<'a>> + 'b { - move |i| { - let fallback = any.parse_next(i)?; - queue(&mut o, style::Print(fallback)) - } -} - -#[cfg(test)] -mod tests { - use std::io::Write; - - use winnow::stream::Offset; - - use super::*; - - macro_rules! validate { - ($test:ident, $input:literal, [$($commands:expr),+ $(,)?]) => { - #[test] - fn $test() -> eyre::Result<()> { - use crossterm::ExecutableCommand; - - let mut input = $input.trim().to_owned(); - input.push(' '); - input.push(' '); - - let mut state = ParseState::new(Some(80)); - let mut presult = vec![]; - let mut offset = 0; - - loop { - let input = Partial::new(&input[offset..]); - match interpret_markdown(input, &mut presult, &mut state) { - Ok(parsed) => { - offset += parsed.offset_from(&input); - state.newline = state.set_newline; - state.set_newline = false; - }, - Err(err) => match err.into_inner() { - Some(err) => panic!("{err}"), - None => break, // Data was incomplete - }, - } - } - - presult.flush()?; - let presult = String::from_utf8(presult)?; - - let mut wresult: Vec = vec![]; - $(wresult.execute($commands)?;)+ - let wresult = String::from_utf8(wresult)?; - - assert_eq!(presult.trim(), wresult); - - Ok(()) - } - }; - } - - validate!(text_1, "hello world!", [style::Print("hello world!")]); - validate!(linted_codeblock_1, "```java\nhello world!```", [ - style::SetAttribute(Attribute::Bold), - style::Print("java\n"), - style::SetAttribute(Attribute::Reset), - style::SetForegroundColor(CODE_COLOR), - style::Print("hello world!"), - style::ResetColor, - ]); - validate!(code_1, "`print`", [ - style::SetForegroundColor(CODE_COLOR), - style::Print("print"), - style::ResetColor, - ]); - validate!(url_1, "[google](google.com)", [ - style::SetForegroundColor(URL_TEXT_COLOR), - style::Print("google "), - style::SetForegroundColor(URL_LINK_COLOR), - style::Print("google.com"), - style::ResetColor, - ]); - validate!(citation_1, "[[1]](google.com)", [ - style::SetForegroundColor(URL_TEXT_COLOR), - style::Print("[^1]"), - style::ResetColor, - ]); - validate!(bold_1, "**hello**", [ - style::SetAttribute(Attribute::Bold), - style::Print("hello"), - style::SetAttribute(Attribute::NormalIntensity) - ]); - validate!(italic_1, "*hello*", [ - style::SetAttribute(Attribute::Italic), - style::Print("hello"), - style::SetAttribute(Attribute::NoItalic) - ]); - validate!(strikethrough_1, "~~hello~~", [ - style::SetAttribute(Attribute::CrossedOut), - style::Print("hello"), - style::SetAttribute(Attribute::NotCrossedOut) - ]); - validate!(less_than_1, "<", [style::Print('<')]); - validate!(greater_than_1, ".>.", [style::Print(".>.")]); - validate!(ampersand_1, "&", [style::Print('&')]); - validate!(quote_1, """, [style::Print('"')]); - validate!(fallback_1, "+ % @ . ? ", [style::Print("+ % @ . ?")]); - validate!(horizontal_rule_1, "---", [style::Print("━".repeat(80))]); - validate!(heading_1, "# Hello World", [ - style::SetForegroundColor(HEADING_COLOR), - style::SetAttribute(Attribute::Bold), - style::Print("# Hello World"), - ]); - validate!(bulleted_item_1, "- bullet", [style::Print("• bullet")]); - validate!(bulleted_item_2, "* bullet", [style::Print("• bullet")]); - validate!(numbered_item_1, "1. number", [style::Print("1. number")]); - validate!(blockquote_1, "> hello", [ - style::SetForegroundColor(BLOCKQUOTE_COLOR), - style::Print("│ hello"), - ]); - validate!(square_bracket_1, "[test]", [style::Print("[test]")]); - validate!(square_bracket_2, "Text with [brackets]", [style::Print( - "Text with [brackets]" - )]); - validate!(square_bracket_empty, "[]", [style::Print("[]")]); - validate!(square_bracket_array, "a[i]", [style::Print("a[i]")]); - validate!(square_bracket_url_like_1, "[text] without url part", [style::Print( - "[text] without url part" - )]); - validate!(square_bracket_url_like_2, "[text](without url part", [style::Print( - "[text](without url part" - )]); -} diff --git a/crates/q_chat/src/parser.rs b/crates/q_chat/src/parser.rs deleted file mode 100644 index ffa7854dc2..0000000000 --- a/crates/q_chat/src/parser.rs +++ /dev/null @@ -1,375 +0,0 @@ -use std::time::{ - Duration, - Instant, -}; - -use eyre::Result; -use fig_api_client::clients::SendMessageOutput; -use fig_api_client::model::ChatResponseStream; -use rand::distr::{ - Alphanumeric, - SampleString, -}; -use thiserror::Error; -use tracing::{ - error, - info, - trace, -}; - -use super::message::{ - AssistantMessage, - AssistantToolUse, -}; - -#[derive(Debug, Error)] -pub struct RecvError { - /// The request id associated with the [SendMessageOutput] stream. - pub request_id: Option, - #[source] - pub source: RecvErrorKind, -} - -impl std::fmt::Display for RecvError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "Failed to receive the next message: ")?; - if let Some(request_id) = self.request_id.as_ref() { - write!(f, "request_id: {}, error: ", request_id)?; - } - write!(f, "{}", self.source)?; - Ok(()) - } -} - -#[derive(Debug, Error)] -pub enum RecvErrorKind { - #[error("{0}")] - Client(#[from] fig_api_client::Error), - #[error("{0}")] - Json(#[from] serde_json::Error), - /// An error was encountered while waiting for the next event in the stream after a noticeably - /// long wait time. - /// - /// *Context*: the client can throw an error after ~100s of waiting with no response, likely due - /// to an exceptionally complex tool use taking too long to generate. - #[error("The stream ended after {}s: {source}", .duration.as_secs())] - StreamTimeout { - source: fig_api_client::Error, - duration: std::time::Duration, - }, - /// Unexpected end of stream while receiving a tool use. - /// - /// *Context*: the stream can unexpectedly end with `Ok(None)` while waiting for an - /// exceptionally complex tool use. This is due to some proxy server dropping idle - /// connections after some timeout is reached. - /// - /// TODO: should this be removed? - #[error("Unexpected end of stream for tool: {} with id: {}", .name, .tool_use_id)] - UnexpectedToolUseEos { - tool_use_id: String, - name: String, - message: Box, - time_elapsed: Duration, - }, -} - -/// State associated with parsing a [ChatResponseStream] into a [Message]. -/// -/// # Usage -/// -/// You should repeatedly call [Self::recv] to receive [ResponseEvent]'s until a -/// [ResponseEvent::EndStream] value is returned. -#[derive(Debug)] -pub struct ResponseParser { - /// The response to consume and parse into a sequence of [Ev]. - response: SendMessageOutput, - /// Buffer to hold the next event in [SendMessageOutput]. - peek: Option, - /// Message identifier for the assistant's response. Randomly generated on creation. - message_id: String, - /// Buffer for holding the accumulated assistant response. - assistant_text: String, - /// Tool uses requested by the model. - tool_uses: Vec, - /// Whether or not we are currently receiving tool use delta events. Tuple of - /// `Some((tool_use_id, name))` if true, [None] otherwise. - parsing_tool_use: Option<(String, String)>, -} - -impl ResponseParser { - pub fn new(response: SendMessageOutput) -> Self { - let message_id = Alphanumeric.sample_string(&mut rand::rng(), 9); - info!(?message_id, "Generated new message id"); - Self { - response, - peek: None, - message_id, - assistant_text: String::new(), - tool_uses: Vec::new(), - parsing_tool_use: None, - } - } - - /// Consumes the associated [ConverseStreamResponse] until a valid [ResponseEvent] is parsed. - pub async fn recv(&mut self) -> Result { - if let Some((id, name)) = self.parsing_tool_use.take() { - let tool_use = self.parse_tool_use(id, name).await?; - self.tool_uses.push(tool_use.clone()); - return Ok(ResponseEvent::ToolUse(tool_use)); - } - - // First, handle discarding AssistantResponseEvent's that immediately precede a - // CodeReferenceEvent. - let peek = self.peek().await?; - if let Some(ChatResponseStream::AssistantResponseEvent { content }) = peek { - // Cloning to bypass borrowchecker stuff. - let content = content.clone(); - self.next().await?; - match self.peek().await? { - Some(ChatResponseStream::CodeReferenceEvent(_)) => (), - _ => { - self.assistant_text.push_str(&content); - return Ok(ResponseEvent::AssistantText(content)); - }, - } - } - - loop { - match self.next().await { - Ok(Some(output)) => match output { - ChatResponseStream::AssistantResponseEvent { content } => { - self.assistant_text.push_str(&content); - return Ok(ResponseEvent::AssistantText(content)); - }, - ChatResponseStream::InvalidStateEvent { reason, message } => { - error!(%reason, %message, "invalid state event"); - }, - ChatResponseStream::ToolUseEvent { - tool_use_id, - name, - input, - stop, - } => { - debug_assert!(input.is_none(), "Unexpected initial content in first tool use event"); - debug_assert!( - stop.is_none_or(|v| !v), - "Unexpected immediate stop in first tool use event" - ); - self.parsing_tool_use = Some((tool_use_id.clone(), name.clone())); - return Ok(ResponseEvent::ToolUseStart { name }); - }, - _ => {}, - }, - Ok(None) => { - let message_id = Some(self.message_id.clone()); - let content = std::mem::take(&mut self.assistant_text); - let message = if self.tool_uses.is_empty() { - AssistantMessage::new_response(message_id, content) - } else { - AssistantMessage::new_tool_use( - message_id, - content, - self.tool_uses.clone().into_iter().collect(), - ) - }; - return Ok(ResponseEvent::EndStream { message }); - }, - Err(err) => return Err(err), - } - } - } - - /// Consumes the response stream until a valid [ToolUse] is parsed. - /// - /// The arguments are the fields from the first [ChatResponseStream::ToolUseEvent] consumed. - async fn parse_tool_use(&mut self, id: String, name: String) -> Result { - let mut tool_string = String::new(); - let start = Instant::now(); - while let Some(ChatResponseStream::ToolUseEvent { .. }) = self.peek().await? { - if let Some(ChatResponseStream::ToolUseEvent { input, stop, .. }) = self.next().await? { - if let Some(i) = input { - tool_string.push_str(&i); - } - if let Some(true) = stop { - break; - } - } - } - - let args = match serde_json::from_str(&tool_string) { - Ok(args) => args, - Err(err) if !tool_string.is_empty() => { - // If we failed deserializing after waiting for a long time, then this is most - // likely bedrock responding with a stop event for some reason without actually - // including the tool contents. Essentially, the tool was too large. - // Timeouts have been seen as short as ~1 minute, so setting the time to 30. - let time_elapsed = start.elapsed(); - if self.peek().await?.is_none() && time_elapsed > Duration::from_secs(30) { - error!( - "Received an unexpected end of stream after spending ~{}s receiving tool events", - time_elapsed.as_secs_f64() - ); - self.tool_uses.push(AssistantToolUse { - id: id.clone(), - name: name.clone(), - args: serde_json::Value::Object( - [( - "key".to_string(), - serde_json::Value::String( - "WARNING: the actual tool use arguments were too complicated to be generated" - .to_string(), - ), - )] - .into_iter() - .collect(), - ), - }); - let message = Box::new(AssistantMessage::new_tool_use( - Some(self.message_id.clone()), - std::mem::take(&mut self.assistant_text), - self.tool_uses.clone().into_iter().collect(), - )); - return Err(self.error(RecvErrorKind::UnexpectedToolUseEos { - tool_use_id: id, - name, - message, - time_elapsed, - })); - } else { - return Err(self.error(err)); - } - }, - // if the tool just does not need any input - _ => serde_json::json!({}), - }; - Ok(AssistantToolUse { id, name, args }) - } - - /// Returns the next event in the [SendMessageOutput] without consuming it. - async fn peek(&mut self) -> Result, RecvError> { - if self.peek.is_some() { - return Ok(self.peek.as_ref()); - } - match self.next().await? { - Some(v) => { - self.peek = Some(v); - Ok(self.peek.as_ref()) - }, - None => Ok(None), - } - } - - /// Consumes the next [SendMessageOutput] event. - async fn next(&mut self) -> Result, RecvError> { - if let Some(ev) = self.peek.take() { - return Ok(Some(ev)); - } - trace!("Attempting to recv next event"); - let start = std::time::Instant::now(); - let result = self.response.recv().await; - let duration = std::time::Instant::now().duration_since(start); - match result { - Ok(r) => { - trace!(?r, "Received new event"); - Ok(r) - }, - Err(err) => { - if duration.as_secs() >= 59 { - Err(self.error(RecvErrorKind::StreamTimeout { source: err, duration })) - } else { - Err(self.error(err)) - } - }, - } - } - - fn request_id(&self) -> Option<&str> { - self.response.request_id() - } - - /// Helper to create a new [RecvError] populated with the associated request id for the stream. - fn error(&self, source: impl Into) -> RecvError { - RecvError { - request_id: self.request_id().map(str::to_string), - source: source.into(), - } - } -} - -#[derive(Debug)] -pub enum ResponseEvent { - /// Text returned by the assistant. This should be displayed to the user as it is received. - AssistantText(String), - /// Notification that a tool use is being received. - ToolUseStart { name: String }, - /// A tool use requested by the assistant. This should be displayed to the user as it is - /// received. - ToolUse(AssistantToolUse), - /// Represents the end of the response. No more events will be returned. - EndStream { - /// The completed message containing all of the assistant text and tool use events - /// previously emitted. This should be stored in the conversation history and sent in - /// subsequent requests. - message: AssistantMessage, - }, -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - async fn test_parse() { - let _ = tracing_subscriber::fmt::try_init(); - let tool_use_id = "TEST_ID".to_string(); - let tool_name = "execute_bash".to_string(); - let tool_args = serde_json::json!({ - "command": "echo hello" - }) - .to_string(); - let tool_use_split_at = 5; - let mut events = vec![ - ChatResponseStream::AssistantResponseEvent { - content: "hi".to_string(), - }, - ChatResponseStream::AssistantResponseEvent { - content: " there".to_string(), - }, - ChatResponseStream::AssistantResponseEvent { - content: "IGNORE ME PLEASE".to_string(), - }, - ChatResponseStream::CodeReferenceEvent(()), - ChatResponseStream::ToolUseEvent { - tool_use_id: tool_use_id.clone(), - name: tool_name.clone(), - input: None, - stop: None, - }, - ChatResponseStream::ToolUseEvent { - tool_use_id: tool_use_id.clone(), - name: tool_name.clone(), - input: Some(tool_args.as_str().split_at(tool_use_split_at).0.to_string()), - stop: None, - }, - ChatResponseStream::ToolUseEvent { - tool_use_id: tool_use_id.clone(), - name: tool_name.clone(), - input: Some(tool_args.as_str().split_at(tool_use_split_at).1.to_string()), - stop: None, - }, - ChatResponseStream::ToolUseEvent { - tool_use_id: tool_use_id.clone(), - name: tool_name.clone(), - input: None, - stop: Some(true), - }, - ]; - events.reverse(); - let mock = SendMessageOutput::Mock(events); - let mut parser = ResponseParser::new(mock); - - for _ in 0..5 { - println!("{:?}", parser.recv().await.unwrap()); - } - } -} diff --git a/crates/q_chat/src/prompt.rs b/crates/q_chat/src/prompt.rs deleted file mode 100644 index 9811f977f1..0000000000 --- a/crates/q_chat/src/prompt.rs +++ /dev/null @@ -1,364 +0,0 @@ -use std::borrow::Cow; - -use crossterm::style::Stylize; -use eyre::Result; -use rustyline::completion::{ - Completer, - FilenameCompleter, - extract_word, -}; -use rustyline::error::ReadlineError; -use rustyline::highlight::{ - CmdKind, - Highlighter, -}; -use rustyline::history::DefaultHistory; -use rustyline::validate::{ - ValidationContext, - ValidationResult, - Validator, -}; -use rustyline::{ - Cmd, - Completer, - CompletionType, - Config, - Context, - EditMode, - Editor, - EventHandler, - Helper, - Hinter, - KeyCode, - KeyEvent, - Modifiers, -}; -use winnow::stream::AsChar; - -pub const COMMANDS: &[&str] = &[ - "/clear", - "/help", - "/editor", - "/issue", - // "/acceptall", /// Functional, but deprecated in favor of /tools trustall - "/quit", - "/tools", - "/tools trust", - "/tools untrust", - "/tools trustall", - "/tools reset", - "/profile", - "/profile help", - "/profile list", - "/profile create", - "/profile delete", - "/profile rename", - "/profile set", - "/context help", - "/context show", - "/context show --expand", - "/context add", - "/context add --global", - "/context rm", - "/context rm --global", - "/context clear", - "/context clear --global", - "/context hooks help", - "/context hooks add", - "/context hooks rm", - "/context hooks enable", - "/context hooks disable", - "/context hooks enable-all", - "/context hooks disable-all", - "/compact", - "/compact help", - "/usage", -]; - -pub fn generate_prompt(current_profile: Option<&str>, warning: bool) -> String { - let warning_symbol = if warning { "!".red().to_string() } else { "".to_string() }; - let profile_part = current_profile - .filter(|&p| p != "default") - .map(|p| format!("[{p}] ").cyan().to_string()) - .unwrap_or_default(); - - format!("{profile_part}{warning_symbol}{}", "> ".magenta()) -} - -/// Complete commands that start with a slash -fn complete_command(word: &str, start: usize) -> (usize, Vec) { - ( - start, - COMMANDS - .iter() - .filter(|p| p.starts_with(word)) - .map(|s| (*s).to_owned()) - .collect(), - ) -} - -/// A wrapper around FilenameCompleter that provides enhanced path detection -/// and completion capabilities for the chat interface. -pub struct PathCompleter { - /// The underlying filename completer from rustyline - filename_completer: FilenameCompleter, -} - -impl PathCompleter { - /// Creates a new PathCompleter instance - pub fn new() -> Self { - Self { - filename_completer: FilenameCompleter::new(), - } - } - - /// Attempts to complete a file path at the given position in the line - pub fn complete_path( - &self, - line: &str, - pos: usize, - ctx: &Context<'_>, - ) -> Result<(usize, Vec), ReadlineError> { - // Use the filename completer to get path completions - match self.filename_completer.complete(line, pos, ctx) { - Ok((pos, completions)) => { - // Convert the filename completer's pairs to strings - let file_completions: Vec = completions.iter().map(|pair| pair.replacement.clone()).collect(); - - // Return the completions if we have any - Ok((pos, file_completions)) - }, - Err(err) => Err(err), - } - } -} - -pub struct PromptCompleter { - sender: std::sync::mpsc::Sender>, - receiver: std::sync::mpsc::Receiver>, -} - -impl PromptCompleter { - fn new(sender: std::sync::mpsc::Sender>, receiver: std::sync::mpsc::Receiver>) -> Self { - PromptCompleter { sender, receiver } - } - - fn complete_prompt(&self, word: &str) -> Result, ReadlineError> { - let sender = &self.sender; - let receiver = &self.receiver; - sender - .send(if !word.is_empty() { Some(word.to_string()) } else { None }) - .map_err(|e| ReadlineError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))?; - let prompt_info = receiver - .recv() - .map_err(|e| ReadlineError::Io(std::io::Error::new(std::io::ErrorKind::Other, e.to_string())))? - .iter() - .map(|n| format!("@{n}")) - .collect::>(); - - Ok(prompt_info) - } -} - -pub struct ChatCompleter { - path_completer: PathCompleter, - prompt_completer: PromptCompleter, -} - -impl ChatCompleter { - fn new(sender: std::sync::mpsc::Sender>, receiver: std::sync::mpsc::Receiver>) -> Self { - Self { - path_completer: PathCompleter::new(), - prompt_completer: PromptCompleter::new(sender, receiver), - } - } -} - -impl Completer for ChatCompleter { - type Candidate = String; - - fn complete( - &self, - line: &str, - pos: usize, - _ctx: &Context<'_>, - ) -> Result<(usize, Vec), ReadlineError> { - let (start, word) = extract_word(line, pos, None, |c| c.is_space()); - - // Handle command completion - if word.starts_with('/') { - return Ok(complete_command(word, start)); - } - - if line.starts_with('@') { - let search_word = line.strip_prefix('@').unwrap_or(""); - if let Ok(completions) = self.prompt_completer.complete_prompt(search_word) { - if !completions.is_empty() { - return Ok((0, completions)); - } - } - } - - // Handle file path completion as fallback - if let Ok((pos, completions)) = self.path_completer.complete_path(line, pos, _ctx) { - if !completions.is_empty() { - return Ok((pos, completions)); - } - } - - // Default: no completions - Ok((start, Vec::new())) - } -} - -/// Custom validator for multi-line input -pub struct MultiLineValidator; - -impl Validator for MultiLineValidator { - fn validate(&self, ctx: &mut ValidationContext<'_>) -> rustyline::Result { - let input = ctx.input(); - - // Check for explicit multi-line markers - if input.starts_with("```") && !input.ends_with("```") { - return Ok(ValidationResult::Incomplete); - } - - // Check for backslash continuation - if input.ends_with('\\') { - return Ok(ValidationResult::Incomplete); - } - - Ok(ValidationResult::Valid(None)) - } -} - -#[derive(Helper, Completer, Hinter)] -pub struct ChatHelper { - #[rustyline(Completer)] - completer: ChatCompleter, - #[rustyline(Hinter)] - hinter: (), - validator: MultiLineValidator, -} - -impl Validator for ChatHelper { - fn validate(&self, ctx: &mut ValidationContext<'_>) -> rustyline::Result { - self.validator.validate(ctx) - } -} - -impl Highlighter for ChatHelper { - fn highlight_hint<'h>(&self, hint: &'h str) -> Cow<'h, str> { - Cow::Owned(format!("\x1b[1m{hint}\x1b[m")) - } - - fn highlight<'l>(&self, line: &'l str, _pos: usize) -> Cow<'l, str> { - Cow::Borrowed(line) - } - - fn highlight_char(&self, _line: &str, _pos: usize, _kind: CmdKind) -> bool { - false - } -} - -pub fn rl( - sender: std::sync::mpsc::Sender>, - receiver: std::sync::mpsc::Receiver>, -) -> Result> { - let edit_mode = match fig_settings::settings::get_string_opt("chat.editMode").as_deref() { - Some("vi" | "vim") => EditMode::Vi, - _ => EditMode::Emacs, - }; - let config = Config::builder() - .history_ignore_space(true) - .completion_type(CompletionType::List) - .edit_mode(edit_mode) - .build(); - let h = ChatHelper { - completer: ChatCompleter::new(sender, receiver), - hinter: (), - validator: MultiLineValidator, - }; - let mut rl = Editor::with_config(config)?; - rl.set_helper(Some(h)); - - // Add custom keybinding for Alt+Enter to insert a newline - rl.bind_sequence( - KeyEvent(KeyCode::Enter, Modifiers::ALT), - EventHandler::Simple(Cmd::Insert(1, "\n".to_string())), - ); - - // Add custom keybinding for Ctrl+J to insert a newline - rl.bind_sequence( - KeyEvent(KeyCode::Char('j'), Modifiers::CTRL), - EventHandler::Simple(Cmd::Insert(1, "\n".to_string())), - ); - - Ok(rl) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_generate_prompt() { - // Test default prompt (no profile) - assert_eq!(generate_prompt(None, false), "> ".magenta().to_string()); - // Test default prompt with warning - assert_eq!(generate_prompt(None, true), format!("{}{}", "!".red(), "> ".magenta())); - // Test default profile (should be same as no profile) - assert_eq!(generate_prompt(Some("default"), false), "> ".magenta().to_string()); - // Test custom profile - assert_eq!( - generate_prompt(Some("test-profile"), false), - format!("{}{}", "[test-profile] ".cyan(), "> ".magenta()) - ); - // Test another custom profile with warning - assert_eq!( - generate_prompt(Some("dev"), true), - format!("{}{}{}", "[dev] ".cyan(), "!".red(), "> ".magenta()) - ); - } - - #[test] - fn test_chat_completer_command_completion() { - let (prompt_request_sender, _) = std::sync::mpsc::channel::>(); - let (_, prompt_response_receiver) = std::sync::mpsc::channel::>(); - let completer = ChatCompleter::new(prompt_request_sender, prompt_response_receiver); - let line = "/h"; - let pos = 2; // Position at the end of "/h" - - // Create a mock context with empty history - let empty_history = DefaultHistory::new(); - let ctx = Context::new(&empty_history); - - // Get completions - let (start, completions) = completer.complete(line, pos, &ctx).unwrap(); - - // Verify start position - assert_eq!(start, 0); - - // Verify completions contain expected commands - assert!(completions.contains(&"/help".to_string())); - } - - #[test] - fn test_chat_completer_no_completion() { - let (prompt_request_sender, _) = std::sync::mpsc::channel::>(); - let (_, prompt_response_receiver) = std::sync::mpsc::channel::>(); - let completer = ChatCompleter::new(prompt_request_sender, prompt_response_receiver); - let line = "Hello, how are you?"; - let pos = line.len(); - - // Create a mock context with empty history - let empty_history = DefaultHistory::new(); - let ctx = Context::new(&empty_history); - - // Get completions - let (_, completions) = completer.complete(line, pos, &ctx).unwrap(); - - // Verify no completions are returned for regular text - assert!(completions.is_empty()); - } -} diff --git a/crates/q_chat/src/skim_integration.rs b/crates/q_chat/src/skim_integration.rs deleted file mode 100644 index 026576be14..0000000000 --- a/crates/q_chat/src/skim_integration.rs +++ /dev/null @@ -1,378 +0,0 @@ -use std::io::{ - BufReader, - Cursor, - Write, - stdout, -}; - -use crossterm::execute; -use crossterm::terminal::{ - EnterAlternateScreen, - LeaveAlternateScreen, -}; -use eyre::{ - Result, - eyre, -}; -use rustyline::{ - Cmd, - ConditionalEventHandler, - EventContext, - RepeatCount, -}; -use skim::prelude::*; -use tempfile::NamedTempFile; - -use super::context::ContextManager; - -pub fn select_profile_with_skim(context_manager: &ContextManager) -> Result> { - let profiles = context_manager.list_profiles_blocking()?; - - launch_skim_selector(&profiles, "Select profile: ", false) - .map(|selected| selected.and_then(|s| s.into_iter().next())) -} - -pub struct SkimCommandSelector { - context_manager: Arc, - tool_names: Vec, -} - -impl SkimCommandSelector { - /// This allows the ConditionalEventHandler handle function to be bound to a KeyEvent. - pub fn new(context_manager: Arc, tool_names: Vec) -> Self { - Self { - context_manager, - tool_names, - } - } -} - -impl ConditionalEventHandler for SkimCommandSelector { - fn handle( - &self, - _evt: &rustyline::Event, - _n: RepeatCount, - _positive: bool, - _ctx: &EventContext<'_>, - ) -> Option { - // Launch skim command selector with the context manager if available - match select_command(self.context_manager.as_ref(), &self.tool_names) { - Ok(Some(command)) => Some(Cmd::Insert(1, command)), - _ => { - // If cancelled or error, do nothing - Some(Cmd::Noop) - }, - } - } -} - -pub fn get_available_commands() -> Vec { - // Import the COMMANDS array directly from prompt.rs - // This is the single source of truth for available commands - let commands_array = super::prompt::COMMANDS; - - let mut commands = Vec::new(); - for &cmd in commands_array { - commands.push(cmd.to_string()); - } - - commands -} - -/// Format commands for skim display -/// Create a standard set of skim options with consistent styling -fn create_skim_options(prompt: &str, multi: bool) -> Result { - SkimOptionsBuilder::default() - .height("100%".to_string()) - .prompt(prompt.to_string()) - .reverse(true) - .multi(multi) - .build() - .map_err(|e| eyre!("Failed to build skim options: {}", e)) -} - -/// Run skim with the given options and items in an alternate screen -/// This helper function handles entering/exiting the alternate screen and running skim -fn run_skim_with_options(options: &SkimOptions, items: SkimItemReceiver) -> Result>>> { - // Enter alternate screen to prevent skim output from persisting in terminal history - execute!(stdout(), EnterAlternateScreen).map_err(|e| eyre!("Failed to enter alternate screen: {}", e))?; - - let selected_items = - Skim::run_with(options, Some(items)).and_then(|out| if out.is_abort { None } else { Some(out.selected_items) }); - - execute!(stdout(), LeaveAlternateScreen).map_err(|e| eyre!("Failed to leave alternate screen: {}", e))?; - - Ok(selected_items) -} - -/// Extract string selections from skim items -fn extract_selections(items: Vec>) -> Vec { - items.iter().map(|item| item.output().to_string()).collect() -} - -/// Launch skim with the given items and return the selected item -pub fn launch_skim_selector(items: &[String], prompt: &str, multi: bool) -> Result>> { - let mut temp_file_for_skim_input = NamedTempFile::new()?; - temp_file_for_skim_input.write_all(items.join("\n").as_bytes())?; - - let options = create_skim_options(prompt, multi)?; - let item_reader = SkimItemReader::default(); - let items = item_reader.of_bufread(BufReader::new(std::fs::File::open(temp_file_for_skim_input.path())?)); - - // Run skim and get selected items - match run_skim_with_options(&options, items)? { - Some(items) if !items.is_empty() => { - let selections = extract_selections(items); - Ok(Some(selections)) - }, - _ => Ok(None), // User cancelled or no selection - } -} - -/// Select files using skim -pub fn select_files_with_skim() -> Result>> { - // Create skim options with appropriate settings - let options = create_skim_options("Select files: ", true)?; - - // Create a command that will be executed by skim - // This avoids loading all files into memory at once - let find_cmd = "find . -type f -not -path '*/\\.*'"; - - // Create a command collector that will execute the find command - let item_reader = SkimItemReader::default(); - let items = item_reader.of_bufread(BufReader::new( - std::process::Command::new("sh") - .args(["-c", find_cmd]) - .stdout(std::process::Stdio::piped()) - .spawn()? - .stdout - .ok_or_else(|| eyre!("Failed to get stdout from command"))?, - )); - - // Run skim with the command output as a stream - match run_skim_with_options(&options, items)? { - Some(items) if !items.is_empty() => { - let selections = extract_selections(items); - Ok(Some(selections)) - }, - _ => Ok(None), // User cancelled or no selection - } -} - -/// Select context paths using skim -pub fn select_context_paths_with_skim(context_manager: &ContextManager) -> Result, bool)>> { - let mut global_paths = Vec::new(); - let mut profile_paths = Vec::new(); - - // Get global paths - for path in &context_manager.global_config.paths { - global_paths.push(format!("(global) {}", path)); - } - - // Get profile-specific paths - for path in &context_manager.profile_config.paths { - profile_paths.push(format!("(profile: {}) {}", context_manager.current_profile, path)); - } - - // Combine paths, but keep track of which are global - let mut all_paths = Vec::new(); - all_paths.extend(global_paths); - all_paths.extend(profile_paths); - - if all_paths.is_empty() { - return Ok(None); // No paths to select - } - - // Create skim options - let options = create_skim_options("Select paths to remove: ", true)?; - - // Create item reader - let item_reader = SkimItemReader::default(); - let items = item_reader.of_bufread(Cursor::new(all_paths.join("\n"))); - - // Run skim and get selected paths - match run_skim_with_options(&options, items)? { - Some(items) if !items.is_empty() => { - let selected_paths = extract_selections(items); - - // Check if any global paths were selected - let has_global = selected_paths.iter().any(|p| p.starts_with("(global)")); - - // Extract the actual paths from the formatted strings - let paths: Vec = selected_paths - .iter() - .map(|p| { - // Extract the path part after the prefix - let parts: Vec<&str> = p.splitn(2, ") ").collect(); - if parts.len() > 1 { - parts[1].to_string() - } else { - p.clone() - } - }) - .collect(); - - Ok(Some((paths, has_global))) - }, - _ => Ok(None), // User cancelled selection - } -} - -/// Launch the command selector and handle the selected command -pub fn select_command(context_manager: &ContextManager, tools: &[String]) -> Result> { - let commands = get_available_commands(); - - match launch_skim_selector(&commands, "Select command: ", false)? { - Some(selections) if !selections.is_empty() => { - let selected_command = &selections[0]; - - match CommandType::from_str(selected_command) { - Some(CommandType::ContextAdd(cmd)) => { - // For context add commands, we need to select files - match select_files_with_skim()? { - Some(files) if !files.is_empty() => { - // Construct the full command with selected files - let mut cmd = cmd.clone(); - for file in files { - cmd.push_str(&format!(" {}", file)); - } - Ok(Some(cmd)) - }, - _ => Ok(Some(selected_command.clone())), /* User cancelled file selection, return just the - * command */ - } - }, - Some(CommandType::ContextRemove(cmd)) => { - // For context rm commands, we need to select from existing context paths - match select_context_paths_with_skim(context_manager)? { - Some((paths, has_global)) if !paths.is_empty() => { - // Construct the full command with selected paths - let mut full_cmd = cmd.clone(); - if has_global { - full_cmd.push_str(" --global"); - } - for path in paths { - full_cmd.push_str(&format!(" {}", path)); - } - Ok(Some(full_cmd)) - }, - Some((_, _)) => Ok(Some(format!("{} (No paths selected)", cmd))), - None => Ok(Some(selected_command.clone())), // User cancelled path selection - } - }, - Some(CommandType::Tools(_)) => { - let options = create_skim_options("Select tool: ", false)?; - let item_reader = SkimItemReader::default(); - let items = item_reader.of_bufread(Cursor::new(tools.join("\n"))); - let selected_tool = match run_skim_with_options(&options, items)? { - Some(items) if !items.is_empty() => Some(items[0].output().to_string()), - _ => None, - }; - - match selected_tool { - Some(tool) => Ok(Some(format!("{} {}", selected_command, tool))), - None => Ok(Some(selected_command.clone())), /* User cancelled tool selection, return just the - * command */ - } - }, - Some(cmd @ CommandType::Profile(_)) if cmd.needs_profile_selection() => { - // For profile operations that need a profile name, show profile selector - match select_profile_with_skim(context_manager)? { - Some(profile) => { - let full_cmd = format!("{} {}", selected_command, profile); - Ok(Some(full_cmd)) - }, - None => Ok(Some(selected_command.clone())), // User cancelled profile selection - } - }, - Some(CommandType::Profile(_)) => { - // For other profile operations (like create), just return the command - Ok(Some(selected_command.clone())) - }, - None => { - // Command doesn't need additional parameters - Ok(Some(selected_command.clone())) - }, - } - }, - _ => Ok(None), // User cancelled command selection - } -} - -#[derive(PartialEq)] -enum CommandType { - ContextAdd(String), - ContextRemove(String), - Tools(&'static str), - Profile(&'static str), -} - -impl CommandType { - fn needs_profile_selection(&self) -> bool { - matches!(self, CommandType::Profile("set" | "delete" | "rename")) - } - - fn from_str(cmd: &str) -> Option { - if cmd.starts_with("/context add") { - Some(CommandType::ContextAdd(cmd.to_string())) - } else if cmd.starts_with("/context rm") { - Some(CommandType::ContextRemove(cmd.to_string())) - } else { - match cmd { - "/tools trust" => Some(CommandType::Tools("trust")), - "/tools untrust" => Some(CommandType::Tools("untrust")), - "/profile set" => Some(CommandType::Profile("set")), - "/profile delete" => Some(CommandType::Profile("delete")), - "/profile rename" => Some(CommandType::Profile("rename")), - "/profile create" => Some(CommandType::Profile("create")), - _ => None, - } - } - } -} - -#[cfg(test)] -mod tests { - use std::collections::HashSet; - - use super::*; - - /// Test to verify that all hardcoded command strings in select_command - /// are present in the COMMANDS array from prompt.rs - #[test] - fn test_hardcoded_commands_in_commands_array() { - // Get the set of available commands from prompt.rs - let available_commands: HashSet = get_available_commands().iter().cloned().collect(); - - // List of hardcoded commands used in select_command - let hardcoded_commands = vec![ - "/context add", - "/context add --global", - "/context rm", - "/context rm --global", - "/tools trust", - "/tools untrust", - "/profile set", - "/profile delete", - "/profile rename", - "/profile create", - ]; - - // Check that each hardcoded command is in the COMMANDS array - for cmd in hardcoded_commands { - assert!( - available_commands.contains(cmd), - "Command '{}' is used in select_command but not defined in COMMANDS array", - cmd - ); - - // This should assert that all the commands we assert are present in the match statement of - // select_command() - assert!( - CommandType::from_str(cmd).is_some(), - "Command '{}' cannot be parsed into a CommandType", - cmd - ); - } - } -} diff --git a/crates/q_chat/src/token_counter.rs b/crates/q_chat/src/token_counter.rs deleted file mode 100644 index 1e651b96b4..0000000000 --- a/crates/q_chat/src/token_counter.rs +++ /dev/null @@ -1,251 +0,0 @@ -use std::ops::Deref; - -use super::conversation_state::{ - BackendConversationState, - ConversationSize, -}; -use super::message::{ - AssistantMessage, - ToolUseResult, - ToolUseResultBlock, - UserMessage, - UserMessageContent, -}; - -#[derive(Debug, Clone, Copy)] -pub struct CharCount(usize); - -impl CharCount { - pub fn value(&self) -> usize { - self.0 - } -} - -impl Deref for CharCount { - type Target = usize; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl From for CharCount { - fn from(value: usize) -> Self { - Self(value) - } -} - -impl std::ops::Add for CharCount { - type Output = CharCount; - - fn add(self, rhs: Self) -> Self::Output { - Self(self.value() + rhs.value()) - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub struct TokenCount(usize); - -impl TokenCount { - pub fn value(&self) -> usize { - self.0 - } -} - -impl Deref for TokenCount { - type Target = usize; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl From for TokenCount { - fn from(value: CharCount) -> Self { - Self(TokenCounter::count_tokens_char_count(value.value())) - } -} - -impl std::fmt::Display for TokenCount { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.0) - } -} - -pub struct TokenCounter; - -impl TokenCounter { - pub const TOKEN_TO_CHAR_RATIO: usize = 3; - - /// Estimates the number of tokens in the input content. - /// Currently uses a simple heuristic: content length / TOKEN_TO_CHAR_RATIO - /// - /// Rounds up to the nearest multiple of 10 to avoid giving users a false sense of precision. - pub fn count_tokens(content: &str) -> usize { - Self::count_tokens_char_count(content.len()) - } - - fn count_tokens_char_count(count: usize) -> usize { - (count / Self::TOKEN_TO_CHAR_RATIO + 5) / 10 * 10 - } - - pub const fn token_to_chars(token: usize) -> usize { - token * Self::TOKEN_TO_CHAR_RATIO - } -} - -/// A trait for types that represent some number of characters (aka bytes). For use in calculating -/// context window size utilization. -pub trait CharCounter { - /// Returns the number of characters contained within this type. - /// - /// One "character" is essentially the same as one "byte" - fn char_count(&self) -> CharCount; -} - -impl CharCounter for BackendConversationState<'_> { - fn char_count(&self) -> CharCount { - self.calculate_conversation_size().char_count() - } -} - -impl CharCounter for ConversationSize { - fn char_count(&self) -> CharCount { - self.user_messages + self.assistant_messages + self.context_messages - } -} - -impl CharCounter for UserMessage { - fn char_count(&self) -> CharCount { - let mut total_chars = 0; - total_chars += self.additional_context().len(); - match self.content() { - UserMessageContent::Prompt { prompt } => { - total_chars += prompt.len(); - }, - UserMessageContent::CancelledToolUses { - prompt, - tool_use_results, - } => { - total_chars += prompt.as_ref().map_or(0, String::len); - total_chars += tool_use_results.as_slice().char_count().0; - }, - UserMessageContent::ToolUseResults { tool_use_results } => { - total_chars += tool_use_results.as_slice().char_count().0; - }, - } - total_chars.into() - } -} - -impl CharCounter for AssistantMessage { - fn char_count(&self) -> CharCount { - let mut total_chars = 0; - total_chars += self.content().len(); - if let Some(tool_uses) = self.tool_uses() { - total_chars += tool_uses - .iter() - .map(|v| calculate_value_char_count(&v.args)) - .reduce(|acc, e| acc + e) - .unwrap_or_default(); - } - total_chars.into() - } -} - -impl CharCounter for &[ToolUseResult] { - fn char_count(&self) -> CharCount { - self.iter() - .flat_map(|v| &v.content) - .fold(0, |acc, v| { - acc + match v { - ToolUseResultBlock::Json(v) => calculate_value_char_count(v), - ToolUseResultBlock::Text(s) => s.len(), - } - }) - .into() - } -} - -fn calculate_value_char_count(document: &serde_json::Value) -> usize { - match document { - serde_json::Value::Null => 1, - serde_json::Value::Bool(_) => 1, - serde_json::Value::Number(_) => 1, - serde_json::Value::String(s) => s.len(), - serde_json::Value::Array(vec) => vec.iter().fold(0, |acc, v| acc + calculate_value_char_count(v)), - serde_json::Value::Object(map) => map.values().fold(0, |acc, v| acc + calculate_value_char_count(v)), - } -} - -#[cfg(test)] -mod tests { - - use super::*; - - #[test] - fn test_token_count() { - let text = "This is a test sentence."; - let count = TokenCounter::count_tokens(text); - assert_eq!(count, (text.len() / 3 + 5) / 10 * 10); - } - - #[test] - fn test_calculate_value_char_count() { - // Test simple types - assert_eq!( - calculate_value_char_count(&serde_json::Value::String("hello".to_string())), - 5 - ); - assert_eq!( - calculate_value_char_count(&serde_json::Value::Number(serde_json::Number::from(123))), - 1 - ); - assert_eq!(calculate_value_char_count(&serde_json::Value::Bool(true)), 1); - assert_eq!(calculate_value_char_count(&serde_json::Value::Null), 1); - - // Test array - let array = serde_json::Value::Array(vec![ - serde_json::Value::String("test".to_string()), - serde_json::Value::Number(serde_json::Number::from(42)), - serde_json::Value::Bool(false), - ]); - assert_eq!(calculate_value_char_count(&array), 6); // "test" (4) + Number (1) + Bool (1) - - // Test object - let mut obj = serde_json::Map::new(); - obj.insert("key1".to_string(), serde_json::Value::String("value1".to_string())); - obj.insert( - "key2".to_string(), - serde_json::Value::Number(serde_json::Number::from(99)), - ); - let object = serde_json::Value::Object(obj); - assert_eq!(calculate_value_char_count(&object), 7); // "value1" (6) + Number (1) - - // Test nested structure - let mut nested_obj = serde_json::Map::new(); - let mut inner_obj = serde_json::Map::new(); - inner_obj.insert( - "inner_key".to_string(), - serde_json::Value::String("inner_value".to_string()), - ); - nested_obj.insert("outer_key".to_string(), serde_json::Value::Object(inner_obj)); - nested_obj.insert( - "array_key".to_string(), - serde_json::Value::Array(vec![ - serde_json::Value::String("item1".to_string()), - serde_json::Value::String("item2".to_string()), - ]), - ); - - let complex = serde_json::Value::Object(nested_obj); - assert_eq!(calculate_value_char_count(&complex), 21); // "inner_value" (11) + "item1" (5) + "item2" (5) - - // Test empty structures - assert_eq!(calculate_value_char_count(&serde_json::Value::Array(vec![])), 0); - assert_eq!( - calculate_value_char_count(&serde_json::Value::Object(serde_json::Map::new())), - 0 - ); - } -} diff --git a/crates/q_chat/src/tool_manager.rs b/crates/q_chat/src/tool_manager.rs deleted file mode 100644 index dfd251cb74..0000000000 --- a/crates/q_chat/src/tool_manager.rs +++ /dev/null @@ -1,1019 +0,0 @@ -use std::collections::HashMap; -use std::hash::{ - DefaultHasher, - Hasher, -}; -use std::io::Write; -use std::path::PathBuf; -use std::sync::mpsc::RecvTimeoutError; -use std::sync::{ - Arc, - RwLock as SyncRwLock, -}; - -use convert_case::Casing; -use crossterm::{ - cursor, - execute, - queue, - style, - terminal, -}; -use fig_api_client::model::{ - ToolResult, - ToolResultContentBlock, - ToolResultStatus, -}; -use futures::{ - StreamExt, - stream, -}; -use mcp_client::{ - JsonRpcResponse, - PromptGet, -}; -use serde::{ - Deserialize, - Serialize, -}; -use thiserror::Error; -use tokio::sync::Mutex; -use tracing::error; - -use super::command::PromptsGetCommand; -use super::message::AssistantToolUse; -use super::tools::custom_tool::{ - CustomToolClient, - CustomToolConfig, -}; -use super::tools::execute_bash::ExecuteBash; -use super::tools::fs_read::FsRead; -use super::tools::fs_write::FsWrite; -use super::tools::gh_issue::GhIssue; -use super::tools::use_aws::UseAws; -use super::tools::{ - Tool, - ToolOrigin, -}; -use crate::tools::ToolSpec; -use crate::tools::custom_tool::CustomTool; - -const NAMESPACE_DELIMITER: &str = "___"; -// This applies for both mcp server and tool name since in the end the tool name as seen by the -// model is just {server_name}{NAMESPACE_DELIMITER}{tool_name} -const VALID_TOOL_NAME: &str = "^[a-zA-Z][a-zA-Z0-9_]*$"; -const SPINNER_CHARS: [char; 10] = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']; - -#[derive(Debug, Error)] -pub enum GetPromptError { - #[error("Prompt with name {0} does not exist")] - PromptNotFound(String), - #[error("Prompt {0} is offered by more than one server. Use one of the following {1}")] - AmbiguousPrompt(String, String), - #[error("Missing client")] - MissingClient, - #[error("Missing prompt name")] - MissingPromptName, - #[error("Synchronization error: {0}")] - Synchronization(String), - #[error("Missing prompt bundle")] - MissingPromptInfo, - #[error(transparent)] - General(#[from] eyre::Report), -} - -/// Messages used for communication between the tool initialization thread and the loading -/// display thread. These messages control the visual loading indicators shown to -/// the user during tool initialization. -enum LoadingMsg { - /// Indicates a new tool is being initialized and should be added to the loading - /// display. The String parameter is the name of the tool being initialized. - Add(String), - /// Indicates a tool has finished initializing successfully and should be removed from - /// the loading display. The String parameter is the name of the tool that - /// completed initialization. - Done(String), - /// Represents an error that occurred during tool initialization. - /// Contains the name of the server that failed to initialize and the error message. - Error { name: String, msg: eyre::Report }, - /// Represents a warning that occurred during tool initialization. - /// Contains the name of the server that generated the warning and the warning message. - Warn { name: String, msg: eyre::Report }, -} - -/// Represents the state of a loading indicator for a tool being initialized. -/// -/// This struct tracks timing information for each tool's loading status display in the terminal. -/// -/// # Fields -/// * `init_time` - When initialization for this tool began, used to calculate load time -struct StatusLine { - init_time: std::time::Instant, -} - -// This is to mirror claude's config set up -#[derive(Clone, Serialize, Deserialize, Debug, Default)] -#[serde(rename_all = "camelCase")] -pub struct McpServerConfig { - mcp_servers: HashMap, -} - -impl McpServerConfig { - pub async fn load_config(output: &mut impl Write) -> eyre::Result { - let mut cwd = std::env::current_dir()?; - cwd.push(".amazonq/mcp.json"); - let expanded_path = shellexpand::tilde("~/.aws/amazonq/mcp.json"); - let global_path = PathBuf::from(expanded_path.as_ref()); - let global_buf = tokio::fs::read(global_path).await.ok(); - let local_buf = tokio::fs::read(cwd).await.ok(); - let conf = match (global_buf, local_buf) { - (Some(global_buf), Some(local_buf)) => { - let mut global_conf = Self::from_slice(&global_buf, output, "global")?; - let local_conf = Self::from_slice(&local_buf, output, "local")?; - for (server_name, config) in local_conf.mcp_servers { - if global_conf.mcp_servers.insert(server_name.clone(), config).is_some() { - queue!( - output, - style::SetForegroundColor(style::Color::Yellow), - style::Print("WARNING: "), - style::ResetColor, - style::Print("MCP config conflict for "), - style::SetForegroundColor(style::Color::Green), - style::Print(server_name), - style::ResetColor, - style::Print(". Using workspace version.\n") - )?; - } - } - global_conf - }, - (None, Some(local_buf)) => Self::from_slice(&local_buf, output, "local")?, - (Some(global_buf), None) => Self::from_slice(&global_buf, output, "global")?, - _ => Default::default(), - }; - output.flush()?; - Ok(conf) - } - - fn from_slice(slice: &[u8], output: &mut impl Write, location: &str) -> eyre::Result { - match serde_json::from_slice::(slice) { - Ok(config) => Ok(config), - Err(e) => { - queue!( - output, - style::SetForegroundColor(style::Color::Yellow), - style::Print("WARNING: "), - style::ResetColor, - style::Print(format!("Error reading {location} mcp config: {e}\n")), - style::Print("Please check to make sure config is correct. Discarding.\n"), - )?; - Ok(McpServerConfig::default()) - }, - } - } -} - -#[derive(Default)] -pub struct ToolManagerBuilder { - mcp_server_config: Option, - prompt_list_sender: Option>>, - prompt_list_receiver: Option>>, - conversation_id: Option, -} - -impl ToolManagerBuilder { - pub fn mcp_server_config(mut self, config: McpServerConfig) -> Self { - self.mcp_server_config.replace(config); - self - } - - pub fn prompt_list_sender(mut self, sender: std::sync::mpsc::Sender>) -> Self { - self.prompt_list_sender.replace(sender); - self - } - - pub fn prompt_list_receiver(mut self, receiver: std::sync::mpsc::Receiver>) -> Self { - self.prompt_list_receiver.replace(receiver); - self - } - - pub fn conversation_id(mut self, conversation_id: &str) -> Self { - self.conversation_id.replace(conversation_id.to_string()); - self - } - - pub fn build(mut self) -> eyre::Result { - let McpServerConfig { mcp_servers } = self.mcp_server_config.ok_or(eyre::eyre!("Missing mcp server config"))?; - debug_assert!(self.conversation_id.is_some()); - let conversation_id = self.conversation_id.ok_or(eyre::eyre!("Missing conversation id"))?; - let regex = regex::Regex::new(VALID_TOOL_NAME)?; - let mut hasher = DefaultHasher::new(); - let pre_initialized = mcp_servers - .into_iter() - .map(|(server_name, server_config)| { - let snaked_cased_name = server_name.to_case(convert_case::Case::Snake); - let sanitized_server_name = sanitize_name(snaked_cased_name, ®ex, &mut hasher); - let custom_tool_client = CustomToolClient::from_config(sanitized_server_name.clone(), server_config); - (sanitized_server_name, custom_tool_client) - }) - .collect::>(); - - // Send up task to update user on server loading status - let (tx, rx) = std::sync::mpsc::channel::(); - // Using a hand rolled thread because it's just easier to do this than do deal with the Send - // requirements that comes with holding onto the stdout lock. - let loading_display_task = std::thread::spawn(move || { - let stdout = std::io::stdout(); - let mut stdout_lock = stdout.lock(); - let mut loading_servers = HashMap::::new(); - let mut spinner_logo_idx: usize = 0; - let mut complete: usize = 0; - let mut failed: usize = 0; - loop { - match rx.recv_timeout(std::time::Duration::from_millis(50)) { - Ok(recv_result) => match recv_result { - LoadingMsg::Add(name) => { - let init_time = std::time::Instant::now(); - let status_line = StatusLine { init_time }; - execute!(stdout_lock, cursor::MoveToColumn(0))?; - if !loading_servers.is_empty() { - // TODO: account for terminal width - execute!(stdout_lock, cursor::MoveUp(1))?; - } - loading_servers.insert(name.clone(), status_line); - let total = loading_servers.len(); - execute!(stdout_lock, terminal::Clear(terminal::ClearType::CurrentLine))?; - queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; - stdout_lock.flush()?; - }, - LoadingMsg::Done(name) => { - if let Some(status_line) = loading_servers.get(&name) { - complete += 1; - let time_taken = - (std::time::Instant::now() - status_line.init_time).as_secs_f64().abs(); - let time_taken = format!("{:.2}", time_taken); - execute!( - stdout_lock, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - queue_success_message(&name, &time_taken, &mut stdout_lock)?; - let total = loading_servers.len(); - queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; - stdout_lock.flush()?; - } - }, - LoadingMsg::Error { name, msg } => { - failed += 1; - execute!( - stdout_lock, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - queue_failure_message(&name, &msg, &mut stdout_lock)?; - let total = loading_servers.len(); - queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; - }, - LoadingMsg::Warn { name, msg } => { - complete += 1; - execute!( - stdout_lock, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - let msg = eyre::eyre!(msg.to_string()); - queue_warn_message(&name, &msg, &mut stdout_lock)?; - let total = loading_servers.len(); - queue_init_message(spinner_logo_idx, complete, failed, total, &mut stdout_lock)?; - stdout_lock.flush()?; - }, - }, - Err(RecvTimeoutError::Timeout) => { - spinner_logo_idx = (spinner_logo_idx + 1) % SPINNER_CHARS.len(); - execute!( - stdout_lock, - cursor::SavePosition, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - style::Print(SPINNER_CHARS[spinner_logo_idx]), - cursor::RestorePosition - )?; - }, - _ => break, - } - } - Ok::<_, eyre::Report>(()) - }); - let mut clients = HashMap::>::new(); - for (mut name, init_res) in pre_initialized { - let _ = tx.send(LoadingMsg::Add(name.clone())); - match init_res { - Ok(client) => { - let mut client = Arc::new(client); - while let Some(collided_client) = clients.insert(name.clone(), client) { - // to avoid server name collision we are going to circumvent this by - // appending the name with 1 - name.push('1'); - client = collided_client; - } - }, - Err(e) => { - error!("Error initializing mcp client for server {}: {:?}", name, &e); - let event = fig_telemetry::EventType::McpServerInit { - conversation_id: conversation_id.clone(), - init_failure_reason: Some(e.to_string()), - number_of_tools: 0, - }; - tokio::spawn(async move { - let app_event = fig_telemetry::AppTelemetryEvent::new(event).await; - fig_telemetry::dispatch_or_send_event(app_event).await; - }); - let _ = tx.send(LoadingMsg::Error { - name: name.clone(), - msg: e, - }); - }, - } - } - let loading_display_task = Some(loading_display_task); - let loading_status_sender = Some(tx); - - // Set up task to handle prompt requests - let sender = self.prompt_list_sender.take(); - let receiver = self.prompt_list_receiver.take(); - let prompts = Arc::new(SyncRwLock::new(HashMap::default())); - // TODO: accommodate hot reload of mcp servers - if let (Some(sender), Some(receiver)) = (sender, receiver) { - let clients = clients.iter().fold(HashMap::new(), |mut acc, (n, c)| { - acc.insert(n.to_string(), Arc::downgrade(c)); - acc - }); - let prompts_clone = prompts.clone(); - tokio::task::spawn_blocking(move || { - let receiver = Arc::new(std::sync::Mutex::new(receiver)); - loop { - let search_word = receiver.lock().map_err(|e| eyre::eyre!("{:?}", e))?.recv()?; - if clients - .values() - .any(|client| client.upgrade().is_some_and(|c| c.is_prompts_out_of_date())) - { - let mut prompts_wl = prompts_clone.write().map_err(|e| { - eyre::eyre!( - "Error retrieving write lock on prompts for tab complete {}", - e.to_string() - ) - })?; - *prompts_wl = clients.iter().fold( - HashMap::>::new(), - |mut acc, (server_name, client)| { - let Some(client) = client.upgrade() else { - return acc; - }; - let prompt_gets = client.list_prompt_gets(); - let Ok(prompt_gets) = prompt_gets.read() else { - tracing::error!("Error retrieving read lock for prompt gets for tab complete"); - return acc; - }; - for (prompt_name, prompt_get) in prompt_gets.iter() { - acc.entry(prompt_name.to_string()) - .and_modify(|bundles| { - bundles.push(PromptBundle { - server_name: server_name.to_owned(), - prompt_get: prompt_get.clone(), - }); - }) - .or_insert(vec![PromptBundle { - server_name: server_name.to_owned(), - prompt_get: prompt_get.clone(), - }]); - } - client.prompts_updated(); - acc - }, - ); - } - let prompts_rl = prompts_clone.read().map_err(|e| { - eyre::eyre!( - "Error retrieving read lock on prompts for tab complete {}", - e.to_string() - ) - })?; - let filtered_prompts = prompts_rl - .iter() - .flat_map(|(prompt_name, bundles)| { - if bundles.len() > 1 { - bundles - .iter() - .map(|b| format!("{}/{}", b.server_name, prompt_name)) - .collect() - } else { - vec![prompt_name.to_owned()] - } - }) - .filter(|n| { - if let Some(p) = &search_word { - n.contains(p) - } else { - true - } - }) - .collect::>(); - if let Err(e) = sender.send(filtered_prompts) { - error!("Error sending prompts to chat helper: {:?}", e); - } - } - #[allow(unreachable_code)] - Ok::<(), eyre::Report>(()) - }); - } - - Ok(ToolManager { - conversation_id, - clients, - prompts, - loading_display_task, - loading_status_sender, - ..Default::default() - }) - } -} - -#[derive(Clone, Debug)] -/// A collection of information that is used for the following purposes: -/// - Checking if prompt info cached is out of date -/// - Retrieve new prompt info -pub struct PromptBundle { - /// The server name from which the prompt is offered / exposed - pub server_name: String, - /// The prompt get (info with which a prompt is retrieved) cached - pub prompt_get: PromptGet, -} - -/// Categorizes different types of tool name validation failures: -/// - `TooLong`: The tool name exceeds the maximum allowed length -/// - `IllegalChar`: The tool name contains characters that are not allowed -/// - `EmptyDescription`: The tool description is empty or missing -#[allow(dead_code)] -enum OutOfSpecName { - TooLong(String), - IllegalChar(String), - EmptyDescription(String), -} - -/// Manages the lifecycle and interactions with tools from various sources, including MCP servers. -/// This struct is responsible for initializing tools, handling tool requests, and maintaining -/// a cache of available prompts from connected servers. -#[derive(Default)] -pub struct ToolManager { - /// Unique identifier for the current conversation. - /// This ID is used to track and associate tools with a specific chat session. - pub conversation_id: String, - - /// Map of server names to their corresponding client instances. - /// These clients are used to communicate with MCP servers. - pub clients: HashMap>, - - /// Cache for prompts collected from different servers. - /// Key: prompt name - /// Value: a list of PromptBundle that has a prompt of this name. - /// This cache helps resolve prompt requests efficiently and handles - /// cases where multiple servers offer prompts with the same name. - pub prompts: Arc>>>, - - /// Handle to the thread that displays loading status for tool initialization. - /// This thread provides visual feedback to users during the tool loading process. - loading_display_task: Option>>, - - /// Channel sender for communicating with the loading display thread. - /// Used to send status updates about tool initialization progress. - loading_status_sender: Option>, - - /// Mapping from sanitized tool names to original tool names. - /// This is used to handle tool name transformations that may occur during initialization - /// to ensure tool names comply with naming requirements. - pub tn_map: HashMap, - - /// A cache of tool's input schema for all of the available tools. - /// This is mainly used to show the user what the tools look like from the perspective of the - /// model. - pub schema: HashMap, -} - -impl ToolManager { - pub async fn load_tools(&mut self) -> eyre::Result> { - let tx = self.loading_status_sender.take(); - let display_task = self.loading_display_task.take(); - let tool_specs = { - let tool_specs = serde_json::from_str::>(include_str!("tools/tool_index.json"))?; - Arc::new(Mutex::new(tool_specs)) - }; - let conversation_id = self.conversation_id.clone(); - let regex = Arc::new(regex::Regex::new(VALID_TOOL_NAME)?); - let load_tool = self - .clients - .iter() - .map(|(server_name, client)| { - let client_clone = client.clone(); - let server_name_clone = server_name.clone(); - let tx_clone = tx.clone(); - let regex_clone = regex.clone(); - let tool_specs_clone = tool_specs.clone(); - let conversation_id = conversation_id.clone(); - async move { - let tool_spec = client_clone.init().await; - let mut sanitized_mapping = HashMap::::new(); - match tool_spec { - Ok((server_name, specs)) => { - // Each mcp server might have multiple tools. - // To avoid naming conflicts we are going to namespace it. - // This would also help us locate which mcp server to call the tool from. - let mut out_of_spec_tool_names = Vec::::new(); - let mut hasher = DefaultHasher::new(); - let number_of_tools = specs.len(); - // Sanitize tool names to ensure they comply with the naming requirements: - // 1. If the name already matches the regex pattern and doesn't contain the namespace delimiter, use it as is - // 2. Otherwise, remove invalid characters and handle special cases: - // - Remove namespace delimiters - // - Ensure the name starts with an alphabetic character - // - Generate a hash-based name if the sanitized result is empty - // This ensures all tool names are valid identifiers that can be safely used in the system - // If after all of the aforementioned modification the combined tool - // name we have exceeds a length of 64, we surface it as an error - for mut spec in specs { - let sn = if !regex_clone.is_match(&spec.name) { - let mut sn = sanitize_name(spec.name.clone(), ®ex_clone, &mut hasher); - while sanitized_mapping.contains_key(&sn) { - sn.push('1'); - } - sn - } else { - spec.name.clone() - }; - let full_name = format!("{}{}{}", server_name, NAMESPACE_DELIMITER, sn); - if full_name.len() > 64 { - out_of_spec_tool_names.push(OutOfSpecName::TooLong(spec.name)); - continue; - } else if spec.description.is_empty() { - out_of_spec_tool_names.push(OutOfSpecName::EmptyDescription(spec.name)); - continue; - } - if sn != spec.name { - sanitized_mapping.insert(full_name.clone(), format!("{}{}{}", server_name, NAMESPACE_DELIMITER, spec.name)); - } - spec.name = full_name; - spec.tool_origin = ToolOrigin::McpServer(server_name.clone()); - tool_specs_clone.lock().await.insert(spec.name.clone(), spec); - } - // Send server load success metric datum - tokio::spawn(async move { - let event = fig_telemetry::EventType::McpServerInit { conversation_id, init_failure_reason: None, number_of_tools }; - let app_event = fig_telemetry::AppTelemetryEvent::new(event).await; - fig_telemetry::dispatch_or_send_event(app_event).await; - }); - // Tool name translation. This is beyond of the scope of what is - // considered a "server load". Reasoning being: - // - Failures here are not related to server load - // - There is not a whole lot we can do with this data - if let Some(tx_clone) = &tx_clone { - let send_result = if !out_of_spec_tool_names.is_empty() { - let msg = out_of_spec_tool_names.iter().fold( - String::from("The following tools are out of spec. They will be excluded from the list of available tools:\n"), - |mut acc, name| { - let (tool_name, msg) = match name { - OutOfSpecName::TooLong(tool_name) => (tool_name.as_str(), "tool name exceeds max length of 64 when combined with server name"), - OutOfSpecName::IllegalChar(tool_name) => (tool_name.as_str(), "tool name must be compliant with ^[a-zA-Z][a-zA-Z0-9_]*$"), - OutOfSpecName::EmptyDescription(tool_name) => (tool_name.as_str(), "tool schema contains empty description"), - }; - acc.push_str(format!(" - {} ({})\n", tool_name, msg).as_str()); - acc - } - ); - tx_clone.send(LoadingMsg::Error { - name: server_name.clone(), - msg: eyre::eyre!(msg), - }) - // TODO: if no tools are valid, we need to offload the server - // from the fleet (i.e. kill the server) - } else if !sanitized_mapping.is_empty() { - let warn = sanitized_mapping.iter().fold(String::from("The following tool names are changed:\n"), |mut acc, (k, v)| { - acc.push_str(format!(" - {} -> {}\n", v, k).as_str()); - acc - }); - tx_clone.send(LoadingMsg::Warn { - name: server_name.clone(), - msg: eyre::eyre!(warn), - }) - } else { - tx_clone.send(LoadingMsg::Done(server_name.clone())) - }; - if let Err(e) = send_result { - error!("Error while sending status update to display task: {:?}", e); - } - } - }, - Err(e) => { - error!("Error obtaining tool spec for {}: {:?}", server_name_clone, e); - let init_failure_reason = Some(e.to_string()); - tokio::spawn(async move { - let event = fig_telemetry::EventType::McpServerInit { conversation_id, init_failure_reason, number_of_tools: 0 }; - let app_event = fig_telemetry::AppTelemetryEvent::new(event).await; - fig_telemetry::dispatch_or_send_event(app_event).await; - }); - if let Some(tx_clone) = &tx_clone { - if let Err(e) = tx_clone.send(LoadingMsg::Error { - name: server_name_clone, - msg: e, - }) { - error!("Error while sending status update to display task: {:?}", e); - } - } - }, - } - Ok::<_, eyre::Report>(Some(sanitized_mapping)) - } - }) - .collect::>(); - // TODO: do we want to introduce a timeout here? - self.tn_map = stream::iter(load_tool) - .map(|async_closure| tokio::task::spawn(async_closure)) - .buffer_unordered(20) - .collect::>() - .await - .into_iter() - .filter_map(|r| r.ok()) - .filter_map(|r| r.ok()) - .flatten() - .flatten() - .collect::>(); - drop(tx); - if let Some(display_task) = display_task { - if let Err(e) = display_task.join() { - error!("Error while joining status display task: {:?}", e); - } - } - let tool_specs = { - let mutex = - Arc::try_unwrap(tool_specs).map_err(|e| eyre::eyre!("Error unwrapping arc for tool specs {:?}", e))?; - mutex.into_inner() - }; - // caching the tool names for skim operations - for tool_name in tool_specs.keys() { - if !self.tn_map.contains_key(tool_name) { - self.tn_map.insert(tool_name.clone(), tool_name.clone()); - } - } - self.schema = tool_specs.clone(); - Ok(tool_specs) - } - - pub fn get_tool_from_tool_use(&self, value: AssistantToolUse) -> Result { - let map_err = |parse_error| ToolResult { - tool_use_id: value.id.clone(), - content: vec![ToolResultContentBlock::Text(format!( - "Failed to validate tool parameters: {parse_error}. The model has either suggested tool parameters which are incompatible with the existing tools, or has suggested one or more tool that does not exist in the list of known tools." - ))], - status: ToolResultStatus::Error, - }; - - Ok(match value.name.as_str() { - "fs_read" => Tool::FsRead(serde_json::from_value::(value.args).map_err(map_err)?), - "fs_write" => Tool::FsWrite(serde_json::from_value::(value.args).map_err(map_err)?), - "execute_bash" => Tool::ExecuteBash(serde_json::from_value::(value.args).map_err(map_err)?), - "use_aws" => Tool::UseAws(serde_json::from_value::(value.args).map_err(map_err)?), - "report_issue" => Tool::GhIssue(serde_json::from_value::(value.args).map_err(map_err)?), - // Note that this name is namespaced with server_name{DELIMITER}tool_name - name => { - let name = self.tn_map.get(name).map_or(name, String::as_str); - let (server_name, tool_name) = name.split_once(NAMESPACE_DELIMITER).ok_or(ToolResult { - tool_use_id: value.id.clone(), - content: vec![ToolResultContentBlock::Text(format!( - "The tool, \"{name}\" is supplied with incorrect name" - ))], - status: ToolResultStatus::Error, - })?; - let Some(client) = self.clients.get(server_name) else { - return Err(ToolResult { - tool_use_id: value.id, - content: vec![ToolResultContentBlock::Text(format!( - "The tool, \"{server_name}\" is not supported by the client" - ))], - status: ToolResultStatus::Error, - }); - }; - // The tool input schema has the shape of { type, properties }. - // The field "params" expected by MCP is { name, arguments }, where name is the - // name of the tool being invoked, - // https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#calling-tools. - // The field "arguments" is where ToolUse::args belong. - let mut params = serde_json::Map::::new(); - params.insert("name".to_owned(), serde_json::Value::String(tool_name.to_owned())); - params.insert("arguments".to_owned(), value.args); - let params = serde_json::Value::Object(params); - let custom_tool = CustomTool { - name: tool_name.to_owned(), - client: client.clone(), - method: "tools/call".to_owned(), - params: Some(params), - }; - Tool::Custom(custom_tool) - }, - }) - } - - #[allow(clippy::await_holding_lock)] - pub async fn get_prompt(&self, get_command: PromptsGetCommand) -> Result { - let (server_name, prompt_name) = match get_command.params.name.split_once('/') { - None => (None::, Some(get_command.params.name.clone())), - Some((server_name, prompt_name)) => (Some(server_name.to_string()), Some(prompt_name.to_string())), - }; - let prompt_name = prompt_name.ok_or(GetPromptError::MissingPromptName)?; - // We need to use a sync lock here because this lock is also used in a blocking thread, - // necessitated by the fact that said thread is also responsible for using a sync channel, - // which is itself necessitated by the fact that consumer of said channel is calling from a - // sync function - let mut prompts_wl = self - .prompts - .write() - .map_err(|e| GetPromptError::Synchronization(e.to_string()))?; - let mut maybe_bundles = prompts_wl.get(&prompt_name); - let mut has_retried = false; - 'blk: loop { - match (maybe_bundles, server_name.as_ref(), has_retried) { - // If we have more than one eligible clients but no server name specified - (Some(bundles), None, _) if bundles.len() > 1 => { - break 'blk Err(GetPromptError::AmbiguousPrompt(prompt_name.clone(), { - bundles.iter().fold("\n".to_string(), |mut acc, b| { - acc.push_str(&format!("- @{}/{}\n", b.server_name, prompt_name)); - acc - }) - })); - }, - // Normal case where we have enough info to proceed - // Note that if bundle exists, it should never be empty - (Some(bundles), sn, _) => { - let bundle = if bundles.len() > 1 { - let Some(server_name) = sn else { - maybe_bundles = None; - continue 'blk; - }; - let bundle = bundles.iter().find(|b| b.server_name == *server_name); - match bundle { - Some(bundle) => bundle, - None => { - maybe_bundles = None; - continue 'blk; - }, - } - } else { - bundles.first().ok_or(GetPromptError::MissingPromptInfo)? - }; - let server_name = bundle.server_name.clone(); - let client = self.clients.get(&server_name).ok_or(GetPromptError::MissingClient)?; - // Here we lazily update the out of date cache - if client.is_prompts_out_of_date() { - let prompt_gets = client.list_prompt_gets(); - let prompt_gets = prompt_gets - .read() - .map_err(|e| GetPromptError::Synchronization(e.to_string()))?; - for (prompt_name, prompt_get) in prompt_gets.iter() { - prompts_wl - .entry(prompt_name.to_string()) - .and_modify(|bundles| { - let mut is_modified = false; - for bundle in &mut *bundles { - let mut updated_bundle = PromptBundle { - server_name: server_name.clone(), - prompt_get: prompt_get.clone(), - }; - if bundle.server_name == *server_name { - std::mem::swap(bundle, &mut updated_bundle); - is_modified = true; - break; - } - } - if !is_modified { - bundles.push(PromptBundle { - server_name: server_name.clone(), - prompt_get: prompt_get.clone(), - }); - } - }) - .or_insert(vec![PromptBundle { - server_name: server_name.clone(), - prompt_get: prompt_get.clone(), - }]); - } - client.prompts_updated(); - } - let PromptsGetCommand { params, .. } = get_command; - let PromptBundle { prompt_get, .. } = prompts_wl - .get(&prompt_name) - .and_then(|bundles| bundles.iter().find(|b| b.server_name == server_name)) - .ok_or(GetPromptError::MissingPromptInfo)?; - // Here we need to convert the positional arguments into key value pair - // The assignment order is assumed to be the order of args as they are - // presented in PromptGet::arguments - let args = if let (Some(schema), Some(value)) = (&prompt_get.arguments, ¶ms.arguments) { - let params = schema.iter().zip(value.iter()).fold( - HashMap::::new(), - |mut acc, (prompt_get_arg, value)| { - acc.insert(prompt_get_arg.name.clone(), value.clone()); - acc - }, - ); - Some(serde_json::json!(params)) - } else { - None - }; - let params = { - let mut params = serde_json::Map::new(); - params.insert("name".to_string(), serde_json::Value::String(prompt_name)); - if let Some(args) = args { - params.insert("arguments".to_string(), args); - } - Some(serde_json::Value::Object(params)) - }; - let resp = client.request("prompts/get", params).await?; - break 'blk Ok(resp); - }, - // If we have no eligible clients this would mean one of the following: - // - The prompt does not exist, OR - // - This is the first time we have a query / our cache is out of date - // Both of which means we would have to requery - (None, _, false) => { - has_retried = true; - self.refresh_prompts(&mut prompts_wl)?; - maybe_bundles = prompts_wl.get(&prompt_name); - continue 'blk; - }, - (_, _, true) => { - break 'blk Err(GetPromptError::PromptNotFound(prompt_name)); - }, - } - } - } - - pub fn refresh_prompts(&self, prompts_wl: &mut HashMap>) -> Result<(), GetPromptError> { - *prompts_wl = self.clients.iter().fold( - HashMap::>::new(), - |mut acc, (server_name, client)| { - let prompt_gets = client.list_prompt_gets(); - let Ok(prompt_gets) = prompt_gets.read() else { - tracing::error!("Error encountered while retrieving read lock"); - return acc; - }; - for (prompt_name, prompt_get) in prompt_gets.iter() { - acc.entry(prompt_name.to_string()) - .and_modify(|bundles| { - bundles.push(PromptBundle { - server_name: server_name.to_owned(), - prompt_get: prompt_get.clone(), - }); - }) - .or_insert(vec![PromptBundle { - server_name: server_name.to_owned(), - prompt_get: prompt_get.clone(), - }]); - } - acc - }, - ); - Ok(()) - } -} - -fn sanitize_name(orig: String, regex: ®ex::Regex, hasher: &mut impl Hasher) -> String { - if regex.is_match(&orig) && !orig.contains(NAMESPACE_DELIMITER) { - return orig; - } - let sanitized: String = orig - .chars() - .filter(|c| c.is_ascii_alphabetic() || c.is_ascii_digit() || *c == '_') - .collect::() - .replace(NAMESPACE_DELIMITER, ""); - if sanitized.is_empty() { - hasher.write(orig.as_bytes()); - let hash = format!("{:03}", hasher.finish() % 1000); - return format!("a{}", hash); - } - match sanitized.chars().next() { - Some(c) if c.is_ascii_alphabetic() => sanitized, - Some(_) => { - format!("a{}", sanitized) - }, - None => { - hasher.write(orig.as_bytes()); - format!("a{}", hasher.finish()) - }, - } -} - -fn queue_success_message(name: &str, time_taken: &str, output: &mut impl Write) -> eyre::Result<()> { - Ok(queue!( - output, - style::SetForegroundColor(style::Color::Green), - style::Print("✓ "), - style::SetForegroundColor(style::Color::Blue), - style::Print(name), - style::ResetColor, - style::Print(" loaded in "), - style::SetForegroundColor(style::Color::Yellow), - style::Print(format!("{time_taken} s\n")), - )?) -} - -fn queue_init_message( - spinner_logo_idx: usize, - complete: usize, - failed: usize, - total: usize, - output: &mut impl Write, -) -> eyre::Result<()> { - if total == complete { - queue!( - output, - style::SetForegroundColor(style::Color::Green), - style::Print("✓"), - style::ResetColor, - )?; - } else if total == complete + failed { - queue!( - output, - style::SetForegroundColor(style::Color::Red), - style::Print("✗"), - style::ResetColor, - )?; - } else { - queue!(output, style::Print(SPINNER_CHARS[spinner_logo_idx]))?; - } - Ok(queue!( - output, - style::SetForegroundColor(style::Color::Blue), - style::Print(format!(" {}", complete)), - style::ResetColor, - style::Print(" of "), - style::SetForegroundColor(style::Color::Blue), - style::Print(format!("{} ", total)), - style::ResetColor, - style::Print("mcp servers initialized\n"), - )?) -} - -fn queue_failure_message(name: &str, fail_load_msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { - Ok(queue!( - output, - style::SetForegroundColor(style::Color::Red), - style::Print("✗ "), - style::SetForegroundColor(style::Color::Blue), - style::Print(name), - style::ResetColor, - style::Print(" has failed to load:\n- "), - style::Print(fail_load_msg), - style::Print("\n"), - style::Print("- run with Q_LOG_LEVEL=trace and see $TMPDIR/qlog for detail\n"), - style::ResetColor, - )?) -} - -fn queue_warn_message(name: &str, msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { - Ok(queue!( - output, - style::SetForegroundColor(style::Color::Yellow), - style::Print("⚠ "), - style::SetForegroundColor(style::Color::Blue), - style::Print(name), - style::ResetColor, - style::Print(" has the following warning:\n"), - style::Print(msg), - style::ResetColor, - )?) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_sanitize_server_name() { - let regex = regex::Regex::new(VALID_TOOL_NAME).unwrap(); - let mut hasher = DefaultHasher::new(); - let orig_name = "@awslabs.cdk-mcp-server"; - let sanitized_server_name = sanitize_name(orig_name.to_string(), ®ex, &mut hasher); - assert_eq!(sanitized_server_name, "awslabscdkmcpserver"); - - let orig_name = "good_name"; - let sanitized_good_name = sanitize_name(orig_name.to_string(), ®ex, &mut hasher); - assert_eq!(sanitized_good_name, orig_name); - - let all_bad_name = "@@@@@"; - let sanitized_all_bad_name = sanitize_name(all_bad_name.to_string(), ®ex, &mut hasher); - assert!(regex.is_match(&sanitized_all_bad_name)); - - let with_delim = format!("a{}b{}c", NAMESPACE_DELIMITER, NAMESPACE_DELIMITER); - let sanitized = sanitize_name(with_delim, ®ex, &mut hasher); - assert_eq!(sanitized, "abc"); - } -} diff --git a/crates/q_chat/src/tools/custom_tool.rs b/crates/q_chat/src/tools/custom_tool.rs deleted file mode 100644 index e034837bac..0000000000 --- a/crates/q_chat/src/tools/custom_tool.rs +++ /dev/null @@ -1,241 +0,0 @@ -use std::collections::HashMap; -use std::io::Write; -use std::sync::Arc; -use std::sync::atomic::Ordering; - -use crossterm::{ - queue, - style, -}; -use eyre::Result; -use fig_os_shim::Context; -use mcp_client::{ - Client as McpClient, - ClientConfig as McpClientConfig, - JsonRpcResponse, - JsonRpcStdioTransport, - MessageContent, - PromptGet, - ServerCapabilities, - StdioTransport, - ToolCallResult, -}; -use serde::{ - Deserialize, - Serialize, -}; -use tokio::sync::RwLock; -use tracing::warn; - -use super::{ - InvokeOutput, - ToolSpec, -}; -use crate::CONTINUATION_LINE; -use crate::token_counter::TokenCounter; - -// TODO: support http transport type -#[derive(Clone, Serialize, Deserialize, Debug)] -pub struct CustomToolConfig { - pub command: String, - #[serde(default)] - pub args: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub env: Option>, - #[serde(default = "default_timeout")] - pub timeout: u64, -} - -fn default_timeout() -> u64 { - 120 * 1000 -} - -#[derive(Debug)] -pub enum CustomToolClient { - Stdio { - server_name: String, - client: McpClient, - server_capabilities: RwLock>, - }, -} - -impl CustomToolClient { - // TODO: add support for http transport - pub fn from_config(server_name: String, config: CustomToolConfig) -> Result { - let CustomToolConfig { - command, - args, - env, - timeout, - } = config; - let mcp_client_config = McpClientConfig { - server_name: server_name.clone(), - bin_path: command.clone(), - args, - timeout, - client_info: serde_json::json!({ - "name": "Q CLI Chat", - "version": "1.0.0" - }), - env, - }; - let client = McpClient::::from_config(mcp_client_config)?; - Ok(CustomToolClient::Stdio { - server_name, - client, - server_capabilities: RwLock::new(None), - }) - } - - pub async fn init(&self) -> Result<(String, Vec)> { - match self { - CustomToolClient::Stdio { - client, - server_name, - server_capabilities, - } => { - // We'll need to first initialize. This is the handshake every client and server - // needs to do before proceeding to anything else - let init_resp = client.init().await?; - // We'll be scrapping this for background server load: https://github.com/aws/amazon-q-developer-cli/issues/1466 - // So don't worry about the tidiness for now - let is_tool_supported = init_resp - .get("result") - .is_some_and(|r| r.get("capabilities").is_some_and(|cap| cap.get("tools").is_some())); - server_capabilities.write().await.replace(init_resp); - // Assuming a shape of return as per https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#listing-tools - let tools = if is_tool_supported { - // And now we make the server tell us what tools they have - let resp = client.request("tools/list", None).await?; - match resp.result.and_then(|r| r.get("tools").cloned()) { - Some(value) => serde_json::from_value::>(value)?, - None => Default::default(), - } - } else { - Default::default() - }; - Ok((server_name.clone(), tools)) - }, - } - } - - pub fn get_server_name(&self) -> &str { - match self { - CustomToolClient::Stdio { server_name, .. } => server_name.as_str(), - } - } - - pub async fn request(&self, method: &str, params: Option) -> Result { - match self { - CustomToolClient::Stdio { client, .. } => Ok(client.request(method, params).await?), - } - } - - pub fn list_prompt_gets(&self) -> Arc>> { - match self { - CustomToolClient::Stdio { client, .. } => client.prompt_gets.clone(), - } - } - - #[allow(dead_code)] - pub async fn notify(&self, method: &str, params: Option) -> Result<()> { - match self { - CustomToolClient::Stdio { client, .. } => Ok(client.notify(method, params).await?), - } - } - - pub fn is_prompts_out_of_date(&self) -> bool { - match self { - CustomToolClient::Stdio { client, .. } => client.is_prompts_out_of_date.load(Ordering::Relaxed), - } - } - - pub fn prompts_updated(&self) { - match self { - CustomToolClient::Stdio { client, .. } => client.is_prompts_out_of_date.store(false, Ordering::Relaxed), - } - } -} - -/// Represents a custom tool that can be invoked through the Model Context Protocol (MCP). -#[derive(Clone, Debug)] -pub struct CustomTool { - /// Actual tool name as recognized by its MCP server. This differs from the tool names as they - /// are seen by the model since they are not prefixed by its MCP server name. - pub name: String, - /// Reference to the client that manages communication with the tool's server process. - pub client: Arc, - /// The method name to call on the tool's server, following the JSON-RPC convention. - /// This corresponds to a specific functionality provided by the tool. - pub method: String, - /// Optional parameters to pass to the tool when invoking the method. - /// Structured as a JSON value to accommodate various parameter types and structures. - pub params: Option, -} - -impl CustomTool { - pub async fn invoke(&self, _ctx: &Context, _updates: &mut impl Write) -> Result { - // Assuming a response shape as per https://spec.modelcontextprotocol.io/specification/2024-11-05/server/tools/#calling-tools - let resp = self.client.request(self.method.as_str(), self.params.clone()).await?; - let result = resp - .result - .ok_or(eyre::eyre!("{} invocation failed to produce a result", self.name))?; - - match serde_json::from_value::(result.clone()) { - Ok(mut de_result) => { - for content in &mut de_result.content { - if let MessageContent::Image { data, .. } = content { - *data = format!("Redacted base64 encoded string of an image of size {}", data.len()); - } - } - Ok(InvokeOutput { - output: super::OutputKind::Json(serde_json::json!(de_result)), - }) - }, - Err(e) => { - warn!("Tool call result deserialization failed: {:?}", e); - Ok(InvokeOutput { - output: super::OutputKind::Json(result.clone()), - }) - }, - } - } - - pub fn queue_description(&self, updates: &mut impl Write) -> Result<()> { - queue!( - updates, - style::Print("Running "), - style::SetForegroundColor(style::Color::Green), - style::Print(&self.name), - style::ResetColor, - )?; - if let Some(params) = &self.params { - let params = match serde_json::to_string_pretty(params) { - Ok(params) => params - .split("\n") - .map(|p| format!("{CONTINUATION_LINE} {p}")) - .collect::>() - .join("\n"), - _ => format!("{:?}", params), - }; - queue!( - updates, - style::Print(" with the param:\n"), - style::Print(params), - style::ResetColor, - )?; - } else { - queue!(updates, style::Print("\n"))?; - } - Ok(()) - } - - pub async fn validate(&mut self, _ctx: &Context) -> Result<()> { - Ok(()) - } - - pub fn get_input_token_size(&self) -> usize { - TokenCounter::count_tokens(self.method.as_str()) - + TokenCounter::count_tokens(self.params.as_ref().map_or("", |p| p.as_str().unwrap_or_default())) - } -} diff --git a/crates/q_chat/src/tools/execute_bash.rs b/crates/q_chat/src/tools/execute_bash.rs deleted file mode 100644 index 5640cecc49..0000000000 --- a/crates/q_chat/src/tools/execute_bash.rs +++ /dev/null @@ -1,373 +0,0 @@ -use std::collections::VecDeque; -use std::io::Write; -use std::process::{ - ExitStatus, - Stdio, -}; -use std::str::from_utf8; - -use crossterm::queue; -use crossterm::style::{ - self, - Color, -}; -use eyre::{ - Context as EyreContext, - Result, -}; -use fig_os_shim::Context; -use serde::Deserialize; -use tokio::io::AsyncBufReadExt; -use tokio::select; -use tracing::error; - -use super::super::util::truncate_safe; -use super::{ - InvokeOutput, - MAX_TOOL_RESPONSE_SIZE, - OutputKind, -}; - -const READONLY_COMMANDS: &[&str] = &["ls", "cat", "echo", "pwd", "which", "head", "tail", "find", "grep"]; - -#[derive(Debug, Clone, Deserialize)] -pub struct ExecuteBash { - pub command: String, -} - -impl ExecuteBash { - pub fn requires_acceptance(&self) -> bool { - let Some(args) = shlex::split(&self.command) else { - return true; - }; - - const DANGEROUS_PATTERNS: &[&str] = &["<(", "$(", "`", ">", "&&", "||", "&", ";"]; - if args - .iter() - .any(|arg| DANGEROUS_PATTERNS.iter().any(|p| arg.contains(p))) - { - return true; - } - - // Split commands by pipe and check each one - let mut current_cmd = Vec::new(); - let mut all_commands = Vec::new(); - - for arg in args { - if arg == "|" { - if !current_cmd.is_empty() { - all_commands.push(current_cmd); - } - current_cmd = Vec::new(); - } else if arg.contains("|") { - // if pipe appears without spacing e.g. `echo myimportantfile|args rm` it won't get - // parsed out, in this case - we want to verify before running - return true; - } else { - current_cmd.push(arg); - } - } - if !current_cmd.is_empty() { - all_commands.push(current_cmd); - } - - // Check if each command in the pipe chain starts with a safe command - for cmd_args in all_commands { - match cmd_args.first() { - // Special casing for `find` so that we support most cases while safeguarding - // against unwanted mutations - Some(cmd) - if cmd == "find" - && cmd_args - .iter() - .any(|arg| arg.contains("-exec") || arg.contains("-delete")) => - { - return true; - }, - Some(cmd) if !READONLY_COMMANDS.contains(&cmd.as_str()) => return true, - None => return true, - _ => (), - } - } - - false - } - - pub async fn invoke(&self, updates: impl Write) -> Result { - let output = run_command(&self.command, MAX_TOOL_RESPONSE_SIZE / 3, Some(updates)).await?; - let result = serde_json::json!({ - "exit_status": output.exit_status.unwrap_or(0).to_string(), - "stdout": output.stdout, - "stderr": output.stderr, - }); - - Ok(InvokeOutput { - output: OutputKind::Json(result), - }) - } - - pub fn queue_description(&self, updates: &mut impl Write) -> Result<()> { - queue!(updates, style::Print("I will run the following shell command: "),)?; - - // TODO: Could use graphemes for a better heuristic - if self.command.len() > 20 { - queue!(updates, style::Print("\n"),)?; - } - - Ok(queue!( - updates, - style::SetForegroundColor(Color::Green), - style::Print(&self.command), - style::Print("\n\n"), - style::ResetColor - )?) - } - - pub async fn validate(&mut self, _ctx: &Context) -> Result<()> { - // TODO: probably some small amount of PATH checking - Ok(()) - } -} - -pub struct CommandResult { - pub exit_status: Option, - /// Truncated stdout - pub stdout: String, - /// Truncated stderr - pub stderr: String, -} - -/// Run a bash command. -/// # Arguments -/// * `max_result_size` - max size of output streams, truncating if required -/// * `updates` - output stream to push informational messages about the progress -/// # Returns -/// A [`CommandResult`] -pub async fn run_command( - command: &str, - max_result_size: usize, - mut updates: Option, -) -> Result { - // We need to maintain a handle on stderr and stdout, but pipe it to the terminal as well - let mut child = tokio::process::Command::new("bash") - .arg("-c") - .arg(command) - .stdin(Stdio::inherit()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn() - .wrap_err_with(|| format!("Unable to spawn command '{}'", command))?; - - let stdout_final: String; - let stderr_final: String; - let exit_status: ExitStatus; - - // Buffered output vs all-at-once - if let Some(u) = updates.as_mut() { - let stdout = child.stdout.take().unwrap(); - let stdout = tokio::io::BufReader::new(stdout); - let mut stdout = stdout.lines(); - - let stderr = child.stderr.take().unwrap(); - let stderr = tokio::io::BufReader::new(stderr); - let mut stderr = stderr.lines(); - - const LINE_COUNT: usize = 1024; - let mut stdout_buf = VecDeque::with_capacity(LINE_COUNT); - let mut stderr_buf = VecDeque::with_capacity(LINE_COUNT); - - let mut stdout_done = false; - let mut stderr_done = false; - exit_status = loop { - select! { - biased; - line = stdout.next_line(), if !stdout_done => match line { - Ok(Some(line)) => { - writeln!(u, "{line}")?; - if stdout_buf.len() >= LINE_COUNT { - stdout_buf.pop_front(); - } - stdout_buf.push_back(line); - }, - Ok(None) => stdout_done = true, - Err(err) => error!(%err, "Failed to read stdout of child process"), - }, - line = stderr.next_line(), if !stderr_done => match line { - Ok(Some(line)) => { - writeln!(u, "{line}")?; - if stderr_buf.len() >= LINE_COUNT { - stderr_buf.pop_front(); - } - stderr_buf.push_back(line); - }, - Ok(None) => stderr_done = true, - Err(err) => error!(%err, "Failed to read stderr of child process"), - }, - exit_status = child.wait() => { - break exit_status; - }, - }; - } - .wrap_err_with(|| format!("No exit status for '{}'", command))?; - - u.flush()?; - - stdout_final = stdout_buf.into_iter().collect::>().join("\n"); - stderr_final = stderr_buf.into_iter().collect::>().join("\n"); - } else { - // Take output all at once since we are not reporting anything in real time - // - // NOTE: If we don't split this logic, then any writes to stdout while calling - // this function concurrently may cause the piped child output to be ignored - - let output = child - .wait_with_output() - .await - .wrap_err_with(|| format!("No exit status for '{}'", command))?; - - exit_status = output.status; - stdout_final = from_utf8(&output.stdout).unwrap_or_default().to_string(); - stderr_final = from_utf8(&output.stderr).unwrap_or_default().to_string(); - } - - Ok(CommandResult { - exit_status: exit_status.code(), - stdout: format!( - "{}{}", - truncate_safe(&stdout_final, max_result_size), - if stdout_final.len() > max_result_size { - " ... truncated" - } else { - "" - } - ), - stderr: format!( - "{}{}", - truncate_safe(&stderr_final, max_result_size), - if stderr_final.len() > max_result_size { - " ... truncated" - } else { - "" - } - ), - }) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[ignore = "todo: fix failing on musl for some reason"] - #[tokio::test] - async fn test_execute_bash_tool() { - let mut stdout = std::io::stdout(); - - // Verifying stdout - let v = serde_json::json!({ - "command": "echo Hello, world!", - }); - let out = serde_json::from_value::(v) - .unwrap() - .invoke(&mut stdout) - .await - .unwrap(); - - if let OutputKind::Json(json) = out.output { - assert_eq!(json.get("exit_status").unwrap(), &0.to_string()); - assert_eq!(json.get("stdout").unwrap(), "Hello, world!"); - assert_eq!(json.get("stderr").unwrap(), ""); - } else { - panic!("Expected JSON output"); - } - - // Verifying stderr - let v = serde_json::json!({ - "command": "echo Hello, world! 1>&2", - }); - let out = serde_json::from_value::(v) - .unwrap() - .invoke(&mut stdout) - .await - .unwrap(); - - if let OutputKind::Json(json) = out.output { - assert_eq!(json.get("exit_status").unwrap(), &0.to_string()); - assert_eq!(json.get("stdout").unwrap(), ""); - assert_eq!(json.get("stderr").unwrap(), "Hello, world!"); - } else { - panic!("Expected JSON output"); - } - - // Verifying exit code - let v = serde_json::json!({ - "command": "exit 1", - "interactive": false - }); - let out = serde_json::from_value::(v) - .unwrap() - .invoke(&mut stdout) - .await - .unwrap(); - if let OutputKind::Json(json) = out.output { - assert_eq!(json.get("exit_status").unwrap(), &1.to_string()); - assert_eq!(json.get("stdout").unwrap(), ""); - assert_eq!(json.get("stderr").unwrap(), ""); - } else { - panic!("Expected JSON output"); - } - } - - #[test] - fn test_requires_acceptance_for_readonly_commands() { - let cmds = &[ - // Safe commands - ("ls ~", false), - ("ls -al ~", false), - ("pwd", false), - ("echo 'Hello, world!'", false), - ("which aws", false), - // Potentially dangerous readonly commands - ("echo hi > myimportantfile", true), - ("ls -al >myimportantfile", true), - ("echo hi 2> myimportantfile", true), - ("echo hi >> myimportantfile", true), - ("echo $(rm myimportantfile)", true), - ("echo `rm myimportantfile`", true), - ("echo hello && rm myimportantfile", true), - ("echo hello&&rm myimportantfile", true), - ("ls nonexistantpath || rm myimportantfile", true), - ("echo myimportantfile | xargs rm", true), - ("echo myimportantfile|args rm", true), - ("echo <(rm myimportantfile)", true), - ("cat <<< 'some string here' > myimportantfile", true), - ("echo '\n#!/usr/bin/env bash\necho hello\n' > myscript.sh", true), - ("cat < myimportantfile\nhello world\nEOF", true), - // Safe piped commands - ("find . -name '*.rs' | grep main", false), - ("ls -la | grep .git", false), - ("cat file.txt | grep pattern | head -n 5", false), - // Unsafe piped commands - ("find . -name '*.rs' | rm", true), - ("ls -la | grep .git | rm -rf", true), - ("echo hello | sudo rm -rf /", true), - // `find` command arguments - ("find important-dir/ -exec rm {} \\;", true), - ("find . -name '*.c' -execdir gcc -o '{}.out' '{}' \\;", true), - ("find important-dir/ -delete", true), - ("find important-dir/ -name '*.txt'", false), - ]; - for (cmd, expected) in cmds { - let tool = serde_json::from_value::(serde_json::json!({ - "command": cmd, - })) - .unwrap(); - assert_eq!( - tool.requires_acceptance(), - *expected, - "expected command: `{}` to have requires_acceptance: `{}`", - cmd, - expected - ); - } - } -} diff --git a/crates/q_chat/src/tools/fs_read.rs b/crates/q_chat/src/tools/fs_read.rs deleted file mode 100644 index 9ff07eeb83..0000000000 --- a/crates/q_chat/src/tools/fs_read.rs +++ /dev/null @@ -1,669 +0,0 @@ -use std::collections::VecDeque; -use std::fs::Metadata; -use std::io::Write; -use std::os::unix::fs::PermissionsExt; - -use crossterm::queue; -use crossterm::style::{ - self, - Color, -}; -use eyre::{ - Result, - bail, -}; -use fig_os_shim::Context; -use serde::{ - Deserialize, - Serialize, -}; -use syntect::util::LinesWithEndings; -use tracing::{ - debug, - warn, -}; - -use super::{ - InvokeOutput, - MAX_TOOL_RESPONSE_SIZE, - OutputKind, - format_path, - sanitize_path_tool_arg, -}; - -#[derive(Debug, Clone, Deserialize)] -#[serde(tag = "mode")] -pub enum FsRead { - Line(FsLine), - Directory(FsDirectory), - Search(FsSearch), -} - -impl FsRead { - pub async fn validate(&mut self, ctx: &Context) -> Result<()> { - match self { - FsRead::Line(fs_line) => fs_line.validate(ctx).await, - FsRead::Directory(fs_directory) => fs_directory.validate(ctx).await, - FsRead::Search(fs_search) => fs_search.validate(ctx).await, - } - } - - pub async fn queue_description(&self, ctx: &Context, updates: &mut impl Write) -> Result<()> { - match self { - FsRead::Line(fs_line) => fs_line.queue_description(ctx, updates).await, - FsRead::Directory(fs_directory) => fs_directory.queue_description(updates), - FsRead::Search(fs_search) => fs_search.queue_description(updates), - } - } - - pub async fn invoke(&self, ctx: &Context, updates: &mut impl Write) -> Result { - match self { - FsRead::Line(fs_line) => fs_line.invoke(ctx, updates).await, - FsRead::Directory(fs_directory) => fs_directory.invoke(ctx, updates).await, - FsRead::Search(fs_search) => fs_search.invoke(ctx, updates).await, - } - } -} - -/// Read lines from a file. -#[derive(Debug, Clone, Deserialize)] -pub struct FsLine { - pub path: String, - pub start_line: Option, - pub end_line: Option, -} - -impl FsLine { - const DEFAULT_END_LINE: i32 = -1; - const DEFAULT_START_LINE: i32 = 1; - - pub async fn validate(&mut self, ctx: &Context) -> Result<()> { - let path = sanitize_path_tool_arg(ctx, &self.path); - if !path.exists() { - bail!("'{}' does not exist", self.path); - } - let is_file = ctx.fs().symlink_metadata(&path).await?.is_file(); - if !is_file { - bail!("'{}' is not a file", self.path); - } - Ok(()) - } - - pub async fn queue_description(&self, ctx: &Context, updates: &mut impl Write) -> Result<()> { - let path = sanitize_path_tool_arg(ctx, &self.path); - let line_count = ctx.fs().read_to_string(&path).await?.lines().count(); - queue!( - updates, - style::Print("Reading file: "), - style::SetForegroundColor(Color::Green), - style::Print(&self.path), - style::ResetColor, - style::Print(", "), - )?; - - let start = convert_negative_index(line_count, self.start_line()) + 1; - let end = convert_negative_index(line_count, self.end_line()) + 1; - match (start, end) { - _ if start == 1 && end == line_count => Ok(queue!(updates, style::Print("all lines".to_string()))?), - _ if end == line_count => Ok(queue!( - updates, - style::Print("from line "), - style::SetForegroundColor(Color::Green), - style::Print(start), - style::ResetColor, - style::Print(" to end of file"), - )?), - _ => Ok(queue!( - updates, - style::Print("from line "), - style::SetForegroundColor(Color::Green), - style::Print(start), - style::ResetColor, - style::Print(" to "), - style::SetForegroundColor(Color::Green), - style::Print(end), - style::ResetColor, - )?), - } - } - - pub async fn invoke(&self, ctx: &Context, _updates: &mut impl Write) -> Result { - let path = sanitize_path_tool_arg(ctx, &self.path); - debug!(?path, "Reading"); - let file = ctx.fs().read_to_string(&path).await?; - let line_count = file.lines().count(); - let (start, end) = ( - convert_negative_index(line_count, self.start_line()), - convert_negative_index(line_count, self.end_line()), - ); - - // safety check to ensure end is always greater than start - let end = end.max(start); - - if start >= line_count { - bail!( - "starting index: {} is outside of the allowed range: ({}, {})", - self.start_line(), - -(line_count as i64), - line_count - ); - } - - // The range should be inclusive on both ends. - let file_contents = file - .lines() - .skip(start) - .take(end - start + 1) - .collect::>() - .join("\n"); - - let byte_count = file_contents.len(); - if byte_count > MAX_TOOL_RESPONSE_SIZE { - bail!( - "This tool only supports reading {MAX_TOOL_RESPONSE_SIZE} bytes at a -time. You tried to read {byte_count} bytes. Try executing with fewer lines specified." - ); - } - - Ok(InvokeOutput { - output: OutputKind::Text(file_contents), - }) - } - - fn start_line(&self) -> i32 { - self.start_line.unwrap_or(Self::DEFAULT_START_LINE) - } - - fn end_line(&self) -> i32 { - self.end_line.unwrap_or(Self::DEFAULT_END_LINE) - } -} - -/// Search in a file. -#[derive(Debug, Clone, Deserialize)] -pub struct FsSearch { - pub path: String, - pub pattern: String, - pub context_lines: Option, -} - -impl FsSearch { - const CONTEXT_LINE_PREFIX: &str = " "; - const DEFAULT_CONTEXT_LINES: usize = 2; - const MATCHING_LINE_PREFIX: &str = "→ "; - - pub async fn validate(&mut self, ctx: &Context) -> Result<()> { - let path = sanitize_path_tool_arg(ctx, &self.path); - let relative_path = format_path(ctx.env().current_dir()?, &path); - if !path.exists() { - bail!("File not found: {}", relative_path); - } - if !ctx.fs().symlink_metadata(path).await?.is_file() { - bail!("Path is not a file: {}", relative_path); - } - if self.pattern.is_empty() { - bail!("Search pattern cannot be empty"); - } - Ok(()) - } - - pub fn queue_description(&self, updates: &mut impl Write) -> Result<()> { - queue!( - updates, - style::Print("Searching: "), - style::SetForegroundColor(Color::Green), - style::Print(&self.path), - style::ResetColor, - style::Print(" for pattern: "), - style::SetForegroundColor(Color::Green), - style::Print(&self.pattern.to_lowercase()), - style::ResetColor, - )?; - Ok(()) - } - - pub async fn invoke(&self, ctx: &Context, updates: &mut impl Write) -> Result { - let file_path = sanitize_path_tool_arg(ctx, &self.path); - let pattern = &self.pattern; - let relative_path = format_path(ctx.env().current_dir()?, &file_path); - - let file_content = ctx.fs().read_to_string(&file_path).await?; - let lines: Vec<&str> = LinesWithEndings::from(&file_content).collect(); - - let mut results = Vec::new(); - let mut total_matches = 0; - - // Case insensitive search - let pattern_lower = pattern.to_lowercase(); - for (line_num, line) in lines.iter().enumerate() { - if line.to_lowercase().contains(&pattern_lower) { - total_matches += 1; - let start = line_num.saturating_sub(self.context_lines()); - let end = lines.len().min(line_num + self.context_lines() + 1); - let mut context_text = Vec::new(); - (start..end).for_each(|i| { - let prefix = if i == line_num { - Self::MATCHING_LINE_PREFIX - } else { - Self::CONTEXT_LINE_PREFIX - }; - let line_text = lines[i].to_string(); - context_text.push(format!("{}{}: {}", prefix, i + 1, line_text)); - }); - let match_text = context_text.join(""); - results.push(SearchMatch { - line_number: line_num + 1, - context: match_text, - }); - } - } - - queue!( - updates, - style::SetForegroundColor(Color::Yellow), - style::ResetColor, - style::Print(format!( - "Found {} matches for pattern '{}' in {}\n", - total_matches, pattern, relative_path - )), - style::Print("\n"), - style::ResetColor, - )?; - - Ok(InvokeOutput { - output: OutputKind::Text(serde_json::to_string(&results)?), - }) - } - - fn context_lines(&self) -> usize { - self.context_lines.unwrap_or(Self::DEFAULT_CONTEXT_LINES) - } -} - -/// List directory contents. -#[derive(Debug, Clone, Deserialize)] -pub struct FsDirectory { - pub path: String, - pub depth: Option, -} - -impl FsDirectory { - const DEFAULT_DEPTH: usize = 0; - - pub async fn validate(&mut self, ctx: &Context) -> Result<()> { - let path = sanitize_path_tool_arg(ctx, &self.path); - let relative_path = format_path(ctx.env().current_dir()?, &path); - if !path.exists() { - bail!("Directory not found: {}", relative_path); - } - if !ctx.fs().symlink_metadata(path).await?.is_dir() { - bail!("Path is not a directory: {}", relative_path); - } - Ok(()) - } - - pub fn queue_description(&self, updates: &mut impl Write) -> Result<()> { - queue!( - updates, - style::Print("Reading directory: "), - style::SetForegroundColor(Color::Green), - style::Print(&self.path), - style::ResetColor, - style::Print(" "), - )?; - let depth = self.depth.unwrap_or_default(); - Ok(queue!( - updates, - style::Print(format!("with maximum depth of {}", depth)) - )?) - } - - pub async fn invoke(&self, ctx: &Context, updates: &mut impl Write) -> Result { - let path = sanitize_path_tool_arg(ctx, &self.path); - let cwd = ctx.env().current_dir()?; - let max_depth = self.depth(); - debug!(?path, max_depth, "Reading directory at path with depth"); - let mut result = Vec::new(); - let mut dir_queue = VecDeque::new(); - dir_queue.push_back((path, 0)); - while let Some((path, depth)) = dir_queue.pop_front() { - if depth > max_depth { - break; - } - let relative_path = format_path(&cwd, &path); - if !relative_path.is_empty() { - queue!( - updates, - style::Print("Reading: "), - style::SetForegroundColor(Color::Green), - style::Print(&relative_path), - style::ResetColor, - style::Print("\n"), - )?; - } - let mut read_dir = ctx.fs().read_dir(path).await?; - while let Some(ent) = read_dir.next_entry().await? { - use std::os::unix::fs::MetadataExt; - let md = ent.metadata().await?; - let formatted_mode = format_mode(md.permissions().mode()).into_iter().collect::(); - - let modified_timestamp = md.modified()?.duration_since(std::time::UNIX_EPOCH)?.as_secs(); - let datetime = time::OffsetDateTime::from_unix_timestamp(modified_timestamp as i64).unwrap(); - let formatted_date = datetime - .format(time::macros::format_description!( - "[month repr:short] [day] [hour]:[minute]" - )) - .unwrap(); - - // Mostly copying "The Long Format" from `man ls`. - // TODO: query user/group database to convert uid/gid to names? - result.push(format!( - "{}{} {} {} {} {} {} {}", - format_ftype(&md), - formatted_mode, - md.nlink(), - md.uid(), - md.gid(), - md.size(), - formatted_date, - ent.path().to_string_lossy() - )); - if md.is_dir() { - dir_queue.push_back((ent.path(), depth + 1)); - } - } - } - - let file_count = result.len(); - let result = result.join("\n"); - let byte_count = result.len(); - if byte_count > MAX_TOOL_RESPONSE_SIZE { - bail!( - "This tool only supports reading up to {MAX_TOOL_RESPONSE_SIZE} bytes at a time. You tried to read {byte_count} bytes ({file_count} files). Try executing with fewer lines specified." - ); - } - - Ok(InvokeOutput { - output: OutputKind::Text(result), - }) - } - - fn depth(&self) -> usize { - self.depth.unwrap_or(Self::DEFAULT_DEPTH) - } -} - -/// Converts negative 1-based indices to positive 0-based indices. -fn convert_negative_index(line_count: usize, i: i32) -> usize { - if i <= 0 { - (line_count as i32 + i).max(0) as usize - } else { - i as usize - 1 - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct SearchMatch { - line_number: usize, - context: String, -} - -fn format_ftype(md: &Metadata) -> char { - if md.is_symlink() { - 'l' - } else if md.is_file() { - '-' - } else if md.is_dir() { - 'd' - } else { - warn!("unknown file metadata: {:?}", md); - '-' - } -} - -/// Formats a permissions mode into the form used by `ls`, e.g. `0o644` to `rw-r--r--` -fn format_mode(mode: u32) -> [char; 9] { - let mut mode = mode & 0o777; - let mut res = ['-'; 9]; - fn octal_to_chars(val: u32) -> [char; 3] { - match val { - 1 => ['-', '-', 'x'], - 2 => ['-', 'w', '-'], - 3 => ['-', 'w', 'x'], - 4 => ['r', '-', '-'], - 5 => ['r', '-', 'x'], - 6 => ['r', 'w', '-'], - 7 => ['r', 'w', 'x'], - _ => ['-', '-', '-'], - } - } - for c in res.rchunks_exact_mut(3) { - c.copy_from_slice(&octal_to_chars(mode & 0o7)); - mode /= 0o10; - } - res -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use super::*; - - const TEST_FILE_CONTENTS: &str = "\ -1: Hello world! -2: This is line 2 -3: asdf -4: Hello world! -"; - - const TEST_FILE_PATH: &str = "/test_file.txt"; - const TEST_HIDDEN_FILE_PATH: &str = "/aaaa2/.hidden"; - - /// Sets up the following filesystem structure: - /// ```text - /// test_file.txt - /// /home/testuser/ - /// /aaaa1/ - /// /bbbb1/ - /// /cccc1/ - /// /aaaa2/ - /// .hidden - /// ``` - async fn setup_test_directory() -> Arc { - let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); - let fs = ctx.fs(); - fs.write(TEST_FILE_PATH, TEST_FILE_CONTENTS).await.unwrap(); - fs.create_dir_all("/aaaa1/bbbb1/cccc1").await.unwrap(); - fs.create_dir_all("/aaaa2").await.unwrap(); - fs.write(TEST_HIDDEN_FILE_PATH, "this is a hidden file").await.unwrap(); - ctx - } - - #[test] - fn test_negative_index_conversion() { - assert_eq!(convert_negative_index(5, -100), 0); - assert_eq!(convert_negative_index(5, -1), 4); - } - - #[test] - fn test_fs_read_deser() { - serde_json::from_value::(serde_json::json!({ "path": "/test_file.txt", "mode": "Line" })).unwrap(); - serde_json::from_value::( - serde_json::json!({ "path": "/test_file.txt", "mode": "Line", "end_line": 5 }), - ) - .unwrap(); - serde_json::from_value::( - serde_json::json!({ "path": "/test_file.txt", "mode": "Line", "start_line": -1 }), - ) - .unwrap(); - serde_json::from_value::( - serde_json::json!({ "path": "/test_file.txt", "mode": "Line", "start_line": None:: }), - ) - .unwrap(); - serde_json::from_value::(serde_json::json!({ "path": "/", "mode": "Directory" })).unwrap(); - serde_json::from_value::( - serde_json::json!({ "path": "/test_file.txt", "mode": "Directory", "depth": 2 }), - ) - .unwrap(); - serde_json::from_value::( - serde_json::json!({ "path": "/test_file.txt", "mode": "Search", "pattern": "hello" }), - ) - .unwrap(); - } - - #[tokio::test] - async fn test_fs_read_line_invoke() { - let ctx = setup_test_directory().await; - let lines = TEST_FILE_CONTENTS.lines().collect::>(); - let mut stdout = std::io::stdout(); - - macro_rules! assert_lines { - ($start_line:expr, $end_line:expr, $expected:expr) => { - let v = serde_json::json!({ - "path": TEST_FILE_PATH, - "mode": "Line", - "start_line": $start_line, - "end_line": $end_line, - }); - let output = serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut stdout) - .await - .unwrap(); - - if let OutputKind::Text(text) = output.output { - assert_eq!(text, $expected.join("\n"), "actual(left) does not equal - expected(right) for (start_line, end_line): ({:?}, {:?})", $start_line, $end_line); - } else { - panic!("expected text output"); - } - } - } - assert_lines!(None::, None::, lines[..]); - assert_lines!(1, 2, lines[..=1]); - assert_lines!(1, -1, lines[..]); - assert_lines!(2, 1, lines[1..=1]); - assert_lines!(-2, -1, lines[2..]); - assert_lines!(-2, None::, lines[2..]); - assert_lines!(2, None::, lines[1..]); - } - - #[tokio::test] - async fn test_fs_read_line_past_eof() { - let ctx = setup_test_directory().await; - let mut stdout = std::io::stdout(); - let v = serde_json::json!({ - "path": TEST_FILE_PATH, - "mode": "Line", - "start_line": 100, - "end_line": None::, - }); - assert!( - serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut stdout) - .await - .is_err() - ); - } - - #[test] - fn test_format_mode() { - macro_rules! assert_mode { - ($actual:expr, $expected:expr) => { - assert_eq!(format_mode($actual).iter().collect::(), $expected); - }; - } - assert_mode!(0o000, "---------"); - assert_mode!(0o700, "rwx------"); - assert_mode!(0o744, "rwxr--r--"); - assert_mode!(0o641, "rw-r----x"); - } - - #[tokio::test] - async fn test_fs_read_directory_invoke() { - let ctx = setup_test_directory().await; - let mut stdout = std::io::stdout(); - - // Testing without depth - let v = serde_json::json!({ - "mode": "Directory", - "path": "/", - }); - let output = serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut stdout) - .await - .unwrap(); - - if let OutputKind::Text(text) = output.output { - assert_eq!(text.lines().collect::>().len(), 4); - } else { - panic!("expected text output"); - } - - // Testing with depth level 1 - let v = serde_json::json!({ - "mode": "Directory", - "path": "/", - "depth": 1, - }); - let output = serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut stdout) - .await - .unwrap(); - - if let OutputKind::Text(text) = output.output { - let lines = text.lines().collect::>(); - assert_eq!(lines.len(), 7); - assert!( - !lines.iter().any(|l| l.contains("cccc1")), - "directory at depth level 2 should not be included in output" - ); - } else { - panic!("expected text output"); - } - } - - #[tokio::test] - async fn test_fs_read_search_invoke() { - let ctx = setup_test_directory().await; - let mut stdout = std::io::stdout(); - - macro_rules! invoke_search { - ($value:tt) => {{ - let v = serde_json::json!($value); - let output = serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut stdout) - .await - .unwrap(); - - if let OutputKind::Text(value) = output.output { - serde_json::from_str::>(&value).unwrap() - } else { - panic!("expected Text output") - } - }}; - } - - let matches = invoke_search!({ - "mode": "Search", - "path": TEST_FILE_PATH, - "pattern": "hello", - }); - assert_eq!(matches.len(), 2); - assert_eq!(matches[0].line_number, 1); - assert_eq!( - matches[0].context, - format!( - "{}1: 1: Hello world!\n{}2: 2: This is line 2\n{}3: 3: asdf\n", - FsSearch::MATCHING_LINE_PREFIX, - FsSearch::CONTEXT_LINE_PREFIX, - FsSearch::CONTEXT_LINE_PREFIX - ) - ); - } -} diff --git a/crates/q_chat/src/tools/fs_write.rs b/crates/q_chat/src/tools/fs_write.rs deleted file mode 100644 index 576937c0ba..0000000000 --- a/crates/q_chat/src/tools/fs_write.rs +++ /dev/null @@ -1,953 +0,0 @@ -use std::io::Write; -use std::path::Path; -use std::sync::LazyLock; - -use crossterm::queue; -use crossterm::style::{ - self, - Color, -}; -use eyre::{ - ContextCompat as _, - Result, - bail, - eyre, -}; -use fig_os_shim::Context; -use serde::Deserialize; -use similar::DiffableStr; -use syntect::easy::HighlightLines; -use syntect::highlighting::ThemeSet; -use syntect::parsing::SyntaxSet; -use syntect::util::{ - LinesWithEndings, - as_24_bit_terminal_escaped, -}; -use tracing::{ - error, - warn, -}; - -use super::{ - InvokeOutput, - format_path, - sanitize_path_tool_arg, - supports_truecolor, -}; - -static SYNTAX_SET: LazyLock = LazyLock::new(SyntaxSet::load_defaults_newlines); -static THEME_SET: LazyLock = LazyLock::new(ThemeSet::load_defaults); - -#[derive(Debug, Clone, Deserialize)] -#[serde(tag = "command")] -pub enum FsWrite { - /// The tool spec should only require `file_text`, but the model sometimes doesn't want to - /// provide it. Thus, including `new_str` as a fallback check, if it's available. - #[serde(rename = "create")] - Create { - path: String, - file_text: Option, - new_str: Option, - }, - #[serde(rename = "str_replace")] - StrReplace { - path: String, - old_str: String, - new_str: String, - }, - #[serde(rename = "insert")] - Insert { - path: String, - insert_line: usize, - new_str: String, - }, - #[serde(rename = "append")] - Append { path: String, new_str: String }, -} - -impl FsWrite { - pub async fn invoke(&self, ctx: &Context, updates: &mut impl Write) -> Result { - let fs = ctx.fs(); - let cwd = ctx.env().current_dir()?; - match self { - FsWrite::Create { path, .. } => { - let file_text = self.canonical_create_command_text(); - let path = sanitize_path_tool_arg(ctx, path); - if let Some(parent) = path.parent() { - fs.create_dir_all(parent).await?; - } - - let invoke_description = if fs.exists(&path) { "Replacing: " } else { "Creating: " }; - queue!( - updates, - style::Print(invoke_description), - style::SetForegroundColor(Color::Green), - style::Print(format_path(cwd, &path)), - style::ResetColor, - style::Print("\n"), - )?; - - write_to_file(ctx, path, file_text).await?; - Ok(Default::default()) - }, - FsWrite::StrReplace { path, old_str, new_str } => { - let path = sanitize_path_tool_arg(ctx, path); - let file = fs.read_to_string(&path).await?; - let matches = file.match_indices(old_str).collect::>(); - queue!( - updates, - style::Print("Updating: "), - style::SetForegroundColor(Color::Green), - style::Print(format_path(cwd, &path)), - style::ResetColor, - style::Print("\n"), - )?; - match matches.len() { - 0 => Err(eyre!("no occurrences of \"{old_str}\" were found")), - 1 => { - let file = file.replacen(old_str, new_str, 1); - fs.write(path, file).await?; - Ok(Default::default()) - }, - x => Err(eyre!("{x} occurrences of old_str were found when only 1 is expected")), - } - }, - FsWrite::Insert { - path, - insert_line, - new_str, - } => { - let path = sanitize_path_tool_arg(ctx, path); - let mut file = fs.read_to_string(&path).await?; - queue!( - updates, - style::Print("Updating: "), - style::SetForegroundColor(Color::Green), - style::Print(format_path(cwd, &path)), - style::ResetColor, - style::Print("\n"), - )?; - - // Get the index of the start of the line to insert at. - let num_lines = file.lines().enumerate().map(|(i, _)| i + 1).last().unwrap_or(1); - let insert_line = insert_line.clamp(&0, &num_lines); - let mut i = 0; - for _ in 0..*insert_line { - let line_len = &file[i..].find("\n").map_or(file[i..].len(), |i| i + 1); - i += line_len; - } - file.insert_str(i, new_str); - write_to_file(ctx, &path, file).await?; - Ok(Default::default()) - }, - FsWrite::Append { path, new_str } => { - let path = sanitize_path_tool_arg(ctx, path); - - queue!( - updates, - style::Print("Appending to: "), - style::SetForegroundColor(Color::Green), - style::Print(format_path(cwd, &path)), - style::ResetColor, - style::Print("\n"), - )?; - - let mut file = fs.read_to_string(&path).await?; - if !file.ends_with_newline() { - file.push('\n'); - } - file.push_str(new_str); - write_to_file(ctx, path, file).await?; - Ok(Default::default()) - }, - } - } - - pub fn queue_description(&self, ctx: &Context, updates: &mut impl Write) -> Result<()> { - let cwd = ctx.env().current_dir()?; - self.print_relative_path(ctx, updates)?; - match self { - FsWrite::Create { path, .. } => { - let file_text = self.canonical_create_command_text(); - let relative_path = format_path(cwd, path); - let prev = if ctx.fs().exists(path) { - let file = ctx.fs().read_to_string_sync(path)?; - stylize_output_if_able(ctx, path, &file) - } else { - Default::default() - }; - let new = stylize_output_if_able(ctx, &relative_path, &file_text); - print_diff(updates, &prev, &new, 1)?; - Ok(()) - }, - FsWrite::Insert { - path, - insert_line, - new_str, - } => { - let relative_path = format_path(cwd, path); - let file = ctx.fs().read_to_string_sync(&relative_path)?; - - // Diff the old with the new by adding extra context around the line being inserted - // at. - let (prefix, start_line, suffix, _) = get_lines_with_context(&file, *insert_line, *insert_line, 3); - let insert_line_content = LinesWithEndings::from(&file) - // don't include any content if insert_line is 0 - .nth(insert_line.checked_sub(1).unwrap_or(usize::MAX)) - .unwrap_or_default(); - let old = [prefix, insert_line_content, suffix].join(""); - let new = [prefix, insert_line_content, new_str, suffix].join(""); - - let old = stylize_output_if_able(ctx, &relative_path, &old); - let new = stylize_output_if_able(ctx, &relative_path, &new); - print_diff(updates, &old, &new, start_line)?; - Ok(()) - }, - FsWrite::StrReplace { path, old_str, new_str } => { - let relative_path = format_path(cwd, path); - let file = ctx.fs().read_to_string_sync(&relative_path)?; - let (start_line, _) = match line_number_at(&file, old_str) { - Some((start_line, end_line)) => (start_line, end_line), - _ => (0, 0), - }; - let old_str = stylize_output_if_able(ctx, &relative_path, old_str); - let new_str = stylize_output_if_able(ctx, &relative_path, new_str); - print_diff(updates, &old_str, &new_str, start_line)?; - - Ok(()) - }, - FsWrite::Append { path, new_str } => { - let relative_path = format_path(cwd, path); - let start_line = ctx.fs().read_to_string_sync(&relative_path)?.lines().count() + 1; - let file = stylize_output_if_able(ctx, &relative_path, new_str); - print_diff(updates, &Default::default(), &file, start_line)?; - Ok(()) - }, - } - } - - pub async fn validate(&mut self, ctx: &Context) -> Result<()> { - match self { - FsWrite::Create { path, .. } => { - if path.is_empty() { - bail!("Path must not be empty") - }; - }, - FsWrite::StrReplace { path, .. } | FsWrite::Insert { path, .. } => { - let path = sanitize_path_tool_arg(ctx, path); - if !path.exists() { - bail!("The provided path must exist in order to replace or insert contents into it") - } - }, - FsWrite::Append { path, new_str } => { - if path.is_empty() { - bail!("Path must not be empty") - }; - if new_str.is_empty() { - bail!("Content to append must not be empty") - }; - }, - } - - Ok(()) - } - - fn print_relative_path(&self, ctx: &Context, updates: &mut impl Write) -> Result<()> { - let cwd = ctx.env().current_dir()?; - let path = match self { - FsWrite::Create { path, .. } => path, - FsWrite::StrReplace { path, .. } => path, - FsWrite::Insert { path, .. } => path, - FsWrite::Append { path, .. } => path, - }; - let relative_path = format_path(cwd, path); - queue!( - updates, - style::Print("Path: "), - style::SetForegroundColor(Color::Green), - style::Print(&relative_path), - style::ResetColor, - style::Print("\n\n"), - )?; - Ok(()) - } - - /// Returns the text to use for the [FsWrite::Create] command. This is required since we can't - /// rely on the model always providing `file_text`. - fn canonical_create_command_text(&self) -> String { - match self { - FsWrite::Create { file_text, new_str, .. } => match (file_text, new_str) { - (Some(file_text), _) => file_text.clone(), - (None, Some(new_str)) => { - warn!("required field `file_text` is missing, using the provided `new_str` instead"); - new_str.clone() - }, - _ => { - warn!("no content provided for the create command"); - String::new() - }, - }, - _ => String::new(), - } - } -} - -/// Writes `content` to `path`, adding a newline if necessary. -async fn write_to_file(ctx: &Context, path: impl AsRef, mut content: String) -> Result<()> { - if !content.ends_with_newline() { - content.push('\n'); - } - ctx.fs().write(path.as_ref(), content).await?; - Ok(()) -} - -/// Returns a prefix/suffix pair before and after the content dictated by `[start_line, end_line]` -/// within `content`. The updated start and end lines containing the original context along with -/// the suffix and prefix are returned. -/// -/// Params: -/// - `start_line` - 1-indexed starting line of the content. -/// - `end_line` - 1-indexed ending line of the content. -/// - `context_lines` - number of lines to include before the start and end. -/// -/// Returns `(prefix, new_start_line, suffix, new_end_line)` -fn get_lines_with_context( - content: &str, - start_line: usize, - end_line: usize, - context_lines: usize, -) -> (&str, usize, &str, usize) { - let line_count = content.lines().count(); - // We want to support end_line being 0, in which case we should be able to set the first line - // as the suffix. - let zero_check_inc = if end_line == 0 { 0 } else { 1 }; - - // Convert to 0-indexing. - let (start_line, end_line) = ( - start_line.saturating_sub(1).clamp(0, line_count - 1), - end_line.saturating_sub(1).clamp(0, line_count - 1), - ); - let new_start_line = 0.max(start_line.saturating_sub(context_lines)); - let new_end_line = (line_count - 1).min(end_line + context_lines); - - // Build prefix - let mut prefix_start = 0; - for line in LinesWithEndings::from(content).take(new_start_line) { - prefix_start += line.len(); - } - let mut prefix_end = prefix_start; - for line in LinesWithEndings::from(&content[prefix_start..]).take(start_line - new_start_line) { - prefix_end += line.len(); - } - - // Build suffix - let mut suffix_start = 0; - for line in LinesWithEndings::from(content).take(end_line + zero_check_inc) { - suffix_start += line.len(); - } - let mut suffix_end = suffix_start; - for line in LinesWithEndings::from(&content[suffix_start..]).take(new_end_line - end_line) { - suffix_end += line.len(); - } - - ( - &content[prefix_start..prefix_end], - new_start_line + 1, - &content[suffix_start..suffix_end], - new_end_line + zero_check_inc, - ) -} - -/// Prints a git-diff style comparison between `old_str` and `new_str`. -/// - `start_line` - 1-indexed line number that `old_str` and `new_str` start at. -fn print_diff( - updates: &mut impl Write, - old_str: &StylizedFile, - new_str: &StylizedFile, - start_line: usize, -) -> Result<()> { - let diff = similar::TextDiff::from_lines(&old_str.content, &new_str.content); - - // First, get the gutter width required for both the old and new lines. - let (mut max_old_i, mut max_new_i) = (1, 1); - for change in diff.iter_all_changes() { - if let Some(i) = change.old_index() { - max_old_i = i + start_line; - } - if let Some(i) = change.new_index() { - max_new_i = i + start_line; - } - } - let old_line_num_width = terminal_width_required_for_line_count(max_old_i); - let new_line_num_width = terminal_width_required_for_line_count(max_new_i); - - // Now, print - fn fmt_index(i: Option, start_line: usize) -> String { - match i { - Some(i) => (i + start_line).to_string(), - _ => " ".to_string(), - } - } - for change in diff.iter_all_changes() { - // Define the colors per line. - let (text_color, gutter_bg_color, line_bg_color) = match (change.tag(), new_str.truecolor) { - (similar::ChangeTag::Equal, true) => (style::Color::Reset, new_str.gutter_bg, new_str.line_bg), - (similar::ChangeTag::Delete, true) => ( - style::Color::Reset, - style::Color::Rgb { r: 79, g: 40, b: 40 }, - style::Color::Rgb { r: 36, g: 25, b: 28 }, - ), - (similar::ChangeTag::Insert, true) => ( - style::Color::Reset, - style::Color::Rgb { r: 40, g: 67, b: 43 }, - style::Color::Rgb { r: 24, g: 38, b: 30 }, - ), - (similar::ChangeTag::Equal, false) => (style::Color::Reset, new_str.gutter_bg, new_str.line_bg), - (similar::ChangeTag::Delete, false) => (style::Color::Red, new_str.gutter_bg, new_str.line_bg), - (similar::ChangeTag::Insert, false) => (style::Color::Green, new_str.gutter_bg, new_str.line_bg), - }; - // Define the change tag character to print, if any. - let sign = match change.tag() { - similar::ChangeTag::Equal => " ", - similar::ChangeTag::Delete => "-", - similar::ChangeTag::Insert => "+", - }; - - let old_i_str = fmt_index(change.old_index(), start_line); - let new_i_str = fmt_index(change.new_index(), start_line); - - // Print the gutter and line numbers. - queue!(updates, style::SetBackgroundColor(gutter_bg_color))?; - queue!( - updates, - style::SetForegroundColor(text_color), - style::Print(sign), - style::Print(" ") - )?; - queue!( - updates, - style::Print(format!( - "{:>old_line_num_width$}", - old_i_str, - old_line_num_width = old_line_num_width - )) - )?; - if sign == " " { - queue!(updates, style::Print(", "))?; - } else { - queue!(updates, style::Print(" "))?; - } - queue!( - updates, - style::Print(format!( - "{:>new_line_num_width$}", - new_i_str, - new_line_num_width = new_line_num_width - )) - )?; - // Print the line. - queue!( - updates, - style::SetForegroundColor(style::Color::Reset), - style::Print(":"), - style::SetForegroundColor(text_color), - style::SetBackgroundColor(line_bg_color), - style::Print(" "), - style::Print(change), - style::ResetColor, - )?; - } - queue!( - updates, - crossterm::terminal::Clear(crossterm::terminal::ClearType::UntilNewLine), - style::Print("\n"), - )?; - - Ok(()) -} - -/// Returns a 1-indexed line number range of the start and end of `needle` inside `file`. -fn line_number_at(file: impl AsRef, needle: impl AsRef) -> Option<(usize, usize)> { - let file = file.as_ref(); - let needle = needle.as_ref(); - if let Some((i, _)) = file.match_indices(needle).next() { - let start = file[..i].matches("\n").count(); - let end = needle.matches("\n").count(); - Some((start + 1, start + end + 1)) - } else { - None - } -} - -/// Returns the number of terminal cells required for displaying line numbers. This is used to -/// determine how many characters the gutter should allocate when displaying line numbers for a -/// text file. -/// -/// For example, `10` and `99` both take 2 cells, whereas `100` and `999` take 3. -fn terminal_width_required_for_line_count(line_count: usize) -> usize { - line_count.to_string().chars().count() -} - -fn stylize_output_if_able(ctx: &Context, path: impl AsRef, file_text: &str) -> StylizedFile { - if supports_truecolor(ctx) { - match stylized_file(path, file_text) { - Ok(s) => return s, - Err(err) => { - error!(?err, "unable to syntax highlight the output"); - }, - } - } - StylizedFile { - truecolor: false, - content: file_text.to_string(), - gutter_bg: style::Color::Reset, - line_bg: style::Color::Reset, - } -} - -/// Represents a [String] that is potentially stylized with truecolor escape codes. -#[derive(Debug)] -struct StylizedFile { - /// Whether or not the file is stylized with 24bit color. - truecolor: bool, - /// File content. If [Self::truecolor] is true, then it has escape codes for styling with 24bit - /// color. - content: String, - /// Background color for the gutter. - gutter_bg: style::Color, - /// Background color for the line content. - line_bg: style::Color, -} - -impl Default for StylizedFile { - fn default() -> Self { - Self { - truecolor: false, - content: Default::default(), - gutter_bg: style::Color::Reset, - line_bg: style::Color::Reset, - } - } -} - -/// Returns a 24bit terminal escaped syntax-highlighted [String] of the file pointed to by `path`, -/// if able. -fn stylized_file(path: impl AsRef, file_text: impl AsRef) -> Result { - let ps = &*SYNTAX_SET; - let ts = &*THEME_SET; - - let extension = path - .as_ref() - .extension() - .wrap_err("missing extension")? - .to_str() - .wrap_err("not utf8")?; - - let syntax = ps - .find_syntax_by_extension(extension) - .wrap_err_with(|| format!("missing extension: {}", extension))?; - - let theme = &ts.themes["base16-ocean.dark"]; - let mut highlighter = HighlightLines::new(syntax, theme); - let file_text = file_text.as_ref().lines(); - let mut file = String::new(); - for line in file_text { - let mut ranges = Vec::new(); - ranges.append(&mut highlighter.highlight_line(line, ps)?); - let mut escaped_line = as_24_bit_terminal_escaped(&ranges[..], false); - escaped_line.push_str(&format!( - "{}\n", - crossterm::terminal::Clear(crossterm::terminal::ClearType::UntilNewLine), - )); - file.push_str(&escaped_line); - } - - let (line_bg, gutter_bg) = match (theme.settings.background, theme.settings.gutter) { - (Some(line_bg), Some(gutter_bg)) => (line_bg, gutter_bg), - (Some(line_bg), None) => (line_bg, line_bg), - _ => bail!("missing theme"), - }; - Ok(StylizedFile { - truecolor: true, - content: file, - gutter_bg: syntect_to_crossterm_color(gutter_bg), - line_bg: syntect_to_crossterm_color(line_bg), - }) -} - -fn syntect_to_crossterm_color(syntect: syntect::highlighting::Color) -> style::Color { - style::Color::Rgb { - r: syntect.r, - g: syntect.g, - b: syntect.b, - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use super::*; - - const TEST_FILE_CONTENTS: &str = "\ -1: Hello world! -2: This is line 2 -3: asdf -4: Hello world! -"; - - const TEST_FILE_PATH: &str = "/test_file.txt"; - const TEST_HIDDEN_FILE_PATH: &str = "/aaaa2/.hidden"; - - /// Sets up the following filesystem structure: - /// ```text - /// test_file.txt - /// /home/testuser/ - /// /aaaa1/ - /// /bbbb1/ - /// /cccc1/ - /// /aaaa2/ - /// .hidden - /// ``` - async fn setup_test_directory() -> Arc { - let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); - let fs = ctx.fs(); - fs.write(TEST_FILE_PATH, TEST_FILE_CONTENTS).await.unwrap(); - fs.create_dir_all("/aaaa1/bbbb1/cccc1").await.unwrap(); - fs.create_dir_all("/aaaa2").await.unwrap(); - fs.write(TEST_HIDDEN_FILE_PATH, "this is a hidden file").await.unwrap(); - ctx - } - - #[test] - fn test_fs_write_deserialize() { - let path = "/my-file"; - let file_text = "hello world"; - - // create - let v = serde_json::json!({ - "path": path, - "command": "create", - "file_text": file_text - }); - let fw = serde_json::from_value::(v).unwrap(); - assert!(matches!(fw, FsWrite::Create { .. })); - - // str_replace - let v = serde_json::json!({ - "path": path, - "command": "str_replace", - "old_str": "prev string", - "new_str": "new string", - }); - let fw = serde_json::from_value::(v).unwrap(); - assert!(matches!(fw, FsWrite::StrReplace { .. })); - - // insert - let v = serde_json::json!({ - "path": path, - "command": "insert", - "insert_line": 3, - "new_str": "new string", - }); - let fw = serde_json::from_value::(v).unwrap(); - assert!(matches!(fw, FsWrite::Insert { .. })); - - // append - let v = serde_json::json!({ - "path": path, - "command": "append", - "new_str": "appended content", - }); - let fw = serde_json::from_value::(v).unwrap(); - assert!(matches!(fw, FsWrite::Append { .. })); - } - - #[tokio::test] - async fn test_fs_write_tool_create() { - let ctx = setup_test_directory().await; - let mut stdout = std::io::stdout(); - - let file_text = "Hello, world!"; - let v = serde_json::json!({ - "path": "/my-file", - "command": "create", - "file_text": file_text - }); - serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut stdout) - .await - .unwrap(); - - assert_eq!( - ctx.fs().read_to_string("/my-file").await.unwrap(), - format!("{}\n", file_text) - ); - - let file_text = "Goodbye, world!\nSee you later"; - let v = serde_json::json!({ - "path": "/my-file", - "command": "create", - "file_text": file_text - }); - serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut stdout) - .await - .unwrap(); - - // File should end with a newline - assert_eq!( - ctx.fs().read_to_string("/my-file").await.unwrap(), - format!("{}\n", file_text) - ); - - let file_text = "This is a new string"; - let v = serde_json::json!({ - "path": "/my-file", - "command": "create", - "new_str": file_text - }); - serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut stdout) - .await - .unwrap(); - - assert_eq!( - ctx.fs().read_to_string("/my-file").await.unwrap(), - format!("{}\n", file_text) - ); - } - - #[tokio::test] - async fn test_fs_write_tool_str_replace() { - let ctx = setup_test_directory().await; - let mut stdout = std::io::stdout(); - - // No instances found - let v = serde_json::json!({ - "path": TEST_FILE_PATH, - "command": "str_replace", - "old_str": "asjidfopjaieopr", - "new_str": "1623749", - }); - assert!( - serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut stdout) - .await - .is_err() - ); - - // Multiple instances found - let v = serde_json::json!({ - "path": TEST_FILE_PATH, - "command": "str_replace", - "old_str": "Hello world!", - "new_str": "Goodbye world!", - }); - assert!( - serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut stdout) - .await - .is_err() - ); - - // Single instance found and replaced - let v = serde_json::json!({ - "path": TEST_FILE_PATH, - "command": "str_replace", - "old_str": "1: Hello world!", - "new_str": "1: Goodbye world!", - }); - serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut stdout) - .await - .unwrap(); - assert_eq!( - ctx.fs() - .read_to_string(TEST_FILE_PATH) - .await - .unwrap() - .lines() - .next() - .unwrap(), - "1: Goodbye world!", - "expected the only occurrence to be replaced" - ); - } - - #[tokio::test] - async fn test_fs_write_tool_insert_at_beginning() { - let ctx = setup_test_directory().await; - let mut stdout = std::io::stdout(); - - let new_str = "1: New first line!\n"; - let v = serde_json::json!({ - "path": TEST_FILE_PATH, - "command": "insert", - "insert_line": 0, - "new_str": new_str, - }); - serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut stdout) - .await - .unwrap(); - let actual = ctx.fs().read_to_string(TEST_FILE_PATH).await.unwrap(); - assert_eq!( - format!("{}\n", actual.lines().next().unwrap()), - new_str, - "expected the first line to be updated to '{}'", - new_str - ); - assert_eq!( - actual.lines().skip(1).collect::>(), - TEST_FILE_CONTENTS.lines().collect::>(), - "the rest of the file should not have been updated" - ); - } - - #[tokio::test] - async fn test_fs_write_tool_insert_after_first_line() { - let ctx = setup_test_directory().await; - let mut stdout = std::io::stdout(); - - let new_str = "2: New second line!\n"; - let v = serde_json::json!({ - "path": TEST_FILE_PATH, - "command": "insert", - "insert_line": 1, - "new_str": new_str, - }); - - serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut stdout) - .await - .unwrap(); - let actual = ctx.fs().read_to_string(TEST_FILE_PATH).await.unwrap(); - assert_eq!( - format!("{}\n", actual.lines().nth(1).unwrap()), - new_str, - "expected the second line to be updated to '{}'", - new_str - ); - assert_eq!( - actual.lines().skip(2).collect::>(), - TEST_FILE_CONTENTS.lines().skip(1).collect::>(), - "the rest of the file should not have been updated" - ); - } - - #[tokio::test] - async fn test_fs_write_tool_insert_when_no_newlines_in_file() { - let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); - let mut stdout = std::io::stdout(); - - let test_file_path = "/file.txt"; - let test_file_contents = "hello there"; - ctx.fs().write(test_file_path, test_file_contents).await.unwrap(); - - let new_str = "test"; - - // First, test appending - let v = serde_json::json!({ - "path": test_file_path, - "command": "insert", - "insert_line": 1, - "new_str": new_str, - }); - serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut stdout) - .await - .unwrap(); - let actual = ctx.fs().read_to_string(test_file_path).await.unwrap(); - assert_eq!(actual, format!("{}{}\n", test_file_contents, new_str)); - - // Then, test prepending - let v = serde_json::json!({ - "path": test_file_path, - "command": "insert", - "insert_line": 0, - "new_str": new_str, - }); - serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut stdout) - .await - .unwrap(); - let actual = ctx.fs().read_to_string(test_file_path).await.unwrap(); - assert_eq!(actual, format!("{}{}{}\n", new_str, test_file_contents, new_str)); - } - - #[tokio::test] - async fn test_fs_write_tool_append() { - let ctx = setup_test_directory().await; - let mut stdout = std::io::stdout(); - - // Test appending to existing file - let content_to_append = "5: Appended line"; - let v = serde_json::json!({ - "path": TEST_FILE_PATH, - "command": "append", - "new_str": content_to_append, - }); - - serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut stdout) - .await - .unwrap(); - - let actual = ctx.fs().read_to_string(TEST_FILE_PATH).await.unwrap(); - assert_eq!( - actual, - format!("{}{}\n", TEST_FILE_CONTENTS, content_to_append), - "Content should be appended to the end of the file with a newline added" - ); - - // Test appending to non-existent file (should fail) - let new_file_path = "/new_append_file.txt"; - let content = "This is a new file created by append"; - let v = serde_json::json!({ - "path": new_file_path, - "command": "append", - "new_str": content, - }); - - let result = serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut stdout) - .await; - - assert!(result.is_err(), "Appending to non-existent file should fail"); - } - - #[test] - fn test_lines_with_context() { - let content = "Hello\nWorld!\nhow\nare\nyou\ntoday?"; - assert_eq!(get_lines_with_context(content, 1, 1, 1), ("", 1, "World!\n", 2)); - assert_eq!(get_lines_with_context(content, 0, 0, 2), ("", 1, "Hello\nWorld!\n", 2)); - assert_eq!( - get_lines_with_context(content, 2, 4, 50), - ("Hello\n", 1, "you\ntoday?", 6) - ); - assert_eq!(get_lines_with_context(content, 4, 100, 2), ("World!\nhow\n", 2, "", 6)); - } - - #[test] - fn test_gutter_width() { - assert_eq!(terminal_width_required_for_line_count(1), 1); - assert_eq!(terminal_width_required_for_line_count(9), 1); - assert_eq!(terminal_width_required_for_line_count(10), 2); - assert_eq!(terminal_width_required_for_line_count(99), 2); - assert_eq!(terminal_width_required_for_line_count(100), 3); - assert_eq!(terminal_width_required_for_line_count(999), 3); - } -} diff --git a/crates/q_chat/src/tools/gh_issue.rs b/crates/q_chat/src/tools/gh_issue.rs deleted file mode 100644 index ace4663873..0000000000 --- a/crates/q_chat/src/tools/gh_issue.rs +++ /dev/null @@ -1,222 +0,0 @@ -use std::collections::{ - HashMap, - VecDeque, -}; -use std::io::Write; - -use crossterm::style::Color; -use crossterm::{ - queue, - style, -}; -use eyre::{ - Result, - WrapErr, - eyre, -}; -use fig_os_shim::Context; -use serde::Deserialize; - -use super::super::context::ContextManager; -use super::super::util::issue::IssueCreator; -use super::{ - InvokeOutput, - ToolPermission, -}; -use crate::token_counter::TokenCounter; - -#[derive(Debug, Clone, Deserialize)] -pub struct GhIssue { - pub title: String, - pub expected_behavior: Option, - pub actual_behavior: Option, - pub steps_to_reproduce: Option, - - #[serde(skip_deserializing)] - pub context: Option, -} - -#[derive(Debug, Clone)] -pub struct GhIssueContext { - pub context_manager: Option, - pub transcript: VecDeque, - pub failed_request_ids: Vec, - pub tool_permissions: HashMap, - pub interactive: bool, -} - -/// Max amount of characters to include in the transcript. -const MAX_TRANSCRIPT_CHAR_LEN: usize = 3_000; - -impl GhIssue { - pub async fn invoke(&self, _updates: impl Write) -> Result { - let Some(context) = self.context.as_ref() else { - return Err(eyre!( - "report_issue: Required tool context (GhIssueContext) not set by the program." - )); - }; - - // Prepare additional details from the chat session - let additional_environment = [ - Self::get_chat_settings(context), - Self::get_request_ids(context), - Self::get_context(context).await, - ] - .join("\n\n"); - - // Add chat history to the actual behavior text. - let actual_behavior = self.actual_behavior.as_ref().map_or_else( - || Self::get_transcript(context), - |behavior| format!("{behavior}\n\n{}\n", Self::get_transcript(context)), - ); - - let _ = IssueCreator { - title: Some(self.title.clone()), - expected_behavior: self.expected_behavior.clone(), - actual_behavior: Some(actual_behavior), - steps_to_reproduce: self.steps_to_reproduce.clone(), - additional_environment: Some(additional_environment), - } - .create_url() - .await - .wrap_err("failed to invoke gh issue tool"); - - Ok(Default::default()) - } - - pub fn set_context(&mut self, context: GhIssueContext) { - self.context = Some(context); - } - - fn get_transcript(context: &GhIssueContext) -> String { - let mut transcript_str = String::from("```\n[chat-transcript]\n"); - let mut is_truncated = false; - let transcript: Vec = context.transcript - .iter() - .rev() // To take last N items - .scan(0, |user_msg_char_count, line| { - if *user_msg_char_count >= MAX_TRANSCRIPT_CHAR_LEN { - is_truncated = true; - return None; - } - let remaining_chars = MAX_TRANSCRIPT_CHAR_LEN - *user_msg_char_count; - let trimmed_line = if line.len() > remaining_chars { - &line[..remaining_chars] - } else { - line - }; - *user_msg_char_count += trimmed_line.len(); - - // backticks will mess up the markdown - let text = trimmed_line.replace("```", r"\```"); - Some(text) - }) - .collect::>() - .into_iter() - .rev() // Now return items to the proper order - .collect(); - - if !transcript.is_empty() { - transcript_str.push_str(&transcript.join("\n\n")); - } else { - transcript_str.push_str("No chat history found."); - } - - if is_truncated { - transcript_str.push_str("\n\n(...truncated)"); - } - transcript_str.push_str("\n```"); - transcript_str - } - - fn get_request_ids(context: &GhIssueContext) -> String { - format!( - "[chat-failed_request_ids]\n{}", - if context.failed_request_ids.is_empty() { - "none".to_string() - } else { - context.failed_request_ids.join("\n") - } - ) - } - - async fn get_context(context: &GhIssueContext) -> String { - let mut ctx_str = "[chat-context]\n".to_string(); - let Some(ctx_manager) = &context.context_manager else { - ctx_str.push_str("No context available."); - return ctx_str; - }; - - ctx_str.push_str(&format!("current_profile={}\n", ctx_manager.current_profile)); - match ctx_manager.list_profiles().await { - Ok(profiles) if !profiles.is_empty() => { - ctx_str.push_str(&format!("profiles=\n{}\n\n", profiles.join("\n"))); - }, - _ => ctx_str.push_str("profiles=none\n\n"), - } - - // Context file categories - if ctx_manager.global_config.paths.is_empty() { - ctx_str.push_str("global_context=none\n\n"); - } else { - ctx_str.push_str(&format!( - "global_context=\n{}\n\n", - &ctx_manager.global_config.paths.join("\n") - )); - } - - if ctx_manager.profile_config.paths.is_empty() { - ctx_str.push_str("profile_context=none\n\n"); - } else { - ctx_str.push_str(&format!( - "profile_context=\n{}\n\n", - &ctx_manager.profile_config.paths.join("\n") - )); - } - - // Handle context files - match ctx_manager.get_context_files(false).await { - Ok(context_files) if !context_files.is_empty() => { - ctx_str.push_str("files=\n"); - let total_size: usize = context_files - .iter() - .map(|(file, content)| { - let size = TokenCounter::count_tokens(content); - ctx_str.push_str(&format!("{}, {} tkns\n", file, size)); - size - }) - .sum(); - ctx_str.push_str(&format!("total context size={total_size} tkns")); - }, - _ => ctx_str.push_str("files=none"), - } - - ctx_str - } - - fn get_chat_settings(context: &GhIssueContext) -> String { - let mut result_str = "[chat-settings]\n".to_string(); - result_str.push_str(&format!("interactive={}", context.interactive)); - - result_str.push_str("\n\n[chat-trusted_tools]"); - for (tool, permission) in context.tool_permissions.iter() { - result_str.push_str(&format!("\n{tool}={}", permission.trusted)); - } - - result_str - } - - pub fn queue_description(&self, updates: &mut impl Write) -> Result<()> { - Ok(queue!( - updates, - style::Print("I will prepare a github issue with our conversation history.\n\n"), - style::SetForegroundColor(Color::Green), - style::Print(format!("Title: {}\n", &self.title)), - style::ResetColor - )?) - } - - pub async fn validate(&mut self, _ctx: &Context) -> Result<()> { - Ok(()) - } -} diff --git a/crates/q_chat/src/tools/mod.rs b/crates/q_chat/src/tools/mod.rs deleted file mode 100644 index 279586736b..0000000000 --- a/crates/q_chat/src/tools/mod.rs +++ /dev/null @@ -1,433 +0,0 @@ -pub mod custom_tool; -pub mod execute_bash; -pub mod fs_read; -pub mod fs_write; -pub mod gh_issue; -pub mod use_aws; - -use std::collections::HashMap; -use std::io::Write; -use std::path::{ - Path, - PathBuf, -}; - -use aws_smithy_types::{ - Document, - Number as SmithyNumber, -}; -use crossterm::style::Stylize; -use custom_tool::CustomTool; -use execute_bash::ExecuteBash; -use eyre::Result; -use fig_os_shim::Context; -use fs_read::FsRead; -use fs_write::FsWrite; -use gh_issue::GhIssue; -use serde::{ - Deserialize, - Serialize, -}; -use use_aws::UseAws; - -use super::consts::MAX_TOOL_RESPONSE_SIZE; - -/// Represents an executable tool use. -#[derive(Debug, Clone)] -pub enum Tool { - FsRead(FsRead), - FsWrite(FsWrite), - ExecuteBash(ExecuteBash), - UseAws(UseAws), - Custom(CustomTool), - GhIssue(GhIssue), -} - -impl Tool { - /// The display name of a tool - pub fn display_name(&self) -> String { - match self { - Tool::FsRead(_) => "fs_read", - Tool::FsWrite(_) => "fs_write", - Tool::ExecuteBash(_) => "execute_bash", - Tool::UseAws(_) => "use_aws", - Tool::Custom(custom_tool) => &custom_tool.name, - Tool::GhIssue(_) => "gh_issue", - } - .to_owned() - } - - /// Whether or not the tool should prompt the user to accept before [Self::invoke] is called. - pub fn requires_acceptance(&self, _ctx: &Context) -> bool { - match self { - Tool::FsRead(_) => false, - Tool::FsWrite(_) => true, - Tool::ExecuteBash(execute_bash) => execute_bash.requires_acceptance(), - Tool::UseAws(use_aws) => use_aws.requires_acceptance(), - Tool::Custom(_) => true, - Tool::GhIssue(_) => false, - } - } - - /// Invokes the tool asynchronously - pub async fn invoke(&self, context: &Context, updates: &mut impl Write) -> Result { - match self { - Tool::FsRead(fs_read) => fs_read.invoke(context, updates).await, - Tool::FsWrite(fs_write) => fs_write.invoke(context, updates).await, - Tool::ExecuteBash(execute_bash) => execute_bash.invoke(updates).await, - Tool::UseAws(use_aws) => use_aws.invoke(context, updates).await, - Tool::Custom(custom_tool) => custom_tool.invoke(context, updates).await, - Tool::GhIssue(gh_issue) => gh_issue.invoke(updates).await, - } - } - - /// Queues up a tool's intention in a human readable format - pub async fn queue_description(&self, ctx: &Context, updates: &mut impl Write) -> Result<()> { - match self { - Tool::FsRead(fs_read) => fs_read.queue_description(ctx, updates).await, - Tool::FsWrite(fs_write) => fs_write.queue_description(ctx, updates), - Tool::ExecuteBash(execute_bash) => execute_bash.queue_description(updates), - Tool::UseAws(use_aws) => use_aws.queue_description(updates), - Tool::Custom(custom_tool) => custom_tool.queue_description(updates), - Tool::GhIssue(gh_issue) => gh_issue.queue_description(updates), - } - } - - /// Validates the tool with the arguments supplied - pub async fn validate(&mut self, ctx: &Context) -> Result<()> { - match self { - Tool::FsRead(fs_read) => fs_read.validate(ctx).await, - Tool::FsWrite(fs_write) => fs_write.validate(ctx).await, - Tool::ExecuteBash(execute_bash) => execute_bash.validate(ctx).await, - Tool::UseAws(use_aws) => use_aws.validate(ctx).await, - Tool::Custom(custom_tool) => custom_tool.validate(ctx).await, - Tool::GhIssue(gh_issue) => gh_issue.validate(ctx).await, - } - } -} - -#[derive(Debug, Clone)] -pub struct ToolPermission { - pub trusted: bool, -} - -#[derive(Debug, Clone)] -/// Holds overrides for tool permissions. -/// Tools that do not have an associated ToolPermission should use -/// their default logic to determine to permission. -pub struct ToolPermissions { - pub permissions: HashMap, -} - -impl ToolPermissions { - pub fn new(capacity: usize) -> Self { - Self { - permissions: HashMap::with_capacity(capacity), - } - } - - pub fn is_trusted(&self, tool_name: &str) -> bool { - self.permissions.get(tool_name).is_some_and(|perm| perm.trusted) - } - - /// Returns a label to describe the permission status for a given tool. - pub fn display_label(&self, tool_name: &str) -> String { - if self.has(tool_name) { - if self.is_trusted(tool_name) { - format!(" {}", "trusted".dark_green().bold()) - } else { - format!(" {}", "not trusted".dark_grey()) - } - } else { - Self::default_permission_label(tool_name) - } - } - - pub fn trust_tool(&mut self, tool_name: &str) { - self.permissions - .insert(tool_name.to_string(), ToolPermission { trusted: true }); - } - - pub fn untrust_tool(&mut self, tool_name: &str) { - self.permissions - .insert(tool_name.to_string(), ToolPermission { trusted: false }); - } - - pub fn reset(&mut self) { - self.permissions.clear(); - } - - pub fn reset_tool(&mut self, tool_name: &str) { - self.permissions.remove(tool_name); - } - - pub fn has(&self, tool_name: &str) -> bool { - self.permissions.contains_key(tool_name) - } - - /// Provide default permission labels for the built-in set of tools. - /// Unknown tools are assumed to be "Per-request" - // This "static" way avoids needing to construct a tool instance. - fn default_permission_label(tool_name: &str) -> String { - let label = match tool_name { - "fs_read" => "trusted".dark_green().bold(), - "fs_write" => "not trusted".dark_grey(), - "execute_bash" => "trust read-only commands".dark_grey(), - "use_aws" => "trust read-only commands".dark_grey(), - "report_issue" => "trusted".dark_green().bold(), - _ => "not trusted".dark_grey(), - }; - - format!("{} {label}", "*".reset()) - } -} - -/// A tool specification to be sent to the model as part of a conversation. Maps to -/// [BedrockToolSpecification]. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolSpec { - pub name: String, - pub description: String, - #[serde(alias = "inputSchema")] - pub input_schema: InputSchema, - #[serde(skip_serializing, default = "tool_origin")] - pub tool_origin: ToolOrigin, -} - -#[derive(Debug, Clone, Deserialize, Eq, PartialEq, Hash)] -pub enum ToolOrigin { - Native, - McpServer(String), -} - -impl std::fmt::Display for ToolOrigin { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - ToolOrigin::Native => write!(f, "Built-in"), - ToolOrigin::McpServer(server) => write!(f, "{} (MCP)", server), - } - } -} - -fn tool_origin() -> ToolOrigin { - ToolOrigin::Native -} - -#[derive(Debug, Clone)] -pub struct QueuedTool { - pub id: String, - pub name: String, - pub accepted: bool, - pub tool: Tool, -} - -/// The schema specification describing a tool's fields. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct InputSchema(pub serde_json::Value); - -/// The output received from invoking a [Tool]. -#[derive(Debug, Default)] -pub struct InvokeOutput { - pub output: OutputKind, -} - -impl InvokeOutput { - pub fn as_str(&self) -> &str { - match &self.output { - OutputKind::Text(s) => s.as_str(), - OutputKind::Json(j) => j.as_str().unwrap_or_default(), - } - } -} - -#[non_exhaustive] -#[derive(Debug)] -pub enum OutputKind { - Text(String), - Json(serde_json::Value), -} - -impl Default for OutputKind { - fn default() -> Self { - Self::Text(String::new()) - } -} - -pub fn serde_value_to_document(value: serde_json::Value) -> Document { - match value { - serde_json::Value::Null => Document::Null, - serde_json::Value::Bool(bool) => Document::Bool(bool), - serde_json::Value::Number(number) => { - if let Some(num) = number.as_u64() { - Document::Number(SmithyNumber::PosInt(num)) - } else if number.as_i64().is_some_and(|n| n < 0) { - Document::Number(SmithyNumber::NegInt(number.as_i64().unwrap())) - } else { - Document::Number(SmithyNumber::Float(number.as_f64().unwrap_or_default())) - } - }, - serde_json::Value::String(string) => Document::String(string), - serde_json::Value::Array(vec) => { - Document::Array(vec.clone().into_iter().map(serde_value_to_document).collect::<_>()) - }, - serde_json::Value::Object(map) => Document::Object( - map.into_iter() - .map(|(k, v)| (k, serde_value_to_document(v))) - .collect::<_>(), - ), - } -} - -pub fn document_to_serde_value(value: Document) -> serde_json::Value { - use serde_json::Value; - match value { - Document::Object(map) => Value::Object( - map.into_iter() - .map(|(k, v)| (k, document_to_serde_value(v))) - .collect::<_>(), - ), - Document::Array(vec) => Value::Array(vec.clone().into_iter().map(document_to_serde_value).collect::<_>()), - Document::Number(number) => { - if let Ok(v) = TryInto::::try_into(number) { - Value::Number(v.into()) - } else if let Ok(v) = TryInto::::try_into(number) { - Value::Number(v.into()) - } else { - Value::Number( - serde_json::Number::from_f64(number.to_f64_lossy()) - .unwrap_or(serde_json::Number::from_f64(0.0).expect("converting from 0.0 will not fail")), - ) - } - }, - Document::String(s) => serde_json::Value::String(s), - Document::Bool(b) => serde_json::Value::Bool(b), - Document::Null => serde_json::Value::Null, - } -} - -/// Performs tilde expansion and other required sanitization modifications for handling tool use -/// path arguments. -/// -/// Required since path arguments are defined by the model. -#[allow(dead_code)] -fn sanitize_path_tool_arg(ctx: &Context, path: impl AsRef) -> PathBuf { - let mut res = PathBuf::new(); - // Expand `~` only if it is the first part. - let mut path = path.as_ref().components(); - match path.next() { - Some(p) if p.as_os_str() == "~" => { - res.push(ctx.env().home().unwrap_or_default()); - }, - Some(p) => res.push(p), - None => return res, - } - for p in path { - res.push(p); - } - // For testing scenarios, we need to make sure paths are appropriately handled in chroot test - // file systems since they are passed directly from the model. - ctx.fs().chroot_path(res) -} - -/// Converts `path` to a relative path according to the current working directory `cwd`. -fn absolute_to_relative(cwd: impl AsRef, path: impl AsRef) -> Result { - let cwd = cwd.as_ref().canonicalize()?; - let path = path.as_ref().canonicalize()?; - let mut cwd_parts = cwd.components().peekable(); - let mut path_parts = path.components().peekable(); - - // Skip common prefix - while let (Some(a), Some(b)) = (cwd_parts.peek(), path_parts.peek()) { - if a == b { - cwd_parts.next(); - path_parts.next(); - } else { - break; - } - } - - // ".." for any uncommon parts, then just append the rest of the path. - let mut relative = PathBuf::new(); - for _ in cwd_parts { - relative.push(".."); - } - for part in path_parts { - relative.push(part); - } - - Ok(relative) -} - -/// Small helper for formatting the path as a relative path, if able. -fn format_path(cwd: impl AsRef, path: impl AsRef) -> String { - absolute_to_relative(cwd, path.as_ref()) - .map(|p| p.to_string_lossy().to_string()) - // If we have three consecutive ".." then it should probably just stay as an absolute path. - .map(|p| { - if p.starts_with("../../..") { - path.as_ref().to_string_lossy().to_string() - } else { - p - } - }) - .unwrap_or(path.as_ref().to_string_lossy().to_string()) -} - -fn supports_truecolor(ctx: &Context) -> bool { - // Simple override to disable truecolor since shell_color doesn't use Context. - !ctx.env().get("Q_DISABLE_TRUECOLOR").is_ok_and(|s| !s.is_empty()) - && shell_color::get_color_support().contains(shell_color::ColorSupport::TERM24BIT) -} - -#[cfg(test)] -mod tests { - use fig_os_shim::EnvProvider; - - use super::*; - - #[tokio::test] - async fn test_tilde_path_expansion() { - let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); - - let actual = sanitize_path_tool_arg(&ctx, "~"); - assert_eq!( - actual, - ctx.fs().chroot_path(ctx.env().home().unwrap()), - "tilde should expand" - ); - let actual = sanitize_path_tool_arg(&ctx, "~/hello"); - assert_eq!( - actual, - ctx.fs().chroot_path(ctx.env().home().unwrap().join("hello")), - "tilde should expand" - ); - let actual = sanitize_path_tool_arg(&ctx, "/~"); - assert_eq!( - actual, - ctx.fs().chroot_path("/~"), - "tilde should not expand when not the first component" - ); - } - - #[tokio::test] - async fn test_format_path() { - async fn assert_paths(cwd: &str, path: &str, expected: &str) { - let ctx = Context::builder().with_test_home().await.unwrap().build_fake(); - let fs = ctx.fs(); - let cwd = sanitize_path_tool_arg(&ctx, cwd); - let path = sanitize_path_tool_arg(&ctx, path); - fs.create_dir_all(&cwd).await.unwrap(); - fs.create_dir_all(&path).await.unwrap(); - // Using `contains` since the chroot test directory will prefix the formatted path with a tmpdir - // path. - assert!(format_path(cwd, path).contains(expected)); - } - assert_paths("/Users/testuser/src", "/Users/testuser/Downloads", "../Downloads").await; - assert_paths( - "/Users/testuser/projects/MyProject/src", - "/Volumes/projects/MyProject/src", - "/Volumes/projects/MyProject/src", - ) - .await; - } -} diff --git a/crates/q_chat/src/tools/tool_index.json b/crates/q_chat/src/tools/tool_index.json deleted file mode 100644 index 397d856cfa..0000000000 --- a/crates/q_chat/src/tools/tool_index.json +++ /dev/null @@ -1,176 +0,0 @@ -{ - "execute_bash": { - "name": "execute_bash", - "description": "Execute the specified bash command.", - "input_schema": { - "type": "object", - "properties": { - "command": { - "type": "string", - "description": "Bash command to execute" - } - }, - "required": [ - "command" - ] - } - }, - "fs_read": { - "name": "fs_read", - "description": "Tool for reading files (for example, `cat -n`) and directories (for example, `ls -la`). The behavior of this tool is determined by the `mode` parameter. The available modes are:\n- line: Show lines in a file, given by an optional `start_line` and optional `end_line`.\n- directory: List directory contents. Content is returned in the \"long format\" of ls (that is, `ls -la`).\n- search: Search for a pattern in a file. The pattern is a string. The matching is case insensitive.\n\nExample Usage:\n1. Read all lines from a file: command=\"line\", path=\"/path/to/file.txt\"\n2. Read the last 5 lines from a file: command=\"line\", path=\"/path/to/file.txt\", start_line=-5\n3. List the files in the home directory: command=\"line\", path=\"~\"\n4. Recursively list files in a directory to a max depth of 2: command=\"line\", path=\"/path/to/directory\", depth=2\n5. Search for all instances of \"test\" in a file: command=\"search\", path=\"/path/to/file.txt\", pattern=\"test\"\n", - "input_schema": { - "type": "object", - "properties": { - "path": { - "description": "Path to the file or directory. The path should be absolute, or otherwise start with ~ for the user's home.", - "type": "string" - }, - "mode": { - "type": "string", - "enum": [ - "Line", - "Directory", - "Search" - ], - "description": "The mode to run in: `Line`, `Directory`, `Search`. `Line` and `Search` are only for text files, and `Directory` is only for directories." - }, - "start_line": { - "type": "integer", - "description": "Starting line number (optional, for Line mode). A negative index represents a line number starting from the end of the file.", - "default": 1 - }, - "end_line": { - "type": "integer", - "description": "Ending line number (optional, for Line mode). A negative index represents a line number starting from the end of the file.", - "default": -1 - }, - "pattern": { - "type": "string", - "description": "Pattern to search for (required, for Search mode). Case insensitive. The pattern matching is performed per line." - }, - "context_lines": { - "type": "integer", - "description": "Number of context lines around search results (optional, for Search mode)", - "default": 2 - }, - "depth": { - "type": "integer", - "description": "Depth of a recursive directory listing (optional, for Directory mode)", - "default": 0 - } - }, - "required": [ - "path", - "mode" - ] - } - }, - "fs_write": { - "name": "fs_write", - "description": "A tool for creating and editing files\n * The `create` command will override the file at `path` if it already exists as a file, and otherwise create a new file\n * The `append` command will add content to the end of an existing file, automatically adding a newline if the file doesn't end with one. The file must exist.\n Notes for using the `str_replace` command:\n * The `old_str` parameter should match EXACTLY one or more consecutive lines from the original file. Be mindful of whitespaces!\n * If the `old_str` parameter is not unique in the file, the replacement will not be performed. Make sure to include enough context in `old_str` to make it unique\n * The `new_str` parameter should contain the edited lines that should replace the `old_str`.", - "input_schema": { - "type": "object", - "properties": { - "command": { - "type": "string", - "enum": [ - "create", - "str_replace", - "insert", - "append" - ], - "description": "The commands to run. Allowed options are: `create`, `str_replace`, `insert`, `append`." - }, - "file_text": { - "description": "Required parameter of `create` command, with the content of the file to be created.", - "type": "string" - }, - "insert_line": { - "description": "Required parameter of `insert` command. The `new_str` will be inserted AFTER the line `insert_line` of `path`.", - "type": "integer" - }, - "new_str": { - "description": "Required parameter of `str_replace` command containing the new string. Required parameter of `insert` command containing the string to insert. Required parameter of `append` command containing the content to append to the file.", - "type": "string" - }, - "old_str": { - "description": "Required parameter of `str_replace` command containing the string in `path` to replace.", - "type": "string" - }, - "path": { - "description": "Absolute path to file or directory, e.g. `/repo/file.py` or `/repo`.", - "type": "string" - } - }, - "required": [ - "command", - "path" - ] - } - }, - "use_aws": { - "name": "use_aws", - "description": "Make an AWS CLI api call with the specified service, operation, and parameters. All arguments MUST conform to the AWS CLI specification. Should the output of the invocation indicate a malformed command, invoke help to obtain the the correct command.", - "input_schema": { - "type": "object", - "properties": { - "service_name": { - "type": "string", - "description": "The name of the AWS service. If you want to query s3, you should use s3api if possible." - }, - "operation_name": { - "type": "string", - "description": "The name of the operation to perform." - }, - "parameters": { - "type": "object", - "description": "The parameters for the operation. The parameter keys MUST conform to the AWS CLI specification. You should prefer to use JSON Syntax over shorthand syntax wherever possible. For parameters that are booleans, prioritize using flags with no value. Denote these flags with flag names as key and an empty string as their value. You should also prefer kebab case." - }, - "region": { - "type": "string", - "description": "Region name for calling the operation on AWS." - }, - "profile_name": { - "type": "string", - "description": "Optional: AWS profile name to use from ~/.aws/credentials. Defaults to default profile if not specified." - }, - "label": { - "type": "string", - "description": "Human readable description of the api that is being called." - } - }, - "required": [ - "region", - "service_name", - "operation_name", - "label" - ] - } - }, - "gh_issue": { - "name": "report_issue", - "description": "Opens the browser to a pre-filled gh (GitHub) issue template to report chat issues, bugs, or feature requests. Pre-filled information includes the conversation transcript, chat context, and chat request IDs from the service.", - "input_schema": { - "type": "object", - "properties": { - "title": { - "type": "string", - "description": "The title of the GitHub issue." - }, - "expected_behavior": { - "type": "string", - "description": "Optional: The expected chat behavior or action that did not happen." - }, - "actual_behavior": { - "type": "string", - "description": "Optional: The actual chat behavior that happened and demonstrates the issue or lack of a feature." - }, - "steps_to_reproduce": { - "type": "string", - "description": "Optional: Previous user chat requests or steps that were taken that may have resulted in the issue or error response." - } - }, - "required": ["title"] - } - } -} diff --git a/crates/q_chat/src/tools/use_aws.rs b/crates/q_chat/src/tools/use_aws.rs deleted file mode 100644 index cfdf97c3ff..0000000000 --- a/crates/q_chat/src/tools/use_aws.rs +++ /dev/null @@ -1,315 +0,0 @@ -use std::collections::HashMap; -use std::io::Write; -use std::process::Stdio; - -use bstr::ByteSlice; -use convert_case::{ - Case, - Casing, -}; -use crossterm::{ - queue, - style, -}; -use eyre::{ - Result, - WrapErr, -}; -use fig_os_shim::Context; -use serde::Deserialize; - -use super::{ - InvokeOutput, - MAX_TOOL_RESPONSE_SIZE, - OutputKind, -}; - -const READONLY_OPS: [&str; 6] = ["get", "describe", "list", "ls", "search", "batch_get"]; - -/// The environment variable name where we set additional metadata for the AWS CLI user agent. -const USER_AGENT_ENV_VAR: &str = "AWS_EXECUTION_ENV"; -const USER_AGENT_APP_NAME: &str = "AmazonQ-For-CLI"; -const USER_AGENT_VERSION_KEY: &str = "Version"; -const USER_AGENT_VERSION_VALUE: &str = env!("CARGO_PKG_VERSION"); - -// TODO: we should perhaps composite this struct with an interface that we can use to mock the -// actual cli with. That will allow us to more thoroughly test it. -#[derive(Debug, Clone, Deserialize)] -pub struct UseAws { - pub service_name: String, - pub operation_name: String, - pub parameters: Option>, - pub region: String, - pub profile_name: Option, - pub label: Option, -} - -impl UseAws { - pub fn requires_acceptance(&self) -> bool { - !READONLY_OPS.iter().any(|op| self.operation_name.starts_with(op)) - } - - pub async fn invoke(&self, _ctx: &Context, _updates: impl Write) -> Result { - let mut command = tokio::process::Command::new("aws"); - - // Set up environment variables - let mut env_vars: std::collections::HashMap = std::env::vars().collect(); - - // Set up additional metadata for the AWS CLI user agent - let user_agent_metadata_value = format!( - "{} {}/{}", - USER_AGENT_APP_NAME, USER_AGENT_VERSION_KEY, USER_AGENT_VERSION_VALUE - ); - - // If the user agent metadata env var already exists, append to it, otherwise set it - if let Some(existing_value) = env_vars.get(USER_AGENT_ENV_VAR) { - if !existing_value.is_empty() { - env_vars.insert( - USER_AGENT_ENV_VAR.to_string(), - format!("{} {}", existing_value, user_agent_metadata_value), - ); - } else { - env_vars.insert(USER_AGENT_ENV_VAR.to_string(), user_agent_metadata_value); - } - } else { - env_vars.insert(USER_AGENT_ENV_VAR.to_string(), user_agent_metadata_value); - } - - command.envs(env_vars).arg("--region").arg(&self.region); - if let Some(profile_name) = self.profile_name.as_deref() { - command.arg("--profile").arg(profile_name); - } - command.arg(&self.service_name).arg(&self.operation_name); - if let Some(parameters) = self.cli_parameters() { - for (name, val) in parameters { - command.arg(name); - if !val.is_empty() { - command.arg(val); - } - } - } - let output = command - .stdout(Stdio::piped()) - .stderr(Stdio::piped()) - .spawn() - .wrap_err_with(|| format!("Unable to spawn command '{:?}'", self))? - .wait_with_output() - .await - .wrap_err_with(|| format!("Unable to spawn command '{:?}'", self))?; - let status = output.status.code().unwrap_or(0).to_string(); - let stdout = output.stdout.to_str_lossy(); - let stderr = output.stderr.to_str_lossy(); - - let stdout = format!( - "{}{}", - &stdout[0..stdout.len().min(MAX_TOOL_RESPONSE_SIZE / 3)], - if stdout.len() > MAX_TOOL_RESPONSE_SIZE / 3 { - " ... truncated" - } else { - "" - } - ); - - let stderr = format!( - "{}{}", - &stderr[0..stderr.len().min(MAX_TOOL_RESPONSE_SIZE / 3)], - if stderr.len() > MAX_TOOL_RESPONSE_SIZE / 3 { - " ... truncated" - } else { - "" - } - ); - - if status.eq("0") { - Ok(InvokeOutput { - output: OutputKind::Json(serde_json::json!({ - "exit_status": status, - "stdout": stdout, - "stderr": stderr.clone() - })), - }) - } else { - Err(eyre::eyre!(stderr)) - } - } - - pub fn queue_description(&self, updates: &mut impl Write) -> Result<()> { - queue!( - updates, - style::Print("Running aws cli command:\n\n"), - style::Print(format!("Service name: {}\n", self.service_name)), - style::Print(format!("Operation name: {}\n", self.operation_name)), - )?; - if let Some(parameters) = &self.parameters { - queue!(updates, style::Print("Parameters: \n".to_string()))?; - for (name, value) in parameters { - match value { - serde_json::Value::String(s) if s.is_empty() => { - queue!(updates, style::Print(format!("- {}\n", name)))?; - }, - _ => { - queue!(updates, style::Print(format!("- {}: {}\n", name, value)))?; - }, - } - } - } - - if let Some(ref profile_name) = self.profile_name { - queue!(updates, style::Print(format!("Profile name: {}\n", profile_name)))?; - } else { - queue!(updates, style::Print("Profile name: default\n".to_string()))?; - } - - queue!(updates, style::Print(format!("Region: {}", self.region)))?; - - if let Some(ref label) = self.label { - queue!(updates, style::Print(format!("\nLabel: {}", label)))?; - } - Ok(()) - } - - pub async fn validate(&mut self, _ctx: &Context) -> Result<()> { - Ok(()) - } - - /// Returns the CLI arguments properly formatted as kebab case if parameters is - /// [Option::Some], otherwise None - fn cli_parameters(&self) -> Option> { - if let Some(parameters) = &self.parameters { - let mut params = vec![]; - for (param_name, val) in parameters { - let param_name = format!("--{}", param_name.trim_start_matches("--").to_case(Case::Kebab)); - let param_val = val.as_str().map(|s| s.to_string()).unwrap_or(val.to_string()); - params.push((param_name, param_val)); - } - Some(params) - } else { - None - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - macro_rules! use_aws { - ($value:tt) => { - serde_json::from_value::(serde_json::json!($value)).unwrap() - }; - } - - #[test] - fn test_requires_acceptance() { - let cmd = use_aws! {{ - "service_name": "ecs", - "operation_name": "list-task-definitions", - "region": "us-west-2", - "profile_name": "default", - "label": "" - }}; - assert!(!cmd.requires_acceptance()); - let cmd = use_aws! {{ - "service_name": "lambda", - "operation_name": "list-functions", - "region": "us-west-2", - "profile_name": "default", - "label": "" - }}; - assert!(!cmd.requires_acceptance()); - let cmd = use_aws! {{ - "service_name": "s3", - "operation_name": "put-object", - "region": "us-west-2", - "profile_name": "default", - "label": "" - }}; - assert!(cmd.requires_acceptance()); - } - - #[test] - fn test_use_aws_deser() { - let cmd = use_aws! {{ - "service_name": "s3", - "operation_name": "put-object", - "parameters": { - "TableName": "table-name", - "KeyConditionExpression": "PartitionKey = :pkValue" - }, - "region": "us-west-2", - "profile_name": "default", - "label": "" - }}; - let params = cmd.cli_parameters().unwrap(); - assert!( - params.iter().any(|p| p.0 == "--table-name" && p.1 == "table-name"), - "not found in {:?}", - params - ); - assert!( - params - .iter() - .any(|p| p.0 == "--key-condition-expression" && p.1 == "PartitionKey = :pkValue"), - "not found in {:?}", - params - ); - } - - #[tokio::test] - #[ignore = "not in ci"] - async fn test_aws_read_only() { - let ctx = Context::new_fake(); - - let v = serde_json::json!({ - "service_name": "s3", - "operation_name": "put-object", - // technically this wouldn't be a valid request with an empty parameter set but it's - // okay for this test - "parameters": {}, - "region": "us-west-2", - "profile_name": "default", - "label": "" - }); - - assert!( - serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut std::io::stdout()) - .await - .is_err() - ); - } - - #[tokio::test] - #[ignore = "not in ci"] - async fn test_aws_output() { - let ctx = Context::new_fake(); - - let v = serde_json::json!({ - "service_name": "s3", - "operation_name": "ls", - "parameters": {}, - "region": "us-west-2", - "profile_name": "default", - "label": "" - }); - let out = serde_json::from_value::(v) - .unwrap() - .invoke(&ctx, &mut std::io::stdout()) - .await - .unwrap(); - - if let OutputKind::Json(json) = out.output { - // depending on where the test is ran we might get different outcome here but it does - // not mean the tool is not working - let exit_status = json.get("exit_status").unwrap(); - if exit_status == 0 { - assert_eq!(json.get("stderr").unwrap(), ""); - } else { - assert_ne!(json.get("stderr").unwrap(), ""); - } - } else { - panic!("Expected JSON output"); - } - } -} diff --git a/crates/q_chat/src/util/issue.rs b/crates/q_chat/src/util/issue.rs deleted file mode 100644 index 05457308da..0000000000 --- a/crates/q_chat/src/util/issue.rs +++ /dev/null @@ -1,82 +0,0 @@ -use anstream::{ - eprintln, - println, -}; -use crossterm::style::Stylize; -use eyre::Result; -use fig_diagnostic::Diagnostics; -use fig_util::GITHUB_REPO_NAME; -use fig_util::system_info::is_remote; - -const TEMPLATE_NAME: &str = "1_bug_report_template.yml"; - -pub struct IssueCreator { - /// Issue title - pub title: Option, - /// Issue description - pub expected_behavior: Option, - /// Issue description - pub actual_behavior: Option, - /// Issue description - pub steps_to_reproduce: Option, - /// Issue description - pub additional_environment: Option, -} - -impl IssueCreator { - pub async fn create_url(&self) -> Result { - println!("Heading over to GitHub..."); - - let warning = |text: &String| { - format!("\n\n{text}") - }; - let diagnostics = Diagnostics::new().await; - - let os = match &diagnostics.system_info.os { - Some(os) => os.to_string(), - None => "None".to_owned(), - }; - - let diagnostic_info = match diagnostics.user_readable() { - Ok(diagnostics) => diagnostics, - Err(err) => { - eprintln!("Error getting diagnostics: {err}"); - "Error occurred while generating diagnostics".to_owned() - }, - }; - - let environment = match &self.additional_environment { - Some(ctx) => format!("{diagnostic_info}\n{ctx}"), - None => diagnostic_info, - }; - - let mut params = Vec::new(); - params.push(("template", TEMPLATE_NAME.to_string())); - params.push(("os", os)); - params.push(("environment", warning(&environment))); - - if let Some(t) = self.title.clone() { - params.push(("title", t)); - } - if let Some(t) = self.expected_behavior.as_ref() { - params.push(("expected", warning(t))); - } - if let Some(t) = self.actual_behavior.as_ref() { - params.push(("actual", warning(t))); - } - if let Some(t) = self.steps_to_reproduce.as_ref() { - params.push(("reproduce", warning(t))); - } - - let url = url::Url::parse_with_params( - &format!("https://github.com/{GITHUB_REPO_NAME}/issues/new"), - params.iter(), - )?; - - if is_remote() || fig_util::open_url_async(url.as_str()).await.is_err() { - println!("Issue Url: {}", url.as_str().underlined()); - } - - Ok(url) - } -} diff --git a/crates/q_chat/src/util/mod.rs b/crates/q_chat/src/util/mod.rs deleted file mode 100644 index f2b2d4f334..0000000000 --- a/crates/q_chat/src/util/mod.rs +++ /dev/null @@ -1,114 +0,0 @@ -pub mod issue; -pub mod shared_writer; -pub mod ui; - -use std::io::Write; -use std::time::Duration; - -use fig_util::system_info::in_cloudshell; - -use super::ChatError; - -const GOV_REGIONS: &[&str] = &["us-gov-east-1", "us-gov-west-1"]; - -pub fn region_check(capability: &'static str) -> eyre::Result<()> { - let Ok(region) = std::env::var("AWS_REGION") else { - return Ok(()); - }; - - if in_cloudshell() && GOV_REGIONS.contains(®ion.as_str()) { - eyre::bail!("AWS GovCloud ({region}) is not supported for {capability}."); - } - - Ok(()) -} - -pub fn truncate_safe(s: &str, max_bytes: usize) -> &str { - if s.len() <= max_bytes { - return s; - } - - let mut byte_count = 0; - let mut char_indices = s.char_indices(); - - for (byte_idx, _) in &mut char_indices { - if byte_count + (byte_idx - byte_count) > max_bytes { - break; - } - byte_count = byte_idx; - } - - &s[..byte_count] -} - -pub fn animate_output(output: &mut impl Write, bytes: &[u8]) -> Result<(), ChatError> { - for b in bytes.chunks(12) { - output.write_all(b)?; - std::thread::sleep(Duration::from_millis(16)); - } - Ok(()) -} - -/// Play the terminal bell notification sound -pub fn play_notification_bell(requires_confirmation: bool) { - // Don't play bell for tools that don't require confirmation - if !requires_confirmation { - return; - } - - // Check if we should play the bell based on terminal type - if should_play_bell() { - print!("\x07"); // ASCII bell character - std::io::stdout().flush().unwrap(); - } -} - -/// Determine if we should play the bell based on terminal type -fn should_play_bell() -> bool { - // Get the TERM environment variable - if let Ok(term) = std::env::var("TERM") { - // List of terminals known to handle bell character well - let bell_compatible_terms = [ - "xterm", - "xterm-256color", - "screen", - "screen-256color", - "tmux", - "tmux-256color", - "rxvt", - "rxvt-unicode", - "linux", - "konsole", - "gnome", - "gnome-256color", - "alacritty", - "iterm2", - ]; - - // Check if the current terminal is in the compatible list - for compatible_term in bell_compatible_terms.iter() { - if term.starts_with(compatible_term) { - return true; - } - } - - // For other terminals, don't play the bell - return false; - } - - // If TERM is not set, default to not playing the bell - false -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_truncate_safe() { - assert_eq!(truncate_safe("Hello World", 5), "Hello"); - assert_eq!(truncate_safe("Hello ", 5), "Hello"); - assert_eq!(truncate_safe("Hello World", 11), "Hello World"); - assert_eq!(truncate_safe("Hello World", 15), "Hello World"); - } -} diff --git a/crates/q_chat/src/util/shared_writer.rs b/crates/q_chat/src/util/shared_writer.rs deleted file mode 100644 index c5a2f55c41..0000000000 --- a/crates/q_chat/src/util/shared_writer.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::io::{ - self, - Write, -}; -use std::sync::{ - Arc, - Mutex, -}; - -/// A thread-safe wrapper for any Write implementation. -#[derive(Clone)] -pub struct SharedWriter { - inner: Arc>>, -} - -impl SharedWriter { - pub fn new(writer: W) -> Self - where - W: Write + Send + 'static, - { - Self { - inner: Arc::new(Mutex::new(Box::new(writer))), - } - } - - pub fn stdout() -> Self { - Self::new(io::stdout()) - } - - pub fn stderr() -> Self { - Self::new(io::stderr()) - } - - pub fn null() -> Self { - Self::new(NullWriter {}) - } -} - -impl std::fmt::Debug for SharedWriter { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SharedWriter").finish() - } -} - -impl Write for SharedWriter { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.inner.lock().expect("Mutex poisoned").write(buf) - } - - fn flush(&mut self) -> io::Result<()> { - self.inner.lock().expect("Mutex poisoned").flush() - } -} - -#[derive(Debug, Clone)] -pub struct NullWriter {} - -impl Write for NullWriter { - fn write(&mut self, buf: &[u8]) -> io::Result { - Ok(buf.len()) - } - - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} - -#[derive(Debug, Clone)] -pub struct TestWriterWithSink { - pub sink: Arc>>, -} - -impl TestWriterWithSink { - #[allow(dead_code)] - pub fn get_content(&self) -> Vec { - self.sink.lock().unwrap().clone() - } -} - -impl Write for TestWriterWithSink { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.sink.lock().unwrap().append(&mut buf.to_vec()); - Ok(buf.len()) - } - - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } -} diff --git a/crates/q_chat/src/util/ui.rs b/crates/q_chat/src/util/ui.rs deleted file mode 100644 index 1c5bfd5da9..0000000000 --- a/crates/q_chat/src/util/ui.rs +++ /dev/null @@ -1,212 +0,0 @@ -use crossterm::style::{ - Color, - Stylize, -}; -use crossterm::terminal::{ - self, - ClearType, -}; -use crossterm::{ - cursor, - execute, - style, -}; -use eyre::Result; -use strip_ansi_escapes::strip_str; - -use super::shared_writer::SharedWriter; - -pub fn draw_box( - mut output: SharedWriter, - title: &str, - content: &str, - box_width: usize, - border_color: Color, -) -> Result<()> { - let inner_width = box_width - 4; // account for │ and padding - - // wrap the single line into multiple lines respecting inner width - // Manually wrap the text by splitting at word boundaries - let mut wrapped_lines = Vec::new(); - let mut line = String::new(); - - for word in content.split_whitespace() { - if line.len() + word.len() < inner_width { - if !line.is_empty() { - line.push(' '); - } - line.push_str(word); - } else { - // Here we need to account for words that are too long as well - if word.len() >= inner_width { - let mut start = 0_usize; - for (i, _) in word.chars().enumerate() { - if i - start >= inner_width { - wrapped_lines.push(word[start..i].to_string()); - start = i; - } - } - wrapped_lines.push(word[start..].to_string()); - line = String::new(); - } else { - wrapped_lines.push(line); - line = word.to_string(); - } - } - } - - if !line.is_empty() { - wrapped_lines.push(line); - } - - let side_len = (box_width.saturating_sub(title.len())) / 2; - let top_border = format!( - "{} {} {}", - style::style(format!("╭{}", "─".repeat(side_len - 2))).with(border_color), - title, - style::style(format!("{}╮", "─".repeat(box_width - side_len - title.len() - 2))).with(border_color) - ); - - execute!( - output, - terminal::Clear(ClearType::CurrentLine), - cursor::MoveToColumn(0), - style::Print(format!("{top_border}\n")), - )?; - - // Top vertical padding - let top_vertical_border = format!( - "{}", - style::style(format!("│{: ::new())); - let test_writer = TestWriterWithSink { sink: buf.clone() }; - let output = SharedWriter::new(test_writer.clone()); - - // Test with a short tip - let short_tip = "This is a short tip"; - draw_box( - output.clone(), - "Did you know?", - short_tip, - GREETING_BREAK_POINT, - Color::DarkGrey, - ) - .expect("Failed to draw tip box"); - - // Test with a longer tip that should wrap - let long_tip = "This is a much longer tip that should wrap to multiple lines because it exceeds the inner width of the tip box which is calculated based on the GREETING_BREAK_POINT constant"; - draw_box( - output.clone(), - "Did you know?", - long_tip, - GREETING_BREAK_POINT, - Color::DarkGrey, - ) - .expect("Failed to draw tip box"); - - // Test with a long tip with two long words that should wrap - let long_tip_with_one_long_word = { - let mut s = "a".repeat(200); - s.push(' '); - s.push_str(&"a".repeat(200)); - s - }; - draw_box( - output.clone(), - "Did you know?", - long_tip_with_one_long_word.as_str(), - GREETING_BREAK_POINT, - Color::DarkGrey, - ) - .expect("Failed to draw tip box"); - // Test with a long tip with two long words that should wrap - let long_tip_with_two_long_words = "a".repeat(200); - draw_box( - output.clone(), - "Did you know?", - long_tip_with_two_long_words.as_str(), - GREETING_BREAK_POINT, - Color::DarkGrey, - ) - .expect("Failed to draw tip box"); - - // Get the output and verify it contains expected formatting elements - let content = test_writer.get_content(); - let output_str = content.to_str_lossy(); - - // Check for box drawing characters - assert!(output_str.contains("╭"), "Output should contain top-left corner"); - assert!(output_str.contains("╮"), "Output should contain top-right corner"); - assert!(output_str.contains("│"), "Output should contain vertical lines"); - assert!(output_str.contains("╰"), "Output should contain bottom-left corner"); - assert!(output_str.contains("╯"), "Output should contain bottom-right corner"); - - // Check for the label - assert!( - output_str.contains("Did you know?"), - "Output should contain the 'Did you know?' label" - ); - - // Check that both tips are present - assert!(output_str.contains(short_tip), "Output should contain the short tip"); - - // For the long tip, we check for substrings since it will be wrapped - let long_tip_parts: Vec<&str> = long_tip.split_whitespace().collect(); - for part in long_tip_parts.iter().take(3) { - assert!(output_str.contains(part), "Output should contain parts of the long tip"); - } - } -} diff --git a/crates/q_cli/Cargo.toml b/crates/q_cli/Cargo.toml index fe1d00008e..4460fa8ce0 100644 --- a/crates/q_cli/Cargo.toml +++ b/crates/q_cli/Cargo.toml @@ -60,7 +60,6 @@ mcp_client.workspace = true mimalloc.workspace = true owo-colors = "4.2.0" parking_lot.workspace = true -q_chat.workspace = true rand.workspace = true regex.workspace = true semver.workspace = true diff --git a/crates/q_cli/src/cli/issue.rs b/crates/q_cli/src/cli/issue.rs index 028d7eec68..3974951e7a 100644 --- a/crates/q_cli/src/cli/issue.rs +++ b/crates/q_cli/src/cli/issue.rs @@ -41,15 +41,7 @@ impl IssueArgs { _ => joined_description, }; - let _ = q_chat::util::issue::IssueCreator { - title: Some(issue_title), - expected_behavior: None, - actual_behavior: None, - steps_to_reproduce: None, - additional_environment: None, - } - .create_url() - .await; + todo!(); Ok(ExitCode::SUCCESS) } diff --git a/crates/q_cli/src/cli/mod.rs b/crates/q_cli/src/cli/mod.rs index fc6244b28a..a0b407f34b 100644 --- a/crates/q_cli/src/cli/mod.rs +++ b/crates/q_cli/src/cli/mod.rs @@ -60,7 +60,6 @@ use fig_util::{ system_info, }; use internal::InternalSubcommand; -use q_chat::cli::Chat; use serde::Serialize; use tracing::{ Level, @@ -187,8 +186,11 @@ pub enum CliRootCommands { /// Open the dashboard Dashboard, /// AI assistant in your terminal - #[command(alias("q"))] - Chat(Chat), + Chat { + /// Args for the chat command + #[arg(trailing_var_arg = true, allow_hyphen_values = true)] + args: Vec, + }, /// Inline shell completions #[command(subcommand)] Inline(inline::InlineSubcommand), @@ -333,11 +335,11 @@ impl Cli { CliRootCommands::Telemetry(subcommand) => subcommand.execute().await, CliRootCommands::Version { changelog } => Self::print_version(changelog), CliRootCommands::Dashboard => launch_dashboard(false).await, - CliRootCommands::Chat(args) => q_chat::launch_chat(args).await, + CliRootCommands::Chat { args } => todo!(), CliRootCommands::Inline(subcommand) => subcommand.execute(&cli_context).await, }, // Root command - None => q_chat::launch_chat(q_chat::cli::Chat::default()).await, + None => todo!(), } } @@ -521,19 +523,6 @@ mod test { verbose: 0, help_all: true, }); - - assert_eq!(Cli::parse_from([CLI_BINARY_NAME, "chat", "-vv"]), Cli { - subcommand: Some(CliRootCommands::Chat(Chat { - accept_all: false, - no_interactive: false, - input: None, - profile: None, - trust_all_tools: false, - trust_tools: None, - })), - verbose: 2, - help_all: false, - }); } /// This test validates that the restart command maintains the same CLI facing definition @@ -673,109 +662,4 @@ mod test { changelog: Some("1.8.0".to_string()), }); } - - #[test] - fn test_chat_with_context_profile() { - assert_parse!( - ["chat", "--profile", "my-profile"], - CliRootCommands::Chat(Chat { - accept_all: false, - no_interactive: false, - input: None, - profile: Some("my-profile".to_string()), - trust_all_tools: false, - trust_tools: None, - }) - ); - } - - #[test] - fn test_chat_with_context_profile_and_input() { - assert_parse!( - ["chat", "--profile", "my-profile", "Hello"], - CliRootCommands::Chat(Chat { - accept_all: false, - no_interactive: false, - input: Some("Hello".to_string()), - profile: Some("my-profile".to_string()), - trust_all_tools: false, - trust_tools: None, - }) - ); - } - - #[test] - fn test_chat_with_context_profile_and_accept_all() { - assert_parse!( - ["chat", "--profile", "my-profile", "--accept-all"], - CliRootCommands::Chat(Chat { - accept_all: true, - no_interactive: false, - input: None, - profile: Some("my-profile".to_string()), - trust_all_tools: false, - trust_tools: None, - }) - ); - } - - #[test] - fn test_chat_with_no_interactive() { - assert_parse!( - ["chat", "--no-interactive"], - CliRootCommands::Chat(Chat { - accept_all: false, - no_interactive: true, - input: None, - profile: None, - trust_all_tools: false, - trust_tools: None, - }) - ); - } - - #[test] - fn test_chat_with_tool_trust_all() { - assert_parse!( - ["chat", "--trust-all-tools"], - CliRootCommands::Chat(Chat { - accept_all: false, - no_interactive: false, - input: None, - profile: None, - trust_all_tools: true, - trust_tools: None, - }) - ); - } - - #[test] - fn test_chat_with_tool_trust_none() { - assert_parse!( - ["chat", "--trust-tools="], - CliRootCommands::Chat(Chat { - accept_all: false, - no_interactive: false, - input: None, - profile: None, - trust_all_tools: false, - trust_tools: Some(vec!["".to_string()]), - }) - ); - } - - #[test] - fn test_chat_with_tool_trust_some() { - assert_parse!( - ["chat", "--trust-tools=fs_read,fs_write"], - CliRootCommands::Chat(Chat { - accept_all: false, - no_interactive: false, - input: None, - profile: None, - trust_all_tools: false, - trust_tools: Some(vec!["fs_read".to_string(), "fs_write".to_string()]), - }) - ); - } } diff --git a/typos.toml b/typos.toml index 40b3547a2b..4a2d6eada0 100644 --- a/typos.toml +++ b/typos.toml @@ -4,7 +4,7 @@ extend-exclude = [ "crates/amzn-codewhisperer-client", "crates/amzn-codewhisperer-streaming-client", "crates/amzn-consolas-client", - "crates/amzn-toolkit-telemetry", + "crates/amzn-toolkit-telemetry-client", "crates/amzn-qdeveloper-client", "crates/amzn-qdeveloper-streaming-client", "crates/aws-toolkit-telemetry-definitions/def.json", @@ -12,7 +12,7 @@ extend-exclude = [ "crates/zbus_names", "packages/fuzzysort", "packages/dashboard-app/public/license/NOTICE.txt", - "pnpm-lock.yaml" + "pnpm-lock.yaml", ] [default] From 4ecda42cb118bce5d22d9c80ae3c7a81e894d883 Mon Sep 17 00:00:00 2001 From: Chay Nabors Date: Sat, 3 May 2025 01:36:23 -0700 Subject: [PATCH 3/3] give temporary name --- Cargo.lock | 232 ++++++++++++++++++------------------- crates/kiro-cli/Cargo.toml | 2 +- 2 files changed, 117 insertions(+), 117 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9383ba0d20..60f7ec51ec 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1493,6 +1493,122 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" +[[package]] +name = "chat_cli" +version = "1.10.0" +dependencies = [ + "amzn-codewhisperer-client", + "amzn-codewhisperer-streaming-client", + "amzn-consolas-client", + "amzn-qdeveloper-streaming-client", + "amzn-toolkit-telemetry-client", + "anstream", + "arboard", + "assert_cmd", + "async-trait", + "aws-config", + "aws-credential-types", + "aws-runtime", + "aws-sdk-cognitoidentity", + "aws-sdk-ssooidc", + "aws-smithy-async", + "aws-smithy-runtime-api", + "aws-smithy-types", + "aws-types", + "base64 0.22.1", + "bitflags 2.9.0", + "bstr", + "bytes", + "camino", + "cfg-if", + "clap", + "clap_complete", + "clap_complete_fig", + "color-eyre", + "color-print", + "convert_case 0.8.0", + "cookie", + "criterion", + "crossterm", + "ctrlc", + "dialoguer", + "dirs 5.0.1", + "eyre", + "fd-lock", + "futures", + "glob", + "globset", + "hex", + "http 1.3.1", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "indicatif", + "indoc", + "insta", + "libc", + "mimalloc", + "mockito", + "nix 0.29.0", + "objc2 0.5.2", + "objc2-app-kit 0.2.2", + "objc2-foundation 0.2.2", + "owo-colors", + "parking_lot", + "paste", + "percent-encoding", + "predicates", + "prettyplease", + "quote", + "r2d2", + "r2d2_sqlite", + "rand 0.9.1", + "regex", + "reqwest", + "ring", + "rusqlite", + "rustls 0.23.26", + "rustls-native-certs 0.8.1", + "rustls-pemfile 2.2.0", + "rustyline", + "security-framework 3.2.0", + "self_update", + "semver", + "serde", + "serde_json", + "sha2", + "shell-color 1.0.0", + "shell-words", + "shellexpand", + "shlex", + "similar", + "skim", + "spinners", + "strip-ansi-escapes", + "strum 0.27.1", + "syn 2.0.101", + "syntect", + "sysinfo", + "tempfile", + "thiserror 2.0.12", + "time", + "tokio", + "tokio-tungstenite", + "tokio-util", + "toml", + "tracing", + "tracing-appender", + "tracing-subscriber", + "tracing-test", + "unicode-width 0.2.0", + "url", + "uuid", + "walkdir", + "webpki-roots", + "whoami", + "winnow 0.6.2", +] + [[package]] name = "chrono" version = "0.4.41" @@ -4849,122 +4965,6 @@ dependencies = [ "serde", ] -[[package]] -name = "kiro_cli" -version = "1.10.0" -dependencies = [ - "amzn-codewhisperer-client", - "amzn-codewhisperer-streaming-client", - "amzn-consolas-client", - "amzn-qdeveloper-streaming-client", - "amzn-toolkit-telemetry-client", - "anstream", - "arboard", - "assert_cmd", - "async-trait", - "aws-config", - "aws-credential-types", - "aws-runtime", - "aws-sdk-cognitoidentity", - "aws-sdk-ssooidc", - "aws-smithy-async", - "aws-smithy-runtime-api", - "aws-smithy-types", - "aws-types", - "base64 0.22.1", - "bitflags 2.9.0", - "bstr", - "bytes", - "camino", - "cfg-if", - "clap", - "clap_complete", - "clap_complete_fig", - "color-eyre", - "color-print", - "convert_case 0.8.0", - "cookie", - "criterion", - "crossterm", - "ctrlc", - "dialoguer", - "dirs 5.0.1", - "eyre", - "fd-lock", - "futures", - "glob", - "globset", - "hex", - "http 1.3.1", - "http-body-util", - "hyper 1.6.0", - "hyper-util", - "indicatif", - "indoc", - "insta", - "libc", - "mimalloc", - "mockito", - "nix 0.29.0", - "objc2 0.5.2", - "objc2-app-kit 0.2.2", - "objc2-foundation 0.2.2", - "owo-colors", - "parking_lot", - "paste", - "percent-encoding", - "predicates", - "prettyplease", - "quote", - "r2d2", - "r2d2_sqlite", - "rand 0.9.1", - "regex", - "reqwest", - "ring", - "rusqlite", - "rustls 0.23.26", - "rustls-native-certs 0.8.1", - "rustls-pemfile 2.2.0", - "rustyline", - "security-framework 3.2.0", - "self_update", - "semver", - "serde", - "serde_json", - "sha2", - "shell-color 1.0.0", - "shell-words", - "shellexpand", - "shlex", - "similar", - "skim", - "spinners", - "strip-ansi-escapes", - "strum 0.27.1", - "syn 2.0.101", - "syntect", - "sysinfo", - "tempfile", - "thiserror 2.0.12", - "time", - "tokio", - "tokio-tungstenite", - "tokio-util", - "toml", - "tracing", - "tracing-appender", - "tracing-subscriber", - "tracing-test", - "unicode-width 0.2.0", - "url", - "uuid", - "walkdir", - "webpki-roots", - "whoami", - "winnow 0.6.2", -] - [[package]] name = "kqueue" version = "1.0.8" diff --git a/crates/kiro-cli/Cargo.toml b/crates/kiro-cli/Cargo.toml index cc0e39aadc..b8668d6841 100644 --- a/crates/kiro-cli/Cargo.toml +++ b/crates/kiro-cli/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "kiro_cli" +name = "chat_cli" authors.workspace = true edition.workspace = true homepage.workspace = true