Skip to content

Commit 2821d66

Browse files
committed
fix: stop spinning up multiple servers
The gob encoding implementation doesn't handle map key ordering so the hash we used wasn't consistent. This change switches to using slices of strings for headers and env. Also, fix the linting errors. Signed-off-by: Donnie Adams <donnie@acorn.io>
1 parent 54b4c1f commit 2821d66

File tree

2 files changed

+39
-30
lines changed

2 files changed

+39
-30
lines changed

pkg/loader/openapi_test.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func TestLoadOpenAPI(t *testing.T) {
2626
}
2727
datav3, err := os.ReadFile("testdata/openapi_v3.yaml")
2828
require.NoError(t, err)
29-
_, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "")
29+
_, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv3, &source{Content: datav3}, "", "")
3030
require.NoError(t, err, "failed to read openapi v3")
3131
require.Equal(t, 3, numOpenAPITools(prgv3.ToolSet), "expected 3 openapi tools")
3232

@@ -35,7 +35,7 @@ func TestLoadOpenAPI(t *testing.T) {
3535
}
3636
datav2, err := os.ReadFile("testdata/openapi_v2.json")
3737
require.NoError(t, err)
38-
_, err = readTool(context.Background(), nil, &prgv2json, &source{Content: datav2}, "", "")
38+
_, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv2json, &source{Content: datav2}, "", "")
3939
require.NoError(t, err, "failed to read openapi v2")
4040
require.Equal(t, 3, numOpenAPITools(prgv2json.ToolSet), "expected 3 openapi tools")
4141

@@ -44,7 +44,7 @@ func TestLoadOpenAPI(t *testing.T) {
4444
}
4545
datav2, err = os.ReadFile("testdata/openapi_v2.yaml")
4646
require.NoError(t, err)
47-
_, err = readTool(context.Background(), nil, &prgv2yaml, &source{Content: datav2}, "", "")
47+
_, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv2yaml, &source{Content: datav2}, "", "")
4848
require.NoError(t, err, "failed to read openapi v2 (yaml)")
4949
require.Equal(t, 3, numOpenAPITools(prgv2yaml.ToolSet), "expected 3 openapi tools")
5050

@@ -57,7 +57,7 @@ func TestOpenAPIv3(t *testing.T) {
5757
}
5858
datav3, err := os.ReadFile("testdata/openapi_v3.yaml")
5959
require.NoError(t, err)
60-
_, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "")
60+
_, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv3, &source{Content: datav3}, "", "")
6161
require.NoError(t, err)
6262

6363
autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi"))
@@ -69,7 +69,7 @@ func TestOpenAPIv3NoOperationIDs(t *testing.T) {
6969
}
7070
datav3, err := os.ReadFile("testdata/openapi_v3_no_operation_ids.yaml")
7171
require.NoError(t, err)
72-
_, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "")
72+
_, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv3, &source{Content: datav3}, "", "")
7373
require.NoError(t, err)
7474

7575
autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi"))
@@ -81,7 +81,7 @@ func TestOpenAPIv2(t *testing.T) {
8181
}
8282
datav2, err := os.ReadFile("testdata/openapi_v2.yaml")
8383
require.NoError(t, err)
84-
_, err = readTool(context.Background(), nil, &prgv2, &source{Content: datav2}, "", "")
84+
_, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv2, &source{Content: datav2}, "", "")
8585
require.NoError(t, err)
8686

8787
autogold.ExpectFile(t, prgv2.ToolSet, autogold.Dir("testdata/openapi"))
@@ -94,7 +94,7 @@ func TestOpenAPIv3Revamp(t *testing.T) {
9494
}
9595
datav3, err := os.ReadFile("testdata/openapi_v3.yaml")
9696
require.NoError(t, err)
97-
_, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "")
97+
_, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv3, &source{Content: datav3}, "", "")
9898
require.NoError(t, err)
9999

100100
autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi"))
@@ -107,7 +107,7 @@ func TestOpenAPIv3NoOperationIDsRevamp(t *testing.T) {
107107
}
108108
datav3, err := os.ReadFile("testdata/openapi_v3_no_operation_ids.yaml")
109109
require.NoError(t, err)
110-
_, err = readTool(context.Background(), nil, &prgv3, &source{Content: datav3}, "", "")
110+
_, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv3, &source{Content: datav3}, "", "")
111111
require.NoError(t, err)
112112

113113
autogold.ExpectFile(t, prgv3.ToolSet, autogold.Dir("testdata/openapi"))
@@ -120,8 +120,14 @@ func TestOpenAPIv2Revamp(t *testing.T) {
120120
}
121121
datav2, err := os.ReadFile("testdata/openapi_v2.yaml")
122122
require.NoError(t, err)
123-
_, err = readTool(context.Background(), nil, &prgv2, &source{Content: datav2}, "", "")
123+
_, err = readTool(context.Background(), nil, fakeMCPLoader{}, &prgv2, &source{Content: datav2}, "", "")
124124
require.NoError(t, err)
125125

126126
autogold.ExpectFile(t, prgv2.ToolSet, autogold.Dir("testdata/openapi"))
127127
}
128+
129+
type fakeMCPLoader struct{}
130+
131+
func (fakeMCPLoader) Load(context.Context, types.Tool) ([]types.Tool, error) {
132+
return nil, nil
133+
}

pkg/mcp/loader.go

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ var (
2323
)
2424

2525
type Local struct {
26-
nextID int64
2726
lock sync.Mutex
2827
sessions map[string]*Session
2928
}
@@ -39,15 +38,17 @@ type Config struct {
3938
MCPServers map[string]ServerConfig `json:"mcpServers"`
4039
}
4140

41+
// ServerConfig represents an MCP server configuration for tools calls.
42+
// It is important that this type doesn't have any maps.
4243
type ServerConfig struct {
43-
DisableInstruction bool `json:"disableInstruction"`
44-
Command string `json:"command"`
45-
Args []string `json:"args"`
46-
Env map[string]string `json:"env"`
47-
Server string `json:"server"`
48-
URL string `json:"url"`
49-
BaseURL string `json:"baseURL,omitempty"`
50-
Headers map[string]string `json:"headers"`
44+
DisableInstruction bool `json:"disableInstruction"`
45+
Command string `json:"command"`
46+
Args []string `json:"args"`
47+
Env []string `json:"env"`
48+
Server string `json:"server"`
49+
URL string `json:"url"`
50+
BaseURL string `json:"baseURL,omitempty"`
51+
Headers []string `json:"headers"`
5152
}
5253

5354
func (s *ServerConfig) GetBaseURL() string {
@@ -62,12 +63,12 @@ func (s *ServerConfig) GetBaseURL() string {
6263

6364
func (l *Local) Load(ctx context.Context, tool types.Tool) (result []types.Tool, _ error) {
6465
if !tool.IsMCP() {
65-
return []types.Tool{tool}, nil
66+
return nil, nil
6667
}
6768

6869
_, configData, _ := strings.Cut(tool.Instructions, "\n")
69-
var servers Config
7070

71+
var servers Config
7172
if err := json.Unmarshal([]byte(strings.TrimSpace(configData)), &servers); err != nil {
7273
return nil, fmt.Errorf("failed to parse MCP configuration: %w\n%s", err, configData)
7374
}
@@ -87,10 +88,10 @@ func (l *Local) Load(ctx context.Context, tool types.Tool) (result []types.Tool,
8788
}
8889

8990
if len(servers.MCPServers) > 1 {
90-
return nil, fmt.Errorf("only a single MCP server definition is support")
91+
return nil, fmt.Errorf("only a single MCP server definition is supported")
9192
}
9293

93-
for _, server := range slices.Sorted(maps.Keys(servers.MCPServers)) {
94+
for server := range maps.Keys(servers.MCPServers) {
9495
session, err := l.loadSession(ctx, servers.MCPServers[server])
9596
if err != nil {
9697
return nil, fmt.Errorf("failed to load MCP session for server %s: %w", server, err)
@@ -202,6 +203,7 @@ func (l *Local) loadSession(ctx context.Context, server ServerConfig) (*Session,
202203
l.lock.Lock()
203204
existing, ok := l.sessions[id]
204205
l.lock.Unlock()
206+
205207
if ok {
206208
return existing, nil
207209
}
@@ -210,13 +212,8 @@ func (l *Local) loadSession(ctx context.Context, server ServerConfig) (*Session,
210212
c client.MCPClient
211213
err error
212214
)
213-
214215
if server.Command != "" {
215-
env := make([]string, 0, len(server.Env))
216-
for k, v := range server.Env {
217-
env = append(env, fmt.Sprintf("%s=%s", k, v))
218-
}
219-
c, err = client.NewStdioMCPClient(server.Command, env, server.Args...)
216+
c, err = client.NewStdioMCPClient(server.Command, server.Env, server.Args...)
220217
if err != nil {
221218
return nil, fmt.Errorf("failed to create MCP stdio client: %w", err)
222219
}
@@ -225,7 +222,13 @@ func (l *Local) loadSession(ctx context.Context, server ServerConfig) (*Session,
225222
if url == "" {
226223
url = server.Server
227224
}
228-
c, err = client.NewSSEMCPClient(url, client.WithHeaders(server.Headers))
225+
226+
headers := make(map[string]string, len(server.Headers))
227+
for _, h := range server.Headers {
228+
k, v, _ := strings.Cut(h, "=")
229+
headers[k] = v
230+
}
231+
c, err = client.NewSSEMCPClient(url, client.WithHeaders(headers))
229232
if err != nil {
230233
return nil, fmt.Errorf("failed to create MCP HTTP client: %w", err)
231234
}
@@ -252,7 +255,7 @@ func (l *Local) loadSession(ctx context.Context, server ServerConfig) (*Session,
252255
l.lock.Lock()
253256
defer l.lock.Unlock()
254257

255-
if existing, ok := l.sessions[id]; ok {
258+
if existing, ok = l.sessions[id]; ok {
256259
return existing, c.Close()
257260
}
258261

0 commit comments

Comments
 (0)