Skip to content

Commit a8b15e1

Browse files
authored
feat: removes mcp servers from client config if autodiscovery is enabled (#133)
Signed-off-by: ChrisJBurns <29541485+ChrisJBurns@users.noreply.github.com>
1 parent de37d2b commit a8b15e1

File tree

4 files changed

+289
-12
lines changed

4 files changed

+289
-12
lines changed

cmd/thv/rm.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import (
77

88
"github.com/spf13/cobra"
99

10+
"github.com/stacklok/toolhive/pkg/client"
11+
"github.com/stacklok/toolhive/pkg/config"
1012
"github.com/stacklok/toolhive/pkg/container"
1113
"github.com/stacklok/toolhive/pkg/labels"
1214
"github.com/stacklok/toolhive/pkg/logger"
@@ -29,6 +31,7 @@ func init() {
2931
rmCmd.Flags().BoolVarP(&rmForce, "force", "f", false, "Force removal of a running container")
3032
}
3133

34+
//nolint:gocyclo // This function is complex but manageable
3235
func rmCmdFunc(_ *cobra.Command, args []string) error {
3336
// Get container name
3437
containerName := args[0]
@@ -101,5 +104,47 @@ func rmCmdFunc(_ *cobra.Command, args []string) error {
101104
}
102105

103106
logger.Log.Info(fmt.Sprintf("Container %s removed", containerName))
107+
108+
if shouldRemoveClientConfig() {
109+
if err := removeClientConfigurations(containerName); err != nil {
110+
logger.Log.Error(fmt.Sprintf("Warning: Failed to remove client configurations: %v", err))
111+
} else {
112+
logger.Log.Info(fmt.Sprintf("Client configurations for %s removed", containerName))
113+
}
114+
}
115+
116+
return nil
117+
}
118+
119+
func shouldRemoveClientConfig() bool {
120+
c := config.GetConfig()
121+
return len(c.Clients.RegisteredClients) > 0 && c.Clients.AutoDiscovery
122+
}
123+
124+
// updateClientConfigurations updates client configuration files with the MCP server URL
125+
func removeClientConfigurations(containerName string) error {
126+
// Find client configuration files
127+
configs, err := client.FindClientConfigs()
128+
if err != nil {
129+
return fmt.Errorf("failed to find client configurations: %w", err)
130+
}
131+
132+
if len(configs) == 0 {
133+
logger.Log.Info("No client configuration files found")
134+
return nil
135+
}
136+
137+
for _, c := range configs {
138+
logger.Log.Info(fmt.Sprintf("Removing MCP server from client configuration: %s", c.Path))
139+
140+
// Remove the MCP server configuration with locking
141+
if err := c.DeleteConfigWithLock(containerName, c.Editor); err != nil {
142+
logger.Log.Warn(fmt.Sprintf("Warning: Failed to remove MCP server from client configurationn %s: %v", c.Path, err))
143+
continue
144+
}
145+
146+
logger.Log.Info(fmt.Sprintf("Successfully removed MCP server from client configuration: %s", c.Path))
147+
}
148+
104149
return nil
105150
}

pkg/client/config.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,67 @@ func (c *ConfigFile) SaveWithLock(serverName, url string, editor ConfigEditor) e
336336
return nil
337337
}
338338

339+
// DeleteConfigWithLock safely removes the MCP server configuration in the file
340+
// It acquires a lock, reads the latest content, applies the change, and saves the file
341+
func (c *ConfigFile) DeleteConfigWithLock(serverName string, editor ConfigEditor) error {
342+
// Create a lock file
343+
fileLock := flock.New(c.Path + ".lock")
344+
345+
// Create a context with timeout
346+
ctx, cancel := context.WithTimeout(context.Background(), lockTimeout)
347+
defer cancel()
348+
349+
// Try to acquire the lock with a timeout
350+
locked, err := fileLock.TryLockContext(ctx, 100*time.Millisecond)
351+
if err != nil {
352+
return fmt.Errorf("failed to acquire lock: %w", err)
353+
}
354+
if !locked {
355+
return fmt.Errorf("failed to acquire lock: timeout after %v", lockTimeout)
356+
}
357+
defer fileLock.Unlock()
358+
359+
// Read the latest content from the file
360+
latestConfig, err := readConfigFile(c.Path)
361+
if err != nil {
362+
return fmt.Errorf("failed to read latest config: %w", err)
363+
}
364+
365+
// Apply our change to the latest content
366+
if err := editor.RemoveServer(&latestConfig, serverName); err != nil {
367+
return fmt.Errorf("failed to update latest config: %w", err)
368+
}
369+
370+
// Determine format based on file extension
371+
ext := strings.ToLower(filepath.Ext(c.Path))
372+
373+
var data []byte
374+
375+
if IsYAML(ext) {
376+
// Marshal YAML
377+
data, err = yaml.Marshal(latestConfig.Contents)
378+
if err != nil {
379+
return fmt.Errorf("failed to marshal YAML: %w", err)
380+
}
381+
} else {
382+
// Default to JSON
383+
data, err = json.MarshalIndent(latestConfig.Contents, "", " ")
384+
if err != nil {
385+
return fmt.Errorf("failed to marshal JSON: %w", err)
386+
}
387+
}
388+
389+
// Write file
390+
if err := os.WriteFile(c.Path, data, 0600); err != nil {
391+
return fmt.Errorf("failed to write file: %w", err)
392+
}
393+
394+
// Update our in-memory representation to match the file
395+
c.Contents = latestConfig.Contents
396+
397+
return nil
398+
}
399+
339400
// GenerateMCPServerURL generates the URL for an MCP server
340401
func GenerateMCPServerURL(host string, port int, containerName string) string {
341402
// The URL format is: http://host:port/sse#container-name

pkg/client/config_test.go

Lines changed: 111 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"path/filepath"
1010
"testing"
1111

12+
"github.com/stacklok/toolhive/pkg/logger"
1213
"github.com/stacklok/toolhive/pkg/transport/ssecommon"
1314
)
1415

@@ -62,7 +63,7 @@ func setupTestConfig(t *testing.T, testName string) (string, string, ConfigFile)
6263
}
6364

6465
// getMCPServers reads the config file and returns the mcpServers map
65-
func getMCPServers(t *testing.T, configPath string) map[string]interface{} {
66+
func getMCPServers(t *testing.T, configPath string, editor ConfigEditor) map[string]interface{} {
6667
t.Helper()
6768

6869
// Read the config file
@@ -71,27 +72,48 @@ func getMCPServers(t *testing.T, configPath string) map[string]interface{} {
7172
t.Fatalf("Failed to read updated config file: %v", err)
7273
}
7374

74-
// Check if the servers were updated correctly
75-
mcpServers, ok := updatedConfig.Contents["mcpServers"].(map[string]interface{})
76-
if !ok {
77-
t.Fatalf("mcpServers is not a map")
75+
_, vsOk := editor.(*VSCodeConfigEditor)
76+
if vsOk {
77+
updatedConfig.Editor = editor
78+
mcpMap, ok := updatedConfig.Contents["mcp"].(map[string]interface{})
79+
if !ok {
80+
t.Fatalf("mcp is not a map")
81+
}
82+
83+
// Get servers child object
84+
mcpServers, ok := mcpMap["servers"]
85+
if !ok {
86+
t.Fatalf("mcpServers is not a map")
87+
}
88+
return mcpServers.(map[string]interface{})
89+
}
90+
91+
_, standardOk := editor.(*StandardConfigEditor)
92+
if standardOk {
93+
mcpServers, ok := updatedConfig.Contents["mcpServers"].(map[string]interface{})
94+
if !ok {
95+
t.Fatalf("mcpServers is not a map")
96+
}
97+
return mcpServers
7898
}
7999

80-
return mcpServers
100+
return nil
81101
}
82102

83103
// testUpdateExistingServer tests updating an existing server
84104
func testUpdateExistingServer(t *testing.T, config ConfigFile, configPath string) {
85105
t.Helper()
86106
// Test updating an existing server with lock
87107
expectedURL := "http://localhost:54321" + ssecommon.HTTPSSEEndpoint + "#test-container"
88-
err := config.SaveWithLock("existing-server", expectedURL, &StandardConfigEditor{})
108+
editor := &StandardConfigEditor{}
109+
110+
err := config.SaveWithLock("existing-server", expectedURL, editor)
89111
if err != nil {
90112
t.Fatalf("Failed to update MCP server config: %v", err)
91113
}
92114

93115
// Get the updated servers
94-
mcpServers := getMCPServers(t, configPath)
116+
mcpServers := getMCPServers(t, configPath, editor)
95117

96118
// Check existing server
97119
existingServer, ok := mcpServers["existing-server"].(map[string]interface{})
@@ -112,13 +134,14 @@ func testAddNewServer(t *testing.T, config ConfigFile, configPath string) {
112134
t.Helper()
113135
// Test adding a new server with lock
114136
expectedURL := "http://localhost:9876" + ssecommon.HTTPSSEEndpoint + "#new-container"
115-
err := config.SaveWithLock("new-server", expectedURL, &StandardConfigEditor{})
137+
editor := &StandardConfigEditor{}
138+
err := config.SaveWithLock("new-server", expectedURL, editor)
116139
if err != nil {
117140
t.Fatalf("Failed to add new MCP server config: %v", err)
118141
}
119142

120143
// Get the updated servers
121-
mcpServers := getMCPServers(t, configPath)
144+
mcpServers := getMCPServers(t, configPath, editor)
122145

123146
// Check new server
124147
newServer, ok := mcpServers["new-server"].(map[string]interface{})
@@ -134,11 +157,53 @@ func testAddNewServer(t *testing.T, config ConfigFile, configPath string) {
134157
}
135158
}
136159

160+
// testAddNewServer tests adding a new server
161+
func testRemovingServer(t *testing.T, config ConfigFile, configPath string, editor ConfigEditor) {
162+
t.Helper()
163+
// Test adding a new server with lock
164+
expectedURL := "http://localhost:9876" + ssecommon.HTTPSSEEndpoint + "#new-container"
165+
err := config.SaveWithLock("new-server", expectedURL, editor)
166+
if err != nil {
167+
t.Fatalf("Failed to add new MCP server config: %v", err)
168+
}
169+
170+
// Get the updated servers
171+
mcpServers := getMCPServers(t, configPath, editor)
172+
173+
// Check new server
174+
newServer, ok := mcpServers["new-server"].(map[string]interface{})
175+
if !ok {
176+
t.Fatalf("new-server is not a map")
177+
}
178+
newURL, ok := newServer["url"].(string)
179+
if !ok {
180+
t.Fatalf("url is not a string")
181+
}
182+
if newURL != expectedURL {
183+
t.Fatalf("Unexpected URL for new-server: %s, expected: %s", newURL, expectedURL)
184+
}
185+
186+
// Remove the server
187+
err = config.DeleteConfigWithLock("new-server", editor)
188+
if err != nil {
189+
t.Fatalf("Failed to remove MCP server config: %v", err)
190+
}
191+
192+
mcpServersNew := getMCPServers(t, configPath, editor)
193+
194+
// Check that the server was removed
195+
_, ok = mcpServersNew["new-server"]
196+
if ok {
197+
t.Fatalf("new-server is still in the config")
198+
}
199+
}
200+
137201
// testPreserveExistingConfig tests that existing configurations are preserved
138202
func testPreserveExistingConfig(t *testing.T, configPath string) {
139203
t.Helper()
140204
// Get the updated servers
141-
mcpServers := getMCPServers(t, configPath)
205+
editor := &StandardConfigEditor{}
206+
mcpServers := getMCPServers(t, configPath, editor)
142207

143208
// Check postgres server (should be unchanged)
144209
postgresServer, ok := mcpServers["postgres"].(map[string]interface{})
@@ -213,6 +278,41 @@ func TestUpdateMCPServerConfig(t *testing.T) {
213278

214279
testPreserveExistingConfig(t, configPath)
215280
})
281+
282+
}
283+
284+
func TestRemoveMCPServerConfig(t *testing.T) {
285+
t.Parallel()
286+
287+
logger.Initialize()
288+
289+
t.Run("RemoveExistingServerStandardEditor", func(t *testing.T) {
290+
t.Parallel()
291+
292+
// Setup test environment for this subtest
293+
tempDir, configPath, config := setupTestConfig(t, "remove")
294+
t.Cleanup(func() {
295+
if err := os.RemoveAll(tempDir); err != nil {
296+
t.Logf("Failed to remove temp dir: %v", err)
297+
}
298+
})
299+
300+
testRemovingServer(t, config, configPath, &StandardConfigEditor{})
301+
})
302+
303+
t.Run("RemoveExistingServerVSCodeEditor", func(t *testing.T) {
304+
t.Parallel()
305+
306+
// Setup test environment for this subtest
307+
tempDir, configPath, config := setupTestConfig(t, "remove")
308+
t.Cleanup(func() {
309+
if err := os.RemoveAll(tempDir); err != nil {
310+
t.Logf("Failed to remove temp dir: %v", err)
311+
}
312+
})
313+
314+
testRemovingServer(t, config, configPath, &VSCodeConfigEditor{})
315+
})
216316
}
217317

218318
func TestGenerateMCPServerURL(t *testing.T) {

0 commit comments

Comments
 (0)