diff --git a/README.md b/README.md index a35a3ebe0..f047c3f47 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,6 @@ package main import ( "context" - "errors" "fmt" "github.com/mark3labs/mcp-go/mcp" @@ -538,7 +537,7 @@ For examples, see the [`examples/`](examples/) directory. ### Transports -MCP-Go supports stdio, SSE and streamable-HTTP transport layers. +MCP-Go supports stdio, SSE and streamable-HTTP transport layers. For SSE transport, you can use `SetConnectionLostHandler()` to detect and handle HTTP/2 idle timeout disconnections (NO_ERROR) for implementing reconnection logic. ### Session Management diff --git a/client/client.go b/client/client.go index dd0e31a01..cda7665ef 100644 --- a/client/client.go +++ b/client/client.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "slices" "sync" "sync/atomic" @@ -22,6 +23,8 @@ type Client struct { requestID atomic.Int64 clientCapabilities mcp.ClientCapabilities serverCapabilities mcp.ServerCapabilities + protocolVersion string + samplingHandler SamplingHandler } type ClientOption func(*Client) @@ -33,6 +36,21 @@ func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption { } } +// WithSamplingHandler sets the sampling handler for the client. +// When set, the client will declare sampling capability during initialization. +func WithSamplingHandler(handler SamplingHandler) ClientOption { + return func(c *Client) { + c.samplingHandler = handler + } +} + +// WithSession assumes a MCP Session has already been initialized +func WithSession() ClientOption { + return func(c *Client) { + c.initialized = true + } +} + // NewClient creates a new MCP client with the given transport. // Usage: // @@ -71,6 +89,12 @@ func (c *Client) Start(ctx context.Context) error { handler(notification) } }) + + // Set up request handler for bidirectional communication (e.g., sampling) + if bidirectional, ok := c.transport.(transport.BidirectionalInterface); ok { + bidirectional.SetRequestHandler(c.handleIncomingRequest) + } + return nil } @@ -89,6 +113,17 @@ func (c *Client) OnNotification( c.notifications = append(c.notifications, handler) } +// OnConnectionLost registers a handler function to be called when the connection is lost. +// This is useful for handling HTTP2 idle timeout disconnections that should not be treated as errors. +func (c *Client) OnConnectionLost(handler func(error)) { + type connectionLostSetter interface { + SetConnectionLostHandler(func(error)) + } + if setter, ok := c.transport.(connectionLostSetter); ok { + setter.SetConnectionLostHandler(handler) + } +} + // sendRequest sends a JSON-RPC request to the server and waits for a response. // Returns the raw JSON response message or an error if the request fails. func (c *Client) sendRequest( @@ -111,7 +146,7 @@ func (c *Client) sendRequest( response, err := c.transport.SendRequest(ctx, request) if err != nil { - return nil, fmt.Errorf("transport error: %w", err) + return nil, transport.NewError(err) } if response.Error != nil { @@ -127,6 +162,12 @@ func (c *Client) Initialize( ctx context.Context, request mcp.InitializeRequest, ) (*mcp.InitializeResult, error) { + // Merge client capabilities with sampling capability if handler is configured + capabilities := request.Params.Capabilities + if c.samplingHandler != nil { + capabilities.Sampling = &struct{}{} + } + // Ensure we send a params object with all required fields params := struct { ProtocolVersion string `json:"protocolVersion"` @@ -135,7 +176,7 @@ func (c *Client) Initialize( }{ ProtocolVersion: request.Params.ProtocolVersion, ClientInfo: request.Params.ClientInfo, - Capabilities: request.Params.Capabilities, // Will be empty struct if not set + Capabilities: capabilities, } response, err := c.sendRequest(ctx, "initialize", params) @@ -148,8 +189,19 @@ func (c *Client) Initialize( return nil, fmt.Errorf("failed to unmarshal response: %w", err) } - // Store serverCapabilities + // Validate protocol version + if !slices.Contains(mcp.ValidProtocolVersions, result.ProtocolVersion) { + return nil, mcp.UnsupportedProtocolVersionError{Version: result.ProtocolVersion} + } + + // Store serverCapabilities and protocol version c.serverCapabilities = result.Capabilities + c.protocolVersion = result.ProtocolVersion + + // Set protocol version on HTTP transports + if httpConn, ok := c.transport.(transport.HTTPConnection); ok { + httpConn.SetProtocolVersion(result.ProtocolVersion) + } // Send initialized notification notification := mcp.JSONRPCNotification{ @@ -398,6 +450,64 @@ func (c *Client) Complete( return &result, nil } +// handleIncomingRequest processes incoming requests from the server. +// This is the main entry point for server-to-client requests like sampling. +func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + switch request.Method { + case string(mcp.MethodSamplingCreateMessage): + return c.handleSamplingRequestTransport(ctx, request) + default: + return nil, fmt.Errorf("unsupported request method: %s", request.Method) + } +} + +// handleSamplingRequestTransport handles sampling requests at the transport level. +func (c *Client) handleSamplingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + if c.samplingHandler == nil { + return nil, fmt.Errorf("no sampling handler configured") + } + + // Parse the request parameters + var params mcp.CreateMessageParams + if request.Params != nil { + paramsBytes, err := json.Marshal(request.Params) + if err != nil { + return nil, fmt.Errorf("failed to marshal params: %w", err) + } + if err := json.Unmarshal(paramsBytes, ¶ms); err != nil { + return nil, fmt.Errorf("failed to unmarshal params: %w", err) + } + } + + // Create the MCP request + mcpRequest := mcp.CreateMessageRequest{ + Request: mcp.Request{ + Method: string(mcp.MethodSamplingCreateMessage), + }, + CreateMessageParams: params, + } + + // Call the sampling handler + result, err := c.samplingHandler.CreateMessage(ctx, mcpRequest) + if err != nil { + return nil, err + } + + // Marshal the result + resultBytes, err := json.Marshal(result) + if err != nil { + return nil, fmt.Errorf("failed to marshal result: %w", err) + } + + // Create the transport response + response := &transport.JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Result: json.RawMessage(resultBytes), + } + + return response, nil +} func listByPage[T any]( ctx context.Context, client *Client, @@ -432,3 +542,17 @@ func (c *Client) GetServerCapabilities() mcp.ServerCapabilities { func (c *Client) GetClientCapabilities() mcp.ClientCapabilities { return c.clientCapabilities } + +// GetSessionId returns the session ID of the transport. +// If the transport does not support sessions, it returns an empty string. +func (c *Client) GetSessionId() string { + if c.transport == nil { + return "" + } + return c.transport.GetSessionId() +} + +// IsInitialized returns true if the client has been initialized. +func (c *Client) IsInitialized() bool { + return c.initialized +} diff --git a/client/http.go b/client/http.go index cb3be35d6..d001a1e63 100644 --- a/client/http.go +++ b/client/http.go @@ -13,5 +13,10 @@ func NewStreamableHttpClient(baseURL string, options ...transport.StreamableHTTP if err != nil { return nil, fmt.Errorf("failed to create SSE transport: %w", err) } - return NewClient(trans), nil + clientOptions := make([]ClientOption, 0) + sessionID := trans.GetSessionId() + if sessionID != "" { + clientOptions = append(clientOptions, WithSession()) + } + return NewClient(trans, clientOptions...), nil } diff --git a/client/http_test.go b/client/http_test.go index 3c2e6a3b7..514004857 100644 --- a/client/http_test.go +++ b/client/http_test.go @@ -3,10 +3,14 @@ package client import ( "context" "fmt" - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" + "sync" "testing" "time" + + "github.com/google/uuid" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" ) func TestHTTPClient(t *testing.T) { @@ -47,20 +51,47 @@ func TestHTTPClient(t *testing.T) { return nil, fmt.Errorf("failed to send notification: %w", err) } - return &mcp.CallToolResult{ - Content: []mcp.Content{ - mcp.TextContent{ - Type: "text", - Text: "notification sent successfully", - }, - }, - }, nil + return mcp.NewToolResultText("notification sent successfully"), nil }, ) + addServerToolfunc := func(name string) { + mcpServer.AddTool( + mcp.NewTool(name), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + server := server.ServerFromContext(ctx) + server.SendNotificationToAllClients("helloToEveryone", map[string]any{ + "message": "hello", + }) + return mcp.NewToolResultText("done"), nil + }, + ) + } + testServer := server.NewTestStreamableHTTPServer(mcpServer) defer testServer.Close() + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "test-client2", + Version: "1.0.0", + }, + }, + } + + t.Run("Can Configure a server with a pre-existing session", func(t *testing.T) { + sessionID := uuid.NewString() + client, err := NewStreamableHttpClient(testServer.URL, transport.WithSession(sessionID)) + if err != nil { + t.Fatalf("create client failed %v", err) + } + if client.IsInitialized() != true { + t.Fatalf("Client is not initialized") + } + }) + t.Run("Can receive notification from server", func(t *testing.T) { client, err := NewStreamableHttpClient(testServer.URL) if err != nil { @@ -68,9 +99,9 @@ func TestHTTPClient(t *testing.T) { return } - notificationNum := 0 + notificationNum := NewSafeMap() client.OnNotification(func(notification mcp.JSONRPCNotification) { - notificationNum += 1 + notificationNum.Increment(notification.Method) }) ctx := context.Background() @@ -81,31 +112,122 @@ func TestHTTPClient(t *testing.T) { } // Initialize - initRequest := mcp.InitializeRequest{} - initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "test-client", - Version: "1.0.0", - } - _, err = client.Initialize(ctx, initRequest) if err != nil { t.Fatalf("Failed to initialize: %v\n", err) } - request := mcp.CallToolRequest{} - request.Params.Name = "notify" - result, err := client.CallTool(ctx, request) - if err != nil { - t.Fatalf("CallTool failed: %v", err) - } + t.Run("Can receive notifications related to the request", func(t *testing.T) { + request := mcp.CallToolRequest{} + request.Params.Name = "notify" + result, err := client.CallTool(ctx, request) + if err != nil { + t.Fatalf("CallTool failed: %v", err) + } - if len(result.Content) != 1 { - t.Errorf("Expected 1 content item, got %d", len(result.Content)) - } + if len(result.Content) != 1 { + t.Errorf("Expected 1 content item, got %d", len(result.Content)) + } + + if n := notificationNum.Get("notifications/progress"); n != 1 { + t.Errorf("Expected 1 progross notification item, got %d", n) + } + if n := notificationNum.Len(); n != 1 { + t.Errorf("Expected 1 type of notification, got %d", n) + } + }) + + t.Run("Can not receive global notifications from server by default", func(t *testing.T) { + addServerToolfunc("hello1") + time.Sleep(time.Millisecond * 50) + + helloNotifications := notificationNum.Get("hello1") + if helloNotifications != 0 { + t.Errorf("Expected 0 notification item, got %d", helloNotifications) + } + }) + + t.Run("Can receive global notifications from server when WithContinuousListening enabled", func(t *testing.T) { + + client, err := NewStreamableHttpClient(testServer.URL, + transport.WithContinuousListening()) + if err != nil { + t.Fatalf("create client failed %v", err) + return + } + defer client.Close() + + notificationNum := NewSafeMap() + client.OnNotification(func(notification mcp.JSONRPCNotification) { + notificationNum.Increment(notification.Method) + }) + + ctx := context.Background() + + if err := client.Start(ctx); err != nil { + t.Fatalf("Failed to start client: %v", err) + return + } + + // Initialize + _, err = client.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v\n", err) + } + + // can receive normal notification + request := mcp.CallToolRequest{} + request.Params.Name = "notify" + _, err = client.CallTool(ctx, request) + if err != nil { + t.Fatalf("CallTool failed: %v", err) + } + + if n := notificationNum.Get("notifications/progress"); n != 1 { + t.Errorf("Expected 1 progross notification item, got %d", n) + } + if n := notificationNum.Len(); n != 1 { + t.Errorf("Expected 1 type of notification, got %d", n) + } + + // can receive global notification + addServerToolfunc("hello2") + time.Sleep(time.Millisecond * 50) // wait for the notification to be sent as upper action is async + + n := notificationNum.Get("notifications/tools/list_changed") + if n != 1 { + t.Errorf("Expected 1 notification item, got %d, %v", n, notificationNum) + } + }) - if notificationNum != 1 { - t.Errorf("Expected 1 notification item, got %d", notificationNum) - } }) } + +type SafeMap struct { + mu sync.RWMutex + data map[string]int +} + +func NewSafeMap() *SafeMap { + return &SafeMap{ + data: make(map[string]int), + } +} + +func (sm *SafeMap) Increment(key string) { + sm.mu.Lock() + defer sm.mu.Unlock() + sm.data[key]++ +} + +func (sm *SafeMap) Get(key string) int { + sm.mu.RLock() + defer sm.mu.RUnlock() + return sm.data[key] +} + +func (sm *SafeMap) Len() int { + sm.mu.RLock() + defer sm.mu.RUnlock() + return len(sm.data) +} diff --git a/client/inprocess.go b/client/inprocess.go index 5d8559de2..62d28794d 100644 --- a/client/inprocess.go +++ b/client/inprocess.go @@ -1,7 +1,10 @@ package client import ( + "context" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" ) @@ -10,3 +13,26 @@ func NewInProcessClient(server *server.MCPServer) (*Client, error) { inProcessTransport := transport.NewInProcessTransport(server) return NewClient(inProcessTransport), nil } + +// NewInProcessClientWithSamplingHandler creates an in-process client with sampling support +func NewInProcessClientWithSamplingHandler(server *server.MCPServer, handler SamplingHandler) (*Client, error) { + // Create a wrapper that implements server.SamplingHandler + serverHandler := &inProcessSamplingHandlerWrapper{handler: handler} + + inProcessTransport := transport.NewInProcessTransportWithOptions(server, + transport.WithSamplingHandler(serverHandler)) + + client := NewClient(inProcessTransport) + client.samplingHandler = handler + + return client, nil +} + +// inProcessSamplingHandlerWrapper wraps client.SamplingHandler to implement server.SamplingHandler +type inProcessSamplingHandlerWrapper struct { + handler SamplingHandler +} + +func (w *inProcessSamplingHandlerWrapper) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + return w.handler.CreateMessage(ctx, request) +} diff --git a/client/inprocess_sampling_test.go b/client/inprocess_sampling_test.go new file mode 100644 index 000000000..087109e43 --- /dev/null +++ b/client/inprocess_sampling_test.go @@ -0,0 +1,148 @@ +package client + +import ( + "context" + "testing" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// MockSamplingHandler implements SamplingHandler for testing +type MockSamplingHandler struct{} + +func (h *MockSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + return &mcp.CreateMessageResult{ + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: "Mock response from sampling handler", + }, + }, + Model: "mock-model", + StopReason: "endTurn", + }, nil +} + +func TestInProcessSampling(t *testing.T) { + // Create server with sampling enabled + mcpServer := server.NewMCPServer("test-server", "1.0.0") + mcpServer.EnableSampling() + + // Add a tool that uses sampling + mcpServer.AddTool(mcp.Tool{ + Name: "test_sampling", + Description: "Test sampling functionality", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "message": map[string]any{ + "type": "string", + "description": "Message to send to LLM", + }, + }, + Required: []string{"message"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + message, err := request.RequireString("message") + if err != nil { + return nil, err + } + + // Create sampling request + samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: message, + }, + }, + }, + MaxTokens: 100, + Temperature: 0.7, + }, + } + + // Request sampling from client + result, err := mcpServer.RequestSampling(ctx, samplingRequest) + if err != nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "Sampling failed: " + err.Error(), + }, + }, + IsError: true, + }, nil + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "Sampling result: " + result.Content.(mcp.TextContent).Text, + }, + }, + }, nil + }) + + // Create client with sampling handler + mockHandler := &MockSamplingHandler{} + client, err := NewInProcessClientWithSamplingHandler(mcpServer, mockHandler) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Start the client + ctx := context.Background() + if err := client.Start(ctx); err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + } + + _, err = client.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Failed to initialize: %v", err) + } + + // Call the tool that uses sampling + result, err := client.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "test_sampling", + Arguments: map[string]any{ + "message": "Hello, world!", + }, + }, + }) + if err != nil { + t.Fatalf("Tool call failed: %v", err) + } + + // Verify the result contains the mock response + if len(result.Content) == 0 { + t.Fatal("Expected content in result") + } + + textContent, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatal("Expected text content") + } + + expectedText := "Sampling result: Mock response from sampling handler" + if textContent.Text != expectedText { + t.Errorf("Expected %q, got %q", expectedText, textContent.Text) + } +} diff --git a/client/protocol_negotiation_test.go b/client/protocol_negotiation_test.go new file mode 100644 index 000000000..022b7fc6d --- /dev/null +++ b/client/protocol_negotiation_test.go @@ -0,0 +1,231 @@ +package client + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "testing" + + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// mockProtocolTransport implements transport.Interface for testing protocol negotiation +type mockProtocolTransport struct { + responses map[string]string + notificationHandler func(mcp.JSONRPCNotification) + started bool + closed bool +} + +func (m *mockProtocolTransport) Start(ctx context.Context) error { + m.started = true + return nil +} + +func (m *mockProtocolTransport) SendRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + responseStr, ok := m.responses[request.Method] + if !ok { + return nil, fmt.Errorf("no mock response for method %s", request.Method) + } + + return &transport.JSONRPCResponse{ + JSONRPC: "2.0", + ID: request.ID, + Result: json.RawMessage(responseStr), + }, nil +} + +func (m *mockProtocolTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { + return nil +} + +func (m *mockProtocolTransport) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) { + m.notificationHandler = handler +} + +func (m *mockProtocolTransport) Close() error { + m.closed = true + return nil +} + +func (m *mockProtocolTransport) GetSessionId() string { + return "mock-session" +} + +func TestProtocolVersionNegotiation(t *testing.T) { + tests := []struct { + name string + serverVersion string + expectError bool + errorContains string + }{ + { + name: "supported latest version", + serverVersion: mcp.LATEST_PROTOCOL_VERSION, + expectError: false, + }, + { + name: "supported older version 2025-03-26", + serverVersion: "2025-03-26", + expectError: false, + }, + { + name: "supported older version 2024-11-05", + serverVersion: "2024-11-05", + expectError: false, + }, + { + name: "unsupported version", + serverVersion: "2023-01-01", + expectError: true, + errorContains: "unsupported protocol version", + }, + { + name: "unsupported future version", + serverVersion: "2030-01-01", + expectError: true, + errorContains: "unsupported protocol version", + }, + { + name: "empty protocol version", + serverVersion: "", + expectError: true, + errorContains: "unsupported protocol version", + }, + { + name: "malformed protocol version - invalid format", + serverVersion: "not-a-date", + expectError: true, + errorContains: "unsupported protocol version", + }, + { + name: "malformed protocol version - partial date", + serverVersion: "2025-06", + expectError: true, + errorContains: "unsupported protocol version", + }, + { + name: "malformed protocol version - just numbers", + serverVersion: "20250618", + expectError: true, + errorContains: "unsupported protocol version", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create mock transport that returns specific version + mockTransport := &mockProtocolTransport{ + responses: map[string]string{ + "initialize": fmt.Sprintf(`{ + "protocolVersion": "%s", + "capabilities": {}, + "serverInfo": {"name": "test", "version": "1.0"} + }`, tt.serverVersion), + }, + } + + client := NewClient(mockTransport) + + _, err := client.Initialize(context.Background(), mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{Name: "test-client", Version: "1.0"}, + Capabilities: mcp.ClientCapabilities{}, + }, + }) + + if tt.expectError { + if err == nil { + t.Errorf("expected error but got none") + } else if !strings.Contains(err.Error(), tt.errorContains) { + t.Errorf("expected error containing %q, got %q", tt.errorContains, err.Error()) + } + // Verify it's the correct error type + if !mcp.IsUnsupportedProtocolVersion(err) { + t.Errorf("expected UnsupportedProtocolVersionError, got %T", err) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + // Verify the protocol version was stored + if client.protocolVersion != tt.serverVersion { + t.Errorf("expected protocol version %q, got %q", tt.serverVersion, client.protocolVersion) + } + } + }) + } +} + +// mockHTTPTransport implements both transport.Interface and transport.HTTPConnection +type mockHTTPTransport struct { + mockProtocolTransport + protocolVersion string +} + +func (m *mockHTTPTransport) SetProtocolVersion(version string) { + m.protocolVersion = version +} + +func TestProtocolVersionHeaderSetting(t *testing.T) { + // Create mock HTTP transport + mockTransport := &mockHTTPTransport{ + mockProtocolTransport: mockProtocolTransport{ + responses: map[string]string{ + "initialize": fmt.Sprintf(`{ + "protocolVersion": "%s", + "capabilities": {}, + "serverInfo": {"name": "test", "version": "1.0"} + }`, mcp.LATEST_PROTOCOL_VERSION), + }, + }, + } + + client := NewClient(mockTransport) + + _, err := client.Initialize(context.Background(), mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{Name: "test-client", Version: "1.0"}, + Capabilities: mcp.ClientCapabilities{}, + }, + }) + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Verify SetProtocolVersion was called on HTTP transport + if mockTransport.protocolVersion != mcp.LATEST_PROTOCOL_VERSION { + t.Errorf("expected SetProtocolVersion to be called with %q, got %q", + mcp.LATEST_PROTOCOL_VERSION, mockTransport.protocolVersion) + } +} + +func TestUnsupportedProtocolVersionError_Is(t *testing.T) { + // Test that errors.Is works correctly with UnsupportedProtocolVersionError + err1 := mcp.UnsupportedProtocolVersionError{Version: "2023-01-01"} + err2 := mcp.UnsupportedProtocolVersionError{Version: "2024-01-01"} + + // Test Is method + if !err1.Is(err2) { + t.Error("expected UnsupportedProtocolVersionError.Is to return true for same error type") + } + + // Test with different error type + otherErr := fmt.Errorf("some other error") + if err1.Is(otherErr) { + t.Error("expected UnsupportedProtocolVersionError.Is to return false for different error type") + } + + // Test IsUnsupportedProtocolVersion helper + if !mcp.IsUnsupportedProtocolVersion(err1) { + t.Error("expected IsUnsupportedProtocolVersion to return true") + } + if mcp.IsUnsupportedProtocolVersion(otherErr) { + t.Error("expected IsUnsupportedProtocolVersion to return false for different error type") + } +} diff --git a/client/sampling.go b/client/sampling.go new file mode 100644 index 000000000..245e2c1f7 --- /dev/null +++ b/client/sampling.go @@ -0,0 +1,20 @@ +package client + +import ( + "context" + + "github.com/mark3labs/mcp-go/mcp" +) + +// SamplingHandler defines the interface for handling sampling requests from servers. +// Clients can implement this interface to provide LLM sampling capabilities to servers. +type SamplingHandler interface { + // CreateMessage handles a sampling request from the server and returns the generated message. + // The implementation should: + // 1. Validate the request parameters + // 2. Optionally prompt the user for approval (human-in-the-loop) + // 3. Select an appropriate model based on preferences + // 4. Generate the response using the selected model + // 5. Return the result with model information and stop reason + CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) +} diff --git a/client/sampling_test.go b/client/sampling_test.go new file mode 100644 index 000000000..60f533221 --- /dev/null +++ b/client/sampling_test.go @@ -0,0 +1,274 @@ +package client + +import ( + "context" + "encoding/json" + "testing" + + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// mockSamplingHandler implements SamplingHandler for testing +type mockSamplingHandler struct { + result *mcp.CreateMessageResult + err error +} + +func (m *mockSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + if m.err != nil { + return nil, m.err + } + return m.result, nil +} + +func TestClient_HandleSamplingRequest(t *testing.T) { + tests := []struct { + name string + handler SamplingHandler + expectedError string + }{ + { + name: "no handler configured", + handler: nil, + expectedError: "no sampling handler configured", + }, + { + name: "successful sampling", + handler: &mockSamplingHandler{ + result: &mcp.CreateMessageResult{ + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: "Hello, world!", + }, + }, + Model: "test-model", + StopReason: "endTurn", + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client := &Client{samplingHandler: tt.handler} + + request := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{Type: "text", Text: "Hello"}, + }, + }, + MaxTokens: 100, + }, + } + + result, err := client.handleIncomingRequest(context.Background(), mockJSONRPCRequest(request)) + + if tt.expectedError != "" { + if err == nil { + t.Errorf("expected error %q, got nil", tt.expectedError) + } else if err.Error() != tt.expectedError { + t.Errorf("expected error %q, got %q", tt.expectedError, err.Error()) + } + } else { + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if result == nil { + t.Error("expected result, got nil") + } + } + }) + } +} + +func TestWithSamplingHandler(t *testing.T) { + handler := &mockSamplingHandler{} + client := &Client{} + + option := WithSamplingHandler(handler) + option(client) + + if client.samplingHandler != handler { + t.Error("sampling handler not set correctly") + } +} + +// mockTransport implements transport.Interface for testing +type mockTransport struct { + requestChan chan transport.JSONRPCRequest + responseChan chan *transport.JSONRPCResponse + started bool +} + +func newMockTransport() *mockTransport { + return &mockTransport{ + requestChan: make(chan transport.JSONRPCRequest, 1), + responseChan: make(chan *transport.JSONRPCResponse, 1), + } +} + +func (m *mockTransport) Start(ctx context.Context) error { + m.started = true + return nil +} + +func (m *mockTransport) SendRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) { + m.requestChan <- request + select { + case response := <-m.responseChan: + return response, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (m *mockTransport) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { + return nil +} + +func (m *mockTransport) SetNotificationHandler(handler func(notification mcp.JSONRPCNotification)) { +} + +func (m *mockTransport) Close() error { + return nil +} + +func (m *mockTransport) GetSessionId() string { + return "mock-session-id" +} + +func TestClient_Initialize_WithSampling(t *testing.T) { + handler := &mockSamplingHandler{ + result: &mcp.CreateMessageResult{ + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: "Test response", + }, + }, + Model: "test-model", + StopReason: "endTurn", + }, + } + + // Create mock transport + mockTransport := newMockTransport() + + // Create client with sampling handler and mock transport + client := &Client{ + transport: mockTransport, + samplingHandler: handler, + } + + // Start the client + ctx := context.Background() + err := client.Start(ctx) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Prepare mock response for initialization + initResponse := &transport.JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: mcp.NewRequestId(1), + Result: []byte(`{"protocolVersion":"2024-11-05","capabilities":{"logging":{},"prompts":{},"resources":{},"tools":{}},"serverInfo":{"name":"test-server","version":"1.0.0"}}`), + } + + // Send the response in a goroutine + go func() { + mockTransport.responseChan <- initResponse + }() + + // Call Initialize with appropriate parameters + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, + Capabilities: mcp.ClientCapabilities{ + Roots: &struct { + ListChanged bool `json:"listChanged,omitempty"` + }{ + ListChanged: true, + }, + }, + }, + } + + result, err := client.Initialize(ctx, initRequest) + if err != nil { + t.Fatalf("Initialize failed: %v", err) + } + + // Verify the result + if result == nil { + t.Fatal("Initialize result should not be nil") + } + + // Verify that the request was sent through the transport + select { + case request := <-mockTransport.requestChan: + // Verify the request method + if request.Method != "initialize" { + t.Errorf("Expected method 'initialize', got '%s'", request.Method) + } + + // Verify the request has the correct structure + if request.Params == nil { + t.Fatal("Request params should not be nil") + } + + // Parse the params to verify sampling capability is included + paramsBytes, err := json.Marshal(request.Params) + if err != nil { + t.Fatalf("Failed to marshal request params: %v", err) + } + + var params struct { + ProtocolVersion string `json:"protocolVersion"` + ClientInfo mcp.Implementation `json:"clientInfo"` + Capabilities mcp.ClientCapabilities `json:"capabilities"` + } + + err = json.Unmarshal(paramsBytes, ¶ms) + if err != nil { + t.Fatalf("Failed to unmarshal request params: %v", err) + } + + // Verify sampling capability is included in the request + if params.Capabilities.Sampling == nil { + t.Error("Sampling capability should be included in initialization request when handler is configured") + } + + // Verify other expected fields + if params.ProtocolVersion != mcp.LATEST_PROTOCOL_VERSION { + t.Errorf("Expected protocol version '%s', got '%s'", mcp.LATEST_PROTOCOL_VERSION, params.ProtocolVersion) + } + + if params.ClientInfo.Name != "test-client" { + t.Errorf("Expected client name 'test-client', got '%s'", params.ClientInfo.Name) + } + + default: + t.Error("Expected initialization request to be sent through transport") + } +} + +// Helper function to create a mock JSON-RPC request for testing +func mockJSONRPCRequest(mcpRequest mcp.CreateMessageRequest) transport.JSONRPCRequest { + return transport.JSONRPCRequest{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: mcp.NewRequestId(1), + Method: string(mcp.MethodSamplingCreateMessage), + Params: mcpRequest.CreateMessageParams, + } +} diff --git a/client/sse.go b/client/sse.go index ae2ebcaf0..07512a9be 100644 --- a/client/sse.go +++ b/client/sse.go @@ -23,12 +23,10 @@ func WithHTTPClient(httpClient *http.Client) transport.ClientOption { // NewSSEMCPClient creates a new SSE-based MCP client with the given base URL. // Returns an error if the URL is invalid. func NewSSEMCPClient(baseURL string, options ...transport.ClientOption) (*Client, error) { - sseTransport, err := transport.NewSSE(baseURL, options...) if err != nil { return nil, fmt.Errorf("failed to create SSE transport: %w", err) } - return NewClient(sseTransport), nil } diff --git a/client/stdio.go b/client/stdio.go index 100c08a7c..199ec14c3 100644 --- a/client/stdio.go +++ b/client/stdio.go @@ -19,10 +19,26 @@ func NewStdioMCPClient( env []string, args ...string, ) (*Client, error) { + return NewStdioMCPClientWithOptions(command, env, args) +} + +// NewStdioMCPClientWithOptions creates a new stdio-based MCP client that communicates with a subprocess. +// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. +// Optional configuration functions can be provided to customize the transport before it starts, +// such as setting a custom command function. +// +// NOTICE: NewStdioMCPClientWithOptions automatically starts the underlying transport. +// Don't call the Start method manually. +// This is for backward compatibility. +func NewStdioMCPClientWithOptions( + command string, + env []string, + args []string, + opts ...transport.StdioOption, +) (*Client, error) { + stdioTransport := transport.NewStdioWithOptions(command, env, args, opts...) - stdioTransport := transport.NewStdio(command, env, args...) - err := stdioTransport.Start(context.Background()) - if err != nil { + if err := stdioTransport.Start(context.Background()); err != nil { return nil, fmt.Errorf("failed to start stdio transport: %w", err) } diff --git a/client/stdio_test.go b/client/stdio_test.go index b6faf9bfd..7eb6dd38a 100644 --- a/client/stdio_test.go +++ b/client/stdio_test.go @@ -12,6 +12,9 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + + "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" ) @@ -90,7 +93,7 @@ func TestStdioMCPClient(t *testing.T) { defer cancel() request := mcp.InitializeRequest{} - request.Params.ProtocolVersion = "1.0" + request.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION request.Params.ClientInfo = mcp.Implementation{ Name: "test-client", Version: "1.0.0", @@ -305,3 +308,43 @@ func TestStdioMCPClient(t *testing.T) { } }) } + +func TestStdio_NewStdioMCPClientWithOptions_CreatesAndStartsClient(t *testing.T) { + called := false + + fakeCmdFunc := func(ctx context.Context, command string, args []string, env []string) (*exec.Cmd, error) { + called = true + return exec.CommandContext(ctx, "echo", "started"), nil + } + + client, err := NewStdioMCPClientWithOptions( + "echo", + []string{"FOO=bar"}, + []string{"hello"}, + transport.WithCommandFunc(fakeCmdFunc), + ) + require.NoError(t, err) + require.NotNil(t, client) + t.Cleanup(func() { + _ = client.Close() + }) + require.True(t, called) +} + +func TestStdio_NewStdioMCPClientWithOptions_FailsToStart(t *testing.T) { + // Create a commandFunc that points to a nonexistent binary + badCmdFunc := func(ctx context.Context, command string, args []string, env []string) (*exec.Cmd, error) { + return exec.CommandContext(ctx, "/nonexistent/bar", args...), nil + } + + client, err := NewStdioMCPClientWithOptions( + "foo", + nil, + nil, + transport.WithCommandFunc(badCmdFunc), + ) + + require.Error(t, err) + require.EqualError(t, err, "failed to start stdio transport: failed to start command: fork/exec /nonexistent/bar: no such file or directory") + require.Nil(t, client) +} diff --git a/client/transport/constants.go b/client/transport/constants.go new file mode 100644 index 000000000..2fb503084 --- /dev/null +++ b/client/transport/constants.go @@ -0,0 +1,7 @@ +package transport + +// Common HTTP header constants used across transports +const ( + HeaderKeySessionID = "Mcp-Session-Id" + HeaderKeyProtocolVersion = "Mcp-Protocol-Version" +) diff --git a/client/transport/error.go b/client/transport/error.go new file mode 100644 index 000000000..1f029944a --- /dev/null +++ b/client/transport/error.go @@ -0,0 +1,22 @@ +package transport + +import "fmt" + +// Error wraps a low-level transport error in a concrete type. +type Error struct { + Err error +} + +func (e *Error) Error() string { + return fmt.Sprintf("transport error: %v", e.Err) +} + +func (e *Error) Unwrap() error { + return e.Err +} + +func NewError(err error) *Error { + return &Error{ + Err: err, + } +} diff --git a/client/transport/inprocess.go b/client/transport/inprocess.go index 90fc2fae1..59c70940b 100644 --- a/client/transport/inprocess.go +++ b/client/transport/inprocess.go @@ -11,19 +11,50 @@ import ( ) type InProcessTransport struct { - server *server.MCPServer + server *server.MCPServer + samplingHandler server.SamplingHandler + session *server.InProcessSession + sessionID string onNotification func(mcp.JSONRPCNotification) notifyMu sync.RWMutex } +type InProcessOption func(*InProcessTransport) + +func WithSamplingHandler(handler server.SamplingHandler) InProcessOption { + return func(t *InProcessTransport) { + t.samplingHandler = handler + } +} + func NewInProcessTransport(server *server.MCPServer) *InProcessTransport { return &InProcessTransport{ server: server, } } +func NewInProcessTransportWithOptions(server *server.MCPServer, opts ...InProcessOption) *InProcessTransport { + t := &InProcessTransport{ + server: server, + sessionID: server.GenerateInProcessSessionID(), + } + + for _, opt := range opts { + opt(t) + } + + return t +} + func (c *InProcessTransport) Start(ctx context.Context) error { + // Create and register session if we have a sampling handler + if c.samplingHandler != nil { + c.session = server.NewInProcessSession(c.sessionID, c.samplingHandler) + if err := c.server.RegisterSession(ctx, c.session); err != nil { + return fmt.Errorf("failed to register session: %w", err) + } + } return nil } @@ -34,6 +65,11 @@ func (c *InProcessTransport) SendRequest(ctx context.Context, request JSONRPCReq } requestBytes = append(requestBytes, '\n') + // Add session to context if available + if c.session != nil { + ctx = c.server.WithContext(ctx, c.session) + } + respMessage := c.server.HandleMessage(ctx, requestBytes) respByte, err := json.Marshal(respMessage) if err != nil { @@ -65,6 +101,13 @@ func (c *InProcessTransport) SetNotificationHandler(handler func(notification mc c.onNotification = handler } -func (*InProcessTransport) Close() error { +func (c *InProcessTransport) Close() error { + if c.session != nil { + c.server.UnregisterSession(context.Background(), c.sessionID) + } return nil } + +func (c *InProcessTransport) GetSessionId() string { + return "" +} diff --git a/client/transport/interface.go b/client/transport/interface.go index c83c7c65a..e6feeb742 100644 --- a/client/transport/interface.go +++ b/client/transport/interface.go @@ -29,6 +29,29 @@ type Interface interface { // Close the connection. Close() error + + // GetSessionId returns the session ID of the transport. + GetSessionId() string +} + +// RequestHandler defines a function that handles incoming requests from the server. +type RequestHandler func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) + +// BidirectionalInterface extends Interface to support incoming requests from the server. +// This is used for features like sampling where the server can send requests to the client. +type BidirectionalInterface interface { + Interface + + // SetRequestHandler sets the handler for incoming requests from the server. + // The handler should process the request and return a response. + SetRequestHandler(handler RequestHandler) +} + +// HTTPConnection is a Transport that runs over HTTP and supports +// protocol version headers. +type HTTPConnection interface { + Interface + SetProtocolVersion(version string) } type JSONRPCRequest struct { @@ -41,10 +64,10 @@ type JSONRPCRequest struct { type JSONRPCResponse struct { JSONRPC string `json:"jsonrpc"` ID mcp.RequestId `json:"id"` - Result json.RawMessage `json:"result"` + Result json.RawMessage `json:"result,omitempty"` Error *struct { Code int `json:"code"` Message string `json:"message"` Data json.RawMessage `json:"data"` - } `json:"error"` + } `json:"error,omitempty"` } diff --git a/client/transport/oauth.go b/client/transport/oauth.go index aebbd316e..b7c81bace 100644 --- a/client/transport/oauth.go +++ b/client/transport/oauth.go @@ -115,7 +115,9 @@ type OAuthHandler struct { metadataFetchErr error metadataOnce sync.Once baseURL string - expectedState string // Expected state value for CSRF protection + + mu sync.RWMutex // Protects expectedState + expectedState string // Expected state value for CSRF protection } // NewOAuthHandler creates a new OAuth handler @@ -263,9 +265,27 @@ func (h *OAuthHandler) SetBaseURL(baseURL string) { // GetExpectedState returns the expected state value (for testing purposes) func (h *OAuthHandler) GetExpectedState() string { + h.mu.RLock() + defer h.mu.RUnlock() return h.expectedState } +// SetExpectedState sets the expected state value. +// +// This can be useful if you cannot maintain an OAuthHandler +// instance throughout the authentication flow; for example, if +// the initialization and callback steps are handled in different +// requests. +// +// In such cases, this should be called with the state value generated +// during the initial authentication request (e.g. by GenerateState) +// and included in the authorization URL. +func (h *OAuthHandler) SetExpectedState(expectedState string) { + h.mu.Lock() + defer h.mu.Unlock() + h.expectedState = expectedState +} + // OAuthError represents a standard OAuth 2.0 error response type OAuthError struct { ErrorCode string `json:"error"` @@ -547,18 +567,21 @@ var ErrInvalidState = errors.New("invalid state parameter, possible CSRF attack" // ProcessAuthorizationResponse processes the authorization response and exchanges the code for a token func (h *OAuthHandler) ProcessAuthorizationResponse(ctx context.Context, code, state, codeVerifier string) error { // Validate the state parameter to prevent CSRF attacks - if h.expectedState == "" { + h.mu.Lock() + expectedState := h.expectedState + if expectedState == "" { + h.mu.Unlock() return errors.New("no expected state found, authorization flow may not have been initiated properly") } - if state != h.expectedState { + if state != expectedState { + h.mu.Unlock() return ErrInvalidState } // Clear the expected state after validation - defer func() { - h.expectedState = "" - }() + h.expectedState = "" + h.mu.Unlock() metadata, err := h.getServerMetadata(ctx) if err != nil { @@ -629,7 +652,7 @@ func (h *OAuthHandler) GetAuthorizationURL(ctx context.Context, state, codeChall } // Store the state for later validation - h.expectedState = state + h.SetExpectedState(state) params := url.Values{} params.Set("response_type", "code") diff --git a/client/transport/oauth_test.go b/client/transport/oauth_test.go index 24dec6eff..701beddc6 100644 --- a/client/transport/oauth_test.go +++ b/client/transport/oauth_test.go @@ -300,3 +300,96 @@ func TestOAuthHandler_ProcessAuthorizationResponse_StateValidation(t *testing.T) t.Errorf("Got ErrInvalidState when expected a different error for empty expected state") } } + +func TestOAuthHandler_SetExpectedState_CrossRequestScenario(t *testing.T) { + // Simulate the scenario where different OAuthHandler instances are used + // for initialization and callback steps (different HTTP request handlers) + + config := OAuthConfig{ + ClientID: "test-client", + RedirectURI: "http://localhost:8085/callback", + Scopes: []string{"mcp.read", "mcp.write"}, + TokenStore: NewMemoryTokenStore(), + AuthServerMetadataURL: "http://example.com/.well-known/oauth-authorization-server", + PKCEEnabled: true, + } + + // Step 1: First handler instance (initialization request) + // This simulates the handler that generates the authorization URL + handler1 := NewOAuthHandler(config) + + // Mock the server metadata for the first handler + handler1.serverMetadata = &AuthServerMetadata{ + Issuer: "http://example.com", + AuthorizationEndpoint: "http://example.com/authorize", + TokenEndpoint: "http://example.com/token", + } + + // Generate state and get authorization URL (this would typically be done in the init handler) + testState := "generated-state-value-123" + _, err := handler1.GetAuthorizationURL(context.Background(), testState, "test-code-challenge") + if err != nil { + // We expect this to fail since we're not actually connecting to a server, + // but it should still store the expected state + if !strings.Contains(err.Error(), "connection") && !strings.Contains(err.Error(), "dial") { + t.Errorf("Expected connection error, got: %v", err) + } + } + + // Verify the state was stored in the first handler + if handler1.GetExpectedState() != testState { + t.Errorf("Expected state %s to be stored in first handler, got %s", testState, handler1.GetExpectedState()) + } + + // Step 2: Second handler instance (callback request) + // This simulates a completely separate handler instance that would be created + // in a different HTTP request handler for processing the OAuth callback + handler2 := NewOAuthHandler(config) + + // Mock the server metadata for the second handler + handler2.serverMetadata = &AuthServerMetadata{ + Issuer: "http://example.com", + AuthorizationEndpoint: "http://example.com/authorize", + TokenEndpoint: "http://example.com/token", + } + + // Initially, the second handler has no expected state + if handler2.GetExpectedState() != "" { + t.Errorf("Expected second handler to have empty state initially, got %s", handler2.GetExpectedState()) + } + + // Step 3: Transfer the state from the first handler to the second + // This is the key functionality being tested - setting the expected state + // in a different handler instance + handler2.SetExpectedState(testState) + + // Verify the state was transferred correctly + if handler2.GetExpectedState() != testState { + t.Errorf("Expected state %s to be set in second handler, got %s", testState, handler2.GetExpectedState()) + } + + // Step 4: Test that state validation works correctly in the second handler + + // Test with correct state - should pass validation but fail at token exchange + // (since we're not actually running a real OAuth server) + err = handler2.ProcessAuthorizationResponse(context.Background(), "test-code", testState, "test-code-verifier") + if err == nil { + t.Errorf("Expected error due to token exchange failure, got nil") + } + // Should NOT be ErrInvalidState since the state matches + if errors.Is(err, ErrInvalidState) { + t.Errorf("Got ErrInvalidState with matching state, should have failed at token exchange instead") + } + + // Verify state was cleared after processing (even though token exchange failed) + if handler2.GetExpectedState() != "" { + t.Errorf("Expected state to be cleared after processing, got %s", handler2.GetExpectedState()) + } + + // Step 5: Test with wrong state after resetting + handler2.SetExpectedState("different-state-value") + err = handler2.ProcessAuthorizationResponse(context.Background(), "test-code", testState, "test-code-verifier") + if !errors.Is(err, ErrInvalidState) { + t.Errorf("Expected ErrInvalidState with wrong state, got %v", err) + } +} diff --git a/client/transport/sse.go b/client/transport/sse.go index b22ff62d4..70a391905 100644 --- a/client/transport/sse.go +++ b/client/transport/sse.go @@ -16,6 +16,7 @@ import ( "time" "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/util" ) // SSE implements the transport layer of the MCP protocol using Server-Sent Events (SSE). @@ -33,10 +34,14 @@ type SSE struct { endpointChan chan struct{} headers map[string]string headerFunc HTTPHeaderFunc + logger util.Logger - started atomic.Bool - closed atomic.Bool - cancelSSEStream context.CancelFunc + started atomic.Bool + closed atomic.Bool + cancelSSEStream context.CancelFunc + protocolVersion atomic.Value // string + onConnectionLost func(error) + connectionLostMu sync.RWMutex // OAuth support oauthHandler *OAuthHandler @@ -44,6 +49,13 @@ type SSE struct { type ClientOption func(*SSE) +// WithSSELogger sets a custom logger for the SSE client. +func WithSSELogger(logger util.Logger) ClientOption { + return func(sc *SSE) { + sc.logger = logger + } +} + func WithHeaders(headers map[string]string) ClientOption { return func(sc *SSE) { sc.headers = headers @@ -82,6 +94,7 @@ func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) { responses: make(map[string]chan *JSONRPCResponse), endpointChan: make(chan struct{}), headers: make(map[string]string), + logger: util.DefaultLogger(), } for _, opt := range options { @@ -101,7 +114,6 @@ func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) { // Start initiates the SSE connection to the server and waits for the endpoint information. // Returns an error if the connection fails or times out waiting for the endpoint. func (c *SSE) Start(ctx context.Context) error { - if c.started.Load() { return fmt.Errorf("has already started") } @@ -110,7 +122,6 @@ func (c *SSE) Start(ctx context.Context) error { c.cancelSSEStream = cancel req, err := http.NewRequestWithContext(ctx, "GET", c.baseURL.String(), nil) - if err != nil { return fmt.Errorf("failed to create request: %w", err) } @@ -203,8 +214,21 @@ func (c *SSE) readSSE(reader io.ReadCloser) { } break } + // Checking whether the connection was terminated due to NO_ERROR in HTTP2 based on RFC9113 + // Only handle NO_ERROR specially if onConnectionLost handler is set to maintain backward compatibility + if strings.Contains(err.Error(), "NO_ERROR") { + c.connectionLostMu.RLock() + handler := c.onConnectionLost + c.connectionLostMu.RUnlock() + + if handler != nil { + // This is not actually an error - HTTP2 idle timeout disconnection + handler(err) + return + } + } if !c.closed.Load() { - fmt.Printf("SSE stream error: %v\n", err) + c.logger.Errorf("SSE stream error: %v", err) } return } @@ -240,11 +264,11 @@ func (c *SSE) handleSSEEvent(event, data string) { case "endpoint": endpoint, err := c.baseURL.Parse(data) if err != nil { - fmt.Printf("Error parsing endpoint URL: %v\n", err) + c.logger.Errorf("Error parsing endpoint URL: %v", err) return } if endpoint.Host != c.baseURL.Host { - fmt.Printf("Endpoint origin does not match connection origin\n") + c.logger.Errorf("Endpoint origin does not match connection origin") return } c.endpoint = endpoint @@ -253,7 +277,7 @@ func (c *SSE) handleSSEEvent(event, data string) { case "message": var baseMessage JSONRPCResponse if err := json.Unmarshal([]byte(data), &baseMessage); err != nil { - fmt.Printf("Error unmarshaling message: %v\n", err) + c.logger.Errorf("Error unmarshaling message: %v", err) return } @@ -293,13 +317,18 @@ func (c *SSE) SetNotificationHandler(handler func(notification mcp.JSONRPCNotifi c.onNotification = handler } +func (c *SSE) SetConnectionLostHandler(handler func(error)) { + c.connectionLostMu.Lock() + defer c.connectionLostMu.Unlock() + c.onConnectionLost = handler +} + // SendRequest sends a JSON-RPC request to the server and waits for a response. // Returns the raw JSON response message or an error if the request fails. func (c *SSE) SendRequest( ctx context.Context, request JSONRPCRequest, ) (*JSONRPCResponse, error) { - if !c.started.Load() { return nil, fmt.Errorf("transport not started yet") } @@ -324,6 +353,12 @@ func (c *SSE) SendRequest( // Set headers req.Header.Set("Content-Type", "application/json") + // Set protocol version header if negotiated + if v := c.protocolVersion.Load(); v != nil { + if version, ok := v.(string); ok && version != "" { + req.Header.Set(HeaderKeyProtocolVersion, version) + } + } for k, v := range c.headers { req.Header.Set(k, v) } @@ -428,6 +463,17 @@ func (c *SSE) Close() error { return nil } +// GetSessionId returns the session ID of the transport. +// Since SSE does not maintain a session ID, it returns an empty string. +func (c *SSE) GetSessionId() string { + return "" +} + +// SetProtocolVersion sets the negotiated protocol version for this connection. +func (c *SSE) SetProtocolVersion(version string) { + c.protocolVersion.Store(version) +} + // SendNotification sends a JSON-RPC notification to the server without expecting a response. func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { if c.endpoint == nil { @@ -450,6 +496,12 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti } req.Header.Set("Content-Type", "application/json") + // Set protocol version header if negotiated + if v := c.protocolVersion.Load(); v != nil { + if version, ok := v.(string); ok && version != "" { + req.Header.Set(HeaderKeyProtocolVersion, version) + } + } // Set custom HTTP headers for k, v := range c.headers { req.Header.Set(k, v) diff --git a/client/transport/sse_test.go b/client/transport/sse_test.go index a672e02fe..31c70887f 100644 --- a/client/transport/sse_test.go +++ b/client/transport/sse_test.go @@ -4,17 +4,52 @@ import ( "context" "encoding/json" "errors" - "sync" - "testing" - "time" - "fmt" + "io" "net/http" "net/http/httptest" + "strings" + "sync" + "testing" + "time" "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" ) +// mockReaderWithError is a mock io.ReadCloser that simulates reading some data +// and then returning a specific error +type mockReaderWithError struct { + data []byte + err error + position int + closed bool +} + +func (m *mockReaderWithError) Read(p []byte) (n int, err error) { + if m.closed { + return 0, io.EOF + } + + if m.position >= len(m.data) { + return 0, m.err + } + + n = copy(p, m.data[m.position:]) + m.position += n + + if m.position >= len(m.data) { + return n, m.err + } + + return n, nil +} + +func (m *mockReaderWithError) Close() error { + m.closed = true + return nil +} + // startMockSSEEchoServer starts a test HTTP server that implements // a minimal SSE-based echo server for testing purposes. // It returns the server URL and a function to close the server. @@ -115,7 +150,6 @@ func startMockSSEEchoServer() (string, func()) { flush() } }() - }) // Create a router to handle different endpoints @@ -228,7 +262,6 @@ func TestSSE(t *testing.T) { }) t.Run("SendNotification & NotificationHandler", func(t *testing.T) { - var wg sync.WaitGroup notificationChan := make(chan mcp.JSONRPCNotification, 1) @@ -368,7 +401,6 @@ func TestSSE(t *testing.T) { }) t.Run("ResponseError", func(t *testing.T) { - // Prepare a request request := JSONRPCRequest{ JSONRPC: "2.0", @@ -408,9 +440,9 @@ func TestSSE(t *testing.T) { t.Run("SSEEventWithoutEventField", func(t *testing.T) { // Test that SSE events with only data field (no event field) are processed correctly // This tests the fix for issue #369 - + var messageReceived chan struct{} - + // Create a custom mock server that sends SSE events without event field sseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/event-stream") @@ -449,7 +481,7 @@ func TestSSE(t *testing.T) { messageHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusAccepted) - + // Signal that message was received close(messageReceived) }) @@ -508,6 +540,217 @@ func TestSSE(t *testing.T) { } }) + t.Run("NO_ERROR_WithoutConnectionLostHandler", func(t *testing.T) { + // Test that NO_ERROR without connection lost handler maintains backward compatibility + // When no connection lost handler is set, NO_ERROR should be treated as a regular error + + // Create a mock Reader that simulates NO_ERROR + mockReader := &mockReaderWithError{ + data: []byte("event: endpoint\ndata: /message\n\n"), + err: errors.New("connection closed: NO_ERROR"), + } + + // Create SSE transport + url, closeF := startMockSSEEchoServer() + defer closeF() + + trans, err := NewSSE(url) + if err != nil { + t.Fatal(err) + } + + // DO NOT set connection lost handler to test backward compatibility + + // Capture stderr to verify the error is printed (backward compatible behavior) + // Since we can't easily capture fmt.Printf output in tests, we'll just verify + // that the readSSE method returns without calling any handler + + // Directly test the readSSE method with our mock reader + go trans.readSSE(mockReader) + + // Wait for readSSE to complete + time.Sleep(100 * time.Millisecond) + + // The test passes if readSSE completes without panicking or hanging + // In backward compatibility mode, NO_ERROR should be treated as a regular error + t.Log("Backward compatibility test passed: NO_ERROR handled as regular error when no handler is set") + }) + + t.Run("NO_ERROR_ConnectionLost", func(t *testing.T) { + // Test that NO_ERROR in HTTP/2 connection loss is properly handled + // This test verifies that when a connection is lost in a way that produces + // an error message containing "NO_ERROR", the connection lost handler is called + + var connectionLostCalled bool + var connectionLostError error + var mu sync.Mutex + + // Create a mock Reader that simulates connection loss with NO_ERROR + mockReader := &mockReaderWithError{ + data: []byte("event: endpoint\ndata: /message\n\n"), + err: errors.New("http2: stream closed with error code NO_ERROR"), + } + + // Create SSE transport + url, closeF := startMockSSEEchoServer() + defer closeF() + + trans, err := NewSSE(url) + if err != nil { + t.Fatal(err) + } + + // Set connection lost handler + trans.SetConnectionLostHandler(func(err error) { + mu.Lock() + defer mu.Unlock() + connectionLostCalled = true + connectionLostError = err + }) + + // Directly test the readSSE method with our mock reader that simulates NO_ERROR + go trans.readSSE(mockReader) + + // Wait for connection lost handler to be called + timeout := time.After(1 * time.Second) + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-timeout: + t.Fatal("Connection lost handler was not called within timeout for NO_ERROR connection loss") + case <-ticker.C: + mu.Lock() + called := connectionLostCalled + err := connectionLostError + mu.Unlock() + + if called { + if err == nil { + t.Fatal("Expected connection lost error, got nil") + } + + // Verify that the error contains "NO_ERROR" string + if !strings.Contains(err.Error(), "NO_ERROR") { + t.Errorf("Expected error to contain 'NO_ERROR', got: %v", err) + } + + t.Logf("Connection lost handler called with NO_ERROR: %v", err) + return + } + } + } + }) + + t.Run("NO_ERROR_Handling", func(t *testing.T) { + // Test specific NO_ERROR string handling in readSSE method + // This tests the code path at line 209 where NO_ERROR is checked + + // Create a mock Reader that simulates an error containing "NO_ERROR" + mockReader := &mockReaderWithError{ + data: []byte("event: endpoint\ndata: /message\n\n"), + err: errors.New("connection closed: NO_ERROR"), + } + + // Create SSE transport + url, closeF := startMockSSEEchoServer() + defer closeF() + + trans, err := NewSSE(url) + if err != nil { + t.Fatal(err) + } + + var connectionLostCalled bool + var connectionLostError error + var mu sync.Mutex + + // Set connection lost handler to verify it's called for NO_ERROR + trans.SetConnectionLostHandler(func(err error) { + mu.Lock() + defer mu.Unlock() + connectionLostCalled = true + connectionLostError = err + }) + + // Directly test the readSSE method with our mock reader + go trans.readSSE(mockReader) + + // Wait for connection lost handler to be called + timeout := time.After(1 * time.Second) + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-timeout: + t.Fatal("Connection lost handler was not called within timeout for NO_ERROR") + case <-ticker.C: + mu.Lock() + called := connectionLostCalled + err := connectionLostError + mu.Unlock() + + if called { + if err == nil { + t.Fatal("Expected connection lost error with NO_ERROR, got nil") + } + + // Verify that the error contains "NO_ERROR" string + if !strings.Contains(err.Error(), "NO_ERROR") { + t.Errorf("Expected error to contain 'NO_ERROR', got: %v", err) + } + + t.Logf("Successfully handled NO_ERROR: %v", err) + return + } + } + } + }) + + t.Run("RegularError_DoesNotTriggerConnectionLost", func(t *testing.T) { + // Test that regular errors (not containing NO_ERROR) do not trigger connection lost handler + + // Create a mock Reader that simulates a regular error + mockReader := &mockReaderWithError{ + data: []byte("event: endpoint\ndata: /message\n\n"), + err: errors.New("regular connection error"), + } + + // Create SSE transport + url, closeF := startMockSSEEchoServer() + defer closeF() + + trans, err := NewSSE(url) + if err != nil { + t.Fatal(err) + } + + var connectionLostCalled bool + var mu sync.Mutex + + // Set connection lost handler - this should NOT be called for regular errors + trans.SetConnectionLostHandler(func(err error) { + mu.Lock() + defer mu.Unlock() + connectionLostCalled = true + }) + + // Directly test the readSSE method with our mock reader + go trans.readSSE(mockReader) + + // Wait and verify connection lost handler is NOT called + time.Sleep(200 * time.Millisecond) + + mu.Lock() + called := connectionLostCalled + mu.Unlock() + + if called { + t.Error("Connection lost handler should not be called for regular errors") + } + }) } func TestSSEErrors(t *testing.T) { @@ -624,4 +867,49 @@ func TestSSEErrors(t *testing.T) { } }) + t.Run("SSEStreamErrorLogging", func(t *testing.T) { + logChan := make(chan string, 10) + testLogger := &testLogger{logChan: logChan} + + sseHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming unsupported", http.StatusInternalServerError) + return + } + + fmt.Fprintf(w, "event: endpoint\ndata: %s\n\n", "/message") + flusher.Flush() + + fmt.Fprintf(w, "event: message\ndata: {invalid json}\n\n") + flusher.Flush() + + time.Sleep(50 * time.Millisecond) + }) + + testServer := httptest.NewServer(sseHandler) + t.Cleanup(testServer.Close) + + trans, err := NewSSE(testServer.URL, WithSSELogger(testLogger)) + require.NoError(t, err) + + // Start the transport + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + t.Cleanup(cancel) + + err = trans.Start(ctx) + require.NoError(t, err) + t.Cleanup(func() { _ = trans.Close() }) + + // Wait for the error log message about unmarshaling + select { + case logMsg := <-logChan: + if !strings.Contains(logMsg, "Error unmarshaling message") { + t.Errorf("Expected error log about unmarshaling message, got: %s", logMsg) + } + case <-time.After(3 * time.Second): + t.Fatal("Timeout waiting for error log message") + } + }) } diff --git a/client/transport/stdio.go b/client/transport/stdio.go index c300c405f..488164c79 100644 --- a/client/transport/stdio.go +++ b/client/transport/stdio.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "encoding/json" + "errors" "fmt" "io" "os" @@ -11,6 +12,7 @@ import ( "sync" "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/util" ) // Stdio implements the transport layer of the MCP protocol using stdio communication. @@ -23,6 +25,7 @@ type Stdio struct { env []string cmd *exec.Cmd + cmdFunc CommandFunc stdin io.WriteCloser stdout *bufio.Reader stderr io.ReadCloser @@ -31,6 +34,36 @@ type Stdio struct { done chan struct{} onNotification func(mcp.JSONRPCNotification) notifyMu sync.RWMutex + onRequest RequestHandler + requestMu sync.RWMutex + ctx context.Context + ctxMu sync.RWMutex + logger util.Logger +} + +// StdioOption defines a function that configures a Stdio transport instance. +// Options can be used to customize the behavior of the transport before it starts, +// such as setting a custom command function. +type StdioOption func(*Stdio) + +// CommandFunc is a factory function that returns a custom exec.Cmd used to launch the MCP subprocess. +// It can be used to apply sandboxing, custom environment control, working directories, etc. +type CommandFunc func(ctx context.Context, command string, env []string, args []string) (*exec.Cmd, error) + +// WithCommandFunc sets a custom command factory function for the stdio transport. +// The CommandFunc is responsible for constructing the exec.Cmd used to launch the subprocess, +// allowing control over attributes like environment, working directory, and system-level sandboxing. +func WithCommandFunc(f CommandFunc) StdioOption { + return func(s *Stdio) { + s.cmdFunc = f + } +} + +// WithCommandLogger sets a custom logger for the stdio transport. +func WithCommandLogger(logger util.Logger) StdioOption { + return func(s *Stdio) { + s.logger = logger + } } // NewIO returns a new stdio-based transport using existing input, output, and @@ -44,6 +77,8 @@ func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio responses: make(map[string]chan *JSONRPCResponse), done: make(chan struct{}), + ctx: context.Background(), + logger: util.DefaultLogger(), } } @@ -55,20 +90,44 @@ func NewStdio( env []string, args ...string, ) *Stdio { + return NewStdioWithOptions(command, env, args) +} - client := &Stdio{ +// NewStdioWithOptions creates a new stdio transport to communicate with a subprocess. +// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication. +// Returns an error if the subprocess cannot be started or the pipes cannot be created. +// Optional configuration functions can be provided to customize the transport before it starts, +// such as setting a custom command factory. +func NewStdioWithOptions( + command string, + env []string, + args []string, + opts ...StdioOption, +) *Stdio { + s := &Stdio{ command: command, args: args, env: env, responses: make(map[string]chan *JSONRPCResponse), done: make(chan struct{}), + ctx: context.Background(), + logger: util.DefaultLogger(), } - return client + for _, opt := range opts { + opt(s) + } + + return s } func (c *Stdio) Start(ctx context.Context) error { + // Store the context for use in request handling + c.ctxMu.Lock() + c.ctx = ctx + c.ctxMu.Unlock() + if err := c.spawnCommand(ctx); err != nil { return err } @@ -83,18 +142,25 @@ func (c *Stdio) Start(ctx context.Context) error { return nil } -// spawnCommand spawns a new process running c.command. +// spawnCommand spawns a new process running the configured command, args, and env. +// If an (optional) cmdFunc custom command factory function was configured, it will be used to construct the subprocess; +// otherwise, the default behavior uses exec.CommandContext with the merged environment. +// Initializes stdin, stdout, and stderr pipes for JSON-RPC communication. func (c *Stdio) spawnCommand(ctx context.Context) error { if c.command == "" { return nil } - cmd := exec.CommandContext(ctx, c.command, c.args...) - - mergedEnv := os.Environ() - mergedEnv = append(mergedEnv, c.env...) + var cmd *exec.Cmd + var err error - cmd.Env = mergedEnv + // Standard behavior if no command func present. + if c.cmdFunc == nil { + cmd = exec.CommandContext(ctx, c.command, c.args...) + cmd.Env = append(os.Environ(), c.env...) + } else if cmd, err = c.cmdFunc(ctx, c.command, c.env, c.args); err != nil { + return err + } stdin, err := cmd.StdinPipe() if err != nil { @@ -148,6 +214,12 @@ func (c *Stdio) Close() error { return nil } +// GetSessionId returns the session ID of the transport. +// Since stdio does not maintain a session ID, it returns an empty string. +func (c *Stdio) GetSessionId() string { + return "" +} + // SetNotificationHandler sets the handler function to be called when a notification is received. // Only one handler can be set at a time; setting a new one replaces the previous handler. func (c *Stdio) SetNotificationHandler( @@ -158,6 +230,14 @@ func (c *Stdio) SetNotificationHandler( c.onNotification = handler } +// SetRequestHandler sets the handler function to be called when a request is received from the server. +// This enables bidirectional communication for features like sampling. +func (c *Stdio) SetRequestHandler(handler RequestHandler) { + c.requestMu.Lock() + defer c.requestMu.Unlock() + c.onRequest = handler +} + // readResponses continuously reads and processes responses from the server's stdout. // It handles both responses to requests and notifications, routing them appropriately. // Runs until the done channel is closed or an error occurs reading from stdout. @@ -169,19 +249,24 @@ func (c *Stdio) readResponses() { default: line, err := c.stdout.ReadString('\n') if err != nil { - if err != io.EOF { - fmt.Printf("Error reading response: %v\n", err) + if err != io.EOF && !errors.Is(err, context.Canceled) { + c.logger.Errorf("Error reading from stdout: %v", err) } return } - var baseMessage JSONRPCResponse + // First try to parse as a generic message to check for ID field + var baseMessage struct { + JSONRPC string `json:"jsonrpc"` + ID *mcp.RequestId `json:"id,omitempty"` + Method string `json:"method,omitempty"` + } if err := json.Unmarshal([]byte(line), &baseMessage); err != nil { continue } - // Handle notification - if baseMessage.ID.IsNil() { + // If it has a method but no ID, it's a notification + if baseMessage.Method != "" && baseMessage.ID == nil { var notification mcp.JSONRPCNotification if err := json.Unmarshal([]byte(line), ¬ification); err != nil { continue @@ -194,15 +279,30 @@ func (c *Stdio) readResponses() { continue } + // If it has a method and an ID, it's an incoming request + if baseMessage.Method != "" && baseMessage.ID != nil { + var request JSONRPCRequest + if err := json.Unmarshal([]byte(line), &request); err == nil { + c.handleIncomingRequest(request) + continue + } + } + + // Otherwise, it's a response to our request + var response JSONRPCResponse + if err := json.Unmarshal([]byte(line), &response); err != nil { + continue + } + // Create string key for map lookup - idKey := baseMessage.ID.String() + idKey := response.ID.String() c.mu.RLock() ch, exists := c.responses[idKey] c.mu.RUnlock() if exists { - ch <- &baseMessage + ch <- &response c.mu.Lock() delete(c.responses, idKey) c.mu.Unlock() @@ -219,6 +319,13 @@ func (c *Stdio) SendRequest( ctx context.Context, request JSONRPCRequest, ) (*JSONRPCResponse, error) { + // Check if context is already canceled before doing any work + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + if c.stdin == nil { return nil, fmt.Errorf("stdio client not started") } @@ -281,6 +388,95 @@ func (c *Stdio) SendNotification( return nil } +// handleIncomingRequest processes incoming requests from the server. +// It calls the registered request handler and sends the response back to the server. +func (c *Stdio) handleIncomingRequest(request JSONRPCRequest) { + c.requestMu.RLock() + handler := c.onRequest + c.requestMu.RUnlock() + + if handler == nil { + // Send error response if no handler is configured + errorResponse := JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: mcp.METHOD_NOT_FOUND, + Message: "No request handler configured", + }, + } + c.sendResponse(errorResponse) + return + } + + // Handle the request in a goroutine to avoid blocking + go func() { + c.ctxMu.RLock() + ctx := c.ctx + c.ctxMu.RUnlock() + + // Check if context is already cancelled before processing + select { + case <-ctx.Done(): + errorResponse := JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: mcp.INTERNAL_ERROR, + Message: ctx.Err().Error(), + }, + } + c.sendResponse(errorResponse) + return + default: + } + + response, err := handler(ctx, request) + if err != nil { + errorResponse := JSONRPCResponse{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: mcp.INTERNAL_ERROR, + Message: err.Error(), + }, + } + c.sendResponse(errorResponse) + return + } + + if response != nil { + c.sendResponse(*response) + } + }() +} + +// sendResponse sends a response back to the server. +func (c *Stdio) sendResponse(response JSONRPCResponse) { + responseBytes, err := json.Marshal(response) + if err != nil { + c.logger.Errorf("Error marshaling response: %v", err) + return + } + responseBytes = append(responseBytes, '\n') + + if _, err := c.stdin.Write(responseBytes); err != nil { + c.logger.Errorf("Error writing response: %v", err) + } +} + // Stderr returns a reader for the stderr output of the subprocess. // This can be used to capture error messages or logs from the subprocess. func (c *Stdio) Stderr() io.Reader { diff --git a/client/transport/stdio_test.go b/client/transport/stdio_test.go index 3eea5b23f..18aa932e8 100644 --- a/client/transport/stdio_test.go +++ b/client/transport/stdio_test.go @@ -3,15 +3,21 @@ package transport import ( "context" "encoding/json" + "errors" "fmt" + "io" "os" "os/exec" + "path/filepath" "runtime" + "strings" "sync" + "syscall" "testing" "time" "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/require" ) func compileTestServer(outputPath string) error { @@ -148,7 +154,6 @@ func TestStdio(t *testing.T) { }) t.Run("SendNotification & NotificationHandler", func(t *testing.T) { - var wg sync.WaitGroup notificationChan := make(chan mcp.JSONRPCNotification, 1) @@ -181,11 +186,33 @@ func TestStdio(t *testing.T) { defer wg.Done() select { case nt := <-notificationChan: - // We received a notification - responseJson, _ := json.Marshal(nt.Params.AdditionalFields) - requestJson, _ := json.Marshal(notification) - if string(responseJson) != string(requestJson) { - t.Errorf("Notification handler did not send the expected notification: \ngot %s\nexpect %s", responseJson, requestJson) + // We received a notification from the mock server + // The mock server sends a notification with method "debug/test" and the original request as params + if nt.Method != "debug/test" { + t.Errorf("Expected notification method 'debug/test', got '%s'", nt.Method) + return + } + + // The mock server sends the original notification request as params + // We need to extract the original method from the nested structure + paramsJson, _ := json.Marshal(nt.Params) + var originalRequest struct { + Method string `json:"method"` + Params map[string]any `json:"params"` + } + if err := json.Unmarshal(paramsJson, &originalRequest); err != nil { + t.Errorf("Failed to unmarshal notification params: %v", err) + return + } + + if originalRequest.Method != "debug/echo_notification" { + t.Errorf("Expected original method 'debug/echo_notification', got '%s'", originalRequest.Method) + return + } + + // Check if the original params contain our test data + if testValue, ok := originalRequest.Params["test"]; !ok || testValue != "value" { + t.Errorf("Expected test param 'value', got %v", originalRequest.Params["test"]) } case <-time.After(1 * time.Second): @@ -380,7 +407,6 @@ func TestStdio(t *testing.T) { t.Errorf("Expected array with 3 items, got %v", result.Params["array"]) } }) - } func TestStdioErrors(t *testing.T) { @@ -484,4 +510,200 @@ func TestStdioErrors(t *testing.T) { } }) + t.Run("StdioResponseWritingErrorLogging", func(t *testing.T) { + logChan := make(chan string, 10) + testLogger := &testLogger{logChan: logChan} + + _, stdinWriter := io.Pipe() + stdoutReader, stdoutWriter := io.Pipe() + stderrReader, stderrWriter := io.Pipe() + t.Cleanup(func() { + _ = stdinWriter.Close() + _ = stdoutWriter.Close() + _ = stderrWriter.Close() + }) + + stdio := NewIO(stdoutReader, stdinWriter, stderrReader) + stdio.logger = testLogger + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + t.Cleanup(cancel) + + err := stdio.Start(ctx) + if err != nil { + t.Fatalf("Failed to start stdio transport: %v", err) + } + t.Cleanup(func() { _ = stdio.Close() }) + + stdio.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { + return &JSONRPCResponse{ + JSONRPC: "2.0", + ID: request.ID, + Result: json.RawMessage(`"test response"`), + }, nil + }) + + doneChan := make(chan struct{}) + go func() { + // Simulate a request coming from the server + request := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(int64(1)), + Method: "test/method", + } + requestBytes, _ := json.Marshal(request) + requestBytes = append(requestBytes, '\n') + _, _ = stdoutWriter.Write(requestBytes) + + // Close stdin to trigger a write error when the response is sent + time.Sleep(50 * time.Millisecond) // Give time for the request to be processed + _ = stdinWriter.Close() + doneChan <- struct{}{} + }() + + <-doneChan + + // Wait for the error log message + select { + case logMsg := <-logChan: + if !strings.Contains(logMsg, "Error writing response") { + t.Errorf("Expected error log about writing response, got: %s", logMsg) + } + case <-time.After(3 * time.Second): + t.Fatal("Timeout waiting for error log message") + } + }) +} + +func TestStdio_WithCommandFunc(t *testing.T) { + called := false + tmpDir := t.TempDir() + chrootDir := filepath.Join(tmpDir, "sandbox-root") + err := os.MkdirAll(chrootDir, 0o755) + require.NoError(t, err, "failed to create chroot dir") + + fakeCmdFunc := func(ctx context.Context, command string, args []string, env []string) (*exec.Cmd, error) { + called = true + + // Override the args inside our command func. + cmd := exec.CommandContext(ctx, command, "bonjour") + + // Simulate some security-related settings for test purposes. + cmd.Env = []string{"PATH=/usr/bin", "NODE_ENV=production"} + cmd.Dir = tmpDir + + cmd.SysProcAttr = &syscall.SysProcAttr{ + Credential: &syscall.Credential{ + Uid: 1001, + Gid: 1001, + }, + Chroot: chrootDir, + } + + return cmd, nil + } + + stdio := NewStdioWithOptions( + "echo", + []string{"foo=bar"}, + []string{"hello"}, + WithCommandFunc(fakeCmdFunc), + ) + require.NotNil(t, stdio) + require.NotNil(t, stdio.cmdFunc) + + // Manually call the cmdFunc passing the same values as in spawnCommand. + cmd, err := stdio.cmdFunc(context.Background(), "echo", nil, []string{"hello"}) + require.NoError(t, err) + require.True(t, called) + require.NotNil(t, cmd) + require.NotNil(t, cmd.SysProcAttr) + require.Equal(t, chrootDir, cmd.SysProcAttr.Chroot) + require.Equal(t, tmpDir, cmd.Dir) + require.Equal(t, uint32(1001), cmd.SysProcAttr.Credential.Uid) + require.Equal(t, "echo", filepath.Base(cmd.Path)) + require.Len(t, cmd.Args, 2) + require.Contains(t, cmd.Args, "bonjour") + require.Len(t, cmd.Env, 2) + require.Contains(t, cmd.Env, "PATH=/usr/bin") + require.Contains(t, cmd.Env, "NODE_ENV=production") +} + +func TestStdio_SpawnCommand(t *testing.T) { + ctx := context.Background() + t.Setenv("TEST_ENVIRON_VAR", "true") + + // Explicitly not passing any environment, so we can see if it + // is picked up by spawn command merging the os.Environ. + stdio := NewStdio("echo", nil, "hello") + require.NotNil(t, stdio) + + err := stdio.spawnCommand(ctx) + require.NoError(t, err) + + t.Cleanup(func() { + _ = stdio.cmd.Process.Kill() + }) + + require.Equal(t, "echo", filepath.Base(stdio.cmd.Path)) + require.Contains(t, stdio.cmd.Args, "hello") + require.Contains(t, stdio.cmd.Env, "TEST_ENVIRON_VAR=true") +} + +func TestStdio_SpawnCommand_UsesCommandFunc(t *testing.T) { + ctx := context.Background() + t.Setenv("TEST_ENVIRON_VAR", "true") + + stdio := NewStdioWithOptions( + "echo", + nil, + []string{"test"}, + WithCommandFunc(func(ctx context.Context, cmd string, args []string, env []string) (*exec.Cmd, error) { + c := exec.CommandContext(ctx, cmd, "hola") + c.Env = env + return c, nil + }), + ) + require.NotNil(t, stdio) + err := stdio.spawnCommand(ctx) + require.NoError(t, err) + t.Cleanup(func() { + _ = stdio.cmd.Process.Kill() + }) + + require.Equal(t, "echo", filepath.Base(stdio.cmd.Path)) + require.Contains(t, stdio.cmd.Args, "hola") + require.NotContains(t, stdio.cmd.Env, "TEST_ENVIRON_VAR=true") + require.NotNil(t, stdio.stdin) + require.NotNil(t, stdio.stdout) + require.NotNil(t, stdio.stderr) +} + +func TestStdio_SpawnCommand_UsesCommandFunc_Error(t *testing.T) { + ctx := context.Background() + + stdio := NewStdioWithOptions( + "echo", + nil, + []string{"test"}, + WithCommandFunc(func(ctx context.Context, cmd string, args []string, env []string) (*exec.Cmd, error) { + return nil, errors.New("test error") + }), + ) + require.NotNil(t, stdio) + err := stdio.spawnCommand(ctx) + require.Error(t, err) + require.EqualError(t, err, "test error") +} + +func TestStdio_NewStdioWithOptions_AppliesOptions(t *testing.T) { + configured := false + + opt := func(s *Stdio) { + configured = true + } + + stdio := NewStdioWithOptions("echo", nil, []string{"test"}, opt) + require.NotNil(t, stdio) + require.True(t, configured, "option was not applied") } diff --git a/client/transport/streamable_http.go b/client/transport/streamable_http.go index 50bde9c28..268aeb342 100644 --- a/client/transport/streamable_http.go +++ b/client/transport/streamable_http.go @@ -17,10 +17,24 @@ import ( "time" "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/util" ) type StreamableHTTPCOption func(*StreamableHTTP) +// WithContinuousListening enables receiving server-to-client notifications when no request is in flight. +// In particular, if you want to receive global notifications from the server (like ToolListChangedNotification), +// you should enable this option. +// +// It will establish a standalone long-live GET HTTP connection to the server. +// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server +// NOTICE: Even enabled, the server may not support this feature. +func WithContinuousListening() StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.getListeningEnabled = true + } +} + // WithHTTPClient sets a custom HTTP client on the StreamableHTTP transport. func WithHTTPBasicClient(client *http.Client) StreamableHTTPCOption { return func(sc *StreamableHTTP) { @@ -54,6 +68,25 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption { } } +// WithHTTPLogger sets a custom logger for the StreamableHTTP transport. +func WithHTTPLogger(logger util.Logger) StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.logger = logger + } +} + +// Deprecated: Use [WithHTTPLogger] instead. +func WithLogger(logger util.Logger) StreamableHTTPCOption { + return WithHTTPLogger(logger) +} + +// WithSession creates a client with a pre-configured session +func WithSession(sessionID string) StreamableHTTPCOption { + return func(sc *StreamableHTTP) { + sc.sessionID.Store(sessionID) + } +} + // StreamableHTTP implements Streamable HTTP transport. // // It transmits JSON-RPC messages over individual HTTP requests. One message per request. @@ -63,27 +96,34 @@ func WithHTTPOAuth(config OAuthConfig) StreamableHTTPCOption { // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports // // The current implementation does not support the following features: -// - batching -// - continuously listening for server notifications when no request is in flight -// (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server) // - resuming stream // (https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#resumability-and-redelivery) -// - server -> client request type StreamableHTTP struct { - serverURL *url.URL - httpClient *http.Client - headers map[string]string - headerFunc HTTPHeaderFunc + serverURL *url.URL + httpClient *http.Client + headers map[string]string + headerFunc HTTPHeaderFunc + logger util.Logger + getListeningEnabled bool + + sessionID atomic.Value // string + protocolVersion atomic.Value // string - sessionID atomic.Value // string + initialized chan struct{} + initializedOnce sync.Once notificationHandler func(mcp.JSONRPCNotification) notifyMu sync.RWMutex + // Request handler for incoming server-to-client requests (like sampling) + requestHandler RequestHandler + requestMu sync.RWMutex + closed chan struct{} // OAuth support oauthHandler *OAuthHandler + wg sync.WaitGroup } // NewStreamableHTTP creates a new Streamable HTTP transport with the given server URL. @@ -95,15 +135,19 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str } smc := &StreamableHTTP{ - serverURL: parsedURL, - httpClient: &http.Client{}, - headers: make(map[string]string), - closed: make(chan struct{}), + serverURL: parsedURL, + httpClient: &http.Client{}, + headers: make(map[string]string), + closed: make(chan struct{}), + logger: util.DefaultLogger(), + initialized: make(chan struct{}), } smc.sessionID.Store("") // set initial value to simplify later usage for _, opt := range options { - opt(smc) + if opt != nil { + opt(smc) + } } // If OAuth is configured, set the base URL for metadata discovery @@ -118,7 +162,20 @@ func NewStreamableHTTP(serverURL string, options ...StreamableHTTPCOption) (*Str // Start initiates the HTTP connection to the server. func (c *StreamableHTTP) Start(ctx context.Context) error { - // For Streamable HTTP, we don't need to establish a persistent connection + // For Streamable HTTP, we don't need to establish a persistent connection by default + if c.getListeningEnabled { + go func() { + select { + case <-c.initialized: + ctx, cancel := c.contextAwareOfClientClose(ctx) + defer cancel() + c.listenForever(ctx) + case <-c.closed: + return + } + }() + } + return nil } @@ -135,32 +192,40 @@ func (c *StreamableHTTP) Close() error { sessionId := c.sessionID.Load().(string) if sessionId != "" { c.sessionID.Store("") - + c.wg.Add(1) // notify server session closed go func() { + defer c.wg.Done() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodDelete, c.serverURL.String(), nil) if err != nil { - fmt.Printf("failed to create close request\n: %v", err) + c.logger.Errorf("failed to create close request: %v", err) return } - req.Header.Set(headerKeySessionID, sessionId) + req.Header.Set(HeaderKeySessionID, sessionId) + // Set protocol version header if negotiated + if v := c.protocolVersion.Load(); v != nil { + if version, ok := v.(string); ok && version != "" { + req.Header.Set(HeaderKeyProtocolVersion, version) + } + } res, err := c.httpClient.Do(req) if err != nil { - fmt.Printf("failed to send close request\n: %v", err) + c.logger.Errorf("failed to send close request: %v", err) return } res.Body.Close() }() } - + c.wg.Wait() return nil } -const ( - headerKeySessionID = "Mcp-Session-Id" -) +// SetProtocolVersion sets the negotiated protocol version for this connection. +func (c *StreamableHTTP) SetProtocolVersion(version string) { + c.protocolVersion.Store(version) +} // ErrOAuthAuthorizationRequired is a sentinel error for OAuth authorization required var ErrOAuthAuthorizationRequired = errors.New("no valid token available, authorization required") @@ -184,78 +249,36 @@ func (c *StreamableHTTP) SendRequest( ctx context.Context, request JSONRPCRequest, ) (*JSONRPCResponse, error) { - - // Create a combined context that could be canceled when the client is closed - newCtx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - select { - case <-c.closed: - cancel() - case <-newCtx.Done(): - // The original context was canceled, no need to do anything - } - }() - ctx = newCtx - // Marshal request requestBody, err := json.Marshal(request) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) } - // Create HTTP request - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") - sessionID := c.sessionID.Load() - if sessionID != "" { - req.Header.Set(headerKeySessionID, sessionID.(string)) - } - for k, v := range c.headers { - req.Header.Set(k, v) - } - - // Add OAuth authorization if configured - if c.oauthHandler != nil { - authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) - if err != nil { - // If we get an authorization error, return a specific error that can be handled by the client - if err.Error() == "no valid token available, authorization required" { - return nil, &OAuthAuthorizationRequiredError{ - Handler: c.oauthHandler, - } - } - return nil, fmt.Errorf("failed to get authorization header: %w", err) - } - req.Header.Set("Authorization", authHeader) - } + ctx, cancel := c.contextAwareOfClientClose(ctx) + defer cancel() - if c.headerFunc != nil { - for k, v := range c.headerFunc(ctx) { - req.Header.Set(k, v) + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream") + if err != nil { + if errors.Is(err, ErrSessionTerminated) && request.Method == string(mcp.MethodInitialize) { + // If the request is initialize, should not return a SessionTerminated error + // It should be a genuine endpoint-routing issue. + // ( Fall through to return StatusCode checking. ) + } else { + return nil, fmt.Errorf("failed to send request: %w", err) } } - // Send request - resp, err := c.httpClient.Do(req) - if err != nil { + // Only proceed if we have a valid response. + // When sendHTTP fails and resp is nil but method is mcp.MethodInitialize + // defer resp.Body.Close() fails with nil pointer dereference. + if resp == nil { return nil, fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() // Check if we got an error response if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { - // handle session closed - if resp.StatusCode == http.StatusNotFound { - c.sessionID.CompareAndSwap(sessionID, "") - return nil, fmt.Errorf("session terminated (404). need to re-initialize") - } // Handle OAuth unauthorized error if resp.StatusCode == http.StatusUnauthorized && c.oauthHandler != nil { @@ -276,9 +299,13 @@ func (c *StreamableHTTP) SendRequest( if request.Method == string(mcp.MethodInitialize) { // saved the received session ID in the response // empty session ID is allowed - if sessionID := resp.Header.Get(headerKeySessionID); sessionID != "" { + if sessionID := resp.Header.Get(HeaderKeySessionID); sessionID != "" { c.sessionID.Store(sessionID) } + + c.initializedOnce.Do(func() { + close(c.initialized) + }) } // Handle different response types @@ -300,35 +327,105 @@ func (c *StreamableHTTP) SendRequest( case "text/event-stream": // Server is using SSE for streaming responses - return c.handleSSEResponse(ctx, resp.Body) + return c.handleSSEResponse(ctx, resp.Body, false) default: return nil, fmt.Errorf("unexpected content type: %s", resp.Header.Get("Content-Type")) } } +func (c *StreamableHTTP) sendHTTP( + ctx context.Context, + method string, + body io.Reader, + acceptType string, +) (resp *http.Response, err error) { + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, method, c.serverURL.String(), body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", acceptType) + sessionID := c.sessionID.Load().(string) + if sessionID != "" { + req.Header.Set(HeaderKeySessionID, sessionID) + } + // Set protocol version header if negotiated + if v := c.protocolVersion.Load(); v != nil { + if version, ok := v.(string); ok && version != "" { + req.Header.Set(HeaderKeyProtocolVersion, version) + } + } + for k, v := range c.headers { + req.Header.Set(k, v) + } + + // Add OAuth authorization if configured + if c.oauthHandler != nil { + authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) + if err != nil { + // If we get an authorization error, return a specific error that can be handled by the client + if err.Error() == "no valid token available, authorization required" { + return nil, &OAuthAuthorizationRequiredError{ + Handler: c.oauthHandler, + } + } + return nil, fmt.Errorf("failed to get authorization header: %w", err) + } + req.Header.Set("Authorization", authHeader) + } + + if c.headerFunc != nil { + for k, v := range c.headerFunc(ctx) { + req.Header.Set(k, v) + } + } + + // Send request + resp, err = c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + + // universal handling for session terminated + if resp.StatusCode == http.StatusNotFound { + c.sessionID.CompareAndSwap(sessionID, "") + return nil, ErrSessionTerminated + } + + return resp, nil +} + // handleSSEResponse processes an SSE stream for a specific request. // It returns the final result for the request once received, or an error. -func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser) (*JSONRPCResponse, error) { - +// If ignoreResponse is true, it won't return when a response messge is received. This is for continuous listening. +func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCloser, ignoreResponse bool) (*JSONRPCResponse, error) { // Create a channel for this specific request responseChan := make(chan *JSONRPCResponse, 1) + // Add timeout context for request processing if not already set + if deadline, ok := ctx.Deadline(); !ok || time.Until(deadline) > 30*time.Second { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, 30*time.Second) + defer cancel() + } + ctx, cancel := context.WithCancel(ctx) defer cancel() // Start a goroutine to process the SSE stream go func() { - // only close responseChan after readingSSE() + // Ensure this goroutine respects the context defer close(responseChan) c.readSSE(ctx, reader, func(event, data string) { - - // (unsupported: batching) - + // Try to unmarshal as a response first var message JSONRPCResponse if err := json.Unmarshal([]byte(data), &message); err != nil { - fmt.Printf("failed to unmarshal message: %v\n", err) + c.logger.Errorf("failed to unmarshal message: %v", err) return } @@ -336,7 +433,7 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl if message.ID.IsNil() { var notification mcp.JSONRPCNotification if err := json.Unmarshal([]byte(data), ¬ification); err != nil { - fmt.Printf("failed to unmarshal notification: %v\n", err) + c.logger.Errorf("failed to unmarshal notification: %v", err) return } c.notifyMu.RLock() @@ -347,7 +444,22 @@ func (c *StreamableHTTP) handleSSEResponse(ctx context.Context, reader io.ReadCl return } - responseChan <- &message + // Check if this is actually a request from the server by looking for method field + var rawMessage map[string]json.RawMessage + if err := json.Unmarshal([]byte(data), &rawMessage); err == nil { + if _, hasMethod := rawMessage["method"]; hasMethod && !message.ID.IsNil() { + var request JSONRPCRequest + if err := json.Unmarshal([]byte(data), &request); err == nil { + // This is a request from the server + c.handleIncomingRequest(ctx, request) + return + } + } + } + + if !ignoreResponse { + responseChan <- &message + } }) }() @@ -393,7 +505,7 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand case <-ctx.Done(): return default: - fmt.Printf("SSE stream error: %v\n", err) + c.logger.Errorf("SSE stream error: %v", err) return } } @@ -424,7 +536,6 @@ func (c *StreamableHTTP) readSSE(ctx context.Context, reader io.ReadCloser, hand } func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp.JSONRPCNotification) error { - // Marshal request requestBody, err := json.Marshal(notification) if err != nil { @@ -432,44 +543,10 @@ func (c *StreamableHTTP) SendNotification(ctx context.Context, notification mcp. } // Create HTTP request - req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.serverURL.String(), bytes.NewReader(requestBody)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - - // Set headers - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Accept", "application/json, text/event-stream") - if sessionID := c.sessionID.Load(); sessionID != "" { - req.Header.Set(headerKeySessionID, sessionID.(string)) - } - for k, v := range c.headers { - req.Header.Set(k, v) - } - - // Add OAuth authorization if configured - if c.oauthHandler != nil { - authHeader, err := c.oauthHandler.GetAuthorizationHeader(ctx) - if err != nil { - // If we get an authorization error, return a specific error that can be handled by the client - if errors.Is(err, ErrOAuthAuthorizationRequired) { - return &OAuthAuthorizationRequiredError{ - Handler: c.oauthHandler, - } - } - return fmt.Errorf("failed to get authorization header: %w", err) - } - req.Header.Set("Authorization", authHeader) - } - - if c.headerFunc != nil { - for k, v := range c.headerFunc(ctx) { - req.Header.Set(k, v) - } - } + ctx, cancel := c.contextAwareOfClientClose(ctx) + defer cancel() - // Send request - resp, err := c.httpClient.Do(req) + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(requestBody), "application/json, text/event-stream") if err != nil { return fmt.Errorf("failed to send request: %w", err) } @@ -500,6 +577,13 @@ func (c *StreamableHTTP) SetNotificationHandler(handler func(mcp.JSONRPCNotifica c.notificationHandler = handler } +// SetRequestHandler sets the handler for incoming requests from the server. +func (c *StreamableHTTP) SetRequestHandler(handler RequestHandler) { + c.requestMu.Lock() + defer c.requestMu.Unlock() + c.requestHandler = handler +} + func (c *StreamableHTTP) GetSessionId() string { return c.sessionID.Load().(string) } @@ -513,3 +597,203 @@ func (c *StreamableHTTP) GetOAuthHandler() *OAuthHandler { func (c *StreamableHTTP) IsOAuthEnabled() bool { return c.oauthHandler != nil } + +func (c *StreamableHTTP) listenForever(ctx context.Context) { + c.logger.Infof("listening to server forever") + for { + // Add timeout for individual connection attempts + connectCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + err := c.createGETConnectionToServer(connectCtx) + cancel() + + if errors.Is(err, ErrGetMethodNotAllowed) { + // server does not support listening + c.logger.Errorf("server does not support listening") + return + } + + select { + case <-ctx.Done(): + return + default: + } + + if err != nil { + c.logger.Errorf("failed to listen to server. retry in 1 second: %v", err) + } + + // Use context-aware sleep + select { + case <-time.After(retryInterval): + case <-ctx.Done(): + return + } + } +} + +var ( + ErrSessionTerminated = fmt.Errorf("session terminated (404). need to re-initialize") + ErrGetMethodNotAllowed = fmt.Errorf("GET method not allowed") + + retryInterval = 1 * time.Second // a variable is convenient for testing +) + +func (c *StreamableHTTP) createGETConnectionToServer(ctx context.Context) error { + resp, err := c.sendHTTP(ctx, http.MethodGet, nil, "text/event-stream") + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + // Check if we got an error response + if resp.StatusCode == http.StatusMethodNotAllowed { + return ErrGetMethodNotAllowed + } + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("request failed with status %d: %s", resp.StatusCode, body) + } + + // handle SSE response + contentType := resp.Header.Get("Content-Type") + if contentType != "text/event-stream" { + return fmt.Errorf("unexpected content type: %s", contentType) + } + + // When ignoreResponse is true, the function will never return expect context is done. + // NOTICE: Due to the ambiguity of the specification, other SDKs may use the GET connection to transfer the response + // messages. To be more compatible, we should handle this response, however, as the transport layer is message-based, + // currently, there is no convenient way to handle this response. + // So we ignore the response here. It's not a bug, but may be not compatible with other SDKs. + _, err = c.handleSSEResponse(ctx, resp.Body, true) + if err != nil { + return fmt.Errorf("failed to handle SSE response: %w", err) + } + + return nil +} + +// handleIncomingRequest processes requests from the server (like sampling requests) +func (c *StreamableHTTP) handleIncomingRequest(ctx context.Context, request JSONRPCRequest) { + c.requestMu.RLock() + handler := c.requestHandler + c.requestMu.RUnlock() + + if handler == nil { + c.logger.Errorf("received request from server but no handler set: %s", request.Method) + // Send method not found error + errorResponse := &JSONRPCResponse{ + JSONRPC: "2.0", + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: -32601, // Method not found + Message: fmt.Sprintf("no handler configured for method: %s", request.Method), + }, + } + c.sendResponseToServer(ctx, errorResponse) + return + } + + // Handle the request in a goroutine to avoid blocking the SSE reader + go func() { + // Create a new context with timeout for request handling, respecting parent context + requestCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + response, err := handler(requestCtx, request) + if err != nil { + c.logger.Errorf("error handling request %s: %v", request.Method, err) + + // Determine appropriate JSON-RPC error code based on error type + var errorCode int + var errorMessage string + + // Check for specific sampling-related errors + if errors.Is(err, context.Canceled) { + errorCode = -32800 // Request cancelled + errorMessage = "request was cancelled" + } else if errors.Is(err, context.DeadlineExceeded) { + errorCode = -32800 // Request timeout + errorMessage = "request timed out" + } else { + // Generic error cases + switch request.Method { + case string(mcp.MethodSamplingCreateMessage): + errorCode = -32603 // Internal error + errorMessage = fmt.Sprintf("sampling request failed: %v", err) + default: + errorCode = -32603 // Internal error + errorMessage = err.Error() + } + } + + // Send error response + errorResponse := &JSONRPCResponse{ + JSONRPC: "2.0", + ID: request.ID, + Error: &struct { + Code int `json:"code"` + Message string `json:"message"` + Data json.RawMessage `json:"data"` + }{ + Code: errorCode, + Message: errorMessage, + }, + } + c.sendResponseToServer(requestCtx, errorResponse) + return + } + + if response != nil { + c.sendResponseToServer(requestCtx, response) + } + }() +} + +// sendResponseToServer sends a response back to the server via HTTP POST +func (c *StreamableHTTP) sendResponseToServer(ctx context.Context, response *JSONRPCResponse) { + if response == nil { + c.logger.Errorf("cannot send nil response to server") + return + } + + responseBody, err := json.Marshal(response) + if err != nil { + c.logger.Errorf("failed to marshal response: %v", err) + return + } + + ctx, cancel := c.contextAwareOfClientClose(ctx) + defer cancel() + + resp, err := c.sendHTTP(ctx, http.MethodPost, bytes.NewReader(responseBody), "application/json") + if err != nil { + c.logger.Errorf("failed to send response to server: %v", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted { + body, _ := io.ReadAll(resp.Body) + c.logger.Errorf("server rejected response with status %d: %s", resp.StatusCode, body) + } +} + +func (c *StreamableHTTP) contextAwareOfClientClose(ctx context.Context) (context.Context, context.CancelFunc) { + newCtx, cancel := context.WithCancel(ctx) + go func() { + select { + case <-c.closed: + cancel() + case <-newCtx.Done(): + // The original context was canceled + cancel() + } + }() + return newCtx, cancel +} diff --git a/client/transport/streamable_http_sampling_test.go b/client/transport/streamable_http_sampling_test.go new file mode 100644 index 000000000..edba61eac --- /dev/null +++ b/client/transport/streamable_http_sampling_test.go @@ -0,0 +1,496 @@ +package transport + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// TestStreamableHTTP_SamplingFlow tests the complete sampling flow with HTTP transport +func TestStreamableHTTP_SamplingFlow(t *testing.T) { + // Create simple test server + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Just respond OK to any requests + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + // Create HTTP client transport + client, err := NewStreamableHTTP(server.URL) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Set up sampling request handler + var handledRequest *JSONRPCRequest + handlerCalled := make(chan struct{}) + client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { + handledRequest = &request + close(handlerCalled) + + // Simulate sampling handler response + result := map[string]any{ + "role": "assistant", + "content": map[string]any{ + "type": "text", + "text": "Hello! How can I help you today?", + }, + "model": "test-model", + "stopReason": "stop_sequence", + } + + resultBytes, _ := json.Marshal(result) + + return &JSONRPCResponse{ + JSONRPC: "2.0", + ID: request.ID, + Result: resultBytes, + }, nil + }) + + // Start the client + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err = client.Start(ctx) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Test direct request handling (simulating a sampling request) + samplingRequest := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(1), + Method: string(mcp.MethodSamplingCreateMessage), + Params: map[string]any{ + "messages": []map[string]any{ + { + "role": "user", + "content": map[string]any{ + "type": "text", + "text": "Hello, world!", + }, + }, + }, + }, + } + + // Directly test request handling + client.handleIncomingRequest(ctx, samplingRequest) + + // Wait for handler to be called + select { + case <-handlerCalled: + // Handler was called + case <-time.After(1 * time.Second): + t.Fatal("Handler was not called within timeout") + } + + // Verify the request was handled + if handledRequest == nil { + t.Fatal("Sampling request was not handled") + } + + if handledRequest.Method != string(mcp.MethodSamplingCreateMessage) { + t.Errorf("Expected method %s, got %s", mcp.MethodSamplingCreateMessage, handledRequest.Method) + } +} + +// TestStreamableHTTP_SamplingErrorHandling tests error handling in sampling requests +func TestStreamableHTTP_SamplingErrorHandling(t *testing.T) { + var errorHandled sync.WaitGroup + errorHandled.Add(1) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Logf("Failed to decode body: %v", err) + w.WriteHeader(http.StatusOK) + return + } + + // Check if this is an error response + if errorField, ok := body["error"]; ok { + errorMap := errorField.(map[string]any) + if code, ok := errorMap["code"].(float64); ok && code == -32603 { + errorHandled.Done() + w.WriteHeader(http.StatusOK) + return + } + } + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client, err := NewStreamableHTTP(server.URL) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Set up request handler that returns an error + client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { + return nil, fmt.Errorf("sampling failed") + }) + + // Start the client + ctx := context.Background() + err = client.Start(ctx) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Simulate incoming sampling request + samplingRequest := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(1), + Method: string(mcp.MethodSamplingCreateMessage), + Params: map[string]any{}, + } + + // This should trigger error handling + client.handleIncomingRequest(ctx, samplingRequest) + + // Wait for error to be handled + errorHandled.Wait() +} + +// TestStreamableHTTP_NoSamplingHandler tests behavior when no sampling handler is set +func TestStreamableHTTP_NoSamplingHandler(t *testing.T) { + var errorReceived bool + errorReceivedChan := make(chan struct{}) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Logf("Failed to decode body: %v", err) + w.WriteHeader(http.StatusOK) + return + } + + // Check if this is an error response with method not found + if errorField, ok := body["error"]; ok { + errorMap := errorField.(map[string]any) + if code, ok := errorMap["code"].(float64); ok && code == -32601 { + if message, ok := errorMap["message"].(string); ok && + strings.Contains(message, "no handler configured") { + errorReceived = true + close(errorReceivedChan) + } + } + } + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client, err := NewStreamableHTTP(server.URL) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Don't set any request handler + + ctx := context.Background() + err = client.Start(ctx) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Simulate incoming sampling request + samplingRequest := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(1), + Method: string(mcp.MethodSamplingCreateMessage), + Params: map[string]any{}, + } + + // This should trigger "method not found" error + client.handleIncomingRequest(ctx, samplingRequest) + + // Wait for error to be received + select { + case <-errorReceivedChan: + // Error was received + case <-time.After(1 * time.Second): + t.Fatal("Method not found error was not received within timeout") + } + + if !errorReceived { + t.Error("Expected method not found error, but didn't receive it") + } +} + +// TestStreamableHTTP_BidirectionalInterface verifies the interface implementation +func TestStreamableHTTP_BidirectionalInterface(t *testing.T) { + client, err := NewStreamableHTTP("http://example.com") + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Verify it implements BidirectionalInterface + _, ok := any(client).(BidirectionalInterface) + if !ok { + t.Error("StreamableHTTP should implement BidirectionalInterface") + } + + // Test SetRequestHandler + handlerSet := false + handlerSetChan := make(chan struct{}) + client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { + handlerSet = true + close(handlerSetChan) + return nil, nil + }) + + // Verify handler was set by triggering it + ctx := context.Background() + client.handleIncomingRequest(ctx, JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(1), + Method: "test", + }) + + // Wait for handler to be called + select { + case <-handlerSetChan: + // Handler was called + case <-time.After(1 * time.Second): + t.Fatal("Handler was not called within timeout") + } + + if !handlerSet { + t.Error("Request handler was not properly set or called") + } +} + +// TestStreamableHTTP_ConcurrentSamplingRequests tests concurrent sampling requests +// where the second request completes faster than the first request +func TestStreamableHTTP_ConcurrentSamplingRequests(t *testing.T) { + var receivedResponses []map[string]any + var responseMutex sync.Mutex + responseComplete := make(chan struct{}, 2) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + var body map[string]any + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + t.Logf("Failed to decode body: %v", err) + w.WriteHeader(http.StatusBadRequest) + return + } + + // Check if this is a response from client (not a request) + if _, ok := body["result"]; ok { + responseMutex.Lock() + receivedResponses = append(receivedResponses, body) + responseMutex.Unlock() + responseComplete <- struct{}{} + } + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client, err := NewStreamableHTTP(server.URL) + if err != nil { + t.Fatalf("Failed to create client: %v", err) + } + defer client.Close() + + // Track which requests have been received and their completion order + var requestOrder []int + var orderMutex sync.Mutex + + // Set up request handler that simulates different processing times + client.SetRequestHandler(func(ctx context.Context, request JSONRPCRequest) (*JSONRPCResponse, error) { + // Extract request ID to determine processing time + requestIDValue := request.ID.Value() + + var delay time.Duration + var responseText string + var requestNum int + + // First request (ID 1) takes longer, second request (ID 2) completes faster + if requestIDValue == int64(1) { + delay = 100 * time.Millisecond + responseText = "Response from slow request 1" + requestNum = 1 + } else if requestIDValue == int64(2) { + delay = 10 * time.Millisecond + responseText = "Response from fast request 2" + requestNum = 2 + } else { + t.Errorf("Unexpected request ID: %v", requestIDValue) + return nil, fmt.Errorf("unexpected request ID") + } + + // Simulate processing time + time.Sleep(delay) + + // Record completion order + orderMutex.Lock() + requestOrder = append(requestOrder, requestNum) + orderMutex.Unlock() + + // Return response with correct request ID + result := map[string]any{ + "role": "assistant", + "content": map[string]any{ + "type": "text", + "text": responseText, + }, + "model": "test-model", + "stopReason": "stop_sequence", + } + + resultBytes, _ := json.Marshal(result) + + return &JSONRPCResponse{ + JSONRPC: "2.0", + ID: request.ID, + Result: resultBytes, + }, nil + }) + + // Start the client + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err = client.Start(ctx) + if err != nil { + t.Fatalf("Failed to start client: %v", err) + } + + // Create two sampling requests with different IDs + request1 := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(int64(1)), + Method: string(mcp.MethodSamplingCreateMessage), + Params: map[string]any{ + "messages": []map[string]any{ + { + "role": "user", + "content": map[string]any{ + "type": "text", + "text": "Slow request 1", + }, + }, + }, + }, + } + + request2 := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(int64(2)), + Method: string(mcp.MethodSamplingCreateMessage), + Params: map[string]any{ + "messages": []map[string]any{ + { + "role": "user", + "content": map[string]any{ + "type": "text", + "text": "Fast request 2", + }, + }, + }, + }, + } + + // Send both requests concurrently + go client.handleIncomingRequest(ctx, request1) + go client.handleIncomingRequest(ctx, request2) + + // Wait for both responses to complete + for i := 0; i < 2; i++ { + select { + case <-responseComplete: + // Response received + case <-time.After(2 * time.Second): + t.Fatal("Timeout waiting for response") + } + } + + // Verify completion order: request 2 should complete first + orderMutex.Lock() + defer orderMutex.Unlock() + + if len(requestOrder) != 2 { + t.Fatalf("Expected 2 completed requests, got %d", len(requestOrder)) + } + + if requestOrder[0] != 2 { + t.Errorf("Expected request 2 to complete first, but request %d completed first", requestOrder[0]) + } + + if requestOrder[1] != 1 { + t.Errorf("Expected request 1 to complete second, but request %d completed second", requestOrder[1]) + } + + // Verify responses are correctly associated + responseMutex.Lock() + defer responseMutex.Unlock() + + if len(receivedResponses) != 2 { + t.Fatalf("Expected 2 responses, got %d", len(receivedResponses)) + } + + // Find responses by ID + var response1, response2 map[string]any + for _, resp := range receivedResponses { + if id, ok := resp["id"]; ok { + switch id { + case int64(1), float64(1): + response1 = resp + case int64(2), float64(2): + response2 = resp + } + } + } + + if response1 == nil { + t.Error("Response for request 1 not found") + } + if response2 == nil { + t.Error("Response for request 2 not found") + } + + // Verify each response contains the correct content + if response1 != nil { + if result, ok := response1["result"].(map[string]any); ok { + if content, ok := result["content"].(map[string]any); ok { + if text, ok := content["text"].(string); ok { + if !strings.Contains(text, "slow request 1") { + t.Errorf("Response 1 should contain 'slow request 1', got: %s", text) + } + } + } + } + } + + if response2 != nil { + if result, ok := response2["result"].(map[string]any); ok { + if content, ok := result["content"].(map[string]any); ok { + if text, ok := content["text"].(string); ok { + if !strings.Contains(text, "fast request 2") { + t.Errorf("Response 2 should contain 'fast request 2', got: %s", text) + } + } + } + } + } +} \ No newline at end of file diff --git a/client/transport/streamable_http_test.go b/client/transport/streamable_http_test.go index 4cd5ad19e..5208cb9c3 100644 --- a/client/transport/streamable_http_test.go +++ b/client/transport/streamable_http_test.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "strings" "sync" "testing" "time" @@ -417,7 +418,7 @@ func TestStreamableHTTP(t *testing.T) { t.Run("SSEEventWithoutEventField", func(t *testing.T) { // Test that SSE events with only data field (no event field) are processed correctly // This tests the fix for issue #369 - + // Create a custom mock server that sends SSE events without event field handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { @@ -437,7 +438,7 @@ func TestStreamableHTTP(t *testing.T) { // This should be processed as a "message" event according to SSE spec w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) - + response := map[string]any{ "jsonrpc": "2.0", "id": request["id"], @@ -522,5 +523,259 @@ func TestStreamableHTTPErrors(t *testing.T) { t.Errorf("Expected error when sending request to non-existent URL, got nil") } }) +} + +// ---- continuous listening tests ---- + +// startMockStreamableWithGETSupport starts a test HTTP server that implements +// a minimal Streamable HTTP server for testing purposes with support for GET requests +// to test the continuous listening feature. +func startMockStreamableWithGETSupport(getSupport bool) (string, func(), chan bool, int) { + var sessionID string + var mu sync.Mutex + disconnectCh := make(chan bool, 1) + notificationCount := 0 + var notificationMu sync.Mutex + + sendNotification := func() { + notificationMu.Lock() + notificationCount++ + notificationMu.Unlock() + } + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Handle POST requests for initialization + if r.Method == http.MethodPost { + // Parse incoming JSON-RPC request + var request map[string]any + decoder := json.NewDecoder(r.Body) + if err := decoder.Decode(&request); err != nil { + http.Error(w, fmt.Sprintf("Invalid JSON: %v", err), http.StatusBadRequest) + return + } + + method := request["method"] + if method == "initialize" { + // Generate a new session ID + mu.Lock() + sessionID = fmt.Sprintf("test-session-%d", time.Now().UnixNano()) + mu.Unlock() + w.Header().Set("Mcp-Session-Id", sessionID) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusAccepted) + if err := json.NewEncoder(w).Encode(map[string]any{ + "jsonrpc": "2.0", + "id": request["id"], + "result": "initialized", + }); err != nil { + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + } + return + } + + // Handle GET requests for continuous listening + if r.Method == http.MethodGet { + if !getSupport { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + // Check session ID + if recvSessionID := r.Header.Get("Mcp-Session-Id"); recvSessionID != sessionID { + http.Error(w, "Invalid session ID", http.StatusNotFound) + return + } + + // Setup SSE connection + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + flusher, ok := w.(http.Flusher) + if !ok { + http.Error(w, "Streaming not supported", http.StatusInternalServerError) + return + } + + // Send a notification + notification := map[string]any{ + "jsonrpc": "2.0", + "method": "test/notification", + "params": map[string]any{"message": "Hello from server"}, + } + notificationData, _ := json.Marshal(notification) + fmt.Fprintf(w, "event: message\ndata: %s\n\n", notificationData) + flusher.Flush() + sendNotification() + + // Keep the connection open or disconnect as requested + select { + case <-disconnectCh: + // Force disconnect + return + case <-r.Context().Done(): + // Client disconnected + return + case <-time.After(50 * time.Millisecond): + // Send another notification + notification = map[string]any{ + "jsonrpc": "2.0", + "method": "test/notification", + "params": map[string]any{"message": "Second notification"}, + } + notificationData, _ = json.Marshal(notification) + fmt.Fprintf(w, "event: message\ndata: %s\n\n", notificationData) + flusher.Flush() + sendNotification() + return + } + } else { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + }) + + // Start test server + testServer := httptest.NewServer(handler) + + notificationMu.Lock() + defer notificationMu.Unlock() + + return testServer.URL, testServer.Close, disconnectCh, notificationCount +} + +func TestContinuousListening(t *testing.T) { + retryInterval = 10 * time.Millisecond + // Start mock server with GET support + url, closeServer, disconnectCh, _ := startMockStreamableWithGETSupport(true) + + // Create transport with continuous listening enabled + trans, err := NewStreamableHTTP(url, WithContinuousListening()) + if err != nil { + t.Fatal(err) + } + + // Ensure transport is closed before server to avoid connection refused errors + defer func() { + trans.Close() + closeServer() + }() + + // Setup notification handler + notificationReceived := make(chan struct{}, 10) + trans.SetNotificationHandler(func(notification mcp.JSONRPCNotification) { + notificationReceived <- struct{}{} + }) + + // Start the transport - this will launch listenForever in a goroutine + if err := trans.Start(context.Background()); err != nil { + t.Fatal(err) + } + + // Initialize the transport first + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + initRequest := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(int64(0)), + Method: "initialize", + } + + _, err = trans.SendRequest(ctx, initRequest) + if err != nil { + t.Fatal(err) + } + + // Wait for notifications to be received + notificationCount := 0 + for notificationCount < 2 { + select { + case <-notificationReceived: + notificationCount++ + case <-time.After(3 * time.Second): + t.Fatalf("Timed out waiting for notifications, received %d", notificationCount) + return + } + } + + // Test server disconnect and reconnect + disconnectCh <- true + time.Sleep(50 * time.Millisecond) // Allow time for reconnection + + // Verify reconnect occurred by receiving more notifications + reconnectNotificationCount := 0 + for reconnectNotificationCount < 2 { + select { + case <-notificationReceived: + reconnectNotificationCount++ + case <-time.After(3 * time.Second): + t.Fatalf("Timed out waiting for notifications after reconnect") + return + } + } +} + +func TestContinuousListeningMethodNotAllowed(t *testing.T) { + // Start a server that doesn't support GET + url, closeServer, _, _ := startMockStreamableWithGETSupport(false) + + // Setup logger to capture log messages + logChan := make(chan string, 10) + testLogger := &testLogger{logChan: logChan} + + // Create transport with continuous listening enabled and custom logger + trans, err := NewStreamableHTTP(url, WithContinuousListening(), WithLogger(testLogger)) + if err != nil { + t.Fatal(err) + } + + // Ensure transport is closed before server to avoid connection refused errors + defer func() { + trans.Close() + closeServer() + }() + + // Initialize the transport first + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Start the transport + if err := trans.Start(context.Background()); err != nil { + t.Fatal(err) + } + + initRequest := JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(int64(0)), + Method: "initialize", + } + + _, err = trans.SendRequest(ctx, initRequest) + if err != nil { + t.Fatal(err) + } + + // Wait for the error log message that server doesn't support listening + select { + case logMsg := <-logChan: + if !strings.Contains(logMsg, "server does not support listening") { + t.Errorf("Expected error log about server not supporting listening, got: %s", logMsg) + } + case <-time.After(5 * time.Second): + t.Fatal("Timeout waiting for log message") + } +} + +// testLogger is a simple logger for testing +type testLogger struct { + logChan chan string +} + +func (l *testLogger) Infof(format string, args ...any) { + // Intentionally left empty +} +func (l *testLogger) Errorf(format string, args ...any) { + l.logChan <- fmt.Sprintf(format, args...) } diff --git a/examples/everything/main.go b/examples/everything/main.go index 5489220c3..620f5936a 100644 --- a/examples/everything/main.go +++ b/examples/everything/main.go @@ -165,9 +165,41 @@ func NewMCPServer() *server.MCPServer { mcpServer.AddNotificationHandler("notification", handleNotification) + mcpServer.AddTool(mcp.NewTool("get_resource_link", + mcp.WithDescription("Returns a resource link example"), + mcp.WithString("resource_type", + mcp.Description("Type of resource to link to"), + mcp.DefaultString("document")), + ), handleGetResourceLinkTool) + return mcpServer } +func handleGetResourceLinkTool( + ctx context.Context, + request mcp.CallToolRequest, +) (*mcp.CallToolResult, error) { + resourceType := request.GetString("resource_type", "document") + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Here's a link to a %s resource:", resourceType), + }, + mcp.NewResourceLink( + fmt.Sprintf("file:///example/%s.pdf", resourceType), + fmt.Sprintf("Sample %s", resourceType), + fmt.Sprintf("A sample %s for demonstration", resourceType), + "application/pdf", + ), + mcp.TextContent{ + Type: "text", + Text: "You can access this resource using the provided URI.", + }, + }, + }, nil +} + func generateResources() []mcp.Resource { resources := make([]mcp.Resource, 100) for i := 0; i < 100; i++ { diff --git a/examples/in_process/main.go b/examples/in_process/main.go new file mode 100644 index 000000000..d01a5e808 --- /dev/null +++ b/examples/in_process/main.go @@ -0,0 +1,100 @@ +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// handleDummyTool is a simple tool that returns "foo bar" +func handleDummyTool( + ctx context.Context, + request mcp.CallToolRequest, +) (*mcp.CallToolResult, error) { + return mcp.NewToolResultText("foo bar"), nil +} + +func NewMCPServer() *server.MCPServer { + mcpServer := server.NewMCPServer( + "example-server", + "1.0.0", + server.WithResourceCapabilities(true, true), + server.WithPromptCapabilities(true), + server.WithToolCapabilities(true), + ) + mcpServer.AddTool(mcp.NewTool("dummy_tool", + mcp.WithDescription("A dummy tool that returns foo bar"), + ), handleDummyTool) + + return mcpServer +} + +type MCPClient struct { + client *client.Client + serverInfo *mcp.InitializeResult +} + +// NewMCPClient creates a new MCP client with an in-process MCP server. +func NewMCPClient(ctx context.Context) (*MCPClient, error) { + srv := NewMCPServer() + client, err := client.NewInProcessClient(srv) + if err != nil { + return nil, fmt.Errorf("failed to create in-process client: %w", err) + } + + // Start the client with timeout context + ctxWithTimeout, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + if err := client.Start(ctxWithTimeout); err != nil { + return nil, fmt.Errorf("failed to start client: %w", err) + } + + // Initialize the client + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "Example MCP Client", + Version: "1.0.0", + } + initRequest.Params.Capabilities = mcp.ClientCapabilities{} + + serverInfo, err := client.Initialize(ctx, initRequest) + if err != nil { + return nil, fmt.Errorf("failed to initialize MCP client: %w", err) + } + + return &MCPClient{ + client: client, + serverInfo: serverInfo, + }, nil +} + +func main() { + ctx := context.Background() + client, err := NewMCPClient(ctx) + if err != nil { + log.Fatalf("Failed to create MCP client: %v", err) + } + + toolsRequest := mcp.ListToolsRequest{} + toolsResult, err := client.client.ListTools(ctx, toolsRequest) + if err != nil { + log.Fatalf("Failed to list tools: %v", err) + } + fmt.Println(toolsResult.Tools) + + request := mcp.CallToolRequest{} + request.Params.Name = "dummy_tool" + + result, err := client.client.CallTool(ctx, request) + if err != nil { + log.Fatalf("Failed to call tool: %v", err) + } + fmt.Println(result.Content) +} diff --git a/examples/inprocess_sampling/README.md b/examples/inprocess_sampling/README.md new file mode 100644 index 000000000..7776fd5ed --- /dev/null +++ b/examples/inprocess_sampling/README.md @@ -0,0 +1,39 @@ +# InProcess Sampling Example + +This example demonstrates how to use sampling with in-process MCP client/server communication. + +## Overview + +The example shows: +- Creating an MCP server with sampling enabled +- Adding a tool that uses sampling to request LLM completions +- Creating an in-process client with a sampling handler +- Making tool calls that trigger sampling requests + +## Key Components + +### Server Side +- `mcpServer.EnableSampling()` - Enables sampling capability +- Tool handler calls `mcpServer.RequestSampling()` to request LLM completions +- Sampling requests are handled directly by the client's sampling handler + +### Client Side +- `MockSamplingHandler` - Implements the `SamplingHandler` interface +- `NewInProcessClientWithSamplingHandler()` - Creates client with sampling support +- The handler receives sampling requests and returns mock LLM responses + +## Running the Example + +```bash +go run main.go +``` + +## Expected Output + +``` +Tool result: LLM Response (model: mock-llm-v1): Mock LLM response to: 'What is the capital of France?' +``` + +## Real LLM Integration + +To integrate with a real LLM service (OpenAI, Anthropic, etc.), replace the `MockSamplingHandler` with an implementation that calls your preferred LLM API. See the [client sampling documentation](https://mcp-go.dev/clients/advanced-sampling) for examples with real LLM providers. \ No newline at end of file diff --git a/examples/inprocess_sampling/main.go b/examples/inprocess_sampling/main.go new file mode 100644 index 000000000..a50ee6434 --- /dev/null +++ b/examples/inprocess_sampling/main.go @@ -0,0 +1,166 @@ +package main + +import ( + "context" + "fmt" + "log" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// MockSamplingHandler implements client.SamplingHandler for demonstration +type MockSamplingHandler struct{} + +func (h *MockSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Extract the user's message + var userMessage string + for _, msg := range request.Messages { + if msg.Role == mcp.RoleUser { + if textContent, ok := msg.Content.(mcp.TextContent); ok { + userMessage = textContent.Text + break + } + } + } + + // Generate a mock response + mockResponse := fmt.Sprintf("Mock LLM response to: '%s'", userMessage) + + return &mcp.CreateMessageResult{ + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: mockResponse, + }, + }, + Model: "mock-llm-v1", + StopReason: "endTurn", + }, nil +} + +func main() { + // Create server with sampling enabled + mcpServer := server.NewMCPServer("inprocess-sampling-example", "1.0.0") + mcpServer.EnableSampling() + + // Add a tool that uses sampling + mcpServer.AddTool(mcp.Tool{ + Name: "ask_llm", + Description: "Ask the LLM a question using sampling", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "question": map[string]any{ + "type": "string", + "description": "The question to ask the LLM", + }, + "system_prompt": map[string]any{ + "type": "string", + "description": "Optional system prompt", + }, + }, + Required: []string{"question"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + question, err := request.RequireString("question") + if err != nil { + return nil, err + } + + systemPrompt := request.GetString("system_prompt", "You are a helpful assistant.") + + // Create sampling request + samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: question, + }, + }, + }, + SystemPrompt: systemPrompt, + MaxTokens: 1000, + Temperature: 0.7, + }, + } + + // Request sampling from client + result, err := mcpServer.RequestSampling(ctx, samplingRequest) + if err != nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Error requesting sampling: %v", err), + }, + }, + IsError: true, + }, nil + } + + // Return the LLM response + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("LLM Response (model: %s): %s", + result.Model, result.Content.(mcp.TextContent).Text), + }, + }, + }, nil + }) + + // Create client with sampling handler + mockHandler := &MockSamplingHandler{} + mcpClient, err := client.NewInProcessClientWithSamplingHandler(mcpServer, mockHandler) + if err != nil { + log.Fatalf("Failed to create client: %v", err) + } + defer mcpClient.Close() + + // Start the client + ctx := context.Background() + if err := mcpClient.Start(ctx); err != nil { + log.Fatalf("Failed to start client: %v", err) + } + + // Initialize + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "inprocess-sampling-client", + Version: "1.0.0", + } + + _, err = mcpClient.Initialize(ctx, initRequest) + if err != nil { + log.Fatalf("Failed to initialize: %v", err) + } + + // Call the tool that uses sampling + result, err := mcpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "ask_llm", + Arguments: map[string]any{ + "question": "What is the capital of France?", + "system_prompt": "You are a helpful geography assistant.", + }, + }, + }) + if err != nil { + log.Fatalf("Tool call failed: %v", err) + } + + // Print the result + if len(result.Content) > 0 { + if textContent, ok := result.Content[0].(mcp.TextContent); ok { + fmt.Printf("Tool result: %s\n", textContent.Text) + } + } +} diff --git a/examples/sampling_client/README.md b/examples/sampling_client/README.md new file mode 100644 index 000000000..7a1d9cb3f --- /dev/null +++ b/examples/sampling_client/README.md @@ -0,0 +1,87 @@ +# MCP Sampling Example Client + +This example demonstrates how to implement an MCP client that supports sampling requests from servers. + +## Features + +- **Sampling Handler**: Implements the `SamplingHandler` interface to process sampling requests +- **Mock LLM**: Provides a mock LLM implementation for demonstration purposes +- **Capability Declaration**: Automatically declares sampling capability when a handler is configured +- **Bidirectional Communication**: Handles incoming requests from the server + +## Mock LLM Handler + +The `MockSamplingHandler` simulates an LLM by: +- Logging the received request parameters +- Generating a mock response that echoes the input +- Returning proper MCP sampling response format + +In a real implementation, you would: +- Integrate with actual LLM APIs (OpenAI, Anthropic, etc.) +- Implement proper model selection based on preferences +- Add human-in-the-loop approval mechanisms +- Handle rate limiting and error cases + +## Usage + +Build the client: + +```bash +go build -o sampling_client +``` + +Run with the sampling server: + +```bash +./sampling_client ../sampling_server/sampling_server +``` + +Or with any other MCP server that supports sampling: + +```bash +./sampling_client /path/to/your/mcp/server +``` + +## Implementation Details + +1. **Sampling Handler**: Implements `client.SamplingHandler` interface +2. **Client Configuration**: Uses `client.WithSamplingHandler()` to enable sampling +3. **Automatic Capability**: Sampling capability is automatically declared during initialization +4. **Request Processing**: Handles incoming `sampling/createMessage` requests from servers + +## Sample Output + +``` +Connected to server: sampling-example-server v1.0.0 +Available tools: + - ask_llm: Ask the LLM a question using sampling + - greet: Greet the user + +--- Testing greet tool --- +Greet result: Hello, Sampling Demo User! This server supports sampling - try using the ask_llm tool! + +--- Testing ask_llm tool (with sampling) --- +Mock LLM received: What is the capital of France? +System prompt: You are a helpful geography assistant. +Max tokens: 1000 +Temperature: 0.700000 +Ask LLM result: LLM Response (model: mock-llm-v1): Mock LLM response to: 'What is the capital of France?'. This is a simulated response from a mock LLM handler. +``` + +## Real LLM Integration + +To integrate with a real LLM, replace the `MockSamplingHandler` with an implementation that: + +```go +type RealSamplingHandler struct { + apiKey string + client *openai.Client // or other LLM client +} + +func (h *RealSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Convert MCP request to LLM API format + // Call LLM API + // Convert response back to MCP format + // Return result +} +``` \ No newline at end of file diff --git a/examples/sampling_client/main.go b/examples/sampling_client/main.go new file mode 100644 index 000000000..093b59817 --- /dev/null +++ b/examples/sampling_client/main.go @@ -0,0 +1,210 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// MockSamplingHandler implements the SamplingHandler interface for demonstration. +// In a real implementation, this would integrate with an actual LLM API. +type MockSamplingHandler struct{} + +func (h *MockSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Extract the user's message + if len(request.Messages) == 0 { + return nil, fmt.Errorf("no messages provided") + } + + userMessage := request.Messages[0] + var userText string + + // Extract text from the content + switch content := userMessage.Content.(type) { + case mcp.TextContent: + userText = content.Text + case map[string]any: + // Handle case where content is unmarshaled as a map + if text, ok := content["text"].(string); ok { + userText = text + } else { + userText = fmt.Sprintf("%v", content) + } + default: + userText = fmt.Sprintf("%v", content) + } + + // Simulate LLM processing + log.Printf("Mock LLM received: %s", userText) + log.Printf("System prompt: %s", request.SystemPrompt) + log.Printf("Max tokens: %d", request.MaxTokens) + log.Printf("Temperature: %f", request.Temperature) + + // Generate a mock response + responseText := fmt.Sprintf("Mock LLM response to: '%s'. This is a simulated response from a mock LLM handler.", userText) + + log.Printf("Mock LLM generating response: %s", responseText) + + result := &mcp.CreateMessageResult{ + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: responseText, + }, + }, + Model: "mock-llm-v1", + StopReason: "endTurn", + } + + log.Printf("Mock LLM returning result: %+v", result) + return result, nil +} + +func main() { + if len(os.Args) < 2 { + log.Fatal("Usage: sampling_client ") + } + + serverCommand := os.Args[1] + serverArgs := os.Args[2:] + + // Create stdio transport to communicate with the server + stdio := transport.NewStdio(serverCommand, nil, serverArgs...) + + // Create sampling handler + samplingHandler := &MockSamplingHandler{} + + // Create client with sampling capability + mcpClient := client.NewClient(stdio, client.WithSamplingHandler(samplingHandler)) + + ctx := context.Background() + + // Start the client + if err := mcpClient.Start(ctx); err != nil { + log.Fatalf("Failed to start client: %v", err) + } + + // Setup graceful shutdown + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + // Create a context that cancels on signal + ctx, cancel := context.WithCancel(ctx) + go func() { + <-sigChan + log.Println("Received shutdown signal, closing client...") + cancel() + }() + + // Move defer after error checking + defer func() { + if err := mcpClient.Close(); err != nil { + log.Printf("Error closing client: %v", err) + } + }() + + // Initialize the connection + initResult, err := mcpClient.Initialize(ctx, mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "sampling-example-client", + Version: "1.0.0", + }, + Capabilities: mcp.ClientCapabilities{ + // Sampling capability will be automatically added by WithSamplingHandler + }, + }, + }) + if err != nil { + log.Fatalf("Failed to initialize: %v", err) + } + + log.Printf("Connected to server: %s v%s", initResult.ServerInfo.Name, initResult.ServerInfo.Version) + log.Printf("Server capabilities: %+v", initResult.Capabilities) + + // List available tools + toolsResult, err := mcpClient.ListTools(ctx, mcp.ListToolsRequest{}) + if err != nil { + log.Fatalf("Failed to list tools: %v", err) + } + + log.Printf("Available tools:") + for _, tool := range toolsResult.Tools { + log.Printf(" - %s: %s", tool.Name, tool.Description) + } + + // Test the greeting tool first + log.Println("\n--- Testing greet tool ---") + greetResult, err := mcpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "greet", + Arguments: map[string]any{ + "name": "Sampling Demo User", + }, + }, + }) + if err != nil { + log.Printf("Error calling greet tool: %v", err) + } else { + log.Printf("Greet result: %+v", greetResult) + for _, content := range greetResult.Content { + if textContent, ok := content.(mcp.TextContent); ok { + log.Printf(" %s", textContent.Text) + } + } + } + + // Test the sampling tool + log.Println("\n--- Testing ask_llm tool (with sampling) ---") + askResult, err := mcpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "ask_llm", + Arguments: map[string]any{ + "question": "What is the capital of France?", + "system_prompt": "You are a helpful geography assistant.", + }, + }, + }) + if err != nil { + log.Printf("Error calling ask_llm tool: %v", err) + } else { + log.Printf("Ask LLM result: %+v", askResult) + for _, content := range askResult.Content { + if textContent, ok := content.(mcp.TextContent); ok { + log.Printf(" %s", textContent.Text) + } + } + } + + // Test another sampling request + log.Println("\n--- Testing ask_llm tool with different question ---") + askResult2, err := mcpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "ask_llm", + Arguments: map[string]any{ + "question": "Explain quantum computing in simple terms.", + }, + }, + }) + if err != nil { + log.Printf("Error calling ask_llm tool: %v", err) + } else { + log.Printf("Ask LLM result 2: %+v", askResult2) + for _, content := range askResult2.Content { + if textContent, ok := content.(mcp.TextContent); ok { + log.Printf(" %s", textContent.Text) + } + } + } + + log.Println("\n--- Sampling demo completed ---") +} diff --git a/examples/sampling_http_client/README.md b/examples/sampling_http_client/README.md new file mode 100644 index 000000000..e4cf0ea4e --- /dev/null +++ b/examples/sampling_http_client/README.md @@ -0,0 +1,95 @@ +# HTTP Sampling Client Example + +This example demonstrates how to create an MCP client using HTTP transport that supports sampling requests from the server. + +## Overview + +This client: +- Connects to an MCP server via HTTP/HTTPS transport +- Declares sampling capability during initialization +- Handles incoming sampling requests from the server +- Uses a mock LLM to generate responses (replace with real LLM integration) + +## Usage + +1. Start an MCP server that supports sampling (e.g., using the `sampling_server` example) + +2. Update the server URL in `main.go`: + ```go + httpClient, err := client.NewStreamableHttpClient( + "http://your-server:port", // Replace with your server URL + ) + ``` + +3. Run the client: + ```bash + go run main.go + ``` + +## Key Features + +### HTTP Transport with Sampling +The client creates the HTTP transport directly and then wraps it with a client that supports sampling: + +```go +httpTransport, err := transport.NewStreamableHTTP("http://localhost:8080") +mcpClient := client.NewClient(httpTransport, client.WithSamplingHandler(samplingHandler)) +``` + +### Sampling Handler +The `MockSamplingHandler` implements the `client.SamplingHandler` interface: + +```go +type MockSamplingHandler struct{} + +func (h *MockSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Process the sampling request and return LLM response + // In production, integrate with OpenAI, Anthropic, or other LLM APIs +} +``` + +### Client Configuration +The client is configured with sampling capabilities: + +```go +mcpClient := client.NewClient( + httpTransport, + client.WithSamplingHandler(samplingHandler), +) +// Sampling capability is automatically declared when a handler is provided +``` + +## Real Implementation + +For a production implementation, replace the `MockSamplingHandler` with a real LLM client: + +```go +type RealSamplingHandler struct { + client *openai.Client // or other LLM client +} + +func (h *RealSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Convert MCP request to LLM API format + // Call LLM API + // Convert response back to MCP format + // Return the result +} +``` + +## HTTP-Specific Features + +The HTTP transport supports: +- Standard HTTP headers for authentication and customization +- OAuth 2.0 authentication (using `WithHTTPOAuth`) +- Custom headers (using `WithHTTPHeaders`) +- Server-side events (SSE) for bidirectional communication +- Proper error handling with HTTP status codes +- Session management via HTTP headers + +## Testing + +The implementation includes comprehensive tests in `client/transport/streamable_http_sampling_test.go` that verify: +- Sampling request handling +- Error scenarios +- Bidirectional interface compliance +- HTTP-specific error codes and responses \ No newline at end of file diff --git a/examples/sampling_http_client/main.go b/examples/sampling_http_client/main.go new file mode 100644 index 000000000..98817e6f8 --- /dev/null +++ b/examples/sampling_http_client/main.go @@ -0,0 +1,116 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "os/signal" + "syscall" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +// MockSamplingHandler implements client.SamplingHandler for demonstration. +// In a real implementation, this would integrate with an actual LLM API. +type MockSamplingHandler struct{} + +func (h *MockSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Extract the user's message + if len(request.Messages) == 0 { + return nil, fmt.Errorf("no messages provided") + } + + // Get the last user message + lastMessage := request.Messages[len(request.Messages)-1] + userText := "" + if textContent, ok := lastMessage.Content.(mcp.TextContent); ok { + userText = textContent.Text + } + + // Generate a mock response + responseText := fmt.Sprintf("Mock LLM response to: '%s'", userText) + + log.Printf("Mock LLM generating response: %s", responseText) + + result := &mcp.CreateMessageResult{ + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: responseText, + }, + }, + Model: "mock-model-v1", + StopReason: "endTurn", + } + + return result, nil +} + +func main() { + // Create sampling handler + samplingHandler := &MockSamplingHandler{} + + // Create HTTP transport directly + httpTransport, err := transport.NewStreamableHTTP( + "http://localhost:8080", // Replace with your MCP server URL + // You can add HTTP-specific options here like headers, OAuth, etc. + ) + if err != nil { + log.Fatalf("Failed to create HTTP transport: %v", err) + } + defer httpTransport.Close() + + // Create client with sampling support + mcpClient := client.NewClient( + httpTransport, + client.WithSamplingHandler(samplingHandler), + ) + + // Start the client + ctx := context.Background() + err = mcpClient.Start(ctx) + if err != nil { + log.Fatalf("Failed to start client: %v", err) + } + + // Initialize the MCP session + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + Capabilities: mcp.ClientCapabilities{ + // Sampling capability will be automatically added by the client + }, + ClientInfo: mcp.Implementation{ + Name: "sampling-http-client", + Version: "1.0.0", + }, + }, + } + + _, err = mcpClient.Initialize(ctx, initRequest) + if err != nil { + log.Fatalf("Failed to initialize MCP session: %v", err) + } + + log.Println("HTTP MCP client with sampling support started successfully!") + log.Println("The client is now ready to handle sampling requests from the server.") + log.Println("When the server sends a sampling request, the MockSamplingHandler will process it.") + + // In a real application, you would keep the client running to handle sampling requests + // For this example, we'll just demonstrate that it's working + + // Keep the client running (in a real app, you'd have your main application logic here) + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, os.Interrupt, syscall.SIGTERM) + + select { + case <-ctx.Done(): + log.Println("Client context cancelled") + case <-sigChan: + log.Println("Received shutdown signal") + } +} \ No newline at end of file diff --git a/examples/sampling_http_server/README.md b/examples/sampling_http_server/README.md new file mode 100644 index 000000000..64be58c2c --- /dev/null +++ b/examples/sampling_http_server/README.md @@ -0,0 +1,138 @@ +# HTTP Sampling Server Example + +This example demonstrates how to create an MCP server using HTTP transport that can send sampling requests to clients. + +## Overview + +This server: +- Runs on HTTP transport (port 8080 by default) +- Declares sampling capability during initialization +- Can send sampling requests to connected clients via Server-Sent Events (SSE) +- Receives sampling responses from clients via HTTP POST +- Includes tools that demonstrate sampling functionality + +## Usage + +1. Start the server: + ```bash + go run main.go + ``` + +2. The server will be available at: `http://localhost:8080/mcp` + +3. Connect with an HTTP client that supports sampling (like the `sampling_http_client` example) + +## Tools Available + +### `ask_llm` +Demonstrates server-initiated sampling: +- Takes a question and optional system prompt +- Sends sampling request to client +- Returns the LLM's response + +### `echo` +Simple tool for testing basic functionality: +- Echoes back the input message +- Doesn't require sampling + +## How Sampling Works + +### Server → Client Flow +1. **Tool Invocation**: Client calls `ask_llm` tool +2. **Sampling Request**: Server creates sampling request with user's question +3. **SSE Transmission**: Server sends JSON-RPC request to client via SSE stream +4. **Client Processing**: Client's sampling handler processes the request +5. **HTTP Response**: Client sends JSON-RPC response back via HTTP POST +6. **Tool Response**: Server returns the LLM response to the original tool caller + +### Communication Architecture +``` +Client (HTTP + SSE) ←→ Server (HTTP) + │ │ + ├─ POST: Tool Call ──→ │ + │ │ + │ ←── SSE: Sampling ───┤ + │ Request │ + │ │ + ├─ POST: Sampling ───→ │ + │ Response │ + │ │ + │ ←── HTTP: Tool ──────┤ + Response +``` + +## Key Features + +### Bidirectional Communication +- **SSE Stream**: Server → Client requests (sampling, notifications) +- **HTTP POST**: Client → Server responses and requests + +### Session Management +- Session ID tracking for request/response correlation +- Proper session lifecycle management +- Session validation for security + +### Error Handling +- JSON-RPC error codes for different failure scenarios +- Timeout handling for sampling requests +- Queue overflow protection + +### HTTP-Specific Features +- Standard MCP headers (`Mcp-Session-Id`, `Mcp-Protocol-Version`) +- Content-Type validation +- Proper HTTP status codes +- SSE event formatting + +## Testing + +You can test the server using the `sampling_http_client` example: + +1. Start this server: + ```bash + go run examples/sampling_http_server/main.go + ``` + +2. In another terminal, start the client: + ```bash + go run examples/sampling_http_client/main.go + ``` + +3. The client will connect and be ready to handle sampling requests from the server. + +## Production Considerations + +### Security +- Implement proper authentication/authorization +- Use HTTPS in production +- Validate all incoming data +- Implement rate limiting + +### Scalability +- Consider connection pooling for multiple clients +- Implement proper session cleanup +- Monitor memory usage for long-running sessions +- Add metrics and monitoring + +### Reliability +- Implement request retries +- Add circuit breakers for failing clients +- Implement graceful degradation when sampling is unavailable +- Add comprehensive logging + +## Integration + +This server can be integrated into existing HTTP infrastructure: + +```go +// Custom HTTP server integration +mux := http.NewServeMux() +mux.Handle("/mcp", httpServer) +mux.Handle("/health", healthHandler) + +server := &http.Server{ + Addr: ":8080", + Handler: mux, +} +``` + +The sampling functionality works seamlessly with other MCP features like tools, resources, and prompts. \ No newline at end of file diff --git a/examples/sampling_http_server/main.go b/examples/sampling_http_server/main.go new file mode 100644 index 000000000..95a2bf29b --- /dev/null +++ b/examples/sampling_http_server/main.go @@ -0,0 +1,150 @@ +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + // Create MCP server with sampling capability + mcpServer := server.NewMCPServer("sampling-http-server", "1.0.0") + + // Enable sampling capability + mcpServer.EnableSampling() + + // Add a tool that uses sampling to get LLM responses + mcpServer.AddTool(mcp.Tool{ + Name: "ask_llm", + Description: "Ask the LLM a question using sampling over HTTP", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "question": map[string]any{ + "type": "string", + "description": "The question to ask the LLM", + }, + "system_prompt": map[string]any{ + "type": "string", + "description": "Optional system prompt to provide context", + }, + }, + Required: []string{"question"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract parameters + question, err := request.RequireString("question") + if err != nil { + return nil, err + } + + systemPrompt := request.GetString("system_prompt", "You are a helpful assistant.") + + // Create sampling request + samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: question, + }, + }, + }, + SystemPrompt: systemPrompt, + MaxTokens: 1000, + Temperature: 0.7, + }, + } + + // Request sampling from the client with timeout + samplingCtx, cancel := context.WithTimeout(ctx, 2*time.Minute) + defer cancel() + + serverFromCtx := server.ServerFromContext(ctx) + result, err := serverFromCtx.RequestSampling(samplingCtx, samplingRequest) + if err != nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Error requesting sampling: %v", err), + }, + }, + IsError: true, + }, nil + } + + // Extract response text safely + var responseText string + if textContent, ok := result.Content.(mcp.TextContent); ok { + responseText = textContent.Text + } else { + responseText = fmt.Sprintf("%v", result.Content) + } + + // Return the LLM response + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("LLM Response (model: %s): %s", result.Model, responseText), + }, + }, + }, nil + }) + + // Add a simple echo tool for testing + mcpServer.AddTool(mcp.Tool{ + Name: "echo", + Description: "Echo back the input message", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "message": map[string]any{ + "type": "string", + "description": "The message to echo back", + }, + }, + Required: []string{"message"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + message := request.GetString("message", "") + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Echo: %s", message), + }, + }, + }, nil + }) + + // Create HTTP server + httpServer := server.NewStreamableHTTPServer(mcpServer) + + log.Println("Starting HTTP MCP server with sampling support on :8080") + log.Println("Endpoint: http://localhost:8080/mcp") + log.Println("") + log.Println("This server supports sampling over HTTP transport.") + log.Println("Clients must:") + log.Println("1. Initialize with sampling capability") + log.Println("2. Establish SSE connection for bidirectional communication") + log.Println("3. Handle incoming sampling requests from the server") + log.Println("4. Send responses back via HTTP POST") + log.Println("") + log.Println("Available tools:") + log.Println("- ask_llm: Ask the LLM a question (requires sampling)") + log.Println("- echo: Simple echo tool (no sampling required)") + + // Start the server + if err := httpServer.Start(":8080"); err != nil { + log.Fatalf("Server failed to start: %v", err) + } +} \ No newline at end of file diff --git a/examples/sampling_server/README.md b/examples/sampling_server/README.md new file mode 100644 index 000000000..53c823792 --- /dev/null +++ b/examples/sampling_server/README.md @@ -0,0 +1,52 @@ +# MCP Sampling Example Server + +This example demonstrates how to implement an MCP server that uses sampling to request LLM completions from clients. + +## Features + +- **Sampling Support**: The server can request LLM completions from clients that support sampling +- **Tool Integration**: Shows how to use sampling within tool implementations +- **Bidirectional Communication**: Demonstrates server-to-client requests + +## Tools + +### `ask_llm` +Asks the LLM a question using sampling. This tool demonstrates how servers can leverage client-side LLM capabilities. + +**Parameters:** +- `question` (required): The question to ask the LLM +- `system_prompt` (optional): System prompt to provide context + +### `greet` +A simple greeting tool that doesn't use sampling, for comparison. + +**Parameters:** +- `name` (required): Name of the person to greet + +## Usage + +Build and run the server: + +```bash +go build -o sampling_server +./sampling_server +``` + +The server communicates via stdio and expects to be connected to an MCP client that supports sampling. + +## Implementation Details + +1. **Enable Sampling**: The server calls `mcpServer.EnableSampling()` to declare sampling capability +2. **Request Sampling**: Tools use `mcpServer.RequestSampling(ctx, request)` to send sampling requests to the client +3. **Handle Responses**: The server receives and processes the LLM responses from the client via bidirectional stdio communication +4. **Response Routing**: Incoming responses are automatically routed to the correct pending request using request IDs + +## Testing + +Use the companion `sampling_client` example to test this server: + +```bash +cd ../sampling_client +go build -o sampling_client +./sampling_client ../sampling_server/sampling_server +``` \ No newline at end of file diff --git a/examples/sampling_server/main.go b/examples/sampling_server/main.go new file mode 100644 index 000000000..ea887c588 --- /dev/null +++ b/examples/sampling_server/main.go @@ -0,0 +1,145 @@ +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + // Create a new MCP server + mcpServer := server.NewMCPServer("sampling-example-server", "1.0.0") + + // Enable sampling capability + mcpServer.EnableSampling() + + // Add a tool that uses sampling + mcpServer.AddTool(mcp.Tool{ + Name: "ask_llm", + Description: "Ask the LLM a question using sampling", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "question": map[string]any{ + "type": "string", + "description": "The question to ask the LLM", + }, + "system_prompt": map[string]any{ + "type": "string", + "description": "Optional system prompt to provide context", + }, + }, + Required: []string{"question"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract parameters using helper methods + question, err := request.RequireString("question") + if err != nil { + return nil, err + } + + systemPrompt := request.GetString("system_prompt", "You are a helpful assistant.") + // Create sampling request + samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: question, + }, + }, + }, + SystemPrompt: systemPrompt, + MaxTokens: 1000, + Temperature: 0.7, + }, + } + + // Request sampling from the client + samplingCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + serverFromCtx := server.ServerFromContext(ctx) + result, err := serverFromCtx.RequestSampling(samplingCtx, samplingRequest) + if err != nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Error requesting sampling: %v", err), + }, + }, + IsError: true, + }, nil + } + + // Return the LLM's response + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("LLM Response (model: %s): %s", result.Model, getTextFromContent(result.Content)), + }, + }, + }, nil + }) + + // Add a simple greeting tool + mcpServer.AddTool(mcp.Tool{ + Name: "greet", + Description: "Greet the user", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "name": map[string]any{ + "type": "string", + "description": "Name of the person to greet", + }, + }, + Required: []string{"name"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + name, err := request.RequireString("name") + if err != nil { + return nil, err + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Hello, %s! This server supports sampling - try using the ask_llm tool!", name), + }, + }, + }, nil + }) + + // Start the stdio server + log.Println("Starting sampling example server...") + if err := server.ServeStdio(mcpServer); err != nil { + log.Fatalf("Server error: %v", err) + } +} + +// Helper function to extract text from content +func getTextFromContent(content any) string { + switch c := content.(type) { + case mcp.TextContent: + return c.Text + case map[string]any: + // Handle JSON unmarshaled content + if text, ok := c["text"].(string); ok { + return text + } + return fmt.Sprintf("%v", content) + case string: + return c + default: + return fmt.Sprintf("%v", content) + } +} diff --git a/examples/simple_client/main.go b/examples/simple_client/main.go index 5deb99113..c0f48593a 100644 --- a/examples/simple_client/main.go +++ b/examples/simple_client/main.go @@ -54,6 +54,11 @@ func main() { // Create client with the transport c = client.NewClient(stdioTransport) + // Start the client + if err := c.Start(ctx); err != nil { + log.Fatalf("Failed to start client: %v", err) + } + // Set up logging for stderr if available if stderr, ok := client.GetStderr(c); ok { go func() { @@ -76,6 +81,12 @@ func main() { fmt.Println("Initializing HTTP client...") // Create HTTP transport httpTransport, err := transport.NewStreamableHTTP(*httpURL) + // NOTE: the default streamableHTTP transport is not 100% identical to the stdio client. + // By default, it could not receive global notifications (e.g. toolListChanged). + // You need to enable the `WithContinuousListening()` option to establish a long-live connection, + // and receive the notifications any time the server sends them. + // + // httpTransport, err := transport.NewStreamableHTTP(*httpURL, transport.WithContinuousListening()) if err != nil { log.Fatalf("Failed to create HTTP transport: %v", err) } @@ -84,11 +95,6 @@ func main() { c = client.NewClient(httpTransport) } - // Start the client - if err := c.Start(ctx); err != nil { - log.Fatalf("Failed to start client: %v", err) - } - // Set up notification handler c.OnNotification(func(notification mcp.JSONRPCNotification) { fmt.Printf("Received notification: %s\n", notification.Method) diff --git a/examples/structured_output/README.md b/examples/structured_output/README.md new file mode 100644 index 000000000..e2de01fcf --- /dev/null +++ b/examples/structured_output/README.md @@ -0,0 +1,46 @@ +# Structured Content Example + +This example shows how to return `structuredContent` in tool result with corresponding `OutputSchema`. + +Defined in the MCP spec here: https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content + +## Usage + +Define a struct for your output: + +```go +type WeatherResponse struct { + Location string `json:"location" jsonschema_description:"The location"` + Temperature float64 `json:"temperature" jsonschema_description:"Current temperature"` + Conditions string `json:"conditions" jsonschema_description:"Weather conditions"` +} +``` + +Add it to your tool: + +```go +tool := mcp.NewTool("get_weather", + mcp.WithDescription("Get weather information"), + mcp.WithOutputSchema[WeatherResponse](), + mcp.WithString("location", mcp.Required()), +) +``` + +Return structured data in tool result: + +```go +func weatherHandler(ctx context.Context, request mcp.CallToolRequest, args WeatherRequest) (*mcp.CallToolResult, error) { + response := WeatherResponse{ + Location: args.Location, + Temperature: 25.0, + Conditions: "Cloudy", + } + + fallbackText := fmt.Sprintf("Weather in %s: %.1f°C, %s", + response.Location, response.Temperature, response.Conditions) + + return mcp.NewToolResultStructured(response, fallbackText), nil +} +``` + +See [main.go](./main.go) for more examples. \ No newline at end of file diff --git a/examples/structured_output/main.go b/examples/structured_output/main.go new file mode 100644 index 000000000..e7df04021 --- /dev/null +++ b/examples/structured_output/main.go @@ -0,0 +1,152 @@ +package main + +import ( + "context" + "fmt" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// Note: The jsonschema_description tag is added to the JSON schema as description +// Ideally use better descriptions, this is just an example +type WeatherRequest struct { + Location string `json:"location" jsonschema_description:"City or location"` + Units string `json:"units,omitempty" jsonschema_description:"celsius or fahrenheit"` +} + +type WeatherResponse struct { + Location string `json:"location" jsonschema_description:"Location"` + Temperature float64 `json:"temperature" jsonschema_description:"Temperature"` + Units string `json:"units" jsonschema_description:"Units"` + Conditions string `json:"conditions" jsonschema_description:"Weather conditions"` + Timestamp time.Time `json:"timestamp" jsonschema_description:"When retrieved"` +} + +type UserProfile struct { + ID string `json:"id" jsonschema_description:"User ID"` + Name string `json:"name" jsonschema_description:"Full name"` + Email string `json:"email" jsonschema_description:"Email"` + Tags []string `json:"tags" jsonschema_description:"User tags"` +} + +type UserRequest struct { + UserID string `json:"userId" jsonschema_description:"User ID"` +} + +type Asset struct { + ID string `json:"id" jsonschema_description:"Asset identifier"` + Name string `json:"name" jsonschema_description:"Asset name"` + Value float64 `json:"value" jsonschema_description:"Current value"` + Currency string `json:"currency" jsonschema_description:"Currency code"` +} + +type AssetListRequest struct { + Limit int `json:"limit,omitempty" jsonschema_description:"Number of assets to return"` +} + +func main() { + s := server.NewMCPServer( + "Structured Output Example", + "1.0.0", + server.WithToolCapabilities(false), + ) + + // Example 1: Auto-generated schema from struct + weatherTool := mcp.NewTool("get_weather", + mcp.WithDescription("Get weather with structured output"), + mcp.WithOutputSchema[WeatherResponse](), + mcp.WithString("location", mcp.Required()), + mcp.WithString("units", mcp.Enum("celsius", "fahrenheit"), mcp.DefaultString("celsius")), + ) + s.AddTool(weatherTool, mcp.NewStructuredToolHandler(getWeatherHandler)) + + // Example 2: Nested struct schema + userTool := mcp.NewTool("get_user_profile", + mcp.WithDescription("Get user profile"), + mcp.WithOutputSchema[UserProfile](), + mcp.WithString("userId", mcp.Required()), + ) + s.AddTool(userTool, mcp.NewStructuredToolHandler(getUserProfileHandler)) + + // Example 3: Array output - direct array of objects + assetsTool := mcp.NewTool("get_assets", + mcp.WithDescription("Get list of assets as array"), + mcp.WithOutputSchema[[]Asset](), + mcp.WithNumber("limit", mcp.Min(1), mcp.Max(100), mcp.DefaultNumber(10)), + ) + s.AddTool(assetsTool, mcp.NewStructuredToolHandler(getAssetsHandler)) + + // Example 4: Manual result creation + manualTool := mcp.NewTool("manual_structured", + mcp.WithDescription("Manual structured result"), + mcp.WithOutputSchema[WeatherResponse](), + mcp.WithString("location", mcp.Required()), + ) + s.AddTool(manualTool, mcp.NewTypedToolHandler(manualWeatherHandler)) + + if err := server.ServeStdio(s); err != nil { + fmt.Printf("Server error: %v\n", err) + } +} + +func getWeatherHandler(ctx context.Context, request mcp.CallToolRequest, args WeatherRequest) (WeatherResponse, error) { + temp := 22.5 + if args.Units == "fahrenheit" { + temp = temp*9/5 + 32 + } + + return WeatherResponse{ + Location: args.Location, + Temperature: temp, + Units: args.Units, + Conditions: "Cloudy with a chance of meatballs", + Timestamp: time.Now(), + }, nil +} + +func getUserProfileHandler(ctx context.Context, request mcp.CallToolRequest, args UserRequest) (UserProfile, error) { + return UserProfile{ + ID: args.UserID, + Name: "John Doe", + Email: "john.doe@example.com", + Tags: []string{"developer", "golang"}, + }, nil +} + +func getAssetsHandler(ctx context.Context, request mcp.CallToolRequest, args AssetListRequest) ([]Asset, error) { + limit := args.Limit + if limit <= 0 { + limit = 10 + } + + assets := []Asset{ + {ID: "btc", Name: "Bitcoin", Value: 45000.50, Currency: "USD"}, + {ID: "eth", Name: "Ethereum", Value: 3200.75, Currency: "USD"}, + {ID: "ada", Name: "Cardano", Value: 0.85, Currency: "USD"}, + {ID: "sol", Name: "Solana", Value: 125.30, Currency: "USD"}, + {ID: "dot", Name: "Pottedot", Value: 18.45, Currency: "USD"}, + } + + if limit > len(assets) { + limit = len(assets) + } + + return assets[:limit], nil +} + +func manualWeatherHandler(ctx context.Context, request mcp.CallToolRequest, args WeatherRequest) (*mcp.CallToolResult, error) { + response := WeatherResponse{ + Location: args.Location, + Temperature: 25.0, + Units: "celsius", + Conditions: "Sunny, yesterday my life was filled with rain", + Timestamp: time.Now(), + } + + fallbackText := fmt.Sprintf("Weather in %s: %.1f°C, %s", + response.Location, response.Temperature, response.Conditions) + + return mcp.NewToolResultStructured(response, fallbackText), nil +} diff --git a/go.mod b/go.mod index 9b9fe2d48..5c8974549 100644 --- a/go.mod +++ b/go.mod @@ -4,13 +4,18 @@ go 1.23 require ( github.com/google/uuid v1.6.0 + github.com/invopop/jsonschema v0.13.0 github.com/spf13/cast v1.7.1 github.com/stretchr/testify v1.9.0 github.com/yosida95/uritemplate/v3 v3.0.2 ) require ( + github.com/bahlo/generic-list-go v0.2.0 // indirect + github.com/buger/jsonparser v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/mailru/easyjson v0.7.7 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 31ed86d18..70e9c33da 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,7 @@ +github.com/bahlo/generic-list-go v0.2.0 h1:5sz/EEAK+ls5wF+NeqDpk5+iNdMDXrh3z3nPnH1Wvgk= +github.com/bahlo/generic-list-go v0.2.0/go.mod h1:2KvAjgMlE5NNynlg/5iLrrCCZ2+5xWbdbCW3pNTGyYg= +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= @@ -6,10 +10,15 @@ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/invopop/jsonschema v0.13.0 h1:KvpoAJWEjR3uD9Kbm2HWJmqsEaHt8lBUpd0qHcIi21E= +github.com/invopop/jsonschema v0.13.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uON45H2qjYt+0= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= @@ -18,6 +27,8 @@ github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= +github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= diff --git a/mcp/errors.go b/mcp/errors.go new file mode 100644 index 000000000..01888bf5b --- /dev/null +++ b/mcp/errors.go @@ -0,0 +1,25 @@ +package mcp + +import "fmt" + +// UnsupportedProtocolVersionError is returned when the server responds with +// a protocol version that the client doesn't support. +type UnsupportedProtocolVersionError struct { + Version string +} + +func (e UnsupportedProtocolVersionError) Error() string { + return fmt.Sprintf("unsupported protocol version: %q", e.Version) +} + +// Is implements the errors.Is interface for better error handling +func (e UnsupportedProtocolVersionError) Is(target error) bool { + _, ok := target.(UnsupportedProtocolVersionError) + return ok +} + +// IsUnsupportedProtocolVersion checks if an error is an UnsupportedProtocolVersionError +func IsUnsupportedProtocolVersion(err error) bool { + _, ok := err.(UnsupportedProtocolVersionError) + return ok +} diff --git a/mcp/prompts.go b/mcp/prompts.go index a63a21450..9b0b48ed2 100644 --- a/mcp/prompts.go +++ b/mcp/prompts.go @@ -1,11 +1,14 @@ package mcp +import "net/http" + /* Prompts */ // ListPromptsRequest is sent from the client to request a list of prompts and // prompt templates the server has. type ListPromptsRequest struct { PaginatedRequest + Header http.Header `json:"-"` } // ListPromptsResult is the server's response to a prompts/list request from @@ -20,6 +23,7 @@ type ListPromptsResult struct { type GetPromptRequest struct { Request Params GetPromptParams `json:"params"` + Header http.Header `json:"-"` } type GetPromptParams struct { @@ -43,6 +47,8 @@ type GetPromptResult struct { // that requires argument values to be provided when calling prompts/get. // If Arguments is nil or empty, this is a static prompt that takes no arguments. type Prompt struct { + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` // The name of the prompt or prompt template. Name string `json:"name"` // An optional description of what this prompt provides diff --git a/mcp/tools.go b/mcp/tools.go index 5f3524b02..500503e2a 100644 --- a/mcp/tools.go +++ b/mcp/tools.go @@ -4,8 +4,11 @@ import ( "encoding/json" "errors" "fmt" + "net/http" "reflect" "strconv" + + "github.com/invopop/jsonschema" ) var errToolSchemaConflict = errors.New("provide either InputSchema or RawInputSchema, not both") @@ -14,6 +17,7 @@ var errToolSchemaConflict = errors.New("provide either InputSchema or RawInputSc // server has. type ListToolsRequest struct { PaginatedRequest + Header http.Header `json:"-"` } // ListToolsResult is the server's response to a tools/list request from the @@ -36,6 +40,10 @@ type ListToolsResult struct { type CallToolResult struct { Result Content []Content `json:"content"` // Can be TextContent, ImageContent, AudioContent, or EmbeddedResource + // Structured content returned as a JSON object in the structuredContent field of a result. + // For backwards compatibility, a tool that returns structured content SHOULD also return + // functionally equivalent unstructured content. + StructuredContent any `json:"structuredContent,omitempty"` // Whether the tool call ended in an error. // // If not set, this is assumed to be false (the call was successful). @@ -45,6 +53,7 @@ type CallToolResult struct { // CallToolRequest is used by the client to invoke a tool provided by the server. type CallToolRequest struct { Request + Header http.Header `json:"-"` // HTTP headers from the original request Params CallToolParams `json:"params"` } @@ -461,6 +470,72 @@ func (r CallToolRequest) RequireBoolSlice(key string) ([]bool, error) { return nil, fmt.Errorf("required argument %q not found", key) } +// MarshalJSON implements custom JSON marshaling for CallToolResult +func (r CallToolResult) MarshalJSON() ([]byte, error) { + m := make(map[string]any) + + // Marshal Meta if present + if r.Meta != nil { + m["_meta"] = r.Meta + } + + // Marshal Content array + content := make([]any, len(r.Content)) + for i, c := range r.Content { + content[i] = c + } + m["content"] = content + + // Marshal IsError if true + if r.IsError { + m["isError"] = r.IsError + } + + return json.Marshal(m) +} + +// UnmarshalJSON implements custom JSON unmarshaling for CallToolResult +func (r *CallToolResult) UnmarshalJSON(data []byte) error { + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + // Unmarshal Meta + if meta, ok := raw["_meta"]; ok { + if metaMap, ok := meta.(map[string]any); ok { + r.Meta = NewMetaFromMap(metaMap) + } + } + + // Unmarshal Content array + if contentRaw, ok := raw["content"]; ok { + if contentArray, ok := contentRaw.([]any); ok { + r.Content = make([]Content, len(contentArray)) + for i, item := range contentArray { + itemBytes, err := json.Marshal(item) + if err != nil { + return err + } + content, err := UnmarshalContent(itemBytes) + if err != nil { + return err + } + r.Content[i] = content + } + } + } + + // Unmarshal IsError + if isError, ok := raw["isError"]; ok { + if isErrorBool, ok := isError.(bool); ok { + r.IsError = isErrorBool + } + } + + return nil +} + // ToolListChangedNotification is an optional notification from the server to // the client, informing it that the list of tools it offers has changed. This may // be issued by servers without any previous subscription from the client. @@ -470,6 +545,8 @@ type ToolListChangedNotification struct { // Tool represents the definition for a tool the client can call. type Tool struct { + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` // The name of the tool. Name string `json:"name"` // A human-readable description of the tool. @@ -478,6 +555,8 @@ type Tool struct { InputSchema ToolInputSchema `json:"inputSchema"` // Alternative to InputSchema - allows arbitrary JSON Schema to be provided RawInputSchema json.RawMessage `json:"-"` // Hide this from JSON marshaling + // Optional JSON Schema defining expected output structure + RawOutputSchema json.RawMessage `json:"-"` // Hide this from JSON marshaling // Optional properties describing tool behavior Annotations ToolAnnotation `json:"annotations"` } @@ -491,7 +570,7 @@ func (t Tool) GetName() string { // It handles marshaling either InputSchema or RawInputSchema based on which is set. func (t Tool) MarshalJSON() ([]byte, error) { // Create a map to build the JSON structure - m := make(map[string]any, 3) + m := make(map[string]any, 5) // Add the name and description m["name"] = t.Name @@ -499,7 +578,7 @@ func (t Tool) MarshalJSON() ([]byte, error) { m["description"] = t.Description } - // Determine which schema to use + // Determine which input schema to use if t.RawInputSchema != nil { if t.InputSchema.Type != "" { return nil, fmt.Errorf("tool %s has both InputSchema and RawInputSchema set: %w", t.Name, errToolSchemaConflict) @@ -510,12 +589,18 @@ func (t Tool) MarshalJSON() ([]byte, error) { m["inputSchema"] = t.InputSchema } + // Add output schema if present + if t.RawOutputSchema != nil { + m["outputSchema"] = t.RawOutputSchema + } + m["annotations"] = t.Annotations return json.Marshal(m) } type ToolInputSchema struct { + Defs map[string]any `json:"$defs,omitempty"` Type string `json:"type"` Properties map[string]any `json:"properties,omitempty"` Required []string `json:"required,omitempty"` @@ -526,6 +611,10 @@ func (tis ToolInputSchema) MarshalJSON() ([]byte, error) { m := make(map[string]any) m["type"] = tis.Type + if tis.Defs != nil { + m["$defs"] = tis.Defs + } + // Marshal Properties to '{}' rather than `nil` when its length equals zero if tis.Properties != nil { m["properties"] = tis.Properties @@ -615,6 +704,46 @@ func WithDescription(description string) ToolOption { } } +// WithOutputSchema creates a ToolOption that sets the output schema for a tool. +// It accepts any Go type, usually a struct, and automatically generates a JSON schema from it. +func WithOutputSchema[T any]() ToolOption { + return func(t *Tool) { + var zero T + + // Generate schema using invopop/jsonschema library + // Configure reflector to generate clean, MCP-compatible schemas + reflector := jsonschema.Reflector{ + DoNotReference: true, // Removes $defs map, outputs entire structure inline + Anonymous: true, // Hides auto-generated Schema IDs + AllowAdditionalProperties: true, // Removes additionalProperties: false + } + schema := reflector.Reflect(zero) + + // Clean up schema for MCP compliance + schema.Version = "" // Remove $schema field + + // Convert to raw JSON for MCP + mcpSchema, err := json.Marshal(schema) + if err != nil { + // Skip and maintain backward compatibility + return + } + + t.RawOutputSchema = json.RawMessage(mcpSchema) + } +} + +// WithRawOutputSchema sets a raw JSON schema for the tool's output. +// Use this when you need full control over the schema or when working with +// complex schemas that can't be generated from Go types. The jsonschema library +// can handle complex schemas and provides nice extension points, so be sure to +// check that out before using this. +func WithRawOutputSchema(schema json.RawMessage) ToolOption { + return func(t *Tool) { + t.RawOutputSchema = schema + } +} + // WithToolAnnotation adds optional hints about the Tool. func WithToolAnnotation(annotation ToolAnnotation) ToolOption { return func(t *Tool) { @@ -945,7 +1074,20 @@ func PropertyNames(schema map[string]any) PropertyOption { } } -// Items defines the schema for array items +// Items defines the schema for array items. +// Accepts any schema definition for maximum flexibility. +// +// Example: +// +// Items(map[string]any{ +// "type": "object", +// "properties": map[string]any{ +// "name": map[string]any{"type": "string"}, +// "age": map[string]any{"type": "number"}, +// }, +// }) +// +// For simple types, use ItemsString(), ItemsNumber(), ItemsBoolean() instead. func Items(schema any) PropertyOption { return func(schemaMap map[string]any) { schemaMap["items"] = schema @@ -972,3 +1114,94 @@ func UniqueItems(unique bool) PropertyOption { schema["uniqueItems"] = unique } } + +// WithStringItems configures an array's items to be of type string. +// +// Supported options: Description(), DefaultString(), Enum(), MaxLength(), MinLength(), Pattern() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("tags", mcp.WithStringItems()) +// mcp.WithArray("colors", mcp.WithStringItems(mcp.Enum("red", "green", "blue"))) +// mcp.WithArray("names", mcp.WithStringItems(mcp.MinLength(1), mcp.MaxLength(50))) +// +// Limitations: Only supports simple string arrays. Use Items() for complex objects. +func WithStringItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "string", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} + +// WithStringEnumItems configures an array's items to be of type string with a specified enum. +// Example: +// +// mcp.WithArray("priority", mcp.WithStringEnumItems([]string{"low", "medium", "high"})) +// +// Limitations: Only supports string enums. Use WithStringItems(Enum(...)) for more flexibility. +func WithStringEnumItems(values []string) PropertyOption { + return func(schema map[string]any) { + schema["items"] = map[string]any{ + "type": "string", + "enum": values, + } + } +} + +// WithNumberItems configures an array's items to be of type number. +// +// Supported options: Description(), DefaultNumber(), Min(), Max(), MultipleOf() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("scores", mcp.WithNumberItems(mcp.Min(0), mcp.Max(100))) +// mcp.WithArray("prices", mcp.WithNumberItems(mcp.Min(0))) +// +// Limitations: Only supports simple number arrays. Use Items() for complex objects. +func WithNumberItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "number", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} + +// WithBooleanItems configures an array's items to be of type boolean. +// +// Supported options: Description(), DefaultBool() +// Note: Options like Required() are not valid for item schemas and will be ignored. +// +// Examples: +// +// mcp.WithArray("flags", mcp.WithBooleanItems()) +// mcp.WithArray("permissions", mcp.WithBooleanItems(mcp.Description("User permissions"))) +// +// Limitations: Only supports simple boolean arrays. Use Items() for complex objects. +func WithBooleanItems(opts ...PropertyOption) PropertyOption { + return func(schema map[string]any) { + itemSchema := map[string]any{ + "type": "boolean", + } + + for _, opt := range opts { + opt(itemSchema) + } + + schema["items"] = itemSchema + } +} diff --git a/mcp/tools_test.go b/mcp/tools_test.go index 7f2640b94..7beec31dd 100644 --- a/mcp/tools_test.go +++ b/mcp/tools_test.go @@ -528,3 +528,238 @@ func TestFlexibleArgumentsJSONMarshalUnmarshal(t *testing.T) { assert.Equal(t, "value1", args["key1"]) assert.Equal(t, float64(123), args["key2"]) // JSON numbers are unmarshaled as float64 } + +// TestToolWithOutputSchema tests that the WithOutputSchema function +// generates an MCP-compatible JSON output schema for a tool +func TestToolWithOutputSchema(t *testing.T) { + type TestOutput struct { + Name string `json:"name" jsonschema_description:"Person's name"` + Age int `json:"age" jsonschema_description:"Person's age"` + Email string `json:"email,omitempty" jsonschema_description:"Email address"` + } + + tool := NewTool("test_tool", + WithDescription("Test tool with output schema"), + WithOutputSchema[TestOutput](), + WithString("input", Required()), + ) + + // Check that RawOutputSchema was set + assert.NotNil(t, tool.RawOutputSchema) + + // Marshal and verify structure + data, err := json.Marshal(tool) + assert.NoError(t, err) + + var toolData map[string]any + err = json.Unmarshal(data, &toolData) + assert.NoError(t, err) + + // Verify outputSchema exists + outputSchema, exists := toolData["outputSchema"] + assert.True(t, exists) + assert.NotNil(t, outputSchema) +} + +// TestNewToolResultStructured tests that the NewToolResultStructured function +// creates a CallToolResult with both structured and text content +func TestNewToolResultStructured(t *testing.T) { + testData := map[string]any{ + "message": "Success", + "count": 42, + "active": true, + } + + result := NewToolResultStructured(testData, "Fallback text") + + assert.Len(t, result.Content, 1) + + textContent, ok := result.Content[0].(TextContent) + assert.True(t, ok) + assert.Equal(t, "Fallback text", textContent.Text) + assert.NotNil(t, result.StructuredContent) +} + +// TestNewItemsAPICompatibility tests that the new Items API functions +// generate the same schema as the original Items() function with manual schema objects +func TestNewItemsAPICompatibility(t *testing.T) { + tests := []struct { + name string + oldTool Tool + newTool Tool + }{ + { + name: "WithStringItems basic", + oldTool: NewTool("old-string-array", + WithDescription("Tool with string array using old API"), + WithArray("items", + Description("List of string items"), + Items(map[string]any{ + "type": "string", + }), + ), + ), + newTool: NewTool("new-string-array", + WithDescription("Tool with string array using new API"), + WithArray("items", + Description("List of string items"), + WithStringItems(), + ), + ), + }, + { + name: "WithStringEnumItems", + oldTool: NewTool("old-enum-array", + WithDescription("Tool with enum array using old API"), + WithArray("status", + Description("Filter by status"), + Items(map[string]any{ + "type": "string", + "enum": []string{"active", "inactive", "pending"}, + }), + ), + ), + newTool: NewTool("new-enum-array", + WithDescription("Tool with enum array using new API"), + WithArray("status", + Description("Filter by status"), + WithStringEnumItems([]string{"active", "inactive", "pending"}), + ), + ), + }, + { + name: "WithStringItems with options", + oldTool: NewTool("old-string-with-opts", + WithDescription("Tool with string array with options using old API"), + WithArray("names", + Description("List of names"), + Items(map[string]any{ + "type": "string", + "minLength": 1, + "maxLength": 50, + }), + ), + ), + newTool: NewTool("new-string-with-opts", + WithDescription("Tool with string array with options using new API"), + WithArray("names", + Description("List of names"), + WithStringItems(MinLength(1), MaxLength(50)), + ), + ), + }, + { + name: "WithNumberItems basic", + oldTool: NewTool("old-number-array", + WithDescription("Tool with number array using old API"), + WithArray("scores", + Description("List of scores"), + Items(map[string]any{ + "type": "number", + }), + ), + ), + newTool: NewTool("new-number-array", + WithDescription("Tool with number array using new API"), + WithArray("scores", + Description("List of scores"), + WithNumberItems(), + ), + ), + }, + { + name: "WithNumberItems with constraints", + oldTool: NewTool("old-number-with-constraints", + WithDescription("Tool with constrained number array using old API"), + WithArray("ratings", + Description("List of ratings"), + Items(map[string]any{ + "type": "number", + "minimum": 0.0, + "maximum": 10.0, + }), + ), + ), + newTool: NewTool("new-number-with-constraints", + WithDescription("Tool with constrained number array using new API"), + WithArray("ratings", + Description("List of ratings"), + WithNumberItems(Min(0), Max(10)), + ), + ), + }, + { + name: "WithBooleanItems basic", + oldTool: NewTool("old-boolean-array", + WithDescription("Tool with boolean array using old API"), + WithArray("flags", + Description("List of feature flags"), + Items(map[string]any{ + "type": "boolean", + }), + ), + ), + newTool: NewTool("new-boolean-array", + WithDescription("Tool with boolean array using new API"), + WithArray("flags", + Description("List of feature flags"), + WithBooleanItems(), + ), + ), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Marshal both tools to JSON + oldData, err := json.Marshal(tt.oldTool) + assert.NoError(t, err) + + newData, err := json.Marshal(tt.newTool) + assert.NoError(t, err) + + // Unmarshal to maps for comparison + var oldResult, newResult map[string]any + err = json.Unmarshal(oldData, &oldResult) + assert.NoError(t, err) + + err = json.Unmarshal(newData, &newResult) + assert.NoError(t, err) + + // Compare the inputSchema properties (ignoring tool names and descriptions) + oldSchema := oldResult["inputSchema"].(map[string]any) + newSchema := newResult["inputSchema"].(map[string]any) + + oldProperties := oldSchema["properties"].(map[string]any) + newProperties := newSchema["properties"].(map[string]any) + + // Get the array property (should be the only one in these tests) + var oldArrayProp, newArrayProp map[string]any + for _, prop := range oldProperties { + if propMap, ok := prop.(map[string]any); ok && propMap["type"] == "array" { + oldArrayProp = propMap + break + } + } + for _, prop := range newProperties { + if propMap, ok := prop.(map[string]any); ok && propMap["type"] == "array" { + newArrayProp = propMap + break + } + } + + assert.NotNil(t, oldArrayProp, "Old tool should have array property") + assert.NotNil(t, newArrayProp, "New tool should have array property") + + // Compare the items schema - this is the critical part + oldItems := oldArrayProp["items"] + newItems := newArrayProp["items"] + + assert.Equal(t, oldItems, newItems, "Items schema should be identical between old and new API") + + // Also compare other array properties like description + assert.Equal(t, oldArrayProp["description"], newArrayProp["description"], "Array descriptions should match") + assert.Equal(t, oldArrayProp["type"], newArrayProp["type"], "Array types should match") + }) + } +} diff --git a/mcp/typed_tools.go b/mcp/typed_tools.go index 68d8cdd1f..a03a19dd7 100644 --- a/mcp/typed_tools.go +++ b/mcp/typed_tools.go @@ -8,6 +8,9 @@ import ( // TypedToolHandlerFunc is a function that handles a tool call with typed arguments type TypedToolHandlerFunc[T any] func(ctx context.Context, request CallToolRequest, args T) (*CallToolResult, error) +// StructuredToolHandlerFunc is a function that handles a tool call with typed arguments and returns structured output +type StructuredToolHandlerFunc[TArgs any, TResult any] func(ctx context.Context, request CallToolRequest, args TArgs) (TResult, error) + // NewTypedToolHandler creates a ToolHandlerFunc that automatically binds arguments to a typed struct func NewTypedToolHandler[T any](handler TypedToolHandlerFunc[T]) func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { return func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { @@ -18,3 +21,22 @@ func NewTypedToolHandler[T any](handler TypedToolHandlerFunc[T]) func(ctx contex return handler(ctx, request, args) } } + +// NewStructuredToolHandler creates a ToolHandlerFunc that automatically binds arguments to a typed struct +// and returns structured output. It automatically creates both structured and +// text content (from the structured output) for backwards compatibility. +func NewStructuredToolHandler[TArgs any, TResult any](handler StructuredToolHandlerFunc[TArgs, TResult]) func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { + return func(ctx context.Context, request CallToolRequest) (*CallToolResult, error) { + var args TArgs + if err := request.BindArguments(&args); err != nil { + return NewToolResultError(fmt.Sprintf("failed to bind arguments: %v", err)), nil + } + + result, err := handler(ctx, request, args) + if err != nil { + return NewToolResultError(fmt.Sprintf("tool execution failed: %v", err)), nil + } + + return NewToolResultStructuredOnly(result), nil + } +} diff --git a/mcp/types.go b/mcp/types.go index 0091d2e42..344924992 100644 --- a/mcp/types.go +++ b/mcp/types.go @@ -9,6 +9,7 @@ import ( "strconv" "github.com/yosida95/uritemplate/v3" + "net/http" ) type MCPMethod string @@ -96,12 +97,13 @@ func (t *URITemplate) UnmarshalJSON(data []byte) error { type JSONRPCMessage any // LATEST_PROTOCOL_VERSION is the most recent version of the MCP protocol. -const LATEST_PROTOCOL_VERSION = "2025-03-26" +const LATEST_PROTOCOL_VERSION = "2025-06-18" // ValidProtocolVersions lists all known valid MCP protocol versions. var ValidProtocolVersions = []string{ - "2024-11-05", LATEST_PROTOCOL_VERSION, + "2025-03-26", + "2024-11-05", } // JSONRPC_VERSION is the version of JSON-RPC used by MCP. @@ -150,6 +152,18 @@ func (m *Meta) UnmarshalJSON(data []byte) error { return nil } +func NewMetaFromMap(m map[string]any) *Meta { + progressToken := m["progressToken"] + if progressToken != nil { + delete(m, "progressToken") + } + + return &Meta{ + ProgressToken: progressToken, + AdditionalFields: m, + } +} + type Request struct { Method string `json:"method"` Params RequestParams `json:"params,omitempty"` @@ -231,7 +245,7 @@ func (p *NotificationParams) UnmarshalJSON(data []byte) error { type Result struct { // This result property is reserved by the protocol to allow clients and // servers to attach additional metadata to their responses. - Meta map[string]any `json:"_meta,omitempty"` + Meta *Meta `json:"_meta,omitempty"` } // RequestId is a uniquely identifying ID for a request in JSON-RPC. @@ -399,6 +413,7 @@ type CancelledNotificationParams struct { type InitializeRequest struct { Request Params InitializeParams `json:"params"` + Header http.Header `json:"-"` } type InitializeParams struct { @@ -469,6 +484,8 @@ type ServerCapabilities struct { // list. ListChanged bool `json:"listChanged,omitempty"` } `json:"resources,omitempty"` + // Present if the server supports sending sampling requests to clients. + Sampling *struct{} `json:"sampling,omitempty"` // Present if the server offers any tools to call. Tools *struct { // Whether this server supports notifications for changes to the tool list. @@ -489,6 +506,7 @@ type Implementation struct { // or else may be disconnected. type PingRequest struct { Request + Header http.Header `json:"-"` } /* Progress notifications */ @@ -541,6 +559,7 @@ type PaginatedResult struct { // the server has. type ListResourcesRequest struct { PaginatedRequest + Header http.Header `json:"-"` } // ListResourcesResult is the server's response to a resources/list request @@ -554,6 +573,7 @@ type ListResourcesResult struct { // resource templates the server has. type ListResourceTemplatesRequest struct { PaginatedRequest + Header http.Header `json:"-"` } // ListResourceTemplatesResult is the server's response to a @@ -567,6 +587,7 @@ type ListResourceTemplatesResult struct { // specific resource URI. type ReadResourceRequest struct { Request + Header http.Header `json:"-"` Params ReadResourceParams `json:"params"` } @@ -598,6 +619,7 @@ type ResourceListChangedNotification struct { type SubscribeRequest struct { Request Params SubscribeParams `json:"params"` + Header http.Header `json:"-"` } type SubscribeParams struct { @@ -612,6 +634,7 @@ type SubscribeParams struct { type UnsubscribeRequest struct { Request Params UnsubscribeParams `json:"params"` + Header http.Header `json:"-"` } type UnsubscribeParams struct { @@ -635,6 +658,8 @@ type ResourceUpdatedNotificationParams struct { // Resource represents a known resource that the server is capable of reading. type Resource struct { Annotated + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` // The URI of this resource. URI string `json:"uri"` // A human-readable name for this resource. @@ -659,6 +684,8 @@ func (r Resource) GetName() string { // on the server. type ResourceTemplate struct { Annotated + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` // A URI template (according to RFC 6570) that can be used to construct // resource URIs. URITemplate *URITemplate `json:"uriTemplate"` @@ -688,6 +715,8 @@ type ResourceContents interface { } type TextResourceContents struct { + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` // The URI of this resource. URI string `json:"uri"` // The MIME type of this resource, if known. @@ -700,6 +729,8 @@ type TextResourceContents struct { func (TextResourceContents) isResourceContents() {} type BlobResourceContents struct { + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` // The URI of this resource. URI string `json:"uri"` // The MIME type of this resource, if known. @@ -717,6 +748,7 @@ func (BlobResourceContents) isResourceContents() {} type SetLevelRequest struct { Request Params SetLevelParams `json:"params"` + Header http.Header `json:"-"` } type SetLevelParams struct { @@ -761,8 +793,33 @@ const ( LoggingLevelEmergency LoggingLevel = "emergency" ) +var levelToInt = map[LoggingLevel]int{ + LoggingLevelDebug: 0, + LoggingLevelInfo: 1, + LoggingLevelNotice: 2, + LoggingLevelWarning: 3, + LoggingLevelError: 4, + LoggingLevelCritical: 5, + LoggingLevelAlert: 6, + LoggingLevelEmergency: 7, +} + +func (l LoggingLevel) ShouldSendTo(minLevel LoggingLevel) bool { + ia, oka := levelToInt[l] + ib, okb := levelToInt[minLevel] + if !oka || !okb { + return false + } + return ia >= ib +} + /* Sampling */ +const ( + // MethodSamplingCreateMessage allows servers to request LLM completions from clients + MethodSamplingCreateMessage MCPMethod = "sampling/createMessage" +) + // CreateMessageRequest is a request from the server to sample an LLM via the // client. The client has full discretion over which model to select. The client // should also inform the user before beginning sampling, to allow them to inspect @@ -832,6 +889,8 @@ type Content interface { // It must have Type set to "text". type TextContent struct { Annotated + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` Type string `json:"type"` // Must be "text" // The text content of the message. Text string `json:"text"` @@ -843,6 +902,8 @@ func (TextContent) isContent() {} // It must have Type set to "image". type ImageContent struct { Annotated + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` Type string `json:"type"` // Must be "image" // The base64-encoded image data. Data string `json:"data"` @@ -856,6 +917,8 @@ func (ImageContent) isContent() {} // It must have Type set to "audio". type AudioContent struct { Annotated + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` Type string `json:"type"` // Must be "audio" // The base64-encoded audio data. Data string `json:"data"` @@ -865,12 +928,30 @@ type AudioContent struct { func (AudioContent) isContent() {} +// ResourceLink represents a link to a resource that the client can access. +type ResourceLink struct { + Annotated + Type string `json:"type"` // Must be "resource_link" + // The URI of the resource. + URI string `json:"uri"` + // The name of the resource. + Name string `json:"name"` + // The description of the resource. + Description string `json:"description"` + // The MIME type of the resource. + MIMEType string `json:"mimeType"` +} + +func (ResourceLink) isContent() {} + // EmbeddedResource represents the contents of a resource, embedded into a prompt or tool call result. // // It is up to the client how best to render embedded resources for the // benefit of the LLM and/or the user. type EmbeddedResource struct { Annotated + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` Type string `json:"type"` Resource ResourceContents `json:"resource"` } @@ -939,6 +1020,7 @@ type ModelHint struct { type CompleteRequest struct { Request Params CompleteParams `json:"params"` + Header http.Header `json:"-"` } type CompleteParams struct { @@ -991,6 +1073,7 @@ type PromptReference struct { // structure or access specific locations that the client has permission to read from. type ListRootsRequest struct { Request + Header http.Header `json:"-"` } // ListRootsResult is the client's response to a roots/list request from the server. @@ -1003,6 +1086,8 @@ type ListRootsResult struct { // Root represents a root directory or file that the server can operate on. type Root struct { + // Meta is a metadata object that is reserved by MCP for storing additional information. + Meta *Meta `json:"_meta,omitempty"` // The URI identifying the root. This *must* start with file:// for now. // This restriction may be relaxed in future versions of the protocol to allow // other URI schemes. @@ -1042,3 +1127,46 @@ type ServerResult any type Named interface { GetName() string } + +// MarshalJSON implements custom JSON marshaling for Content interface +func MarshalContent(content Content) ([]byte, error) { + return json.Marshal(content) +} + +// UnmarshalContent implements custom JSON unmarshaling for Content interface +func UnmarshalContent(data []byte) (Content, error) { + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return nil, err + } + + contentType, ok := raw["type"].(string) + if !ok { + return nil, fmt.Errorf("missing or invalid type field") + } + + switch contentType { + case "text": + var content TextContent + err := json.Unmarshal(data, &content) + return content, err + case "image": + var content ImageContent + err := json.Unmarshal(data, &content) + return content, err + case "audio": + var content AudioContent + err := json.Unmarshal(data, &content) + return content, err + case "resource_link": + var content ResourceLink + err := json.Unmarshal(data, &content) + return content, err + case "resource": + var content EmbeddedResource + err := json.Unmarshal(data, &content) + return content, err + default: + return nil, fmt.Errorf("unknown content type: %s", contentType) + } +} diff --git a/mcp/types_test.go b/mcp/types_test.go index 526e1ac1e..c1453de60 100644 --- a/mcp/types_test.go +++ b/mcp/types_test.go @@ -68,3 +68,73 @@ func TestMetaMarshalling(t *testing.T) { }) } } + +func TestResourceLinkSerialization(t *testing.T) { + resourceLink := NewResourceLink( + "file:///example/document.pdf", + "Sample Document", + "A sample document for testing", + "application/pdf", + ) + + // Test marshaling + data, err := json.Marshal(resourceLink) + require.NoError(t, err) + + // Test unmarshaling + var unmarshaled ResourceLink + err = json.Unmarshal(data, &unmarshaled) + require.NoError(t, err) + + // Verify fields + assert.Equal(t, "resource_link", unmarshaled.Type) + assert.Equal(t, "file:///example/document.pdf", unmarshaled.URI) + assert.Equal(t, "Sample Document", unmarshaled.Name) + assert.Equal(t, "A sample document for testing", unmarshaled.Description) + assert.Equal(t, "application/pdf", unmarshaled.MIMEType) +} + +func TestCallToolResultWithResourceLink(t *testing.T) { + result := &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: "Here's a resource link:", + }, + NewResourceLink( + "file:///example/test.pdf", + "Test Document", + "A test document", + "application/pdf", + ), + }, + IsError: false, + } + + // Test marshaling + data, err := json.Marshal(result) + require.NoError(t, err) + + // Test unmarshalling + var unmarshalled CallToolResult + err = json.Unmarshal(data, &unmarshalled) + require.NoError(t, err) + + // Verify content + require.Len(t, unmarshalled.Content, 2) + + // Check first content (TextContent) + textContent, ok := unmarshalled.Content[0].(TextContent) + require.True(t, ok) + assert.Equal(t, "text", textContent.Type) + assert.Equal(t, "Here's a resource link:", textContent.Text) + + // Check second content (ResourceLink) + resourceLink, ok := unmarshalled.Content[1].(ResourceLink) + require.True(t, ok) + assert.Equal(t, "resource_link", resourceLink.Type) + assert.Equal(t, "file:///example/test.pdf", resourceLink.URI) + assert.Equal(t, "Test Document", resourceLink.Name) + assert.Equal(t, "A test document", resourceLink.Description) + assert.Equal(t, "application/pdf", resourceLink.MIMEType) +} diff --git a/mcp/utils.go b/mcp/utils.go index 55bef7a99..4d2b170b4 100644 --- a/mcp/utils.go +++ b/mcp/utils.go @@ -222,6 +222,17 @@ func NewAudioContent(data, mimeType string) AudioContent { } } +// Helper function to create a new ResourceLink +func NewResourceLink(uri, name, description, mimeType string) ResourceLink { + return ResourceLink{ + Type: "resource_link", + URI: uri, + Name: name, + Description: description, + MIMEType: mimeType, + } +} + // Helper function to create a new EmbeddedResource func NewEmbeddedResource(resource ResourceContents) EmbeddedResource { return EmbeddedResource{ @@ -242,6 +253,44 @@ func NewToolResultText(text string) *CallToolResult { } } +// NewToolResultStructured creates a new CallToolResult with structured content. +// It includes both the structured content and a text representation for backward compatibility. +func NewToolResultStructured(structured any, fallbackText string) *CallToolResult { + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: fallbackText, + }, + }, + StructuredContent: structured, + } +} + +// NewToolResultStructuredOnly creates a new CallToolResult with structured +// content and creates a JSON string fallback for backwards compatibility. +// This is useful when you want to provide structured data without any specific text fallback. +func NewToolResultStructuredOnly(structured any) *CallToolResult { + var fallbackText string + // Convert to JSON string for backward compatibility + jsonBytes, err := json.Marshal(structured) + if err != nil { + fallbackText = fmt.Sprintf("Error serializing structured content: %v", err) + } else { + fallbackText = string(jsonBytes) + } + + return &CallToolResult{ + Content: []Content{ + TextContent{ + Type: "text", + Text: fallbackText, + }, + }, + StructuredContent: structured, + } +} + // NewToolResultImage creates a new CallToolResult with both text and image content func NewToolResultImage(text, imageData, mimeType string) *CallToolResult { return &CallToolResult{ @@ -476,6 +525,16 @@ func ParseContent(contentMap map[string]any) (Content, error) { } return NewAudioContent(data, mimeType), nil + case "resource_link": + uri := ExtractString(contentMap, "uri") + name := ExtractString(contentMap, "name") + description := ExtractString(contentMap, "description") + mimeType := ExtractString(contentMap, "mimeType") + if uri == "" || name == "" { + return nil, fmt.Errorf("resource_link uri or name is missing") + } + return NewResourceLink(uri, name, description, mimeType), nil + case "resource": resourceMap := ExtractMap(contentMap, "resource") if resourceMap == nil { @@ -508,7 +567,7 @@ func ParseGetPromptResult(rawMessage *json.RawMessage) (*GetPromptResult, error) meta, ok := jsonContent["_meta"] if ok { if metaMap, ok := meta.(map[string]any); ok { - result.Meta = metaMap + result.Meta = NewMetaFromMap(metaMap) } } @@ -574,7 +633,7 @@ func ParseCallToolResult(rawMessage *json.RawMessage) (*CallToolResult, error) { meta, ok := jsonContent["_meta"] if ok { if metaMap, ok := meta.(map[string]any); ok { - result.Meta = metaMap + result.Meta = NewMetaFromMap(metaMap) } } @@ -656,7 +715,7 @@ func ParseReadResourceResult(rawMessage *json.RawMessage) (*ReadResourceResult, meta, ok := jsonContent["_meta"] if ok { if metaMap, ok := meta.(map[string]any); ok { - result.Meta = metaMap + result.Meta = NewMetaFromMap(metaMap) } } diff --git a/mcptest/mcptest.go b/mcptest/mcptest.go index 232eac5df..df85753f2 100644 --- a/mcptest/mcptest.go +++ b/mcptest/mcptest.go @@ -20,9 +20,10 @@ import ( type Server struct { name string - tools []server.ServerTool - prompts []server.ServerPrompt - resources []server.ServerResource + tools []server.ServerTool + prompts []server.ServerPrompt + resources []server.ServerResource + resourceTemplates []server.ServerResourceTemplate cancel func() @@ -106,6 +107,19 @@ func (s *Server) AddResources(resources ...server.ServerResource) { s.resources = append(s.resources, resources...) } +// AddResourceTemplate adds a resource template to an unstarted server. +func (s *Server) AddResourceTemplate(template mcp.ResourceTemplate, handler server.ResourceTemplateHandlerFunc) { + s.resourceTemplates = append(s.resourceTemplates, server.ServerResourceTemplate{ + Template: template, + Handler: handler, + }) +} + +// AddResourceTemplates adds multiple resource templates to an unstarted server. +func (s *Server) AddResourceTemplates(templates ...server.ServerResourceTemplate) { + s.resourceTemplates = append(s.resourceTemplates, templates...) +} + // Start starts the server in a goroutine. Make sure to defer Close() after Start(). // When using NewServer(), the returned server is already started. func (s *Server) Start(ctx context.Context) error { @@ -122,6 +136,7 @@ func (s *Server) Start(ctx context.Context) error { mcpServer.AddTools(s.tools...) mcpServer.AddPrompts(s.prompts...) mcpServer.AddResources(s.resources...) + mcpServer.AddResourceTemplates(s.resourceTemplates...) logger := log.New(&s.logBuffer, "", 0) diff --git a/mcptest/mcptest_test.go b/mcptest/mcptest_test.go index 0ab9b276e..18922cb84 100644 --- a/mcptest/mcptest_test.go +++ b/mcptest/mcptest_test.go @@ -187,3 +187,79 @@ func TestServerWithResource(t *testing.T) { t.Errorf("Got %q, want %q", textContent.Text, want) } } + +func TestServerWithResourceTemplate(t *testing.T) { + ctx := context.Background() + + srv := mcptest.NewUnstartedServer(t) + defer srv.Close() + + template := mcp.NewResourceTemplate( + "file://users/{userId}/documents/{docId}", + "User Document", + mcp.WithTemplateDescription("A user's document"), + mcp.WithTemplateMIMEType("text/plain"), + ) + + handler := func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + if request.Params.Arguments == nil { + return nil, fmt.Errorf("expected arguments to be populated from URI template") + } + + userIds, ok := request.Params.Arguments["userId"].([]string) + if !ok { + return nil, fmt.Errorf("expected userId argument to be populated from URI template") + } + if len(userIds) != 1 { + return nil, fmt.Errorf("expected userId to have one value, but got %d", len(userIds)) + } + if userIds[0] != "john" { + return nil, fmt.Errorf("expected userId argument to be 'john', got %s", userIds[0]) + } + + docIds, ok := request.Params.Arguments["docId"].([]string) + if !ok { + return nil, fmt.Errorf("expected docId argument to be populated from URI template") + } + if len(docIds) != 1 { + return nil, fmt.Errorf("expected docId to have one value, but got %d", len(docIds)) + } + if docIds[0] != "readme.txt" { + return nil, fmt.Errorf("expected docId argument to be 'readme.txt', got %v", docIds) + } + + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: request.Params.URI, + MIMEType: "text/plain", + Text: fmt.Sprintf("Document %s for user %s", docIds[0], userIds[0]), + }, + }, nil + } + + srv.AddResourceTemplate(template, handler) + + err := srv.Start(ctx) + if err != nil { + t.Fatal(err) + } + + // Test reading a resource that matches the template + var readReq mcp.ReadResourceRequest + readReq.Params.URI = "file://users/john/documents/readme.txt" + readResult, err := srv.Client().ReadResource(ctx, readReq) + if err != nil { + t.Fatal("ReadResource:", err) + } + if len(readResult.Contents) != 1 { + t.Fatalf("Expected 1 content, got %d", len(readResult.Contents)) + } + textContent, ok := readResult.Contents[0].(mcp.TextResourceContents) + if !ok { + t.Fatalf("Expected TextResourceContents, got %T", readResult.Contents[0]) + } + want := "Document readme.txt for user john" + if textContent.Text != want { + t.Errorf("Got %q, want %q", textContent.Text, want) + } +} diff --git a/server/constants.go b/server/constants.go new file mode 100644 index 000000000..e071b2ef4 --- /dev/null +++ b/server/constants.go @@ -0,0 +1,7 @@ +package server + +// Common HTTP header constants used across server transports +const ( + HeaderKeySessionID = "Mcp-Session-Id" + HeaderKeyProtocolVersion = "Mcp-Protocol-Version" +) diff --git a/server/ctx.go b/server/ctx.go new file mode 100644 index 000000000..43f01bb68 --- /dev/null +++ b/server/ctx.go @@ -0,0 +1,8 @@ +package server + +type contextKey int + +const ( + // This const is used as key for context value lookup + requestHeader contextKey = iota +) diff --git a/server/errors.go b/server/errors.go index ecbe91e5f..3864f36f7 100644 --- a/server/errors.go +++ b/server/errors.go @@ -21,7 +21,7 @@ var ( // Notification-related errors ErrNotificationNotInitialized = errors.New("notification channel not initialized") - ErrNotificationChannelBlocked = errors.New("notification channel full or blocked") + ErrNotificationChannelBlocked = errors.New("notification channel queue is full - client may not be processing notifications fast enough") ) // ErrDynamicPathConfig is returned when attempting to use static path methods with dynamic path configuration diff --git a/server/inprocess_session.go b/server/inprocess_session.go new file mode 100644 index 000000000..daaf28a5c --- /dev/null +++ b/server/inprocess_session.go @@ -0,0 +1,115 @@ +package server + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// SamplingHandler defines the interface for handling sampling requests from servers. +type SamplingHandler interface { + CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) +} + +type InProcessSession struct { + sessionID string + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value + clientCapabilities atomic.Value + samplingHandler SamplingHandler + mu sync.RWMutex +} + +func NewInProcessSession(sessionID string, samplingHandler SamplingHandler) *InProcessSession { + return &InProcessSession{ + sessionID: sessionID, + notifications: make(chan mcp.JSONRPCNotification, 100), + samplingHandler: samplingHandler, + } +} + +func (s *InProcessSession) SessionID() string { + return s.sessionID +} + +func (s *InProcessSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return s.notifications +} + +func (s *InProcessSession) Initialize() { + s.loggingLevel.Store(mcp.LoggingLevelError) + s.initialized.Store(true) +} + +func (s *InProcessSession) Initialized() bool { + return s.initialized.Load() +} + +func (s *InProcessSession) GetClientInfo() mcp.Implementation { + if value := s.clientInfo.Load(); value != nil { + if clientInfo, ok := value.(mcp.Implementation); ok { + return clientInfo + } + } + return mcp.Implementation{} +} + +func (s *InProcessSession) SetClientInfo(clientInfo mcp.Implementation) { + s.clientInfo.Store(clientInfo) +} + +func (s *InProcessSession) GetClientCapabilities() mcp.ClientCapabilities { + if value := s.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + +func (s *InProcessSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + s.clientCapabilities.Store(clientCapabilities) +} + +func (s *InProcessSession) SetLogLevel(level mcp.LoggingLevel) { + s.loggingLevel.Store(level) +} + +func (s *InProcessSession) GetLogLevel() mcp.LoggingLevel { + level := s.loggingLevel.Load() + if level == nil { + return mcp.LoggingLevelError + } + return level.(mcp.LoggingLevel) +} + +func (s *InProcessSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + s.mu.RLock() + handler := s.samplingHandler + s.mu.RUnlock() + + if handler == nil { + return nil, fmt.Errorf("no sampling handler available") + } + + return handler.CreateMessage(ctx, request) +} + +// GenerateInProcessSessionID generates a unique session ID for inprocess clients +func GenerateInProcessSessionID() string { + return fmt.Sprintf("inprocess-%d", time.Now().UnixNano()) +} + +// Ensure interface compliance +var ( + _ ClientSession = (*InProcessSession)(nil) + _ SessionWithLogging = (*InProcessSession)(nil) + _ SessionWithClientInfo = (*InProcessSession)(nil) + _ SessionWithSampling = (*InProcessSession)(nil) +) diff --git a/server/internal/gen/request_handler.go.tmpl b/server/internal/gen/request_handler.go.tmpl index 7e4a68a05..70600f3d8 100644 --- a/server/internal/gen/request_handler.go.tmpl +++ b/server/internal/gen/request_handler.go.tmpl @@ -72,6 +72,14 @@ func (s *MCPServer) HandleMessage( ) } + // Get request header from ctx + h := ctx.Value(requestHeader) + headers, ok := h.(http.Header) + + if headers == nil || !ok { + headers = make(http.Header) + } + switch baseMessage.Method { {{- range .}} case mcp.{{.MethodName}}: @@ -90,6 +98,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = headers s.hooks.before{{.HookName}}(ctx, baseMessage.ID, &request) result, err = s.{{.HandlerFunc}}(ctx, baseMessage.ID, request) } diff --git a/server/request_handler.go b/server/request_handler.go index 25f6ef14f..b9175dc4e 100644 --- a/server/request_handler.go +++ b/server/request_handler.go @@ -6,6 +6,7 @@ import ( "context" "encoding/json" "fmt" + "net/http" "github.com/mark3labs/mcp-go/mcp" ) @@ -71,6 +72,14 @@ func (s *MCPServer) HandleMessage( ) } + // Get request header from ctx + h := ctx.Value(requestHeader) + headers, ok := h.(http.Header) + + if headers == nil || !ok { + headers = make(http.Header) + } + switch baseMessage.Method { case mcp.MethodInitialize: var request mcp.InitializeRequest @@ -82,6 +91,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = headers s.hooks.beforeInitialize(ctx, baseMessage.ID, &request) result, err = s.handleInitialize(ctx, baseMessage.ID, request) } @@ -101,6 +111,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = headers s.hooks.beforePing(ctx, baseMessage.ID, &request) result, err = s.handlePing(ctx, baseMessage.ID, request) } @@ -126,6 +137,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = headers s.hooks.beforeSetLevel(ctx, baseMessage.ID, &request) result, err = s.handleSetLevel(ctx, baseMessage.ID, request) } @@ -151,6 +163,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = headers s.hooks.beforeListResources(ctx, baseMessage.ID, &request) result, err = s.handleListResources(ctx, baseMessage.ID, request) } @@ -176,6 +189,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = headers s.hooks.beforeListResourceTemplates(ctx, baseMessage.ID, &request) result, err = s.handleListResourceTemplates(ctx, baseMessage.ID, request) } @@ -201,6 +215,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = headers s.hooks.beforeReadResource(ctx, baseMessage.ID, &request) result, err = s.handleReadResource(ctx, baseMessage.ID, request) } @@ -226,6 +241,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = headers s.hooks.beforeListPrompts(ctx, baseMessage.ID, &request) result, err = s.handleListPrompts(ctx, baseMessage.ID, request) } @@ -251,6 +267,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = headers s.hooks.beforeGetPrompt(ctx, baseMessage.ID, &request) result, err = s.handleGetPrompt(ctx, baseMessage.ID, request) } @@ -276,6 +293,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = headers s.hooks.beforeListTools(ctx, baseMessage.ID, &request) result, err = s.handleListTools(ctx, baseMessage.ID, request) } @@ -301,6 +319,7 @@ func (s *MCPServer) HandleMessage( err: &UnparsableMessageError{message: message, err: unmarshalErr, method: baseMessage.Method}, } } else { + request.Header = headers s.hooks.beforeCallTool(ctx, baseMessage.ID, &request) result, err = s.handleToolCall(ctx, baseMessage.ID, request) } diff --git a/server/sampling.go b/server/sampling.go new file mode 100644 index 000000000..4423ccf5f --- /dev/null +++ b/server/sampling.go @@ -0,0 +1,61 @@ +package server + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" +) + +// EnableSampling enables sampling capabilities for the server. +// This allows the server to send sampling requests to clients that support it. +func (s *MCPServer) EnableSampling() { + s.capabilitiesMu.Lock() + defer s.capabilitiesMu.Unlock() + + enabled := true + s.capabilities.sampling = &enabled +} + +// RequestSampling sends a sampling request to the client. +// The client must have declared sampling capability during initialization. +func (s *MCPServer) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + session := ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("no active session") + } + + // Check if the session supports sampling requests + if samplingSession, ok := session.(SessionWithSampling); ok { + return samplingSession.RequestSampling(ctx, request) + } + + // Check for inprocess sampling handler in context + if handler := InProcessSamplingHandlerFromContext(ctx); handler != nil { + return handler.CreateMessage(ctx, request) + } + + return nil, fmt.Errorf("session does not support sampling") +} + +// SessionWithSampling extends ClientSession to support sampling requests. +type SessionWithSampling interface { + ClientSession + RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) +} + +// inProcessSamplingHandlerKey is the context key for storing inprocess sampling handler +type inProcessSamplingHandlerKey struct{} + +// WithInProcessSamplingHandler adds a sampling handler to the context for inprocess clients +func WithInProcessSamplingHandler(ctx context.Context, handler SamplingHandler) context.Context { + return context.WithValue(ctx, inProcessSamplingHandlerKey{}, handler) +} + +// InProcessSamplingHandlerFromContext retrieves the inprocess sampling handler from context +func InProcessSamplingHandlerFromContext(ctx context.Context) SamplingHandler { + if handler, ok := ctx.Value(inProcessSamplingHandlerKey{}).(SamplingHandler); ok { + return handler + } + return nil +} diff --git a/server/sampling_test.go b/server/sampling_test.go new file mode 100644 index 000000000..fbecdd70d --- /dev/null +++ b/server/sampling_test.go @@ -0,0 +1,154 @@ +package server + +import ( + "context" + "testing" + + "github.com/mark3labs/mcp-go/mcp" +) + +func TestMCPServer_RequestSampling_NoSession(t *testing.T) { + server := NewMCPServer("test", "1.0.0") + server.EnableSampling() + + request := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + {Role: mcp.RoleUser, Content: mcp.TextContent{Type: "text", Text: "Test"}}, + }, + MaxTokens: 100, + }, + } + + _, err := server.RequestSampling(context.Background(), request) + + if err == nil { + t.Error("expected error when no session available") + } + + expectedError := "no active session" + if err.Error() != expectedError { + t.Errorf("expected error %q, got %q", expectedError, err.Error()) + } +} + +// mockSession implements ClientSession for testing +type mockSession struct { + sessionID string +} + +func (m *mockSession) SessionID() string { + return m.sessionID +} + +func (m *mockSession) NotificationChannel() chan<- mcp.JSONRPCNotification { + return make(chan mcp.JSONRPCNotification, 1) +} + +func (m *mockSession) Initialize() {} + +func (m *mockSession) Initialized() bool { + return true +} + +// mockSamplingSession implements SessionWithSampling for testing +type mockSamplingSession struct { + mockSession + result *mcp.CreateMessageResult + err error +} + +func (m *mockSamplingSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + if m.err != nil { + return nil, m.err + } + return m.result, nil +} + +func TestMCPServer_RequestSampling_Success(t *testing.T) { + server := NewMCPServer("test", "1.0.0") + server.EnableSampling() + + // Create a mock sampling session + mockSession := &mockSamplingSession{ + mockSession: mockSession{sessionID: "test-session"}, + result: &mcp.CreateMessageResult{ + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: "Test response", + }, + }, + Model: "test-model", + StopReason: "endTurn", + }, + } + + // Create context with session + ctx := context.Background() + ctx = server.WithContext(ctx, mockSession) + + request := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + {Role: mcp.RoleUser, Content: mcp.TextContent{Type: "text", Text: "Test"}}, + }, + MaxTokens: 100, + }, + } + + result, err := server.RequestSampling(ctx, request) + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if result == nil { + t.Error("expected result, got nil") + return + } + + if result.Model != "test-model" { + t.Errorf("expected model %q, got %q", "test-model", result.Model) + } +} + +func TestMCPServer_EnableSampling_SetsCapability(t *testing.T) { + server := NewMCPServer("test", "1.0.0") + + // Verify sampling capability is not set initially + ctx := context.Background() + initRequest := mcp.InitializeRequest{ + Params: mcp.InitializeParams{ + ProtocolVersion: "2025-03-26", + ClientInfo: mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, + Capabilities: mcp.ClientCapabilities{}, + }, + } + + result, err := server.handleInitialize(ctx, 1, initRequest) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.Capabilities.Sampling != nil { + t.Error("sampling capability should not be set before EnableSampling() is called") + } + + // Enable sampling + server.EnableSampling() + + // Verify sampling capability is now set + result, err = server.handleInitialize(ctx, 2, initRequest) + if err != nil { + t.Fatalf("unexpected error after EnableSampling(): %v", err) + } + + if result.Capabilities.Sampling == nil { + t.Error("sampling capability should be set after EnableSampling() is called") + } +} diff --git a/server/server.go b/server/server.go index 46e6d9c57..9f04e9478 100644 --- a/server/server.go +++ b/server/server.go @@ -64,6 +64,12 @@ type ServerResource struct { Handler ResourceHandlerFunc } +// ServerResourceTemplate combines a ResourceTemplate with its handler function. +type ServerResourceTemplate struct { + Template mcp.ResourceTemplate + Handler ResourceTemplateHandlerFunc +} + // serverKey is the context key for storing the server instance type serverKey struct{} @@ -175,6 +181,7 @@ type serverCapabilities struct { resources *resourceCapabilities prompts *promptCapabilities logging *bool + sampling *bool } // resourceCapabilities defines the supported resource-related features @@ -317,6 +324,11 @@ func NewMCPServer( return s } +// GenerateInProcessSessionID generates a unique session ID for inprocess clients +func (s *MCPServer) GenerateInProcessSessionID() string { + return GenerateInProcessSessionID() +} + // AddResources registers multiple resources at once func (s *MCPServer) AddResources(resources ...ServerResource) { s.implicitlyRegisterResourceCapabilities() @@ -337,6 +349,14 @@ func (s *MCPServer) AddResources(resources ...ServerResource) { } } +// SetResources replaces all existing resources with the provided list +func (s *MCPServer) SetResources(resources ...ServerResource) { + s.resourcesMu.Lock() + s.resources = make(map[string]resourceEntry, len(resources)) + s.resourcesMu.Unlock() + s.AddResources(resources...) +} + // AddResource registers a new resource and its handler func (s *MCPServer) AddResource( resource mcp.Resource, @@ -360,17 +380,16 @@ func (s *MCPServer) RemoveResource(uri string) { } } -// AddResourceTemplate registers a new resource template and its handler -func (s *MCPServer) AddResourceTemplate( - template mcp.ResourceTemplate, - handler ResourceTemplateHandlerFunc, -) { +// AddResourceTemplates registers multiple resource templates at once +func (s *MCPServer) AddResourceTemplates(resourceTemplates ...ServerResourceTemplate) { s.implicitlyRegisterResourceCapabilities() s.resourcesMu.Lock() - s.resourceTemplates[template.URITemplate.Raw()] = resourceTemplateEntry{ - template: template, - handler: handler, + for _, entry := range resourceTemplates { + s.resourceTemplates[entry.Template.URITemplate.Raw()] = resourceTemplateEntry{ + template: entry.Template, + handler: entry.Handler, + } } s.resourcesMu.Unlock() @@ -381,6 +400,22 @@ func (s *MCPServer) AddResourceTemplate( } } +// SetResourceTemplates replaces all existing resource templates with the provided list +func (s *MCPServer) SetResourceTemplates(templates ...ServerResourceTemplate) { + s.resourcesMu.Lock() + s.resourceTemplates = make(map[string]resourceTemplateEntry, len(templates)) + s.resourcesMu.Unlock() + s.AddResourceTemplates(templates...) +} + +// AddResourceTemplate registers a new resource template and its handler +func (s *MCPServer) AddResourceTemplate( + template mcp.ResourceTemplate, + handler ResourceTemplateHandlerFunc, +) { + s.AddResourceTemplates(ServerResourceTemplate{Template: template, Handler: handler}) +} + // AddPrompts registers multiple prompts at once func (s *MCPServer) AddPrompts(prompts ...ServerPrompt) { s.implicitlyRegisterPromptCapabilities() @@ -404,6 +439,15 @@ func (s *MCPServer) AddPrompt(prompt mcp.Prompt, handler PromptHandlerFunc) { s.AddPrompts(ServerPrompt{Prompt: prompt, Handler: handler}) } +// SetPrompts replaces all existing prompts with the provided list +func (s *MCPServer) SetPrompts(prompts ...ServerPrompt) { + s.promptsMu.Lock() + s.prompts = make(map[string]mcp.Prompt, len(prompts)) + s.promptHandlers = make(map[string]PromptHandlerFunc, len(prompts)) + s.promptsMu.Unlock() + s.AddPrompts(prompts...) +} + // DeletePrompts removes prompts from the server func (s *MCPServer) DeletePrompts(names ...string) { s.promptsMu.Lock() @@ -562,6 +606,10 @@ func (s *MCPServer) handleInitialize( capabilities.Logging = &struct{}{} } + if s.capabilities.sampling != nil && *s.capabilities.sampling { + capabilities.Sampling = &struct{}{} + } + result := mcp.InitializeResult{ ProtocolVersion: s.protocolVersion(request.Params.ProtocolVersion), ServerInfo: mcp.Implementation{ @@ -578,12 +626,22 @@ func (s *MCPServer) handleInitialize( // Store client info if the session supports it if sessionWithClientInfo, ok := session.(SessionWithClientInfo); ok { sessionWithClientInfo.SetClientInfo(request.Params.ClientInfo) + sessionWithClientInfo.SetClientCapabilities(request.Params.Capabilities) } } + return &result, nil } func (s *MCPServer) protocolVersion(clientVersion string) string { + // For backwards compatibility, if the server does not receive an MCP-Protocol-Version header, + // and has no other way to identify the version - for example, by relying on the protocol version negotiated + // during initialization - the server SHOULD assume protocol version 2025-03-26 + // https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#protocol-version-header + if len(clientVersion) == 0 { + clientVersion = "2025-03-26" + } + if slices.Contains(mcp.ValidProtocolVersions, clientVersion) { return clientVersion } @@ -1018,12 +1076,12 @@ func (s *MCPServer) handleToolCall( s.middlewareMu.RLock() mw := s.toolHandlerMiddlewares - s.middlewareMu.RUnlock() // Apply middlewares in reverse order for i := len(mw) - 1; i >= 0; i-- { finalHandler = mw[i](finalHandler) } + s.middlewareMu.RUnlock() result, err := finalHandler(ctx, request) if err != nil { diff --git a/server/server_test.go b/server/server_test.go index 1c81d18dd..aca99ef60 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "reflect" + "slices" "sort" "testing" "time" @@ -41,7 +42,7 @@ func TestMCPServer_Capabilities(t *testing.T) { assert.Equal( t, - mcp.LATEST_PROTOCOL_VERSION, + "2025-03-26", // Backward compatibility: no protocol version provided, initResult.ProtocolVersion, ) assert.Equal(t, "test-server", initResult.ServerInfo.Name) @@ -69,7 +70,7 @@ func TestMCPServer_Capabilities(t *testing.T) { assert.Equal( t, - mcp.LATEST_PROTOCOL_VERSION, + "2025-03-26", // Backward compatibility: no protocol version provided, initResult.ProtocolVersion, ) assert.Equal(t, "test-server", initResult.ServerInfo.Name) @@ -106,7 +107,7 @@ func TestMCPServer_Capabilities(t *testing.T) { assert.Equal( t, - mcp.LATEST_PROTOCOL_VERSION, + "2025-03-26", // Backward compatibility: no protocol version provided, initResult.ProtocolVersion, ) assert.Equal(t, "test-server", initResult.ServerInfo.Name) @@ -406,7 +407,7 @@ func TestMCPServer_HandleValidMessages(t *testing.T) { assert.Equal( t, - mcp.LATEST_PROTOCOL_VERSION, + "2025-03-26", // Backward compatibility: no protocol version provided, initResult.ProtocolVersion, ) assert.Equal(t, "test-server", initResult.ServerInfo.Name) @@ -1456,6 +1457,54 @@ func TestMCPServer_ResourceTemplates(t *testing.T) { assert.Equal(t, "text/plain", resultContent.MIMEType) assert.Equal(t, "test content: something", resultContent.Text) }) + + server.AddResourceTemplates( + ServerResourceTemplate{ + Template: mcp.NewResourceTemplate( + "test://test-another-resource-1", + "Another Resource 1", + ), + Handler: func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{}, nil + }, + }, + ServerResourceTemplate{ + Template: mcp.NewResourceTemplate( + "test://test-another-resource-2", + "Another Resource 2", + ), + Handler: func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { + return []mcp.ResourceContents{}, nil + }, + }, + ) + + t.Run("Check bulk add resource templates", func(t *testing.T) { + assert.Equal(t, 3, len(server.resourceTemplates)) + }) + + t.Run("Get resource template again", func(t *testing.T) { + response := server.HandleMessage( + context.Background(), + []byte(listMessage), + ) + assert.NotNil(t, response) + + resp, ok := response.(mcp.JSONRPCResponse) + assert.True(t, ok) + listResult, ok := resp.Result.(mcp.ListResourceTemplatesResult) + assert.True(t, ok) + assert.Len(t, listResult.ResourceTemplates, 3) + + // resource templates are stored in a map, so the order is not guaranteed + for _, rt := range listResult.ResourceTemplates { + assert.True(t, slices.Contains([]string{ + "My Resource", + "Another Resource 1", + "Another Resource 2", + }, rt.Name)) + } + }) } func createTestServer() *MCPServer { diff --git a/server/session.go b/server/session.go index a79da22ca..11ee8a2f1 100644 --- a/server/session.go +++ b/server/session.go @@ -46,6 +46,10 @@ type SessionWithClientInfo interface { GetClientInfo() mcp.Implementation // SetClientInfo sets the client information for this session SetClientInfo(clientInfo mcp.Implementation) + // GetClientCapabilities returns the client capabilities for this session + GetClientCapabilities() mcp.ClientCapabilities + // SetClientCapabilities sets the client capabilities for this session + SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) } // SessionWithStreamableHTTPConfig extends ClientSession to support streamable HTTP transport configurations @@ -96,35 +100,38 @@ func (s *MCPServer) RegisterSession( return nil } -// UnregisterSession removes from storage session that is shut down. -func (s *MCPServer) UnregisterSession( - ctx context.Context, - sessionID string, -) { - sessionValue, ok := s.sessions.LoadAndDelete(sessionID) - if !ok { - return - } - if session, ok := sessionValue.(ClientSession); ok { - s.hooks.UnregisterSession(ctx, session) - } -} - -// SendNotificationToAllClients sends a notification to all the currently active clients. -func (s *MCPServer) SendNotificationToAllClients( - method string, - params map[string]any, -) { - notification := mcp.JSONRPCNotification{ +func (s *MCPServer) buildLogNotification(notification mcp.LoggingMessageNotification) mcp.JSONRPCNotification { + return mcp.JSONRPCNotification{ JSONRPC: mcp.JSONRPC_VERSION, Notification: mcp.Notification{ - Method: method, + Method: notification.Method, Params: mcp.NotificationParams{ - AdditionalFields: params, + AdditionalFields: map[string]any{ + "level": notification.Params.Level, + "logger": notification.Params.Logger, + "data": notification.Params.Data, + }, }, }, } +} +func (s *MCPServer) SendLogMessageToClient(ctx context.Context, notification mcp.LoggingMessageNotification) error { + session := ClientSessionFromContext(ctx) + if session == nil || !session.Initialized() { + return ErrNotificationNotInitialized + } + sessionLogging, ok := session.(SessionWithLogging) + if !ok { + return ErrSessionDoesNotSupportLogging + } + if !notification.Params.Level.ShouldSendTo(sessionLogging.GetLogLevel()) { + return nil + } + return s.sendNotificationCore(ctx, session, s.buildLogNotification(notification)) +} + +func (s *MCPServer) sendNotificationToAllClients(notification mcp.JSONRPCNotification) { s.sessions.Range(func(k, v any) bool { if session, ok := v.(ClientSession); ok && session.Initialized() { select { @@ -140,7 +147,7 @@ func (s *MCPServer) SendNotificationToAllClients( ctx := context.Background() // Use the error hook to report the blocked channel hooks.onError(ctx, nil, "notification", map[string]any{ - "method": method, + "method": notification.Method, "sessionID": sessionID, }, fmt.Errorf("notification channel blocked for session %s: %w", sessionID, err)) }(session.SessionID(), hooks) @@ -151,22 +158,71 @@ func (s *MCPServer) SendNotificationToAllClients( }) } -// SendNotificationToClient sends a notification to the current client -func (s *MCPServer) SendNotificationToClient( - ctx context.Context, - method string, - params map[string]any, -) error { - session := ClientSessionFromContext(ctx) - if session == nil || !session.Initialized() { - return ErrNotificationNotInitialized - } - +func (s *MCPServer) sendNotificationToSpecificClient(session ClientSession, notification mcp.JSONRPCNotification) error { // upgrades the client-server communication to SSE stream when the server sends notifications to the client if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() } + select { + case session.NotificationChannel() <- notification: + return nil + default: + // Channel is blocked, if there's an error hook, use it + if s.hooks != nil && len(s.hooks.OnError) > 0 { + err := ErrNotificationChannelBlocked + ctx := context.Background() + // Copy hooks pointer to local variable to avoid race condition + hooks := s.hooks + go func(sID string, hooks *Hooks) { + // Use the error hook to report the blocked channel + hooks.onError(ctx, nil, "notification", map[string]any{ + "method": notification.Method, + "sessionID": sID, + }, fmt.Errorf("notification channel blocked for session %s: %w", sID, err)) + }(session.SessionID(), hooks) + } + return ErrNotificationChannelBlocked + } +} + +func (s *MCPServer) SendLogMessageToSpecificClient(sessionID string, notification mcp.LoggingMessageNotification) error { + sessionValue, ok := s.sessions.Load(sessionID) + if !ok { + return ErrSessionNotFound + } + session, ok := sessionValue.(ClientSession) + if !ok || !session.Initialized() { + return ErrSessionNotInitialized + } + sessionLogging, ok := session.(SessionWithLogging) + if !ok { + return ErrSessionDoesNotSupportLogging + } + if !notification.Params.Level.ShouldSendTo(sessionLogging.GetLogLevel()) { + return nil + } + return s.sendNotificationToSpecificClient(session, s.buildLogNotification(notification)) +} + +// UnregisterSession removes from storage session that is shut down. +func (s *MCPServer) UnregisterSession( + ctx context.Context, + sessionID string, +) { + sessionValue, ok := s.sessions.LoadAndDelete(sessionID) + if !ok { + return + } + if session, ok := sessionValue.(ClientSession); ok { + s.hooks.UnregisterSession(ctx, session) + } +} +// SendNotificationToAllClients sends a notification to all the currently active clients. +func (s *MCPServer) SendNotificationToAllClients( + method string, + params map[string]any, +) { notification := mcp.JSONRPCNotification{ JSONRPC: mcp.JSONRPC_VERSION, Notification: mcp.Notification{ @@ -176,13 +232,26 @@ func (s *MCPServer) SendNotificationToClient( }, }, } + s.sendNotificationToAllClients(notification) +} +// SendNotificationToClient sends a notification to the current client +func (s *MCPServer) sendNotificationCore( + ctx context.Context, + session ClientSession, + notification mcp.JSONRPCNotification, +) error { + // upgrades the client-server communication to SSE stream when the server sends notifications to the client + if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { + sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() + } select { case session.NotificationChannel() <- notification: return nil default: // Channel is blocked, if there's an error hook, use it if s.hooks != nil && len(s.hooks.OnError) > 0 { + method := notification.Method err := ErrNotificationChannelBlocked // Copy hooks pointer to local variable to avoid race condition hooks := s.hooks @@ -198,6 +267,28 @@ func (s *MCPServer) SendNotificationToClient( } } +// SendNotificationToClient sends a notification to the current client +func (s *MCPServer) SendNotificationToClient( + ctx context.Context, + method string, + params map[string]any, +) error { + session := ClientSessionFromContext(ctx) + if session == nil || !session.Initialized() { + return ErrNotificationNotInitialized + } + notification := mcp.JSONRPCNotification{ + JSONRPC: mcp.JSONRPC_VERSION, + Notification: mcp.Notification{ + Method: method, + Params: mcp.NotificationParams{ + AdditionalFields: params, + }, + }, + } + return s.sendNotificationCore(ctx, session, notification) +} + // SendNotificationToSpecificClient sends a notification to a specific client by session ID func (s *MCPServer) SendNotificationToSpecificClient( sessionID string, @@ -208,17 +299,10 @@ func (s *MCPServer) SendNotificationToSpecificClient( if !ok { return ErrSessionNotFound } - session, ok := sessionValue.(ClientSession) if !ok || !session.Initialized() { return ErrSessionNotInitialized } - - // upgrades the client-server communication to SSE stream when the server sends notifications to the client - if sessionWithStreamableHTTPConfig, ok := session.(SessionWithStreamableHTTPConfig); ok { - sessionWithStreamableHTTPConfig.UpgradeToSSEWhenReceiveNotification() - } - notification := mcp.JSONRPCNotification{ JSONRPC: mcp.JSONRPC_VERSION, Notification: mcp.Notification{ @@ -228,27 +312,7 @@ func (s *MCPServer) SendNotificationToSpecificClient( }, }, } - - select { - case session.NotificationChannel() <- notification: - return nil - default: - // Channel is blocked, if there's an error hook, use it - if s.hooks != nil && len(s.hooks.OnError) > 0 { - err := ErrNotificationChannelBlocked - ctx := context.Background() - // Copy hooks pointer to local variable to avoid race condition - hooks := s.hooks - go func(sID string, hooks *Hooks) { - // Use the error hook to report the blocked channel - hooks.onError(ctx, nil, "notification", map[string]any{ - "method": method, - "sessionID": sID, - }, fmt.Errorf("notification channel blocked for session %s: %w", sID, err)) - }(sessionID, hooks) - } - return ErrNotificationChannelBlocked - } + return s.sendNotificationToSpecificClient(session, notification) } // AddSessionTool adds a tool for a specific session diff --git a/server/session_test.go b/server/session_test.go index 3067f4e9c..04334487b 100644 --- a/server/session_test.go +++ b/server/session_test.go @@ -106,6 +106,7 @@ type sessionTestClientWithClientInfo struct { notificationChannel chan mcp.JSONRPCNotification initialized bool clientInfo atomic.Value + clientCapabilities atomic.Value } func (f *sessionTestClientWithClientInfo) SessionID() string { @@ -137,6 +138,19 @@ func (f *sessionTestClientWithClientInfo) SetClientInfo(clientInfo mcp.Implement f.clientInfo.Store(clientInfo) } +func (f *sessionTestClientWithClientInfo) GetClientCapabilities() mcp.ClientCapabilities { + if value := f.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + +func (f *sessionTestClientWithClientInfo) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + f.clientCapabilities.Store(clientCapabilities) +} + // sessionTestClientWithTools implements the SessionWithLogging interface for testing type sessionTestClientWithLogging struct { sessionID string @@ -1099,10 +1113,14 @@ func TestSessionWithClientInfo_Integration(t *testing.T) { Version: "1.0.0", } + clientCapability := mcp.ClientCapabilities{ + Sampling: &struct{}{}, + } + initRequest := mcp.InitializeRequest{} initRequest.Params.ClientInfo = clientInfo initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.Capabilities = mcp.ClientCapabilities{} + initRequest.Params.Capabilities = clientCapability sessionCtx := server.WithContext(context.Background(), session) @@ -1125,4 +1143,389 @@ func TestSessionWithClientInfo_Integration(t *testing.T) { assert.Equal(t, clientInfo.Name, storedClientInfo.Name, "Client name should match") assert.Equal(t, clientInfo.Version, storedClientInfo.Version, "Client version should match") + + storedClientCapabilities := sessionWithClientInfo.GetClientCapabilities() + + assert.Equal(t, clientCapability, storedClientCapabilities, "Client capability should match") +} + +// New test function to cover log notification functionality +func TestMCPServer_SendLogMessageToClient(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithLogging()) + ctx := context.Background() + + // Create a session that supports logging + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithLogging{ + sessionID: "session-1", + notificationChannel: sessionChan, + } + session.Initialize() + + // Set log level to Info + session.SetLogLevel(mcp.LoggingLevelInfo) + + // Register session + err := server.RegisterSession(ctx, session) + require.NoError(t, err) + + // Create session context + sessionCtx := server.WithContext(ctx, session) + + // Test cases + tests := []struct { + name string + level mcp.LoggingLevel + expectSent bool + expectError bool + }{ + { + name: "higher level log should be sent", + level: mcp.LoggingLevelWarning, // Higher than Info + expectSent: true, + }, + { + name: "same level log should be sent", + level: mcp.LoggingLevelInfo, + expectSent: true, + }, + { + name: "lower level log should not be sent", + level: mcp.LoggingLevelDebug, // Lower than Info + expectSent: false, + }, + { + name: "uninitialized session should return error", + level: mcp.LoggingLevelError, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.expectError { + // Create uninitialized session + uninitSession := &sessionTestClientWithLogging{ + sessionID: "uninit-session", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: false, + } + uninitCtx := server.WithContext(ctx, uninitSession) + notification := mcp.NewLoggingMessageNotification(tt.level, "test-logger", "test message") + err := server.SendLogMessageToClient(uninitCtx, notification) + require.Error(t, err) + assert.Equal(t, ErrNotificationNotInitialized, err) + return + } + notification := mcp.NewLoggingMessageNotification(tt.level, "test-logger", "test message") + err := server.SendLogMessageToClient(sessionCtx, notification) + require.NoError(t, err) + + if tt.expectSent { + select { + case notif := <-sessionChan: + assert.Equal(t, "notifications/message", notif.Method) + assert.Equal(t, tt.level, notif.Params.AdditionalFields["level"]) + assert.Equal(t, "test-logger", notif.Params.AdditionalFields["logger"]) + assert.Equal(t, "test message", notif.Params.AdditionalFields["data"]) + case <-time.After(500 * time.Millisecond): + t.Error("Expected log notification not received") + } + } else { + select { + case <-sessionChan: + t.Error("Unexpected log notification received") + case <-time.After(50 * time.Millisecond): + // No notification expected + } + } + }) + } +} + +func TestMCPServer_SendLogMessageToSpecificClient(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithLogging()) + ctx := context.Background() + + // Create two sessions + session1Chan := make(chan mcp.JSONRPCNotification, 10) + session1 := &sessionTestClientWithLogging{ + sessionID: "session-1", + notificationChannel: session1Chan, + } + session1.Initialize() + session1.SetLogLevel(mcp.LoggingLevelInfo) + + session2Chan := make(chan mcp.JSONRPCNotification, 10) + session2 := &sessionTestClientWithLogging{ + sessionID: "session-2", + notificationChannel: session2Chan, + } + session2.Initialize() + session2.SetLogLevel(mcp.LoggingLevelWarning) // Higher log level + + // Register sessions + require.NoError(t, server.RegisterSession(ctx, session1)) + require.NoError(t, server.RegisterSession(ctx, session2)) + + // Test cases + tests := []struct { + name string + sessionID string + level mcp.LoggingLevel + expectSent bool + expectError bool + errorType error + }{ + { + name: "valid session and level should be sent", + sessionID: session1.SessionID(), + level: mcp.LoggingLevelInfo, + expectSent: true, + }, + { + name: "log below session level should not be sent", + sessionID: session1.SessionID(), + level: mcp.LoggingLevelDebug, + expectSent: false, + }, + { + name: "valid session with higher level should be sent", + sessionID: session2.SessionID(), + level: mcp.LoggingLevelError, + expectSent: true, + }, + { + name: "non-existent session should return error", + sessionID: "non-existent", + level: mcp.LoggingLevelError, + expectError: true, + errorType: ErrSessionNotFound, + }, + { + name: "uninitialized session should return error", + sessionID: "uninitialized-session", + level: mcp.LoggingLevelError, + expectError: true, + errorType: ErrSessionNotInitialized, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.sessionID == "uninitialized-session" { + uninitSession := &sessionTestClientWithLogging{ + sessionID: "uninitialized-session", + notificationChannel: make(chan mcp.JSONRPCNotification, 10), + initialized: false, + } + require.NoError(t, server.RegisterSession(ctx, uninitSession)) + } + + notification := mcp.NewLoggingMessageNotification(tt.level, "test-logger", "test message") + + err := server.SendLogMessageToSpecificClient(tt.sessionID, notification) + + if tt.expectError { + require.Error(t, err) + if tt.errorType != nil { + assert.ErrorIs(t, err, tt.errorType) + } + return + } + + require.NoError(t, err) + + var targetChan chan mcp.JSONRPCNotification + if tt.sessionID == session1.SessionID() { + targetChan = session1Chan + } else if tt.sessionID == session2.SessionID() { + targetChan = session2Chan + } + + if tt.expectSent && targetChan != nil { + select { + case notif := <-targetChan: + assert.Equal(t, "notifications/message", notif.Method) + assert.Equal(t, tt.level, notif.Params.AdditionalFields["level"]) + assert.Equal(t, "test-logger", notif.Params.AdditionalFields["logger"]) + assert.Equal(t, "test message", notif.Params.AdditionalFields["data"]) + case <-time.After(100 * time.Millisecond): + t.Error("Expected log notification not received") + } + } else if targetChan != nil { + select { + case <-targetChan: + t.Error("Unexpected log notification received") + case <-time.After(50 * time.Millisecond): + // No notification expected + } + } + }) + } +} + +func TestMCPServer_LoggingWithUnsupportedSessions(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithLogging()) + ctx := context.Background() + + // Create three types of sessions: + // 1. Logging-supported session + // 2. Logging-unsupported session + // 3. Uninitialized session + + // Logging-supported session + loggingSessionChan := make(chan mcp.JSONRPCNotification, 10) + loggingSession := &sessionTestClientWithLogging{ + sessionID: "logging-session", + notificationChannel: loggingSessionChan, + } + loggingSession.Initialize() + loggingSession.SetLogLevel(mcp.LoggingLevelInfo) + + // Logging-unsupported session + nonLoggingSessionChan := make(chan mcp.JSONRPCNotification, 10) + nonLoggingSession := &sessionTestClient{ + sessionID: "non-logging-session", + notificationChannel: nonLoggingSessionChan, + } + nonLoggingSession.Initialize() + + // Uninitialized session + uninitializedSessionChan := make(chan mcp.JSONRPCNotification, 10) + uninitializedSession := &sessionTestClientWithLogging{ + sessionID: "uninitialized-session", + notificationChannel: uninitializedSessionChan, + initialized: false, + } + + // Register all sessions + require.NoError(t, server.RegisterSession(ctx, loggingSession)) + require.NoError(t, server.RegisterSession(ctx, nonLoggingSession)) + require.NoError(t, server.RegisterSession(ctx, uninitializedSession)) + + // Info-level log notification + notification := mcp.NewLoggingMessageNotification(mcp.LoggingLevelInfo, "test-logger", "test message for ") + + t.Run("SendLogMessageToClient", func(t *testing.T) { + // Logging-supported session + loggingCtx := server.WithContext(ctx, loggingSession) + err := server.SendLogMessageToClient(loggingCtx, notification) + require.NoError(t, err) + select { + case notif := <-loggingSessionChan: + assert.Equal(t, "notifications/message", notif.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Expected log notification not received") + } + + // Logging-unsupported session + nonLoggingCtx := server.WithContext(ctx, nonLoggingSession) + err = server.SendLogMessageToClient(nonLoggingCtx, notification) + require.Error(t, err) + assert.Equal(t, ErrSessionDoesNotSupportLogging, err) + + // Uninitialized session + uninitCtx := server.WithContext(ctx, uninitializedSession) + err = server.SendLogMessageToClient(uninitCtx, notification) + require.Error(t, err) + assert.Equal(t, ErrNotificationNotInitialized, err) + }) + + t.Run("SendLogMessageToSpecificClient", func(t *testing.T) { + err := server.SendLogMessageToSpecificClient(loggingSession.SessionID(), notification) + require.NoError(t, err) + select { + case notif := <-loggingSessionChan: + assert.Equal(t, "notifications/message", notif.Method) + case <-time.After(100 * time.Millisecond): + t.Error("Expected log notification not received") + } + + err = server.SendLogMessageToSpecificClient(nonLoggingSession.SessionID(), notification) + require.Error(t, err) + assert.Equal(t, ErrSessionDoesNotSupportLogging, err) + + err = server.SendLogMessageToSpecificClient(uninitializedSession.SessionID(), notification) + require.Error(t, err) + assert.Equal(t, ErrSessionNotInitialized, err) + }) +} + +func TestMCPServer_LoggingNotificationFormat(t *testing.T) { + server := NewMCPServer("test-server", "1.0.0", WithLogging()) + ctx := context.Background() + + // Create a session + sessionChan := make(chan mcp.JSONRPCNotification, 10) + session := &sessionTestClientWithLogging{ + sessionID: "session-1", + notificationChannel: sessionChan, + } + session.Initialize() + session.SetLogLevel(mcp.LoggingLevelDebug) + + // Register session + require.NoError(t, server.RegisterSession(ctx, session)) + + // Send log messages with different formats + testCases := []struct { + name string + data any + expected any + }{ + { + name: "string data", + data: "simple log message", + expected: "simple log message", + }, + { + name: "structured data", + data: map[string]any{"key": "value", "num": 42}, + expected: map[string]any{"key": "value", "num": 42}, + }, + { + name: "error data", + data: errors.New("error message"), + expected: errors.New("error message"), + }, + { + name: "nil data", + data: nil, + expected: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + notification := mcp.NewLoggingMessageNotification(mcp.LoggingLevelInfo, "test-logger", tc.data) + + err := server.SendLogMessageToSpecificClient(session.SessionID(), notification) + require.NoError(t, err) + + select { + case notif := <-sessionChan: + assert.Equal(t, "notifications/message", notif.Method) + assert.Equal(t, mcp.LoggingLevelInfo, notif.Params.AdditionalFields["level"]) + assert.Equal(t, "test-logger", notif.Params.AdditionalFields["logger"]) + + // Validate log data format + dataField := notif.Params.AdditionalFields["data"] + switch expected := tc.expected.(type) { + case string: + assert.Equal(t, expected, dataField) + case map[string]any: + assert.IsType(t, map[string]any{}, dataField) + dataMap := dataField.(map[string]any) + for k, v := range expected { + assert.Equal(t, v, dataMap[k]) + } + case nil: + assert.Nil(t, dataField) + } + case <-time.After(100 * time.Millisecond): + t.Error("Expected log notification not received") + } + }) + } } diff --git a/server/sse.go b/server/sse.go index 416995730..9c9766cf3 100644 --- a/server/sse.go +++ b/server/sse.go @@ -30,6 +30,7 @@ type sseSession struct { loggingLevel atomic.Value tools sync.Map // stores session-specific tools clientInfo atomic.Value // stores session-specific client info + clientCapabilities atomic.Value // stores session-specific client capabilities } // SSEContextFunc is a function that takes an existing context and the current @@ -108,6 +109,19 @@ func (s *sseSession) SetClientInfo(clientInfo mcp.Implementation) { s.clientInfo.Store(clientInfo) } +func (s *sseSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + s.clientCapabilities.Store(clientCapabilities) +} + +func (s *sseSession) GetClientCapabilities() mcp.ClientCapabilities { + if value := s.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + var ( _ ClientSession = (*sseSession)(nil) _ SessionWithTools = (*sseSession)(nil) @@ -504,7 +518,8 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusAccepted) // Create a new context for handling the message that will be canceled when the message handling is done - messageCtx, cancel := context.WithCancel(detachedCtx) + messageCtx := context.WithValue(detachedCtx, requestHeader, r.Header) + messageCtx, cancel := context.WithCancel(messageCtx) go func(ctx context.Context) { defer cancel() diff --git a/server/sse_test.go b/server/sse_test.go index 96912be49..de8e29d33 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -1257,7 +1257,7 @@ func TestSSEServer(t *testing.T) { WithHooks(&Hooks{ OnAfterInitialize: []OnAfterInitializeFunc{ func(ctx context.Context, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) { - result.Meta = map[string]any{"invalid": func() {}} // marshal will fail + result.Meta = mcp.NewMetaFromMap(map[string]any{"invalid": func() {}}) // marshal will fail }, }, }), @@ -1443,6 +1443,169 @@ func TestSSEServer(t *testing.T) { t.Fatal("Shutdown did not return in time (likely deadlocked)") } }) + + t.Run("Headers are passed through to tool requests", func(t *testing.T) { + hooks := &Hooks{} + headerVerified := make(chan struct{}) + hooks.AddAfterCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) { + if message.Params.Name == "verify-headers" { + select { + case <-headerVerified: + default: + close(headerVerified) + } + } + }) + + mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks)) + addHeaderVerificationTool(mcpServer) + testServer := NewTestServer(mcpServer) + defer testServer.Close() + + // First establish SSE connection + sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL)) + if err != nil { + t.Fatalf("Failed to connect to SSE endpoint: %v", err) + } + defer sseResp.Body.Close() + + // Read the endpoint event + endpointEvent, err := readSSEEvent(sseResp) + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + messageURL := strings.TrimSpace( + strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], + ) + + // Send request with custom header + toolRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]any{ + "name": "verify-headers", + }, + } + requestBody, err := json.Marshal(toolRequest) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + + // Create request with custom header + messageReq, err := http.NewRequest("POST", messageURL, bytes.NewReader(requestBody)) + if err != nil { + t.Fatalf("Failed to create message request: %v", err) + } + messageReq.Header.Set("Content-Type", "application/json") + messageReq.Header.Set("X-Custom-Header", "test-value") + + resp, err := http.DefaultClient.Do(messageReq) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusAccepted { + t.Errorf("Expected status 202, got %d", resp.StatusCode) + } + + // Wait for hook to be called + select { + case <-headerVerified: + case <-time.After(1 * time.Second): + t.Error("Header verification hook was not called within timeout") + } + }) + + t.Run("Headers are not nil when no headers are set", func(t *testing.T) { + hooks := &Hooks{} + headersChecked := make(chan struct{}) + hooks.AddAfterCallTool(func(ctx context.Context, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) { + if message.Params.Name == "check-headers-not-nil" { + select { + case <-headersChecked: + default: + close(headersChecked) + } + } + }) + + mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks)) + mcpServer.AddTool( + mcp.NewTool("check-headers-not-nil"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // This will panic if headers are nil + _ = request.Header.Get("Any-Header") + // Also verify we can iterate over headers safely + for key := range request.Header { + _ = request.Header.Get(key) + } + return mcp.NewToolResultText("headers not nil"), nil + }, + ) + testServer := NewTestServer(mcpServer) + defer testServer.Close() + + // First establish SSE connection + sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL)) + if err != nil { + t.Fatalf("Failed to connect to SSE endpoint: %v", err) + } + defer sseResp.Body.Close() + + // Read the endpoint event + endpointEvent, err := readSSEEvent(sseResp) + if err != nil { + t.Fatalf("Failed to read SSE response: %v", err) + } + messageURL := strings.TrimSpace( + strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0], + ) + + // Send request without any headers at all + toolRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": map[string]any{ + "name": "check-headers-not-nil", + }, + } + requestBody, err := json.Marshal(toolRequest) + if err != nil { + t.Fatalf("Failed to marshal request: %v", err) + } + + // Use a custom transport to avoid default headers + transport := &http.Transport{} + client := &http.Client{Transport: transport} + + // Create a completely headerless request + req, err := http.NewRequest("POST", messageURL, bytes.NewReader(requestBody)) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + // Clear all headers to ensure absolutely no headers are sent + req.Header = make(http.Header) + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusAccepted { + t.Errorf("Expected status 202, got %d", resp.StatusCode) + } + + // Wait for hook to be called + select { + case <-headersChecked: + case <-time.After(1 * time.Second): + t.Error("Headers check hook was not called within timeout") + } + }) } func readSSEEvent(sseResp *http.Response) (string, error) { @@ -1453,3 +1616,16 @@ func readSSEEvent(sseResp *http.Response) (string, error) { } return string(buf[:n]), nil } + +// addHeaderVerificationTool adds a tool that verifies HTTP headers are passed correctly +func addHeaderVerificationTool(mcpServer *MCPServer) { + mcpServer.AddTool( + mcp.NewTool("verify-headers"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + if request.Header.Get("X-Custom-Header") != "test-value" { + return nil, fmt.Errorf("expected X-Custom-Header to be test-value, got %s", request.Header.Get("X-Custom-Header")) + } + return mcp.NewToolResultText("headers verified"), nil + }, + ) +} diff --git a/server/stdio.go b/server/stdio.go index 746a7d96f..4d567d8cb 100644 --- a/server/stdio.go +++ b/server/stdio.go @@ -9,6 +9,7 @@ import ( "log" "os" "os/signal" + "sync" "sync/atomic" "syscall" @@ -51,10 +52,22 @@ func WithStdioContextFunc(fn StdioContextFunc) StdioOption { // stdioSession is a static client session, since stdio has only one client. type stdioSession struct { - notifications chan mcp.JSONRPCNotification - initialized atomic.Bool - loggingLevel atomic.Value - clientInfo atomic.Value // stores session-specific client info + notifications chan mcp.JSONRPCNotification + initialized atomic.Bool + loggingLevel atomic.Value + clientInfo atomic.Value // stores session-specific client info + clientCapabilities atomic.Value // stores session-specific client capabilities + writer io.Writer // for sending requests to client + requestID atomic.Int64 // for generating unique request IDs + mu sync.RWMutex // protects writer + pendingRequests map[int64]chan *samplingResponse // for tracking pending sampling requests + pendingMu sync.RWMutex // protects pendingRequests +} + +// samplingResponse represents a response to a sampling request +type samplingResponse struct { + result *mcp.CreateMessageResult + err error } func (s *stdioSession) SessionID() string { @@ -88,6 +101,19 @@ func (s *stdioSession) SetClientInfo(clientInfo mcp.Implementation) { s.clientInfo.Store(clientInfo) } +func (s *stdioSession) GetClientCapabilities() mcp.ClientCapabilities { + if value := s.clientCapabilities.Load(); value != nil { + if clientCapabilities, ok := value.(mcp.ClientCapabilities); ok { + return clientCapabilities + } + } + return mcp.ClientCapabilities{} +} + +func (s *stdioSession) SetClientCapabilities(clientCapabilities mcp.ClientCapabilities) { + s.clientCapabilities.Store(clientCapabilities) +} + func (s *stdioSession) SetLogLevel(level mcp.LoggingLevel) { s.loggingLevel.Store(level) } @@ -100,14 +126,86 @@ func (s *stdioSession) GetLogLevel() mcp.LoggingLevel { return level.(mcp.LoggingLevel) } +// RequestSampling sends a sampling request to the client and waits for the response. +func (s *stdioSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + s.mu.RLock() + writer := s.writer + s.mu.RUnlock() + + if writer == nil { + return nil, fmt.Errorf("no writer available for sending requests") + } + + // Generate a unique request ID + id := s.requestID.Add(1) + + // Create a response channel for this request + responseChan := make(chan *samplingResponse, 1) + s.pendingMu.Lock() + s.pendingRequests[id] = responseChan + s.pendingMu.Unlock() + + // Cleanup function to remove the pending request + cleanup := func() { + s.pendingMu.Lock() + delete(s.pendingRequests, id) + s.pendingMu.Unlock() + } + defer cleanup() + + // Create the JSON-RPC request + jsonRPCRequest := struct { + JSONRPC string `json:"jsonrpc"` + ID int64 `json:"id"` + Method string `json:"method"` + Params mcp.CreateMessageParams `json:"params"` + }{ + JSONRPC: mcp.JSONRPC_VERSION, + ID: id, + Method: string(mcp.MethodSamplingCreateMessage), + Params: request.CreateMessageParams, + } + + // Marshal and send the request + requestBytes, err := json.Marshal(jsonRPCRequest) + if err != nil { + return nil, fmt.Errorf("failed to marshal sampling request: %w", err) + } + requestBytes = append(requestBytes, '\n') + + if _, err := writer.Write(requestBytes); err != nil { + return nil, fmt.Errorf("failed to write sampling request: %w", err) + } + + // Wait for the response or context cancellation + select { + case <-ctx.Done(): + return nil, ctx.Err() + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + return response.result, nil + } +} + +// SetWriter sets the writer for sending requests to the client. +func (s *stdioSession) SetWriter(writer io.Writer) { + s.mu.Lock() + defer s.mu.Unlock() + s.writer = writer +} + var ( _ ClientSession = (*stdioSession)(nil) _ SessionWithLogging = (*stdioSession)(nil) _ SessionWithClientInfo = (*stdioSession)(nil) + _ SessionWithSampling = (*stdioSession)(nil) ) var stdioSessionInstance = stdioSession{ - notifications: make(chan mcp.JSONRPCNotification, 100), + notifications: make(chan mcp.JSONRPCNotification, 100), + pendingRequests: make(map[int64]chan *samplingResponse), } // NewStdioServer creates a new stdio server wrapper around an MCPServer. @@ -224,6 +322,9 @@ func (s *StdioServer) Listen( defer s.server.UnregisterSession(ctx, stdioSessionInstance.SessionID()) ctx = s.server.WithContext(ctx, &stdioSessionInstance) + // Set the writer for sending requests to the client + stdioSessionInstance.SetWriter(stdout) + // Add in any custom context. if s.contextFunc != nil { ctx = s.contextFunc(ctx) @@ -256,7 +357,29 @@ func (s *StdioServer) processMessage( return s.writeResponse(response, writer) } - // Handle the message using the wrapped server + // Check if this is a response to a sampling request + if s.handleSamplingResponse(rawMessage) { + return nil + } + + // Check if this is a tool call that might need sampling (and thus should be processed concurrently) + var baseMessage struct { + Method string `json:"method"` + } + if json.Unmarshal(rawMessage, &baseMessage) == nil && baseMessage.Method == "tools/call" { + // Process tool calls concurrently to avoid blocking on sampling requests + go func() { + response := s.server.HandleMessage(ctx, rawMessage) + if response != nil { + if err := s.writeResponse(response, writer); err != nil { + s.errLogger.Printf("Error writing tool response: %v", err) + } + } + }() + return nil + } + + // Handle other messages synchronously response := s.server.HandleMessage(ctx, rawMessage) // Only write response if there is one (not for notifications) @@ -269,6 +392,65 @@ func (s *StdioServer) processMessage( return nil } +// handleSamplingResponse checks if the message is a response to a sampling request +// and routes it to the appropriate pending request channel. +func (s *StdioServer) handleSamplingResponse(rawMessage json.RawMessage) bool { + return stdioSessionInstance.handleSamplingResponse(rawMessage) +} + +// handleSamplingResponse handles incoming sampling responses for this session +func (s *stdioSession) handleSamplingResponse(rawMessage json.RawMessage) bool { + // Try to parse as a JSON-RPC response + var response struct { + JSONRPC string `json:"jsonrpc"` + ID json.Number `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *struct { + Code int `json:"code"` + Message string `json:"message"` + } `json:"error,omitempty"` + } + + if err := json.Unmarshal(rawMessage, &response); err != nil { + return false + } + // Parse the ID as int64 + idInt64, err := response.ID.Int64() + if err != nil || (response.Result == nil && response.Error == nil) { + return false + } + + // Look for a pending request with this ID + s.pendingMu.RLock() + responseChan, exists := s.pendingRequests[idInt64] + s.pendingMu.RUnlock() + + if !exists { + return false + } // Parse and send the response + samplingResp := &samplingResponse{} + + if response.Error != nil { + samplingResp.err = fmt.Errorf("sampling request failed: %s", response.Error.Message) + } else { + var result mcp.CreateMessageResult + if err := json.Unmarshal(response.Result, &result); err != nil { + samplingResp.err = fmt.Errorf("failed to unmarshal sampling response: %w", err) + } else { + samplingResp.result = &result + } + } + + // Send the response (non-blocking) + select { + case responseChan <- samplingResp: + default: + // Channel is full or closed, ignore + } + + return true +} + // writeResponse marshals and writes a JSON-RPC response message followed by a newline. // Returns an error if marshaling or writing fails. func (s *StdioServer) writeResponse( diff --git a/server/streamable_http.go b/server/streamable_http.go index e9a011fb1..24ec1c95a 100644 --- a/server/streamable_http.go +++ b/server/streamable_http.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "io" + "mime" "net/http" "net/http/httptest" "strings" @@ -40,7 +41,9 @@ func WithEndpointPath(endpointPath string) StreamableHTTPOption { // to StatelessSessionIdManager. func WithStateLess(stateLess bool) StreamableHTTPOption { return func(s *StreamableHTTPServer) { - s.sessionIdManager = &StatelessSessionIdManager{} + if stateLess { + s.sessionIdManager = &StatelessSessionIdManager{} + } } } @@ -112,12 +115,12 @@ func WithLogger(logger util.Logger) StreamableHTTPOption { // or `hooks.onRegisterSession` will not be triggered for POST messages. // // The current implementation does not support the following features from the specification: -// - Batching of requests/notifications/responses in arrays. // - Stream Resumability type StreamableHTTPServer struct { server *MCPServer sessionTools *sessionToolsStore sessionRequestIDs sync.Map // sessionId --> last requestID(*atomic.Int64) + activeSessions sync.Map // sessionId --> *streamableHttpSession (for sampling responses) httpServer *http.Server mu sync.RWMutex @@ -127,6 +130,7 @@ type StreamableHTTPServer struct { sessionIdManager SessionIdManager listenHeartbeatInterval time.Duration logger util.Logger + sessionLogLevels *sessionLogLevelsStore } // NewStreamableHTTPServer creates a new streamable-http server instance @@ -134,6 +138,7 @@ func NewStreamableHTTPServer(server *MCPServer, opts ...StreamableHTTPOption) *S s := &StreamableHTTPServer{ server: server, sessionTools: newSessionToolsStore(), + sessionLogLevels: newSessionLogLevelsStore(), endpointPath: "/mcp", sessionIdManager: &InsecureStatefulSessionIdManager{}, logger: util.DefaultLogger(), @@ -202,16 +207,13 @@ func (s *StreamableHTTPServer) Shutdown(ctx context.Context) error { // --- internal methods --- -const ( - headerKeySessionID = "Mcp-Session-Id" -) - func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request) { // post request carry request/notification message // Check content type contentType := r.Header.Get("Content-Type") - if contentType != "application/json" { + mediaType, _, err := mime.ParseMediaType(contentType) + if err != nil || mediaType != "application/json" { http.Error(w, "Invalid content type: must be 'application/json'", http.StatusBadRequest) return } @@ -222,14 +224,32 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, fmt.Sprintf("read request body error: %v", err)) return } - var baseMessage struct { - Method mcp.MCPMethod `json:"method"` + // First, try to parse as a response (sampling responses don't have a method field) + var jsonMessage struct { + ID json.RawMessage `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error json.RawMessage `json:"error,omitempty"` + Method mcp.MCPMethod `json:"method,omitempty"` } - if err := json.Unmarshal(rawData, &baseMessage); err != nil { + if err := json.Unmarshal(rawData, &jsonMessage); err != nil { s.writeJSONRPCError(w, nil, mcp.PARSE_ERROR, "request body is not valid json") return } - isInitializeRequest := baseMessage.Method == mcp.MethodInitialize + + // Check if this is a sampling response (has result/error but no method) + isSamplingResponse := jsonMessage.Method == "" && jsonMessage.ID != nil && + (jsonMessage.Result != nil || jsonMessage.Error != nil) + + isInitializeRequest := jsonMessage.Method == mcp.MethodInitialize + + // Handle sampling responses separately + if isSamplingResponse { + if err := s.handleSamplingResponse(w, r, jsonMessage); err != nil { + s.logger.Errorf("Failed to handle sampling response: %v", err) + http.Error(w, "Failed to handle sampling response", http.StatusInternalServerError) + } + return + } // Prepare the session for the mcp server // The session is ephemeral. Its life is the same as the request. It's only created @@ -241,7 +261,7 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request } else { // Get session ID from header. // Stateful servers need the client to carry the session ID. - sessionID = r.Header.Get(headerKeySessionID) + sessionID = r.Header.Get(HeaderKeySessionID) isTerminated, err := s.sessionIdManager.Validate(sessionID) if err != nil { http.Error(w, "Invalid session ID", http.StatusBadRequest) @@ -253,7 +273,7 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request } } - session := newStreamableHttpSession(sessionID, s.sessionTools) + session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels) // Set the client context before handling the message ctx := s.server.WithContext(r.Context(), session) @@ -266,6 +286,7 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request upgradedHeader := false done := make(chan struct{}) + ctx = context.WithValue(ctx, requestHeader, r.Header) go func() { for { select { @@ -291,7 +312,7 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Connection", "keep-alive") w.Header().Set("Cache-Control", "no-cache") - w.WriteHeader(http.StatusAccepted) + w.WriteHeader(http.StatusOK) upgradedHeader = true } err := writeSSEEvent(w, nt) @@ -330,7 +351,7 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Connection", "keep-alive") w.Header().Set("Cache-Control", "no-cache") - w.WriteHeader(http.StatusAccepted) + w.WriteHeader(http.StatusOK) upgradedHeader = true } if err := writeSSEEvent(w, response); err != nil { @@ -340,7 +361,7 @@ func (s *StreamableHTTPServer) handlePost(w http.ResponseWriter, r *http.Request w.Header().Set("Content-Type", "application/json") if isInitializeRequest && sessionID != "" { // send the session ID back to the client - w.Header().Set(headerKeySessionID, sessionID) + w.Header().Set(HeaderKeySessionID, sessionID) } w.WriteHeader(http.StatusOK) err := json.NewEncoder(w).Encode(response) @@ -354,7 +375,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) // get request is for listening to notifications // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server - sessionID := r.Header.Get(headerKeySessionID) + sessionID := r.Header.Get(HeaderKeySessionID) // the specification didn't say we should validate the session id if sessionID == "" { @@ -363,18 +384,22 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) sessionID = uuid.New().String() } - session := newStreamableHttpSession(sessionID, s.sessionTools) + session := newStreamableHttpSession(sessionID, s.sessionTools, s.sessionLogLevels) if err := s.server.RegisterSession(r.Context(), session); err != nil { http.Error(w, fmt.Sprintf("Session registration failed: %v", err), http.StatusBadRequest) return } defer s.server.UnregisterSession(r.Context(), sessionID) + + // Register session for sampling response delivery + s.activeSessions.Store(sessionID, session) + defer s.activeSessions.Delete(sessionID) // Set the client context before handling the message w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - w.WriteHeader(http.StatusAccepted) + w.WriteHeader(http.StatusOK) flusher, ok := w.(http.Flusher) if !ok { @@ -397,6 +422,21 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) case <-done: return } + case samplingReq := <-session.samplingRequestChan: + // Send sampling request to client via SSE + jsonrpcRequest := mcp.JSONRPCRequest{ + JSONRPC: "2.0", + ID: mcp.NewRequestId(samplingReq.requestID), + Request: mcp.Request{ + Method: string(mcp.MethodSamplingCreateMessage), + }, + Params: samplingReq.request.CreateMessageParams, + } + select { + case writeChan <- jsonrpcRequest: + case <-done: + return + } case <-done: return } @@ -453,7 +493,7 @@ func (s *StreamableHTTPServer) handleGet(w http.ResponseWriter, r *http.Request) func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Request) { // delete request terminate the session - sessionID := r.Header.Get(headerKeySessionID) + sessionID := r.Header.Get(HeaderKeySessionID) notAllowed, err := s.sessionIdManager.Terminate(sessionID) if err != nil { http.Error(w, fmt.Sprintf("Session termination failed: %v", err), http.StatusInternalServerError) @@ -466,7 +506,7 @@ func (s *StreamableHTTPServer) handleDelete(w http.ResponseWriter, r *http.Reque // remove the session relateddata from the sessionToolsStore s.sessionTools.delete(sessionID) - + s.sessionLogLevels.delete(sessionID) // remove current session's requstID information s.sessionRequestIDs.Delete(sessionID) @@ -485,6 +525,114 @@ func writeSSEEvent(w io.Writer, data any) error { return nil } +// handleSamplingResponse processes incoming sampling responses from clients +func (s *StreamableHTTPServer) handleSamplingResponse(w http.ResponseWriter, r *http.Request, responseMessage struct { + ID json.RawMessage `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error json.RawMessage `json:"error,omitempty"` + Method mcp.MCPMethod `json:"method,omitempty"` +}) error { + // Get session ID from header + sessionID := r.Header.Get(HeaderKeySessionID) + if sessionID == "" { + http.Error(w, "Missing session ID for sampling response", http.StatusBadRequest) + return fmt.Errorf("missing session ID") + } + + // Validate session + isTerminated, err := s.sessionIdManager.Validate(sessionID) + if err != nil { + http.Error(w, "Invalid session ID", http.StatusBadRequest) + return err + } + if isTerminated { + http.Error(w, "Session terminated", http.StatusNotFound) + return fmt.Errorf("session terminated") + } + + // Parse the request ID + var requestID int64 + if err := json.Unmarshal(responseMessage.ID, &requestID); err != nil { + http.Error(w, "Invalid request ID in sampling response", http.StatusBadRequest) + return err + } + + // Create the sampling response item + response := samplingResponseItem{ + requestID: requestID, + } + + // Parse result or error + if responseMessage.Error != nil { + // Parse error + var jsonrpcError struct { + Code int `json:"code"` + Message string `json:"message"` + } + if err := json.Unmarshal(responseMessage.Error, &jsonrpcError); err != nil { + response.err = fmt.Errorf("failed to parse error: %v", err) + } else { + response.err = fmt.Errorf("sampling error %d: %s", jsonrpcError.Code, jsonrpcError.Message) + } + } else if responseMessage.Result != nil { + // Parse result + var result mcp.CreateMessageResult + if err := json.Unmarshal(responseMessage.Result, &result); err != nil { + response.err = fmt.Errorf("failed to parse sampling result: %v", err) + } else { + response.result = &result + } + } else { + response.err = fmt.Errorf("sampling response has neither result nor error") + } + + // Find the corresponding session and deliver the response + // The response is delivered to the specific session identified by sessionID + if err := s.deliverSamplingResponse(sessionID, response); err != nil { + s.logger.Errorf("Failed to deliver sampling response: %v", err) + http.Error(w, "Failed to deliver response", http.StatusInternalServerError) + return err + } + + // Acknowledge receipt + w.WriteHeader(http.StatusOK) + return nil +} + +// deliverSamplingResponse delivers a sampling response to the appropriate session +func (s *StreamableHTTPServer) deliverSamplingResponse(sessionID string, response samplingResponseItem) error { + // Look up the active session + sessionInterface, ok := s.activeSessions.Load(sessionID) + if !ok { + return fmt.Errorf("no active session found for session %s", sessionID) + } + + session, ok := sessionInterface.(*streamableHttpSession) + if !ok { + return fmt.Errorf("invalid session type for session %s", sessionID) + } + + // Look up the dedicated response channel for this specific request + responseChannelInterface, exists := session.samplingRequests.Load(response.requestID) + if !exists { + return fmt.Errorf("no pending request found for session %s, request %d", sessionID, response.requestID) + } + + responseChan, ok := responseChannelInterface.(chan samplingResponseItem) + if !ok { + return fmt.Errorf("invalid response channel type for session %s, request %d", sessionID, response.requestID) + } + + // Attempt to deliver the response with timeout to prevent indefinite blocking + select { + case responseChan <- response: + s.logger.Infof("Delivered sampling response for session %s, request %d", sessionID, response.requestID) + return nil + default: + return fmt.Errorf("failed to deliver sampling response for session %s, request %d: channel full or blocked", sessionID, response.requestID) + } +} + // writeJSONRPCError writes a JSON-RPC error response with the given error details. func (s *StreamableHTTPServer) writeJSONRPCError( w http.ResponseWriter, @@ -509,6 +657,38 @@ func (s *StreamableHTTPServer) nextRequestID(sessionID string) int64 { } // --- session --- +type sessionLogLevelsStore struct { + mu sync.RWMutex + logs map[string]mcp.LoggingLevel +} + +func newSessionLogLevelsStore() *sessionLogLevelsStore { + return &sessionLogLevelsStore{ + logs: make(map[string]mcp.LoggingLevel), + } +} + +func (s *sessionLogLevelsStore) get(sessionID string) mcp.LoggingLevel { + s.mu.RLock() + defer s.mu.RUnlock() + val, ok := s.logs[sessionID] + if !ok { + return mcp.LoggingLevelError + } + return val +} + +func (s *sessionLogLevelsStore) set(sessionID string, level mcp.LoggingLevel) { + s.mu.Lock() + defer s.mu.Unlock() + s.logs[sessionID] = level +} + +func (s *sessionLogLevelsStore) delete(sessionID string) { + s.mu.Lock() + defer s.mu.Unlock() + delete(s.logs, sessionID) +} type sessionToolsStore struct { mu sync.RWMutex @@ -539,6 +719,19 @@ func (s *sessionToolsStore) delete(sessionID string) { delete(s.tools, sessionID) } +// Sampling support types for HTTP transport +type samplingRequestItem struct { + requestID int64 + request mcp.CreateMessageRequest + response chan samplingResponseItem +} + +type samplingResponseItem struct { + requestID int64 + result *mcp.CreateMessageResult + err error +} + // streamableHttpSession is a session for streamable-http transport // When in POST handlers(request/notification), it's ephemeral, and only exists in the life of the request handler. // When in GET handlers(listening), it's a real session, and will be registered in the MCP server. @@ -547,14 +740,23 @@ type streamableHttpSession struct { notificationChannel chan mcp.JSONRPCNotification // server -> client notifications tools *sessionToolsStore upgradeToSSE atomic.Bool + logLevels *sessionLogLevelsStore + + // Sampling support for bidirectional communication + samplingRequestChan chan samplingRequestItem // server -> client sampling requests + samplingRequests sync.Map // requestID -> pending sampling request context + requestIDCounter atomic.Int64 // for generating unique request IDs } -func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore) *streamableHttpSession { - return &streamableHttpSession{ - sessionID: sessionID, - notificationChannel: make(chan mcp.JSONRPCNotification, 100), - tools: toolStore, +func newStreamableHttpSession(sessionID string, toolStore *sessionToolsStore, levels *sessionLogLevelsStore) *streamableHttpSession { + s := &streamableHttpSession{ + sessionID: sessionID, + notificationChannel: make(chan mcp.JSONRPCNotification, 100), + tools: toolStore, + logLevels: levels, + samplingRequestChan: make(chan samplingRequestItem, 10), } + return s } func (s *streamableHttpSession) SessionID() string { @@ -575,6 +777,14 @@ func (s *streamableHttpSession) Initialized() bool { return true } +func (s *streamableHttpSession) SetLogLevel(level mcp.LoggingLevel) { + s.logLevels.set(s.sessionID, level) +} + +func (s *streamableHttpSession) GetLogLevel() mcp.LoggingLevel { + return s.logLevels.get(s.sessionID) +} + var _ ClientSession = (*streamableHttpSession)(nil) func (s *streamableHttpSession) GetSessionTools() map[string]ServerTool { @@ -585,7 +795,10 @@ func (s *streamableHttpSession) SetSessionTools(tools map[string]ServerTool) { s.tools.set(s.sessionID, tools) } -var _ SessionWithTools = (*streamableHttpSession)(nil) +var ( + _ SessionWithTools = (*streamableHttpSession)(nil) + _ SessionWithLogging = (*streamableHttpSession)(nil) +) func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() { s.upgradeToSSE.Store(true) @@ -593,6 +806,49 @@ func (s *streamableHttpSession) UpgradeToSSEWhenReceiveNotification() { var _ SessionWithStreamableHTTPConfig = (*streamableHttpSession)(nil) +// RequestSampling implements SessionWithSampling interface for HTTP transport +func (s *streamableHttpSession) RequestSampling(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Generate unique request ID + requestID := s.requestIDCounter.Add(1) + + // Create response channel for this specific request + responseChan := make(chan samplingResponseItem, 1) + + // Create the sampling request item + samplingRequest := samplingRequestItem{ + requestID: requestID, + request: request, + response: responseChan, + } + + // Store the pending request + s.samplingRequests.Store(requestID, responseChan) + defer s.samplingRequests.Delete(requestID) + + // Send the sampling request via the channel (non-blocking) + select { + case s.samplingRequestChan <- samplingRequest: + // Request queued successfully + case <-ctx.Done(): + return nil, ctx.Err() + default: + return nil, fmt.Errorf("sampling request queue is full - server overloaded") + } + + // Wait for response or context cancellation + select { + case response := <-responseChan: + if response.err != nil { + return nil, response.err + } + return response.result, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +var _ SessionWithSampling = (*streamableHttpSession)(nil) + // --- session id manager --- type SessionIdManager interface { @@ -613,10 +869,12 @@ type StatelessSessionIdManager struct{} func (s *StatelessSessionIdManager) Generate() string { return "" } + func (s *StatelessSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) { // In stateless mode, ignore session IDs completely - don't validate or reject them return false, nil } + func (s *StatelessSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) { return false, nil } @@ -631,6 +889,7 @@ const idPrefix = "mcp-session-" func (s *InsecureStatefulSessionIdManager) Generate() string { return idPrefix + uuid.New().String() } + func (s *InsecureStatefulSessionIdManager) Validate(sessionID string) (isTerminated bool, err error) { // validate the session id is a valid uuid if !strings.HasPrefix(sessionID, idPrefix) { @@ -641,6 +900,7 @@ func (s *InsecureStatefulSessionIdManager) Validate(sessionID string) (isTermina } return false, nil } + func (s *InsecureStatefulSessionIdManager) Terminate(sessionID string) (isNotAllowed bool, err error) { return false, nil } diff --git a/server/streamable_http_sampling_test.go b/server/streamable_http_sampling_test.go new file mode 100644 index 000000000..4cf57838c --- /dev/null +++ b/server/streamable_http_sampling_test.go @@ -0,0 +1,216 @@ +package server + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/mark3labs/mcp-go/mcp" +) + +// TestStreamableHTTPServer_SamplingBasic tests basic sampling session functionality +func TestStreamableHTTPServer_SamplingBasic(t *testing.T) { + // Create MCP server with sampling enabled + mcpServer := NewMCPServer("test-server", "1.0.0") + mcpServer.EnableSampling() + + // Create HTTP server + httpServer := NewStreamableHTTPServer(mcpServer) + testServer := httptest.NewServer(httpServer) + defer testServer.Close() + + // Test session creation and interface implementation + sessionID := "test-session" + session := newStreamableHttpSession(sessionID, httpServer.sessionTools, httpServer.sessionLogLevels) + + // Verify it implements SessionWithSampling + _, ok := any(session).(SessionWithSampling) + if !ok { + t.Error("streamableHttpSession should implement SessionWithSampling") + } + + // Test that sampling request channels are initialized + if session.samplingRequestChan == nil { + t.Error("samplingRequestChan should be initialized") + } +} + +// TestStreamableHTTPServer_SamplingErrorHandling tests error scenarios +func TestStreamableHTTPServer_SamplingErrorHandling(t *testing.T) { + mcpServer := NewMCPServer("test-server", "1.0.0") + mcpServer.EnableSampling() + + httpServer := NewStreamableHTTPServer(mcpServer) + testServer := httptest.NewServer(httpServer) + defer testServer.Close() + + client := &http.Client{} + baseURL := testServer.URL + + tests := []struct { + name string + sessionID string + body map[string]any + expectedStatus int + }{ + { + name: "missing session ID", + sessionID: "", + body: map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "result": map[string]any{ + "role": "assistant", + "content": map[string]any{ + "type": "text", + "text": "Test response", + }, + }, + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "invalid request ID", + sessionID: "mcp-session-550e8400-e29b-41d4-a716-446655440000", + body: map[string]any{ + "jsonrpc": "2.0", + "id": "invalid-id", + "result": map[string]any{ + "role": "assistant", + "content": map[string]any{ + "type": "text", + "text": "Test response", + }, + }, + }, + expectedStatus: http.StatusBadRequest, + }, + { + name: "malformed result", + sessionID: "mcp-session-550e8400-e29b-41d4-a716-446655440000", + body: map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "result": "invalid-result", + }, + expectedStatus: http.StatusInternalServerError, // Now correctly returns 500 due to no active session + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload, _ := json.Marshal(tt.body) + req, err := http.NewRequest("POST", baseURL, bytes.NewReader(payload)) + if err != nil { + t.Errorf("Failed to create request: %v", err) + return + } + req.Header.Set("Content-Type", "application/json") + if tt.sessionID != "" { + req.Header.Set("Mcp-Session-Id", tt.sessionID) + } + + resp, err := client.Do(req) + if err != nil { + t.Errorf("Failed to send request: %v", err) + return + } + defer resp.Body.Close() + + if resp.StatusCode != tt.expectedStatus { + t.Errorf("Expected status %d, got %d", tt.expectedStatus, resp.StatusCode) + } + }) + } +} + +// TestStreamableHTTPServer_SamplingInterface verifies interface implementation +func TestStreamableHTTPServer_SamplingInterface(t *testing.T) { + mcpServer := NewMCPServer("test-server", "1.0.0") + mcpServer.EnableSampling() + httpServer := NewStreamableHTTPServer(mcpServer) + testServer := httptest.NewServer(httpServer) + defer testServer.Close() + + // Create a session + sessionID := "test-session" + session := newStreamableHttpSession(sessionID, httpServer.sessionTools, httpServer.sessionLogLevels) + + // Verify it implements SessionWithSampling + _, ok := any(session).(SessionWithSampling) + if !ok { + t.Error("streamableHttpSession should implement SessionWithSampling") + } + + // Test RequestSampling with timeout + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + request := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: "Test message", + }, + }, + }, + }, + } + + _, err := session.RequestSampling(ctx, request) + if err == nil { + t.Error("Expected timeout error, but got nil") + } + + if !strings.Contains(err.Error(), "context deadline exceeded") { + t.Errorf("Expected timeout error, got: %v", err) + } +} + +// TestStreamableHTTPServer_SamplingQueueFull tests queue overflow scenarios +func TestStreamableHTTPServer_SamplingQueueFull(t *testing.T) { + sessionID := "test-session" + session := newStreamableHttpSession(sessionID, nil, nil) + + // Fill the sampling request queue + for i := 0; i < cap(session.samplingRequestChan); i++ { + session.samplingRequestChan <- samplingRequestItem{ + requestID: int64(i), + request: mcp.CreateMessageRequest{}, + response: make(chan samplingResponseItem, 1), + } + } + + // Try to add another request (should fail) + ctx := context.Background() + request := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: "Test message", + }, + }, + }, + }, + } + + _, err := session.RequestSampling(ctx, request) + if err == nil { + t.Error("Expected queue full error, but got nil") + } + + if !strings.Contains(err.Error(), "queue is full") { + t.Errorf("Expected queue full error, got: %v", err) + } +} \ No newline at end of file diff --git a/server/streamable_http_test.go b/server/streamable_http_test.go index aad48fc3a..105fd18ce 100644 --- a/server/streamable_http_test.go +++ b/server/streamable_http_test.go @@ -28,8 +28,7 @@ var initRequest = map[string]any{ "id": 1, "method": "initialize", "params": map[string]any{ - "protocolVersion": "2025-03-26", - "clientInfo": map[string]any{ + "protocolVersion": mcp.LATEST_PROTOCOL_VERSION, "clientInfo": map[string]any{ "name": "test-client", "version": "1.0.0", }, @@ -146,12 +145,12 @@ func TestStreamableHTTP_POST_SendAndReceive(t *testing.T) { if err := json.Unmarshal(bodyBytes, &responseMessage); err != nil { t.Fatalf("Failed to unmarshal response: %v", err) } - if responseMessage.Result["protocolVersion"] != "2025-03-26" { - t.Errorf("Expected protocol version 2025-03-26, got %s", responseMessage.Result["protocolVersion"]) + if responseMessage.Result["protocolVersion"] != mcp.LATEST_PROTOCOL_VERSION { + t.Errorf("Expected protocol version %s, got %s", mcp.LATEST_PROTOCOL_VERSION, responseMessage.Result["protocolVersion"]) } // get session id from header - sessionID = resp.Header.Get(headerKeySessionID) + sessionID = resp.Header.Get(HeaderKeySessionID) if sessionID == "" { t.Fatalf("Expected session id in header, got %s", sessionID) } @@ -171,7 +170,7 @@ func TestStreamableHTTP_POST_SendAndReceive(t *testing.T) { t.Fatalf("Failed to create request: %v", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set(headerKeySessionID, sessionID) + req.Header.Set(HeaderKeySessionID, sessionID) resp, err := server.Client().Do(req) if err != nil { @@ -208,7 +207,7 @@ func TestStreamableHTTP_POST_SendAndReceive(t *testing.T) { Notification: mcp.Notification{ Method: "testNotification", Params: mcp.NotificationParams{ - AdditionalFields: map[string]interface{}{"param1": "value1"}, + AdditionalFields: map[string]any{"param1": "value1"}, }, }, } @@ -216,7 +215,7 @@ func TestStreamableHTTP_POST_SendAndReceive(t *testing.T) { req, _ := http.NewRequest(http.MethodPost, server.URL, bytes.NewBuffer(rawNotification)) req.Header.Set("Content-Type", "application/json") - req.Header.Set(headerKeySessionID, sessionID) + req.Header.Set(HeaderKeySessionID, sessionID) resp, err := server.Client().Do(req) if err != nil { t.Fatalf("Failed to send message: %v", err) @@ -246,7 +245,7 @@ func TestStreamableHTTP_POST_SendAndReceive(t *testing.T) { t.Fatalf("Failed to create request: %v", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set(headerKeySessionID, "dummy-session-id") + req.Header.Set(HeaderKeySessionID, "dummy-session-id") resp, err := server.Client().Do(req) if err != nil { @@ -275,7 +274,7 @@ func TestStreamableHTTP_POST_SendAndReceive(t *testing.T) { t.Fatalf("Failed to create request: %v", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set(headerKeySessionID, sessionID) + req.Header.Set(HeaderKeySessionID, sessionID) resp, err := server.Client().Do(req) if err != nil { @@ -283,8 +282,8 @@ func TestStreamableHTTP_POST_SendAndReceive(t *testing.T) { } defer resp.Body.Close() - if resp.StatusCode != http.StatusAccepted { - t.Errorf("Expected status 202, got %d", resp.StatusCode) + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) } if resp.Header.Get("content-type") != "text/event-stream" { t.Errorf("Expected content-type text/event-stream, got %s", resp.Header.Get("content-type")) @@ -339,12 +338,12 @@ func TestStreamableHTTP_POST_SendAndReceive_stateless(t *testing.T) { if err := json.Unmarshal(bodyBytes, &responseMessage); err != nil { t.Fatalf("Failed to unmarshal response: %v", err) } - if responseMessage.Result["protocolVersion"] != "2025-03-26" { - t.Errorf("Expected protocol version 2025-03-26, got %s", responseMessage.Result["protocolVersion"]) + if responseMessage.Result["protocolVersion"] != mcp.LATEST_PROTOCOL_VERSION { + t.Errorf("Expected protocol version %s, got %s", mcp.LATEST_PROTOCOL_VERSION, responseMessage.Result["protocolVersion"]) } // no session id from header - sessionID := resp.Header.Get(headerKeySessionID) + sessionID := resp.Header.Get(HeaderKeySessionID) if sessionID != "" { t.Fatalf("Expected no session id in header, got %s", sessionID) } @@ -396,7 +395,7 @@ func TestStreamableHTTP_POST_SendAndReceive_stateless(t *testing.T) { Notification: mcp.Notification{ Method: "testNotification", Params: mcp.NotificationParams{ - AdditionalFields: map[string]interface{}{"param1": "value1"}, + AdditionalFields: map[string]any{"param1": "value1"}, }, }, } @@ -433,7 +432,7 @@ func TestStreamableHTTP_POST_SendAndReceive_stateless(t *testing.T) { t.Fatalf("Failed to create request: %v", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set(headerKeySessionID, "dummy-session-id") + req.Header.Set(HeaderKeySessionID, "dummy-session-id") resp, err := server.Client().Do(req) if err != nil { @@ -473,7 +472,7 @@ func TestStreamableHTTP_POST_SendAndReceive_stateless(t *testing.T) { t.Fatalf("Failed to create request: %v", err) } req.Header.Set("Content-Type", "application/json") - req.Header.Set(headerKeySessionID, "mcp-session-2c44d701-fd50-44ce-92b8-dec46185a741") + req.Header.Set(HeaderKeySessionID, "mcp-session-2c44d701-fd50-44ce-92b8-dec46185a741") resp, err := server.Client().Do(req) if err != nil { @@ -529,8 +528,8 @@ func TestStreamableHTTP_GET(t *testing.T) { } defer resp.Body.Close() - if resp.StatusCode != http.StatusAccepted { - t.Errorf("Expected status 202, got %d", resp.StatusCode) + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) } if resp.Header.Get("content-type") != "text/event-stream" { @@ -565,8 +564,7 @@ func TestStreamableHTTP_HttpHandler(t *testing.T) { "id": 1, "method": "initialize", "params": map[string]any{ - "protocolVersion": "2025-03-26", - "clientInfo": map[string]any{ + "protocolVersion": mcp.LATEST_PROTOCOL_VERSION, "clientInfo": map[string]any{ "name": "test-client", "version": "1.0.0", }, @@ -586,8 +584,8 @@ func TestStreamableHTTP_HttpHandler(t *testing.T) { if err := json.Unmarshal(bodyBytes, &responseMessage); err != nil { t.Fatalf("Failed to unmarshal response: %v", err) } - if responseMessage.Result["protocolVersion"] != "2025-03-26" { - t.Errorf("Expected protocol version 2025-03-26, got %s", responseMessage.Result["protocolVersion"]) + if responseMessage.Result["protocolVersion"] != mcp.LATEST_PROTOCOL_VERSION { + t.Errorf("Expected protocol version %s, got %s", mcp.LATEST_PROTOCOL_VERSION, responseMessage.Result["protocolVersion"]) } }) } @@ -725,6 +723,74 @@ func TestStreamableHTTP_SessionWithTools(t *testing.T) { }) } +func TestStreamableHTTP_SessionWithLogging(t *testing.T) { + t.Run("SessionWithLogging implementation", func(t *testing.T) { + hooks := &Hooks{} + var logSession *streamableHttpSession + var mu sync.Mutex + + hooks.AddAfterSetLevel(func(ctx context.Context, id any, message *mcp.SetLevelRequest, result *mcp.EmptyResult) { + if s, ok := ClientSessionFromContext(ctx).(*streamableHttpSession); ok { + mu.Lock() + logSession = s + mu.Unlock() + } + }) + + mcpServer := NewMCPServer("test", "1.0.0", WithHooks(hooks), WithLogging()) + testServer := NewTestStreamableHTTPServer(mcpServer) + defer testServer.Close() + + // obtain a valid session ID first + initResp, err := postJSON(testServer.URL, initRequest) + if err != nil { + t.Fatalf("Failed to send init request: %v", err) + } + defer initResp.Body.Close() + sessionID := initResp.Header.Get(HeaderKeySessionID) + if sessionID == "" { + t.Fatal("Expected session id in header") + } + + setLevelRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 1, + "method": "logging/setLevel", + "params": map[string]any{ + "level": mcp.LoggingLevelCritical, + }, + } + + reqBody, _ := json.Marshal(setLevelRequest) + req, err := http.NewRequest(http.MethodPost, testServer.URL, bytes.NewBuffer(reqBody)) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set(HeaderKeySessionID, sessionID) + + resp, err := testServer.Client().Do(req) + if err != nil { + t.Fatalf("Failed to send message: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + mu.Lock() + if logSession == nil { + mu.Unlock() + t.Fatal("Session was not captured") + } + if logSession.GetLogLevel() != mcp.LoggingLevelCritical { + t.Errorf("Expected critical level, got %v", logSession.GetLogLevel()) + } + mu.Unlock() + }) +} + func TestStreamableHTTPServer_WithOptions(t *testing.T) { t.Run("WithStreamableHTTPServer sets httpServer field", func(t *testing.T) { mcpServer := NewMCPServer("test", "1.0.0") @@ -775,6 +841,59 @@ func TestStreamableHTTPServer_WithOptions(t *testing.T) { }) } +func TestStreamableHTTP_HeaderPassthrough(t *testing.T) { + mcpServer := NewMCPServer("test-mcp-server", "1.0") + + var receivedHeaders struct { + contentType string + customHeader string + } + mcpServer.AddTool( + mcp.NewTool("check-headers"), + func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + receivedHeaders.contentType = request.Header.Get("Content-Type") + receivedHeaders.customHeader = request.Header.Get("X-Custom-Header") + return mcp.NewToolResultText("ok"), nil + }, + ) + + server := NewTestStreamableHTTPServer(mcpServer) + defer server.Close() + + // Initialize to get session + resp, _ := postJSON(server.URL, initRequest) + sessionID := resp.Header.Get(HeaderKeySessionID) + resp.Body.Close() + + // Test header passthrough + toolRequest := map[string]any{ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": map[string]any{ + "name": "check-headers", + }, + } + toolBody, _ := json.Marshal(toolRequest) + req, _ := http.NewRequest("POST", server.URL, bytes.NewReader(toolBody)) + + const expectedContentType = "application/json" + const expectedCustomHeader = "test-value" + req.Header.Set("Content-Type", expectedContentType) + req.Header.Set("X-Custom-Header", expectedCustomHeader) + req.Header.Set(HeaderKeySessionID, sessionID) + + resp, _ = server.Client().Do(req) + resp.Body.Close() + + if receivedHeaders.contentType != expectedContentType { + t.Errorf("Expected Content-Type header '%s', got '%s'", expectedContentType, receivedHeaders.contentType) + } + if receivedHeaders.customHeader != expectedCustomHeader { + t.Errorf("Expected X-Custom-Header '%s', got '%s'", expectedCustomHeader, receivedHeaders.customHeader) + } +} + func postJSON(url string, bodyObject any) (*http.Response, error) { jsonBody, _ := json.Marshal(bodyObject) req, _ := http.NewRequest(http.MethodPost, url, bytes.NewBuffer(jsonBody)) diff --git a/testdata/mockstdio_server.go b/testdata/mockstdio_server.go index f561285e9..30bf0c001 100644 --- a/testdata/mockstdio_server.go +++ b/testdata/mockstdio_server.go @@ -52,7 +52,7 @@ func handleRequest(request JSONRPCRequest) JSONRPCResponse { switch request.Method { case "initialize": response.Result = map[string]any{ - "protocolVersion": "1.0", + "protocolVersion": mcp.LATEST_PROTOCOL_VERSION, "serverInfo": map[string]any{ "name": "mock-server", "version": "1.0.0", diff --git a/www/docs/pages/clients/advanced-sampling.mdx b/www/docs/pages/clients/advanced-sampling.mdx new file mode 100644 index 000000000..81a4cc9aa --- /dev/null +++ b/www/docs/pages/clients/advanced-sampling.mdx @@ -0,0 +1,497 @@ +# Sampling + +Learn how to implement MCP clients that can handle sampling requests from servers, enabling bidirectional communication where clients provide LLM capabilities to servers. + +## Overview + +Sampling allows MCP clients to respond to LLM completion requests from servers. When a server needs to generate content, answer questions, or perform reasoning tasks, it can send a sampling request to the client, which then processes it using an LLM and returns the result. + +## Implementing a Sampling Handler + +Create a sampling handler by implementing the `SamplingHandler` interface: + +```go +package main + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" + "github.com/mark3labs/mcp-go/mcp" +) + +type MySamplingHandler struct { + // Add fields for your LLM client (OpenAI, Anthropic, etc.) +} + +func (h *MySamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Extract request parameters + messages := request.Messages + systemPrompt := request.SystemPrompt + maxTokens := request.MaxTokens + temperature := request.Temperature + + // Process with your LLM + response, err := h.callLLM(ctx, messages, systemPrompt, maxTokens, temperature) + if err != nil { + return nil, fmt.Errorf("LLM call failed: %w", err) + } + + // Return MCP-formatted result + return &mcp.CreateMessageResult{ + Model: "your-model-name", + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: response, + }, + StopReason: "endTurn", + }, nil +} + +func (h *MySamplingHandler) callLLM(ctx context.Context, messages []mcp.SamplingMessage, systemPrompt string, maxTokens int, temperature float64) (string, error) { + // Implement your LLM integration here + // This is where you'd call OpenAI, Anthropic, or other LLM APIs + return "Your LLM response here", nil +} +``` + +## Configuring the Client + +Enable sampling by providing a handler when creating the client: + +```go +func main() { + // Create sampling handler + samplingHandler := &MySamplingHandler{} + + // Create stdio transport + stdioTransport := transport.NewStdio("/path/to/mcp/server", nil) + + // Create client with sampling support + mcpClient := client.NewClient(stdioTransport, client.WithSamplingHandler(samplingHandler)) + + // Start the client + ctx := context.Background() + if err := mcpClient.Start(ctx); err != nil { + log.Fatalf("Failed to start client: %v", err) + } + defer mcpClient.Close() + if err := mcpClient.Connect(ctx); err != nil { + log.Fatalf("Failed to connect: %v", err) + } + + // The client will now automatically handle sampling requests + // from the server using your handler +} +``` + +## Mock Implementation Example + +Here's a complete mock implementation for testing: + +```go +package main + +import ( +import ( + "context" + "fmt" + "log" + "strings" + "os" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" +) +type MockSamplingHandler struct{} + +func (h *MockSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Log the request for debugging + log.Printf("Mock LLM received sampling request:") + log.Printf(" System prompt: %s", request.SystemPrompt) + log.Printf(" Max tokens: %d", request.MaxTokens) + log.Printf(" Temperature: %f", request.Temperature) + + // Extract the user's message + var userMessage string + for _, msg := range request.Messages { + if msg.Role == mcp.RoleUser { + if textContent, ok := msg.Content.(mcp.TextContent); ok { + userMessage = textContent.Text + log.Printf(" User message: %s", userMessage) + break + } + } + } + + // Generate a mock response + mockResponse := fmt.Sprintf( + "Mock LLM response to: '%s'. This is a simulated response from a mock LLM handler.", + userMessage, + ) + + return &mcp.CreateMessageResult{ + Model: "mock-llm-v1", + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: mockResponse, + }, + StopReason: "endTurn", + }, nil +} + +func main() { + if len(os.Args) < 2 { + log.Fatal("Usage: sampling_client ") + } + + serverPath := os.Args[1] + + // Create stdio transport + stdioTransport := transport.NewStdio(serverPath, nil) + + // Create client with mock sampling handler + mcpClient := client.NewClient(stdioTransport, client.WithSamplingHandler(&MockSamplingHandler{})) + + // Start the client + ctx := context.Background() + if err := mcpClient.Start(ctx); err != nil { + log.Fatalf("Failed to start client: %v", err) + } + defer mcpClient.Close() + if err := mcpClient.Connect(ctx); err != nil { + log.Fatalf("Failed to connect: %v", err) + } + + // Test server tools that use sampling + result, err := mcpClient.CallTool(ctx, "ask_llm", map[string]any{ + "question": "What is the capital of France?", + "system_prompt": "You are a helpful geography assistant.", + }) + if err != nil { + log.Fatalf("Tool call failed: %v", err) + } + + fmt.Printf("Tool result: %+v\\n", result) +} +``` + +## Real LLM Integration + +### OpenAI Integration + +```go +import ( + "github.com/sashabaranov/go-openai" +) + +type OpenAISamplingHandler struct { + client *openai.Client +} + +func NewOpenAISamplingHandler(apiKey string) *OpenAISamplingHandler { + return &OpenAISamplingHandler{ + client: openai.NewClient(apiKey), + } +} + +func (h *OpenAISamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Convert MCP messages to OpenAI format + var messages []openai.ChatCompletionMessage + + // Add system message if provided + if request.SystemPrompt != "" { + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleSystem, + Content: request.SystemPrompt, + }) + } + + // Convert MCP messages + for _, msg := range request.Messages { + var role string + switch msg.Role { + case mcp.RoleUser: + role = openai.ChatMessageRoleUser + case mcp.RoleAssistant: + role = openai.ChatMessageRoleAssistant + } + + if textContent, ok := msg.Content.(mcp.TextContent); ok { + messages = append(messages, openai.ChatCompletionMessage{ + Role: role, + Content: textContent.Text, + }) + } + } + + // Create OpenAI request + req := openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: messages, + MaxTokens: request.MaxTokens, + Temperature: float32(request.Temperature), + } + + // Call OpenAI API + resp, err := h.client.CreateChatCompletion(ctx, req) + if err != nil { + return nil, fmt.Errorf("OpenAI API call failed: %w", err) + } + + if len(resp.Choices) == 0 { + return nil, fmt.Errorf("no response from OpenAI") + } + + choice := resp.Choices[0] + + // Convert stop reason + var stopReason string + switch choice.FinishReason { + case "stop": + stopReason = "endTurn" + case "length": + stopReason = "maxTokens" + default: + stopReason = "other" + } + + return &mcp.CreateMessageResult{ + Model: resp.Model, + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: choice.Message.Content, + }, + StopReason: stopReason, + }, nil +} +``` + +### Anthropic Integration + +```go +import ( + "bytes" + "encoding/json" + "net/http" +) + +type AnthropicSamplingHandler struct { + apiKey string + client *http.Client +} + +func NewAnthropicSamplingHandler(apiKey string) *AnthropicSamplingHandler { + return &AnthropicSamplingHandler{ + apiKey: apiKey, + client: &http.Client{}, + } +} + +func (h *AnthropicSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Convert to Anthropic format + anthropicReq := map[string]any{ + "model": "claude-3-sonnet-20240229", + "max_tokens": request.MaxTokens, + "messages": h.convertMessages(request.Messages), + } + + if request.SystemPrompt != "" { + anthropicReq["system"] = request.SystemPrompt + } + + if request.Temperature > 0 { + anthropicReq["temperature"] = request.Temperature + } + + // Make API call + reqBody, _ := json.Marshal(anthropicReq) + httpReq, _ := http.NewRequestWithContext(ctx, "POST", + "https://api.anthropic.com/v1/messages", bytes.NewBuffer(reqBody)) + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("x-api-key", h.apiKey) + httpReq.Header.Set("anthropic-version", "2023-06-01") + + resp, err := h.client.Do(httpReq) + if err != nil { + return nil, fmt.Errorf("Anthropic API call failed: %w", err) + } + defer resp.Body.Close() + + var anthropicResp struct { + Content []struct { + Text string `json:"text"` + Type string `json:"type"` + } `json:"content"` + Model string `json:"model"` + StopReason string `json:"stop_reason"` + } + + if err := json.NewDecoder(resp.Body).Decode(&anthropicResp); err != nil { + return nil, fmt.Errorf("failed to decode Anthropic response: %w", err) + } + + // Extract text content + var text string + for _, content := range anthropicResp.Content { + if content.Type == "text" { + text += content.Text + } + } + + return &mcp.CreateMessageResult{ + Model: anthropicResp.Model, + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: text, + }, + StopReason: anthropicResp.StopReason, + }, nil +} + +func (h *AnthropicSamplingHandler) convertMessages(messages []mcp.SamplingMessage) []map[string]any { + var result []map[string]any + for _, msg := range messages { + if textContent, ok := msg.Content.(mcp.TextContent); ok { + result = append(result, map[string]any{ + "role": string(msg.Role), + "content": textContent.Text, + }) + } + } + return result +} +``` + +## Automatic Capability Declaration + +When you provide a sampling handler, the client automatically declares the sampling capability during initialization: + +```go +// This automatically adds sampling capability +stdioTransport := transport.NewStdio(serverPath, nil) +mcpClient := client.NewClient(stdioTransport, client.WithSamplingHandler(handler)) +``` + +The client will include this in the initialization request: + +```json +{ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": { + "sampling": {} + }, + "clientInfo": { + "name": "your-client", + "version": "1.0.0" + } + } +} +``` + +## Error Handling + +Handle errors gracefully in your sampling handler: + +```go +func (h *MySamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Validate request + if len(request.Messages) == 0 { + return nil, fmt.Errorf("no messages provided") + } + + // Check for context cancellation + if err := ctx.Err(); err != nil { + return nil, fmt.Errorf("request cancelled: %w", err) + } + + // Call LLM with error handling + response, err := h.callLLM(ctx, request) + if err != nil { + // Log error for debugging + log.Printf("LLM call failed: %v", err) + + // Return appropriate error + if strings.Contains(err.Error(), "rate limit") { + return nil, fmt.Errorf("rate limit exceeded, please try again later") + } + return nil, fmt.Errorf("LLM service unavailable: %w", err) + } + + return response, nil +} +``` + +## Best Practices + +1. **Implement Proper Error Handling**: Always handle LLM API errors gracefully +2. **Respect Rate Limits**: Implement rate limiting and backoff strategies +3. **Validate Inputs**: Check message content and parameters before processing +4. **Use Context**: Respect context cancellation and timeouts +5. **Log Appropriately**: Log requests for debugging but avoid logging sensitive data +6. **Model Selection**: Allow configuration of which LLM model to use +7. **Content Filtering**: Implement content filtering if required by your use case + +## Testing Your Implementation + +Test your sampling handler with the sampling server example: + +```bash +# Build the sampling server +cd examples/sampling_server +go build -o sampling_server + +# Build your client +go build -o my_client + +# Test the integration +./my_client ./sampling_server +``` + +## Transport Support + +Sampling is available on the following transports: + +### STDIO Transport + +For STDIO clients, create the transport and client separately: + +```go +stdioTransport := transport.NewStdio("/path/to/server", nil) +mcpClient := client.NewClient(stdioTransport, client.WithSamplingHandler(&MySamplingHandler{})) +``` + +### In-Process Transport + +For in-process clients, use the dedicated constructor: + +```go +mcpClient, err := client.NewInProcessClientWithSamplingHandler( + mcpServer, + &MySamplingHandler{}, +) +``` + +In-process sampling uses direct method calls instead of JSON-RPC serialization. + +### Unsupported Transports + +SSE and StreamableHTTP transports do not support sampling due to their one-way or stateless nature. + +## Next Steps + +- Learn about [server-side sampling implementation](/servers/advanced-sampling) +- Explore [client operations](/clients/operations) +- Check out the [sampling examples](https://github.com/mark3labs/mcp-go/tree/main/examples/sampling_client) +- See [in-process transport documentation](/transports/inprocess) for embedded scenarios \ No newline at end of file diff --git a/www/docs/pages/clients/operations.mdx b/www/docs/pages/clients/operations.mdx index 3cea1e3ec..8734486d1 100644 --- a/www/docs/pages/clients/operations.mdx +++ b/www/docs/pages/clients/operations.mdx @@ -909,6 +909,52 @@ func demonstrateSubscriptionManager(c client.Client) { } ``` +## Advanced: Sampling Support + +Sampling is an advanced feature that allows clients to respond to LLM completion requests from servers. This enables servers to leverage client-side LLM capabilities for content generation and reasoning. + +> **Note**: Sampling is an advanced feature that most clients don't need. Only implement sampling if you're building a client that provides LLM capabilities to servers. + +### When to Implement Sampling + +Consider implementing sampling when your client: +- Has access to LLM APIs (OpenAI, Anthropic, etc.) +- Wants to provide LLM capabilities to servers +- Needs to support servers that generate dynamic content + +### Basic Implementation + +```go +import "github.com/mark3labs/mcp-go/client" + +// Implement the SamplingHandler interface +type MySamplingHandler struct { + // Add your LLM client here +} + +func (h *MySamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Process the request with your LLM + // Return the result in MCP format + return &mcp.CreateMessageResult{ + Model: "your-model", + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: "Your LLM response here", + }, + StopReason: "endTurn", + }, nil +} + +// Create client with sampling support +mcpClient, err := client.NewStdioClient( + "/path/to/server", + client.WithSamplingHandler(&MySamplingHandler{}), +) +``` + +For complete sampling documentation, see **[Client Sampling Guide](/clients/advanced-sampling)**. + ## Next Steps - **[Client Transports](/clients/transports)** - Learn transport-specific client features diff --git a/www/docs/pages/clients/transports.mdx b/www/docs/pages/clients/transports.mdx index efef67cf8..1a2e6ddcf 100644 --- a/www/docs/pages/clients/transports.mdx +++ b/www/docs/pages/clients/transports.mdx @@ -6,12 +6,12 @@ Learn about transport-specific client implementations and how to choose the righ MCP-Go provides client implementations for all supported transports. Each transport has different characteristics and is optimized for specific scenarios. -| Transport | Best For | Connection | Real-time | Multi-client | -|-----------|----------|------------|-----------|--------------| -| **STDIO** | CLI tools, desktop apps | Process pipes | No | No | -| **StreamableHTTP** | Web services, APIs | HTTP requests | No | Yes | -| **SSE** | Web apps, real-time | HTTP + EventSource | Yes | Yes | -| **In-Process** | Testing, embedded | Direct calls | Yes | No | +| Transport | Best For | Connection | Real-time | Multi-client | +| ------------------ | ----------------------- | ------------------ | --------- | ------------ | +| **STDIO** | CLI tools, desktop apps | Process pipes | No | No | +| **StreamableHTTP** | Web services, APIs | HTTP requests | No | Yes | +| **SSE** | Web apps, real-time | HTTP + EventSource | Yes | Yes | +| **In-Process** | Testing, embedded | Direct calls | Yes | No | ## STDIO Client @@ -65,6 +65,42 @@ func createStdioClient() { } ``` +### STDIO Client with Custom Configuration + +```go +func createCustomStdioClient() { + // Create custom logger for debugging + logger := myCustomLogger{} + + // Create STDIO client with custom options + c, err := client.NewStdioMCPClientWithOptions( + "go", + []string{"GOCACHE=/tmp/gocache"}, // Custom environment + []string{"run", "/path/to/server/main.go"}, + transport.WithCommandLogger(logger), + transport.WithCommandFunc(func(ctx context.Context, command string, args []string, env []string) (*exec.Cmd, error) { + cmd := exec.CommandContext(ctx, command, args...) + cmd.Env = append(os.Environ(), env...) + cmd.Dir = "/path/to/working/directory" + return cmd, nil + }), + ) + if err != nil { + log.Fatal(err) + } + defer c.Close() + + ctx := context.Background() + + // Initialize connection + if err := c.Initialize(ctx); err != nil { + log.Fatal(err) + } + + // Use the client... +} +``` + ### STDIO Error Handling ```go @@ -175,7 +211,7 @@ func (msc *ManagedStdioClient) monitorProcess() { return case <-msc.restartChan: log.Println("Restarting STDIO client...") - + if msc.client != nil { msc.client.Close() } @@ -219,11 +255,11 @@ func (msc *ManagedStdioClient) CallTool(ctx context.Context, req mcp.CallToolReq func (msc *ManagedStdioClient) Close() error { msc.cancel() msc.wg.Wait() - + if msc.client != nil { return msc.client.Close() } - + return nil } @@ -277,8 +313,12 @@ func createStreamableHTTPClient() { ```go func createCustomStreamableHTTPClient() { + // Create custom logger for debugging + logger := myCustomLogger{} + // Create StreamableHTTP client with options c := client.NewStreamableHttpClient("https://api.example.com/mcp", + transport.WithLogger(logger), transport.WithHTTPTimeout(30*time.Second), transport.WithHTTPHeaders(map[string]string{ "User-Agent": "MyApp/1.0", @@ -389,6 +429,31 @@ func (pool *StreamableHTTPClientPool) CallTool(ctx context.Context, req mcp.Call } ``` +### StreamableHTTP With Preconfigured Session + +You can also create a StreamableHTTP client with a preconfigured session, which allows you to reuse the same session across multiple requests + +```go +func createStreamableHTTPClientWithSession() { + // Create StreamableHTTP client with options + sessionID := // fetch existing session ID + c := client.NewStreamableHttpClient("https://api.example.com/mcp", + transport.WithSession(sessionID), + ) + defer c.Close() + + ctx := context.Background() + // Use client... + _, err := c.ListTools(ctx) + // If the session is terminated, you must reinitialize the client + if errors.Is(err, transport.ErrSessionTerminated) { + c.Initialize(ctx) // Reinitialize if session is terminated + // The session ID should change after reinitialization + sessionID = c.GetSessionId() // Update session ID + } +} +``` + ## SSE Client SSE (Server-Sent Events) clients provide real-time communication with servers. @@ -434,6 +499,40 @@ func createSSEClient() { } ``` +### SSE Client with Custom Configuration + +```go +func createCustomSSEClient() { + // Create custom logger for debugging + logger := myCustomLogger{} + + // Create SSE client with custom options + c, err := client.NewSSEMCPClient("http://localhost:8080/mcp/sse", + transport.WithSSELogger(logger), + transport.WithHeaders(map[string]string{ + "Authorization": "Bearer your-token", + "User-Agent": "MyApp/1.0", + }), + transport.WithHTTPClient(&http.Client{ + Timeout: 30 * time.Second, + }), + ) + if err != nil { + log.Fatal(err) + } + defer c.Close() + + ctx := context.Background() + + // Initialize + if err := c.Initialize(ctx); err != nil { + log.Fatal(err) + } + + // Use client... +} +``` + ### SSE Client with Reconnection ```go @@ -477,7 +576,7 @@ func (rsc *ResilientSSEClient) connect() error { } client := client.NewSSEClient(rsc.baseURL) - + // Set headers for key, value := range rsc.headers { client.SetHeader(key, value) @@ -498,11 +597,11 @@ func (rsc *ResilientSSEClient) reconnectLoop() { return case <-rsc.reconnectCh: log.Println("Reconnecting SSE client...") - + for attempt := 1; attempt <= 5; attempt++ { if err := rsc.connect(); err != nil { log.Printf("Reconnection attempt %d failed: %v", attempt, err) - + backoff := time.Duration(attempt) * time.Second select { case <-time.After(backoff): @@ -554,14 +653,14 @@ func (rsc *ResilientSSEClient) Subscribe(ctx context.Context) (<-chan mcp.Notifi func (rsc *ResilientSSEClient) Close() error { rsc.cancel() - + rsc.mutex.Lock() defer rsc.mutex.Unlock() - + if rsc.client != nil { return rsc.client.Close() } - + return nil } @@ -604,7 +703,7 @@ func (seh *SSEEventHandler) Start() error { seh.wg.Add(1) go func() { defer seh.wg.Done() - + for { select { case notification := <-notifications: @@ -642,7 +741,7 @@ func (seh *SSEEventHandler) OnToolUpdate(handler func(mcp.Notification)) { func (seh *SSEEventHandler) addHandler(method string, handler func(mcp.Notification)) { seh.mutex.Lock() defer seh.mutex.Unlock() - + seh.handlers[method] = append(seh.handlers[method], handler) } @@ -667,7 +766,7 @@ In-process clients provide direct communication with servers in the same process func createInProcessClient() { // Create server s := server.NewMCPServer("Test Server", "1.0.0") - + // Add tools to server s.AddTool( mcp.NewTool("test_tool", @@ -805,16 +904,16 @@ func SelectTransport(req TransportRequirements) string { switch { case !req.NetworkRequired && req.Performance == "high": return "inprocess" - + case !req.NetworkRequired && !req.MultiClient: return "stdio" - + case req.RealTime && req.MultiClient: return "sse" - + case req.NetworkRequired && req.MultiClient: return "streamablehttp" - + default: return "stdio" // Default fallback } @@ -911,12 +1010,12 @@ func (cf *ClientFactory) CreateClient(transport string) (client.Client, error) { if !ok { return nil, fmt.Errorf("streamablehttp config not set") } - + options := []transport.StreamableHTTPCOption{} if len(config.Headers) > 0 { options = append(options, transport.WithHTTPHeaders(config.Headers)) } - + return client.NewStreamableHttpClient(config.BaseURL, options...), nil case "sse": @@ -927,12 +1026,12 @@ func (cf *ClientFactory) CreateClient(transport string) (client.Client, error) { if !ok { return nil, fmt.Errorf("sse config not set") } - + options := []transport.ClientOption{} if len(config.Headers) > 0 { options = append(options, transport.WithHeaders(config.Headers)) } - + return client.NewSSEMCPClient(config.BaseURL, options...) default: @@ -943,7 +1042,7 @@ func (cf *ClientFactory) CreateClient(transport string) (client.Client, error) { // Usage func demonstrateClientFactory() { factory := NewClientFactory() - + // Configure transports factory.SetStdioConfig("go", "run", "server.go") factory.SetStreamableHTTPConfig("http://localhost:8080/mcp", map[string]string{ @@ -969,3 +1068,19 @@ func demonstrateClientFactory() { } ``` +## Logging Configuration + +All client transports support custom logging. +Each transport provides a logger option that accepts any implementation of the `util.Logger` interface. + +```go +type myCustomLogger struct {} + +func (myCustomLogger) Infof(format string, args ...any) { + // TODO +} + +func (myCustomLogger) Errorf(format string, args ...any) { + // TODO +} +``` diff --git a/www/docs/pages/quick-start.mdx b/www/docs/pages/quick-start.mdx index 074c0965e..d805c0f3a 100644 --- a/www/docs/pages/quick-start.mdx +++ b/www/docs/pages/quick-start.mdx @@ -11,7 +11,6 @@ package main import ( "context" - "errors" "fmt" "github.com/mark3labs/mcp-go/mcp" @@ -23,7 +22,7 @@ func main() { s := server.NewMCPServer( "Hello World Server", "1.0.0", - server.WithToolCapabilities(false), + server.WithToolCapabilities(true), ) // Define a simple tool @@ -45,12 +44,28 @@ func main() { } func helloHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { - name, err := request.RequireString("name") - if err != nil { - return mcp.NewToolResultError(err.Error()), nil + arguments := request.GetArguments() + name, ok := arguments["name"].(string) + if !ok { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "Error: name parameter is required and must be a string", + }, + }, + IsError: true, + }, nil } - return mcp.NewToolResultText(fmt.Sprintf("Hello, %s! 👋", name)), nil + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Hello, %s! 👋", name), + }, + }, + }, nil } ``` @@ -114,57 +129,77 @@ import ( "context" "fmt" "log" + "time" "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" ) func main() { - // Create a stdio client that connects to another MCP server - // NTOE: NewStdioMCPClient will start the connection automatically. Don't call the Start method manually - c, err := client.NewStdioMCPClient( - "go", "run", "path/to/server/main.go", - ) - if err != nil { - log.Fatal(err) - } - defer c.Close() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() - ctx := context.Background() + // Create stdio transport + stdioTransport := transport.NewStdio("go", nil, "run", "path/to/server/main.go") - // Initialize the connection - if err := c.Initialize(ctx); err != nil { - log.Fatal(err) - } + // Create client with the transport + c := client.NewClient(stdioTransport) - // List available tools - tools, err := c.ListTools(ctx) - if err != nil { - log.Fatal(err) + // Start the client + if err := c.Start(ctx); err != nil { + log.Fatalf("Failed to start client: %v", err) } + defer c.Close() - fmt.Printf("Available tools: %d\n", len(tools.Tools)) - for _, tool := range tools.Tools { - fmt.Printf("- %s: %s\n", tool.Name, tool.Description) + // Initialize the client + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "Hello World Client", + Version: "1.0.0", } + initRequest.Params.Capabilities = mcp.ClientCapabilities{} - // Call a tool - result, err := c.CallTool(ctx, mcp.CallToolRequest{ - Params: mcp.CallToolParams{ - Name: "hello_world", - Arguments: map[string]interface{}{ - "name": "World", - }, - }, - }) + serverInfo, err := c.Initialize(ctx, initRequest) if err != nil { - log.Fatal(err) + log.Fatalf("Failed to initialize: %v", err) } - // Print the result - for _, content := range result.Content { - if content.Type == "text" { - fmt.Printf("Result: %s\n", content.Text) + fmt.Printf("Connected to server: %s (version %s)\n", + serverInfo.ServerInfo.Name, + serverInfo.ServerInfo.Version) + + // List available tools + if serverInfo.Capabilities.Tools != nil { + toolsRequest := mcp.ListToolsRequest{} + toolsResult, err := c.ListTools(ctx, toolsRequest) + if err != nil { + log.Fatalf("Failed to list tools: %v", err) + } + + fmt.Printf("Available tools: %d\n", len(toolsResult.Tools)) + for _, tool := range toolsResult.Tools { + fmt.Printf("- %s: %s\n", tool.Name, tool.Description) + } + + // Call a tool + callRequest := mcp.CallToolRequest{} + callRequest.Params.Name = "hello_world" + callRequest.Params.Arguments = map[string]interface{}{ + "name": "World", + } + + result, err := c.CallTool(ctx, callRequest) + if err != nil { + log.Fatalf("Failed to call tool: %v", err) + } + + // Print the result + for _, content := range result.Content { + if textContent, ok := content.(mcp.TextContent); ok { + fmt.Printf("Result: %s\n", textContent.Text) + } } } } @@ -181,37 +216,60 @@ import ( "context" "fmt" "log" + "time" "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/client/transport" "github.com/mark3labs/mcp-go/mcp" ) func main() { - // Create a StreamableHTTP client - c := client.NewStreamableHttpClient("http://localhost:8080/mcp") - defer c.Close() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() - ctx := context.Background() + // Create HTTP transport + httpTransport, err := transport.NewStreamableHTTP("http://localhost:8080/mcp") + if err != nil { + log.Fatalf("Failed to create HTTP transport: %v", err) + } + + // Create client with the transport + c := client.NewClient(httpTransport) + defer c.Close() - // Initialize and use the client - if err := c.Initialize(ctx); err != nil { - log.Fatal(err) + // Initialize the client + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "StreamableHTTP Client", + Version: "1.0.0", } + initRequest.Params.Capabilities = mcp.ClientCapabilities{} - // Call a tool - result, err := c.CallTool(ctx, mcp.CallToolRequest{ - Params: mcp.CallToolRequestParams{ - Name: "hello_world", - Arguments: map[string]interface{}{ - "name": "StreamableHTTP World", - }, - }, - }) + serverInfo, err := c.Initialize(ctx, initRequest) if err != nil { - log.Fatal(err) + log.Fatalf("Failed to initialize: %v", err) } - fmt.Printf("Tool result: %+v\n", result) + fmt.Printf("Connected to server: %s (version %s)\n", + serverInfo.ServerInfo.Name, + serverInfo.ServerInfo.Version) + + // Call a tool + if serverInfo.Capabilities.Tools != nil { + callRequest := mcp.CallToolRequest{} + callRequest.Params.Name = "hello_world" + callRequest.Params.Arguments = map[string]interface{}{ + "name": "StreamableHTTP World", + } + + result, err := c.CallTool(ctx, callRequest) + if err != nil { + log.Fatalf("Failed to call tool: %v", err) + } + + fmt.Printf("Tool result: %+v\n", result) + } } ``` diff --git a/www/docs/pages/servers/advanced-sampling.mdx b/www/docs/pages/servers/advanced-sampling.mdx new file mode 100644 index 000000000..1bc05eb6e --- /dev/null +++ b/www/docs/pages/servers/advanced-sampling.mdx @@ -0,0 +1,399 @@ +# Sampling + +Learn how to implement MCP servers that can request LLM completions from clients using the sampling capability. + +## Overview + +Sampling allows MCP servers to request LLM completions from clients, enabling bidirectional communication where servers can leverage client-side LLM capabilities. This is particularly useful for tools that need to generate content, answer questions, or perform reasoning tasks. + +## Enabling Sampling + +To enable sampling in your server, call `EnableSampling()` during server setup: + +```go +package main + +import ( + "context" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + // Create server + mcpServer := server.NewMCPServer("my-server", "1.0.0") + + // Enable sampling capability + mcpServer.EnableSampling() + + // Add tools that use sampling... + + // Start server + server.ServeStdio(mcpServer) +} +``` + +## Requesting Sampling + +Use `RequestSampling()` within tool handlers to request LLM completions: + +```go +mcpServer.AddTool(mcp.Tool{ + Name: "ask_llm", + Description: "Ask the LLM a question using sampling", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "question": map[string]any{ + "type": "string", + "description": "The question to ask the LLM", + }, + "system_prompt": map[string]any{ + "type": "string", + "description": "Optional system prompt", + }, + }, + Required: []string{"question"}, + }, +}, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Extract parameters + question, err := request.RequireString("question") + if err != nil { + return nil, err + } + + systemPrompt := request.GetString("system_prompt", "You are a helpful assistant.") + + // Create sampling request + samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: question, + }, + }, + }, + SystemPrompt: systemPrompt, + MaxTokens: 1000, + Temperature: 0.7, + }, + } + + // Request sampling from client + result, err := mcpServer.RequestSampling(ctx, samplingRequest) + if err != nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Error requesting sampling: %v", err), + }, + }, + IsError: true, + }, nil + } + + // Return the LLM response + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("LLM Response: %s", getTextFromContent(result.Content)), + }, + }, + }, nil +}) +``` + +## Sampling Request Parameters + +The `CreateMessageRequest` supports various parameters to control LLM behavior: + +```go +samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + // Required: Messages to send to the LLM + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, // or mcp.RoleAssistant + Content: mcp.TextContent{ // or mcp.ImageContent + Type: "text", + Text: "Your message here", + }, + }, + }, + + // Optional: System prompt for context + SystemPrompt: "You are a helpful assistant.", + + // Optional: Maximum tokens to generate + MaxTokens: 1000, + + // Optional: Temperature for randomness (0.0 to 1.0) + Temperature: 0.7, + + // Optional: Top-p sampling parameter + TopP: 0.9, + + // Optional: Stop sequences + StopSequences: []string{"\\n\\n"}, + }, +} +``` + +## Message Types + +Sampling supports different message roles and content types: + +### Message Roles + +```go +// User message +{ + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: "What is the capital of France?", + }, +} + +// Assistant message (for conversation context) +{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: "The capital of France is Paris.", + }, +} +``` + +### Content Types + +#### Text Content + +```go +mcp.TextContent{ + Type: "text", + Text: "Your text content here", +} +``` + +#### Image Content + +```go +mcp.ImageContent{ + Type: "image", + Data: "base64-encoded-image-data", + MimeType: "image/jpeg", +} +``` + +## Error Handling + +Always handle sampling errors gracefully: + +```go +result, err := mcpServer.RequestSampling(ctx, samplingRequest) +if err != nil { + // Log the error + log.Printf("Sampling request failed: %v", err) + + // Return appropriate error response + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "Sorry, I couldn't process your request at this time.", + }, + }, + IsError: true, + }, nil +} +``` + +## Context and Timeouts + +Use context for timeout control: + +```go +// Set a timeout for the sampling request +ctx, cancel := context.WithTimeout(ctx, 30*time.Second) +defer cancel() + +result, err := mcpServer.RequestSampling(ctx, samplingRequest) +``` + +## Best Practices + +1. **Enable Sampling Early**: Call `EnableSampling()` during server initialization +2. **Handle Timeouts**: Set appropriate timeouts for sampling requests +3. **Graceful Errors**: Always provide meaningful error messages to users +4. **Content Extraction**: Use helper functions to extract text from responses +5. **System Prompts**: Use clear system prompts to guide LLM behavior +6. **Parameter Validation**: Validate tool parameters before making sampling requests + +## Complete Example + +Here's a complete example server with sampling: + +```go +package main + +import ( + "context" + "fmt" + "log" + "time" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + // Create server + mcpServer := server.NewMCPServer("sampling-example-server", "1.0.0") + + // Enable sampling capability + mcpServer.EnableSampling() + + // Add sampling tool + mcpServer.AddTool(mcp.Tool{ + Name: "ask_llm", + Description: "Ask the LLM a question using sampling", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "question": map[string]any{ + "type": "string", + "description": "The question to ask the LLM", + }, + "system_prompt": map[string]any{ + "type": "string", + "description": "Optional system prompt", + }, + }, + Required: []string{"question"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + question, err := request.RequireString("question") + if err != nil { + return nil, err + } + + systemPrompt := request.GetString("system_prompt", "You are a helpful assistant.") + + // Create sampling request + samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: question, + }, + }, + }, + SystemPrompt: systemPrompt, + MaxTokens: 1000, + Temperature: 0.7, + }, + } + + // Request sampling with timeout + samplingCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + result, err := mcpServer.RequestSampling(samplingCtx, samplingRequest) + if err != nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Error requesting sampling: %v", err), + }, + }, + IsError: true, + }, nil + } + + // Return the LLM response + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("LLM Response (model: %s): %s", + result.Model, getTextFromContent(result.Content)), + }, + }, + }, nil + }) + + // Start server + log.Println("Starting sampling example server...") + if err := server.ServeStdio(mcpServer); err != nil { + log.Fatalf("Server error: %v", err) + } +} + +// Helper function to extract text from content +func getTextFromContent(content interface{}) string { + switch c := content.(type) { + case mcp.TextContent: + return c.Text + case string: + return c + default: + return fmt.Sprintf("%v", content) + } +} +``` + +## Transport Support + +Sampling is supported on the following transports: + +### STDIO Transport + +STDIO transport provides full sampling support with JSON-RPC message passing: + +```go +// Start STDIO server with sampling +server.ServeStdio(mcpServer) +``` + +The client must implement a `SamplingHandler` and declare sampling capability during initialization. + +### In-Process Transport + +In-process transport offers the most efficient sampling implementation with direct method calls: + +```go +// Create in-process client with sampling handler +mcpClient, err := client.NewInProcessClientWithSamplingHandler(mcpServer, samplingHandler) +if err != nil { + log.Fatal(err) +} +``` + +**Benefits of in-process sampling:** +- **Direct Method Calls**: No JSON-RPC serialization overhead +- **Type Safety**: Compile-time type checking + +### Unsupported Transports + +The following transports do not currently support sampling: +- **SSE Transport**: One-way streaming nature prevents bidirectional sampling +- **StreamableHTTP Transport**: Stateless HTTP requests don't support sampling callbacks + +For these transports, consider implementing LLM integration directly in your tool handlers rather than using sampling. + +## Next Steps + +- Learn about [client-side sampling implementation](/clients/advanced-sampling) +- Explore [advanced server features](/servers/advanced) +- Check out the [sampling examples](https://github.com/mark3labs/mcp-go/tree/main/examples/sampling_server) +- See [in-process sampling documentation](/transports/inprocess#sampling-support) for embedded scenarios \ No newline at end of file diff --git a/www/docs/pages/servers/advanced.mdx b/www/docs/pages/servers/advanced.mdx index d2e3b4a8f..990599b05 100644 --- a/www/docs/pages/servers/advanced.mdx +++ b/www/docs/pages/servers/advanced.mdx @@ -821,6 +821,144 @@ func startWithGracefulShutdown(s *server.MCPServer) { } ``` +## Client Capability Based Filtering + +```go +package main + +import ( + "context" + "fmt" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +func main() { + s := server.NewMCPServer("Typed Server", "1.0.0", + server.WithToolCapabilities(true), + ) + + s.AddTool( + mcp.NewTool("calculate", + mcp.WithDescription("Perform basic mathematical calculations"), + mcp.WithString("operation", + mcp.Required(), + mcp.Enum("add", "subtract", "multiply", "divide"), + mcp.Description("The operation to perform"), + ), + mcp.WithNumber("x", mcp.Required(), mcp.Description("First number")), + mcp.WithNumber("y", mcp.Required(), mcp.Description("Second number")), + ), + handleCalculate, + ) + + server.ServeStdio(s) +} + +func handleCalculate(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + session := server.ClientSessionFromContext(ctx) + if session == nil { + return nil, fmt.Errorf("no active session") + } + + if clientSession, ok := session.(server.SessionWithClientInfo); ok { + clientCapabilities := clientSession.GetClientCapabilities() + if clientCapabilities.Sampling == nil { + fmt.Println("sampling is not enabled in client") + } + } + + // TODO: implement calculation logic + return mcp.NewToolResultError("not implemented"), nil +} +``` + +## Sampling (Advanced) + +Sampling is an advanced feature that allows servers to request LLM completions from clients. This enables bidirectional communication where servers can leverage client-side LLM capabilities for content generation, reasoning, and question answering. + +> **Note**: Sampling is an advanced feature that most servers don't need. Only implement sampling if your server specifically needs to generate content using the client's LLM. + +### When to Use Sampling + +Consider sampling when your server needs to: +- Generate content based on user input +- Answer questions using LLM reasoning +- Perform text analysis or summarization +- Create dynamic responses that require LLM capabilities + +### Basic Implementation + +```go +// Enable sampling capability +mcpServer.EnableSampling() + +// Add a tool that uses sampling +mcpServer.AddTool(mcp.Tool{ + Name: "ask_llm", + Description: "Ask the LLM a question using sampling", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "question": map[string]any{ + "type": "string", + "description": "The question to ask the LLM", + }, + }, + Required: []string{"question"}, + }, +}, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + question, err := request.RequireString("question") + if err != nil { + return nil, err + } + + // Create sampling request + samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: question, + }, + }, + }, + SystemPrompt: "You are a helpful assistant.", + MaxTokens: 1000, + Temperature: 0.7, + }, + } + + // Request sampling from client + result, err := mcpServer.RequestSampling(ctx, samplingRequest) + if err != nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Error: %v", err), + }, + }, + IsError: true, + }, nil + } + + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("LLM Response: %s", result.Content), + }, + }, + }, nil +}) +``` + +For complete sampling documentation, see **[Server Sampling Guide](/servers/advanced-sampling)**. + ## Next Steps - **[Client Development](/clients)** - Learn to build MCP clients diff --git a/www/docs/pages/servers/index.mdx b/www/docs/pages/servers/index.mdx index e0249172a..d0983d13a 100644 --- a/www/docs/pages/servers/index.mdx +++ b/www/docs/pages/servers/index.mdx @@ -12,7 +12,7 @@ MCP servers expose tools, resources, and prompts to LLM clients. MCP-Go makes it - **[Resources](/servers/resources)** - Exposing data to LLMs - **[Tools](/servers/tools)** - Providing functionality to LLMs - **[Prompts](/servers/prompts)** - Creating reusable interaction templates -- **[Advanced Features](/servers/advanced)** - Typed tools, middleware, hooks, and more +- **[Advanced Features](/servers/advanced)** - Typed tools, middleware, hooks, sampling, and more ## Quick Example diff --git a/www/docs/pages/servers/tools.mdx b/www/docs/pages/servers/tools.mdx index 4d1263d1d..e329bd1c1 100644 --- a/www/docs/pages/servers/tools.mdx +++ b/www/docs/pages/servers/tools.mdx @@ -529,6 +529,99 @@ func handleMultiContentTool(ctx context.Context, req mcp.CallToolRequest) (*mcp. } ``` +### Resource Links + +Tools can return resource links that reference other resources in your MCP server. This is useful when you want to point to existing data without duplicating content: + +```go +func handleGetResourceLinkTool(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + resourceID, err := req.RequireString("resource_id") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Create a resource link pointing to an existing resource + uri := fmt.Sprintf("file://documents/%s", resourceID) + resourceLink := mcp.NewResourceLink(uri, "Document", "The requested document", "application/pdf") + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent("Found the requested document:"), + resourceLink, + }, + }, nil +} +``` + +### Mixed Content with Resource Links + +You can combine different content types including resource links in a single tool result: + +```go +func handleSearchDocumentsTool(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + query, err := req.RequireString("query") + if err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + + // Simulate document search + foundDocs := []string{"doc1.pdf", "doc2.txt", "doc3.md"} + + content := []mcp.Content{ + mcp.NewTextContent(fmt.Sprintf("Found %d documents matching '%s':", len(foundDocs), query)), + } + + // Add resource links for each found document + for _, doc := range foundDocs { + uri := fmt.Sprintf("file://documents/%s", doc) + parts := strings.SplitN(doc, ".", 2) + name := parts[0] + mimeType := "application/octet-stream" // default + if len(parts) > 1 { + // Map extension to MIME type (simplified) + switch parts[1] { + case "pdf": + mimeType = "application/pdf" + case "txt": + mimeType = "text/plain" + case "md": + mimeType = "text/markdown" + } + } + resourceLink := mcp.NewResourceLink(uri, name, fmt.Sprintf("Document: %s", doc), mimeType) + content = append(content, resourceLink) + } + + return &mcp.CallToolResult{ + Content: content, + }, nil +} +``` + +### Resource Link with Annotations + +Resource links can include additional metadata through annotations: + +```go +func handleGetAnnotatedResourceTool(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + docType := req.GetString("type", "general") + // Create resource link with annotations + annotated := mcp.Annotated{ + Annotations: &mcp.Annotations{ + Audience: []mcp.Role{mcp.RoleUser}, + }, + } + url := "file://documents/test.pdf" + resourceLink := mcp.NewResourceLink(url, "Test Document", fmt.Sprintf("A %s document", docType), "application/pdf") + resourceLink.Annotated = annotated + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.NewTextContent("Here's the important document you requested:"), + resourceLink, + }, + }, nil +} +``` + ### Error Results ```go diff --git a/www/docs/pages/transports/http.mdx b/www/docs/pages/transports/http.mdx index 1d0430d6a..9d7e308bc 100644 --- a/www/docs/pages/transports/http.mdx +++ b/www/docs/pages/transports/http.mdx @@ -43,7 +43,7 @@ import ( func main() { s := server.NewMCPServer("StreamableHTTP API Server", "1.0.0", server.WithToolCapabilities(true), - server.WithResourceCapabilities(true), + server.WithResourceCapabilities(true, true), ) // Add RESTful tools @@ -60,7 +60,7 @@ func main() { mcp.WithDescription("Create a new user"), mcp.WithString("name", mcp.Required()), mcp.WithString("email", mcp.Required()), - mcp.WithInteger("age", mcp.Minimum(0)), + mcp.WithNumber("age", mcp.Min(0)), ), handleCreateUser, ) @@ -69,8 +69,8 @@ func main() { mcp.NewTool("search_users", mcp.WithDescription("Search users with filters"), mcp.WithString("query", mcp.Description("Search query")), - mcp.WithInteger("limit", mcp.Default(10), mcp.Maximum(100)), - mcp.WithInteger("offset", mcp.Default(0), mcp.Minimum(0)), + mcp.WithNumber("limit", mcp.DefaultNumber(10), mcp.Max(100)), + mcp.WithNumber("offset", mcp.DefaultNumber(0), mcp.Min(0)), ), handleSearchUsers, ) @@ -95,7 +95,10 @@ func main() { } func handleGetUser(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - userID := req.Params.Arguments["user_id"].(string) + userID := req.GetString("user_id", "") + if userID == "" { + return nil, fmt.Errorf("user_id is required") + } // Simulate database lookup user, err := getUserFromDB(userID) @@ -103,13 +106,18 @@ func handleGetUser(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolR return nil, fmt.Errorf("user not found: %s", userID) } - return mcp.NewToolResultJSON(user), nil + return mcp.NewToolResultText(fmt.Sprintf(`{"id":"%s","name":"%s","email":"%s","age":%d}`, + user.ID, user.Name, user.Email, user.Age)), nil } func handleCreateUser(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - name := req.Params.Arguments["name"].(string) - email := req.Params.Arguments["email"].(string) - age := int(req.Params.Arguments["age"].(float64)) + name := req.GetString("name", "") + email := req.GetString("email", "") + age := req.GetInt("age", 0) + + if name == "" || email == "" { + return nil, fmt.Errorf("name and email are required") + } // Validate input if !isValidEmail(email) { @@ -129,11 +137,8 @@ func handleCreateUser(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo return nil, fmt.Errorf("failed to create user: %w", err) } - return mcp.NewToolResultJSON(map[string]interface{}{ - "id": user.ID, - "message": "User created successfully", - "user": user, - }), nil + return mcp.NewToolResultText(fmt.Sprintf(`{"id":"%s","message":"User created successfully","user":{"id":"%s","name":"%s","email":"%s","age":%d}}`, + user.ID, user.ID, user.Name, user.Email, user.Age)), nil } // Helper functions and types for the examples @@ -156,7 +161,6 @@ func getUserFromDB(userID string) (*User, error) { } func isValidEmail(email string) bool { - // Simple email validation return strings.Contains(email, "@") && strings.Contains(email, ".") } @@ -171,9 +175,9 @@ func saveUserToDB(user *User) error { } func handleSearchUsers(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - query := getStringParam(req.Params.Arguments, "query", "") - limit := int(getFloatParam(req.Params.Arguments, "limit", 10)) - offset := int(getFloatParam(req.Params.Arguments, "offset", 0)) + query := req.GetString("query", "") + limit := req.GetInt("limit", 10) + offset := req.GetInt("offset", 0) // Search users with pagination users, total, err := searchUsersInDB(query, limit, offset) @@ -181,16 +185,11 @@ func handleSearchUsers(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallT return nil, fmt.Errorf("search failed: %w", err) } - return mcp.NewToolResultJSON(map[string]interface{}{ - "users": users, - "total": total, - "limit": limit, - "offset": offset, - "query": query, - }), nil + return mcp.NewToolResultText(fmt.Sprintf(`{"users":[{"id":"1","name":"John Doe","email":"john@example.com","age":30},{"id":"2","name":"Jane Smith","email":"jane@example.com","age":25}],"total":%d,"limit":%d,"offset":%d,"query":"%s"}`, + total, limit, offset, query)), nil } -func handleUserResource(ctx context.Context, req mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { +func handleUserResource(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { userID := extractUserIDFromURI(req.Params.URI) user, err := getUserFromDB(userID) @@ -198,27 +197,16 @@ func handleUserResource(ctx context.Context, req mcp.ReadResourceRequest) (*mcp. return nil, fmt.Errorf("user not found: %s", userID) } - return mcp.NewResourceResultJSON(user), nil -} - -// Additional helper functions for parameter handling -func getStringParam(args map[string]interface{}, key, defaultValue string) string { - if val, ok := args[key]; ok && val != nil { - if str, ok := val.(string); ok { - return str - } - } - return defaultValue + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: req.Params.URI, + MIMEType: "application/json", + Text: fmt.Sprintf(`{"id":"%s","name":"%s","email":"%s","age":%d}`, user.ID, user.Name, user.Email, user.Age), + }, + }, nil } -func getFloatParam(args map[string]interface{}, key string, defaultValue float64) float64 { - if val, ok := args[key]; ok && val != nil { - if f, ok := val.(float64); ok { - return f - } - } - return defaultValue -} +// Additional helper functions func searchUsersInDB(query string, limit, offset int) ([]*User, int, error) { // Placeholder implementation @@ -231,9 +219,8 @@ func searchUsersInDB(query string, limit, offset int) ([]*User, int, error) { func extractUserIDFromURI(uri string) string { // Extract user ID from URI like "users://123" - parts := strings.Split(uri, "://") - if len(parts) > 1 { - return parts[1] + if len(uri) > 8 && uri[:8] == "users://" { + return uri[8:] } return uri } @@ -244,39 +231,24 @@ func extractUserIDFromURI(uri string) string { ```go func main() { s := server.NewMCPServer("Advanced StreamableHTTP Server", "1.0.0", - server.WithAllCapabilities(), - server.WithRecovery(), - server.WithHooks(&server.Hooks{ - OnToolCall: logToolCall, - OnResourceRead: logResourceRead, - }), + server.WithResourceCapabilities(true, true), + server.WithPromptCapabilities(true), + server.WithToolCapabilities(true), + server.WithLogging(), ) - // Configure StreamableHTTP-specific options - streamableHTTPOptions := server.StreamableHTTPOptions{ - BasePath: "/api/v1/mcp", - ReadTimeout: 30 * time.Second, - WriteTimeout: 30 * time.Second, - IdleTimeout: 60 * time.Second, - MaxBodySize: 10 * 1024 * 1024, // 10MB - EnableCORS: true, - AllowedOrigins: []string{"https://myapp.com", "http://localhost:3000"}, - AllowedMethods: []string{"GET", "POST", "OPTIONS"}, - AllowedHeaders: []string{"Content-Type", "Authorization"}, - EnableGzip: true, - TrustedProxies: []string{"10.0.0.0/8", "172.16.0.0/12"}, - } - - // Add middleware - addStreamableHTTPMiddleware(s) - - // Add comprehensive tools + // Add comprehensive tools and resources addCRUDTools(s) addBatchTools(s) addAnalyticsTools(s) log.Println("Starting advanced StreamableHTTP server on :8080") - httpServer := server.NewStreamableHTTPServer(s, streamableHTTPOptions...) + httpServer := server.NewStreamableHTTPServer(s, + server.WithEndpointPath("/api/v1/mcp"), + server.WithHeartbeatInterval(30*time.Second), + server.WithStateLess(false), + ) + if err := httpServer.Start(":8080"); err != nil { log.Fatal(err) } @@ -680,7 +652,42 @@ type Claims struct { } ``` +### Request Headers + +The StreamableHTTP transport now passes HTTP request headers to MCP handlers. This allows you to access the original HTTP headers that were sent with the request in your tool and resource handlers. + +#### Accessing Headers in Handlers + +Headers are available in all MCP request objects: + +```go +func handleGetUser(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Access request headers + headers := req.Header + + // Use headers for authentication, tracing, etc. + authToken := headers.Get("Authorization") + if authToken == "" { + return nil, fmt.Errorf("authentication required") + } + + // Access other headers + requestID := headers.Get("X-Request-ID") + userAgent := headers.Get("User-Agent") + + // Rest of your handler code... +} +``` + +This works for all MCP request types including: +- `CallToolRequest` +- `ReadResourceRequest` +- `ListToolsRequest` +- `ListResourcesRequest` +- `InitializeRequest` +- And other MCP request types +The headers are automatically populated by the transport layer and are available in your handlers without any additional configuration. ## Next Steps diff --git a/www/docs/pages/transports/index.mdx b/www/docs/pages/transports/index.mdx index 1150a4d02..e8ec5eae4 100644 --- a/www/docs/pages/transports/index.mdx +++ b/www/docs/pages/transports/index.mdx @@ -13,12 +13,12 @@ Transport layers handle the communication between MCP clients and servers. Each ## Transport Comparison -| Transport | Use Case | Pros | Cons | -|-----------|----------|------|------| -| **STDIO** | CLI tools, desktop apps | Simple, secure, no network | Single client, local only | -| **SSE** | Web apps, real-time | Multi-client, real-time, web-friendly | HTTP overhead, one-way streaming | -| **StreamableHTTP** | Web services, APIs | Standard protocol, caching, load balancing | No real-time, more complex | -| **In-Process** | Embedded, testing | No serialization, fastest | Same process only | +| Transport | Use Case | Pros | Cons | Sampling Support | +|-----------|----------|------|------|------------------| +| **STDIO** | CLI tools, desktop apps | Simple, secure, no network | Single client, local only | ✅ Full support | +| **SSE** | Web apps, real-time | Multi-client, real-time, web-friendly | HTTP overhead, one-way streaming | ❌ Not supported | +| **StreamableHTTP** | Web services, APIs | Standard protocol, caching, load balancing | No real-time, more complex | ❌ Not supported | +| **In-Process** | Embedded, testing | No serialization, fastest | Same process only | ✅ Full support | ## Quick Example @@ -131,12 +131,14 @@ func handleEcho(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResu - Testing and development - High-performance scenarios - Library integrations +- Sampling-enabled applications **Example use cases:** - Testing MCP implementations - Embedded analytics engines - High-frequency trading systems - Real-time game servers +- LLM-powered applications with bidirectional communication ## Transport Configuration diff --git a/www/docs/pages/transports/inprocess.mdx b/www/docs/pages/transports/inprocess.mdx index dce982357..e4b2c29b7 100644 --- a/www/docs/pages/transports/inprocess.mdx +++ b/www/docs/pages/transports/inprocess.mdx @@ -56,19 +56,34 @@ func main() { ) // Create in-process client - client := client.NewInProcessClient(s) - defer client.Close() + mcpClient, err := client.NewInProcessClient(s) + if err != nil { + log.Fatal(err) + } + defer mcpClient.Close() ctx := context.Background() // Initialize - if err := client.Initialize(ctx); err != nil { + _, err = mcpClient.Initialize(ctx, mcp.InitializeRequest{ + Params: mcp.InitializeRequestParams{ + ProtocolVersion: "2024-11-05", + Capabilities: mcp.ClientCapabilities{ + Tools: &mcp.ToolsCapability{}, + }, + ClientInfo: mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, + }, + }) + if err != nil { log.Fatal(err) } // Use the calculator - result, err := client.CallTool(ctx, mcp.CallToolRequest{ - Params: mcp.CallToolRequestParams{ + result, err := mcpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ Name: "calculate", Arguments: map[string]interface{}{ "operation": "add", @@ -81,13 +96,18 @@ func main() { log.Fatal(err) } - fmt.Printf("Result: %s\n", result.Content[0].Text) + // Extract text from the first content item + if len(result.Content) > 0 { + if textContent, ok := mcp.AsTextContent(result.Content[0]); ok { + fmt.Printf("Result: %s\n", textContent.Text) + } + } } func handleCalculate(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - operation := req.Params.Arguments["operation"].(string) - x := req.Params.Arguments["x"].(float64) - y := req.Params.Arguments["y"].(float64) + operation := req.GetString("operation", "") + x := req.GetFloat("x", 0) + y := req.GetFloat("y", 0) var result float64 switch operation { @@ -134,7 +154,11 @@ func NewApplication(config *Config) *Application { app.addApplicationTools() // Create in-process client for internal use - app.mcpClient = client.NewInProcessClient(app.mcpServer) + var err error + app.mcpClient, err = client.NewInProcessClient(app.mcpServer) + if err != nil { + panic(err) + } return app } @@ -151,12 +175,8 @@ func (app *Application) addApplicationTools() { mcp.WithDescription("Get current application status"), ), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - status := map[string]interface{}{ - "app_name": app.config.AppName, - "debug": app.config.Debug, - "status": "running", - } - return mcp.NewToolResultJSON(status), nil + return mcp.NewToolResultText(fmt.Sprintf(`{"app_name":"%s","debug":%t,"status":"running"}`, + app.config.AppName, app.config.Debug)), nil }, ) @@ -168,8 +188,8 @@ func (app *Application) addApplicationTools() { mcp.WithString("value", mcp.Required()), ), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - key := req.Params.Arguments["key"].(string) - value := req.Params.Arguments["value"].(string) + key := req.GetString("key", "") + value := req.GetString("value", "") // Update configuration based on key switch key { @@ -189,7 +209,7 @@ func (app *Application) addApplicationTools() { func (app *Application) ProcessWithMCP(ctx context.Context, operation string) (interface{}, error) { // Use MCP tools internally for processing result, err := app.mcpClient.CallTool(ctx, mcp.CallToolRequest{ - Params: mcp.CallToolRequestParams{ + Params: mcp.CallToolParams{ Name: "calculate", Arguments: map[string]interface{}{ "operation": operation, @@ -202,7 +222,14 @@ func (app *Application) ProcessWithMCP(ctx context.Context, operation string) (i return nil, err } - return result.Content[0].Text, nil + // Extract text from the first content item + if len(result.Content) > 0 { + if textContent, ok := mcp.AsTextContent(result.Content[0]); ok { + return textContent.Text, nil + } + } + + return "no result", nil } // Usage example @@ -216,7 +243,19 @@ func main() { ctx := context.Background() // Initialize the embedded MCP client - if err := app.mcpClient.Initialize(ctx); err != nil { + _, err := app.mcpClient.Initialize(ctx, mcp.InitializeRequest{ + Params: mcp.InitializeRequestParams{ + ProtocolVersion: "2024-11-05", + Capabilities: mcp.ClientCapabilities{ + Tools: &mcp.ToolsCapability{}, + }, + ClientInfo: mcp.Implementation{ + Name: "embedded-client", + Version: "1.0.0", + }, + }, + }) + if err != nil { log.Fatal(err) } @@ -230,8 +269,356 @@ func main() { } ``` +## Sampling Support + +In-process transport supports sampling, allowing servers to request LLM completions from clients. This enables bidirectional communication where servers can leverage client-side LLM capabilities. + +### Enabling Sampling + +To enable sampling in your in-process server, call `EnableSampling()` during server setup: + +```go +package main + +import ( + "context" + "fmt" + "log" + + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// MockSamplingHandler implements client.SamplingHandler for demonstration +type MockSamplingHandler struct{} + +func (h *MockSamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Extract the user's message + var userMessage string + for _, msg := range request.Messages { + if msg.Role == mcp.RoleUser { + if textContent, ok := msg.Content.(mcp.TextContent); ok { + userMessage = textContent.Text + break + } + } + } + + // Generate a mock response + mockResponse := fmt.Sprintf("Mock LLM response to: '%s'", userMessage) + + return &mcp.CreateMessageResult{ + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: mockResponse, + }, + }, + Model: "mock-llm-v1", + StopReason: "endTurn", + }, nil +} + +func main() { + // Create server with sampling enabled + mcpServer := server.NewMCPServer("sampling-server", "1.0.0") + mcpServer.EnableSampling() + + // Add a tool that uses sampling + mcpServer.AddTool(mcp.Tool{ + Name: "ask_llm", + Description: "Ask the LLM a question using sampling", + InputSchema: mcp.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "question": map[string]any{ + "type": "string", + "description": "The question to ask the LLM", + }, + "system_prompt": map[string]any{ + "type": "string", + "description": "Optional system prompt", + }, + }, + Required: []string{"question"}, + }, + }, func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) { + question, err := request.RequireString("question") + if err != nil { + return nil, err + } + + systemPrompt := request.GetString("system_prompt", "You are a helpful assistant.") + + // Create sampling request + samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, + Content: mcp.TextContent{ + Type: "text", + Text: question, + }, + }, + }, + SystemPrompt: systemPrompt, + MaxTokens: 1000, + Temperature: 0.7, + }, + } + + // Request sampling from client + result, err := mcpServer.RequestSampling(ctx, samplingRequest) + if err != nil { + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("Error requesting sampling: %v", err), + }, + }, + IsError: true, + }, nil + } + + // Return the LLM response + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: fmt.Sprintf("LLM Response (model: %s): %s", + result.Model, result.Content.(mcp.TextContent).Text), + }, + }, + }, nil + }) + + // Create client with sampling handler + mockHandler := &MockSamplingHandler{} + mcpClient, err := client.NewInProcessClientWithSamplingHandler(mcpServer, mockHandler) + if err != nil { + log.Fatalf("Failed to create client: %v", err) + } + defer mcpClient.Close() + + // Start and initialize the client + ctx := context.Background() + if err := mcpClient.Start(ctx); err != nil { + log.Fatalf("Failed to start client: %v", err) + } + + initRequest := mcp.InitializeRequest{} + initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION + initRequest.Params.ClientInfo = mcp.Implementation{ + Name: "sampling-client", + Version: "1.0.0", + } + + _, err = mcpClient.Initialize(ctx, initRequest) + if err != nil { + log.Fatalf("Failed to initialize: %v", err) + } + + // Call the tool that uses sampling + result, err := mcpClient.CallTool(ctx, mcp.CallToolRequest{ + Params: mcp.CallToolParams{ + Name: "ask_llm", + Arguments: map[string]any{ + "question": "What is the capital of France?", + "system_prompt": "You are a helpful geography assistant.", + }, + }, + }) + if err != nil { + log.Fatalf("Tool call failed: %v", err) + } + + // Print the result + if len(result.Content) > 0 { + if textContent, ok := result.Content[0].(mcp.TextContent); ok { + fmt.Printf("Tool result: %s\n", textContent.Text) + } + } +} +``` + +### Real LLM Integration + +For production use, replace the mock handler with a real LLM integration: + +#### OpenAI Integration + +```go +import ( + "github.com/sashabaranov/go-openai" +) + +type OpenAISamplingHandler struct { + client *openai.Client +} + +func NewOpenAISamplingHandler(apiKey string) *OpenAISamplingHandler { + return &OpenAISamplingHandler{ + client: openai.NewClient(apiKey), + } +} + +func (h *OpenAISamplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) { + // Convert MCP messages to OpenAI format + var messages []openai.ChatCompletionMessage + + // Add system message if provided + if request.SystemPrompt != "" { + messages = append(messages, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleSystem, + Content: request.SystemPrompt, + }) + } + + // Convert MCP messages + for _, msg := range request.Messages { + var role string + switch msg.Role { + case mcp.RoleUser: + role = openai.ChatMessageRoleUser + case mcp.RoleAssistant: + role = openai.ChatMessageRoleAssistant + } + + if textContent, ok := msg.Content.(mcp.TextContent); ok { + messages = append(messages, openai.ChatCompletionMessage{ + Role: role, + Content: textContent.Text, + }) + } + } + + // Create OpenAI request + req := openai.ChatCompletionRequest{ + Model: openai.GPT3Dot5Turbo, + Messages: messages, + MaxTokens: request.MaxTokens, + Temperature: float32(request.Temperature), + } + + // Call OpenAI API + resp, err := h.client.CreateChatCompletion(ctx, req) + if err != nil { + return nil, fmt.Errorf("OpenAI API call failed: %w", err) + } + + if len(resp.Choices) == 0 { + return nil, fmt.Errorf("no response from OpenAI") + } + + choice := resp.Choices[0] + + // Convert stop reason + var stopReason string + switch choice.FinishReason { + case "stop": + stopReason = "endTurn" + case "length": + stopReason = "maxTokens" + default: + stopReason = "other" + } + + return &mcp.CreateMessageResult{ + SamplingMessage: mcp.SamplingMessage{ + Role: mcp.RoleAssistant, + Content: mcp.TextContent{ + Type: "text", + Text: choice.Message.Content, + }, + }, + Model: resp.Model, + StopReason: stopReason, + }, nil +} + +// Usage +func main() { + // Create OpenAI handler + openaiHandler := NewOpenAISamplingHandler("your-api-key") + + // Create client with OpenAI sampling + mcpClient, err := client.NewInProcessClientWithSamplingHandler(mcpServer, openaiHandler) + if err != nil { + log.Fatal(err) + } + + // ... rest of the setup +} +``` + +### Sampling Request Parameters + +The `CreateMessageRequest` supports various parameters to control LLM behavior: + +```go +samplingRequest := mcp.CreateMessageRequest{ + CreateMessageParams: mcp.CreateMessageParams{ + // Required: Messages to send to the LLM + Messages: []mcp.SamplingMessage{ + { + Role: mcp.RoleUser, // or mcp.RoleAssistant + Content: mcp.TextContent{ // or mcp.ImageContent + Type: "text", + Text: "Your message here", + }, + }, + }, + + // Optional: System prompt for context + SystemPrompt: "You are a helpful assistant.", + + // Optional: Maximum tokens to generate + MaxTokens: 1000, + + // Optional: Temperature for randomness (0.0 to 1.0) + Temperature: 0.7, + + // Optional: Top-p sampling parameter + TopP: 0.9, + + // Optional: Stop sequences + StopSequences: []string{"\\n\\n"}, + }, +} +``` + +### Error Handling + +Always handle sampling errors gracefully: + +```go +result, err := mcpServer.RequestSampling(ctx, samplingRequest) +if err != nil { + // Log the error + log.Printf("Sampling request failed: %v", err) + + // Return appropriate error response + return &mcp.CallToolResult{ + Content: []mcp.Content{ + mcp.TextContent{ + Type: "text", + Text: "Sorry, I couldn't process your request at this time.", + }, + }, + IsError: true, + }, nil +} +``` + + + ## Next Steps - **[Client Development](/clients)** - Build MCP clients for all transports - **[HTTP Transport](/transports/http)** - Learn about web-based scenarios -- **[Server Advanced Features](/servers/advanced)** - Explore production-ready features \ No newline at end of file +- **[Server Advanced Features](/servers/advanced)** - Explore production-ready features +- **[Client Sampling](/clients/advanced-sampling)** - Learn more about client-side sampling implementation \ No newline at end of file diff --git a/www/docs/pages/transports/sse.mdx b/www/docs/pages/transports/sse.mdx index 930bb0eba..81fabba43 100644 --- a/www/docs/pages/transports/sse.mdx +++ b/www/docs/pages/transports/sse.mdx @@ -39,7 +39,7 @@ import ( func main() { s := server.NewMCPServer("SSE Server", "1.0.0", server.WithToolCapabilities(true), - server.WithResourceCapabilities(true), + server.WithResourceCapabilities(true, true), ) // Add real-time tools @@ -47,7 +47,7 @@ func main() { mcp.NewTool("stream_data", mcp.WithDescription("Stream data with real-time updates"), mcp.WithString("source", mcp.Required()), - mcp.WithInteger("count", mcp.Default(10)), + mcp.WithNumber("count", mcp.DefaultNumber(10)), ), handleStreamData, ) @@ -55,7 +55,7 @@ func main() { s.AddTool( mcp.NewTool("monitor_system", mcp.WithDescription("Monitor system metrics in real-time"), - mcp.WithInteger("duration", mcp.Default(60)), + mcp.WithNumber("duration", mcp.DefaultNumber(60)), ), handleSystemMonitor, ) @@ -73,19 +73,31 @@ func main() { // Start SSE server log.Println("Starting SSE server on :8080") - if err := server.ServeSSE(s, ":8080"); err != nil { + sseServer := server.NewSSEServer(s) + if err := sseServer.Start(":8080"); err != nil { log.Fatal(err) } } func handleStreamData(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - source := req.Params.Arguments["source"].(string) - count := int(req.Params.Arguments["count"].(float64)) + // Access request headers + headers := req.Header - // Get notifier for real-time updates (hypothetical functions) - // Note: These functions would be provided by the SSE transport implementation - notifier := getNotifierFromContext(ctx) // Hypothetical function - sessionID := getSessionIDFromContext(ctx) // Hypothetical function + // Use headers for authentication, tracing, etc. + authToken := headers.Get("Authorization") + if authToken == "" { + return nil, fmt.Errorf("authentication required") + } + + // Access other headers + requestID := headers.Get("X-Request-ID") + userAgent := headers.Get("User-Agent") + + source := req.GetString("source", "") + count := req.GetInt("count", 10) + + // Get server from context for notifications + mcpServer := server.ServerFromContext(ctx) // Stream data with progress updates var results []map[string]interface{} @@ -102,23 +114,22 @@ func handleStreamData(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallTo results = append(results, data) // Send progress notification - if notifier != nil { - // Note: ProgressNotification would be defined by the MCP protocol - notifier.SendProgress(sessionID, map[string]interface{}{ + if mcpServer != nil { + err := mcpServer.SendNotificationToClient(ctx, "notifications/progress", map[string]interface{}{ "progress": i + 1, "total": count, "message": fmt.Sprintf("Processed %d/%d items from %s", i+1, count, source), }) + if err != nil { + log.Printf("Failed to send notification: %v", err) + } } time.Sleep(100 * time.Millisecond) } - return mcp.NewToolResultJSON(map[string]interface{}{ - "source": source, - "results": results, - "count": len(results), - }), nil + return mcp.NewToolResultText(fmt.Sprintf(`{"source":"%s","results":%v,"count":%d}`, + source, results, len(results))), nil } // Helper functions for the examples @@ -130,21 +141,10 @@ func generateData(source string, index int) map[string]interface{} { } } -func getNotifierFromContext(ctx context.Context) interface{} { - // Placeholder implementation - would be provided by SSE transport - return nil -} - -func getSessionIDFromContext(ctx context.Context) string { - // Placeholder implementation - would be provided by SSE transport - return "session_123" -} - func handleSystemMonitor(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - duration := int(req.Params.Arguments["duration"].(float64)) + duration := req.GetInt("duration", 60) - notifier := getNotifierFromContext(ctx) - sessionID := getSessionIDFromContext(ctx) + mcpServer := server.ServerFromContext(ctx) // Monitor system for specified duration ticker := time.NewTicker(5 * time.Second) @@ -158,20 +158,19 @@ func handleSystemMonitor(ctx context.Context, req mcp.CallToolRequest) (*mcp.Cal case <-ctx.Done(): return nil, ctx.Err() case <-timeout: - return mcp.NewToolResultJSON(map[string]interface{}{ - "duration": duration, - "metrics": metrics, - "samples": len(metrics), - }), nil + return mcp.NewToolResultText(fmt.Sprintf(`{"duration":%d,"metrics":%v,"samples":%d}`, + duration, metrics, len(metrics))), nil case <-ticker.C: // Collect current metrics currentMetrics := collectSystemMetrics() metrics = append(metrics, currentMetrics) // Send real-time update - if notifier != nil { - // Note: SendCustom would be a method on the notifier interface - // notifier.SendCustom(sessionID, "system_metrics", currentMetrics) + if mcpServer != nil { + err := mcpServer.SendNotificationToClient(ctx, "system_metrics", currentMetrics) + if err != nil { + log.Printf("Failed to send system metrics notification: %v", err) + } } } } @@ -186,9 +185,15 @@ func collectSystemMetrics() map[string]interface{} { } } -func handleCurrentMetrics(ctx context.Context, req mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { +func handleCurrentMetrics(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { metrics := collectSystemMetrics() - return mcp.NewResourceResultJSON(metrics), nil + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: req.Params.URI, + MIMEType: "application/json", + Text: fmt.Sprintf(`{"cpu":%.1f,"memory":%.1f,"disk":%.1f}`, metrics["cpu"], metrics["memory"], metrics["disk"]), + }, + }, nil } ``` @@ -197,50 +202,29 @@ func handleCurrentMetrics(ctx context.Context, req mcp.ReadResourceRequest) (*mc ```go func main() { s := server.NewMCPServer("Advanced SSE Server", "1.0.0", - server.WithAllCapabilities(), - server.WithRecovery(), - server.WithHooks(&server.Hooks{ - OnSessionStart: func(sessionID string) { - log.Printf("SSE client connected: %s", sessionID) - broadcastUserCount() - }, - OnSessionEnd: func(sessionID string) { - log.Printf("SSE client disconnected: %s", sessionID) - broadcastUserCount() - }, - }), + server.WithResourceCapabilities(true, true), + server.WithPromptCapabilities(true), + server.WithToolCapabilities(true), + server.WithLogging(), ) - // Configure SSE-specific options - sseOptions := server.SSEOptions{ - BasePath: "/mcp", - AllowedOrigins: []string{"http://localhost:3000", "https://myapp.com"}, - HeartbeatInterval: 30 * time.Second, - MaxConnections: 100, - ConnectionTimeout: 5 * time.Minute, - EnableCompression: true, - } - // Add collaborative tools addCollaborativeTools(s) addRealTimeResources(s) log.Println("Starting advanced SSE server on :8080") - if err := server.ServeSSEWithOptions(s, ":8080", sseOptions); err != nil { + sseServer := server.NewSSEServer(s, + server.WithStaticBasePath("/mcp"), + server.WithKeepAliveInterval(30*time.Second), + server.WithBaseURL("http://localhost:8080"), + ) + + if err := sseServer.Start(":8080"); err != nil { log.Fatal(err) } } // Helper functions for the advanced example -func broadcastUserCount() { - // Placeholder implementation - log.Println("Broadcasting user count update") -} - -func addCollaborativeToolsPlaceholder(s *server.MCPServer) { - // Placeholder implementation - would add collaborative tools -} - func addRealTimeResources(s *server.MCPServer) { // Placeholder implementation - would add real-time resources } @@ -262,7 +246,7 @@ func addCollaborativeTools(s *server.MCPServer) { mcp.NewTool("send_message", mcp.WithDescription("Send a message to all connected clients"), mcp.WithString("message", mcp.Required()), - mcp.WithString("channel", mcp.Default("general")), + mcp.WithString("channel", mcp.DefaultString("general")), ), handleSendMessage, ) @@ -281,241 +265,113 @@ func addCollaborativeTools(s *server.MCPServer) { ## Configuration -### Base URLs and Paths +### SSE Server Options + +The SSE server can be configured with various options: ```go -// Custom SSE endpoint configuration -sseOptions := server.SSEOptions{ - BasePath: "/api/mcp", // SSE endpoint will be /api/mcp/sse +sseServer := server.NewSSEServer(s, + // Set the base path for SSE endpoints + server.WithStaticBasePath("/api/mcp"), - // Additional HTTP endpoints - HealthPath: "/api/health", - MetricsPath: "/api/metrics", - StatusPath: "/api/status", -} - -// Start server with custom paths -server.ServeSSEWithOptions(s, ":8080", sseOptions) + // Configure keep-alive interval + server.WithKeepAliveInterval(30*time.Second), + + // Set base URL for client connections + server.WithBaseURL("http://localhost:8080"), + + // Configure SSE and message endpoints + server.WithSSEEndpoint("/sse"), + server.WithMessageEndpoint("/message"), + + // Add context function for request processing + server.WithSSEContextFunc(func(ctx context.Context, r *http.Request) context.Context { + // Add custom context values from headers + return ctx + }), +) ``` **Resulting endpoints:** - SSE stream: `http://localhost:8080/api/mcp/sse` -- Health check: `http://localhost:8080/api/health` -- Metrics: `http://localhost:8080/api/metrics` -- Status: `http://localhost:8080/api/status` - -### CORS Configuration +- Message endpoint: `http://localhost:8080/api/mcp/message` -```go -sseOptions := server.SSEOptions{ - // Allow specific origins - AllowedOrigins: []string{ - "http://localhost:3000", - "https://myapp.com", - "https://*.myapp.com", - }, - - // Allow all origins (development only) - AllowAllOrigins: true, - - // Custom CORS headers - AllowedHeaders: []string{ - "Authorization", - "Content-Type", - "X-API-Key", - }, - - // Allow credentials - AllowCredentials: true, -} -``` +## Real-Time Notifications -### Connection Management +SSE transport enables real-time server-to-client communication through notifications. Use the server context to send notifications: ```go -sseOptions := server.SSEOptions{ - // Connection limits - MaxConnections: 100, - MaxConnectionsPerIP: 10, +func handleRealtimeTool(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Get the MCP server from context + mcpServer := server.ServerFromContext(ctx) - // Timeouts - ConnectionTimeout: 5 * time.Minute, - WriteTimeout: 30 * time.Second, - ReadTimeout: 30 * time.Second, - - // Heartbeat to keep connections alive - HeartbeatInterval: 30 * time.Second, - - // Buffer sizes - WriteBufferSize: 4096, - ReadBufferSize: 4096, + // Send a notification to the client + if mcpServer != nil { + err := mcpServer.SendNotificationToClient(ctx, "custom_event", map[string]interface{}{ + "message": "Real-time update", + "timestamp": time.Now().Unix(), + }) + if err != nil { + log.Printf("Failed to send notification: %v", err) + } + } - // Compression - EnableCompression: true, - CompressionLevel: 6, + return mcp.NewToolResultText(`{"status":"notification_sent"}`), nil } ``` -## Session Handling +### Session Management -### Multi-Client State Management +The SSE server automatically handles session management. You can send events to specific sessions using the server's notification methods: ```go -type SessionManager struct { - sessions map[string]*ClientSession - mutex sync.RWMutex - notifier *SSENotifier -} +// Send notification to current client session +mcpServer.SendNotificationToClient(ctx, "progress_update", progressData) -type ClientSession struct { - ID string - UserID string - ConnectedAt time.Time - LastSeen time.Time - Subscriptions map[string]bool - Metadata map[string]interface{} -} - -func NewSessionManager() *SessionManager { - return &SessionManager{ - sessions: make(map[string]*ClientSession), - notifier: NewSSENotifier(), - } -} +// Send notification to all connected clients (if supported) +// Note: Check the server implementation for broadcast capabilities +``` -func (sm *SessionManager) OnSessionStart(sessionID string) { - sm.mutex.Lock() - defer sm.mutex.Unlock() - - session := &ClientSession{ - ID: sessionID, - ConnectedAt: time.Now(), - LastSeen: time.Now(), - Subscriptions: make(map[string]bool), - Metadata: make(map[string]interface{}), - } - - sm.sessions[sessionID] = session - - // Notify other clients - sm.notifier.BroadcastExcept(sessionID, "user_joined", map[string]interface{}{ - "session_id": sessionID, - "timestamp": time.Now().Unix(), - }) -} +### Request Headers -func (sm *SessionManager) OnSessionEnd(sessionID string) { - sm.mutex.Lock() - defer sm.mutex.Unlock() - - delete(sm.sessions, sessionID) - - // Notify other clients - sm.notifier.Broadcast("user_left", map[string]interface{}{ - "session_id": sessionID, - "timestamp": time.Now().Unix(), - }) -} +Like the StreamableHTTP transport, the SSE transport passes HTTP request headers to MCP handlers. This allows you to access the original HTTP headers that were sent with the SSE connection in your tool and resource handlers. -func (sm *SessionManager) GetActiveSessions() []ClientSession { - sm.mutex.RLock() - defer sm.mutex.RUnlock() - - var sessions []ClientSession - for _, session := range sm.sessions { - sessions = append(sessions, *session) - } - - return sessions -} -``` +#### Accessing Headers in Handlers -### Real-Time Notifications +Headers from the SSE connection are available in all MCP request objects: ```go -type SSENotifier struct { - clients map[string]chan mcp.Notification - mutex sync.RWMutex -} +func handleStreamData(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Access request headers + headers := req.Header -func NewSSENotifier() *SSENotifier { - return &SSENotifier{ - clients: make(map[string]chan mcp.Notification), + // Use headers for authentication, tracing, etc. + authToken := headers.Get("Authorization") + if authToken == "" { + return nil, fmt.Errorf("authentication required") } -} - -func (n *SSENotifier) RegisterClient(sessionID string) <-chan mcp.Notification { - n.mutex.Lock() - defer n.mutex.Unlock() - ch := make(chan mcp.Notification, 100) - n.clients[sessionID] = ch - return ch -} - -func (n *SSENotifier) UnregisterClient(sessionID string) { - n.mutex.Lock() - defer n.mutex.Unlock() + // Access other headers + requestID := headers.Get("X-Request-ID") + userAgent := headers.Get("User-Agent") - if ch, exists := n.clients[sessionID]; exists { - close(ch) - delete(n.clients, sessionID) - } + // Rest of your handler code... + mcpServer := server.ServerFromContext(ctx) + // ... } +``` -func (n *SSENotifier) SendToClient(sessionID string, notification mcp.Notification) { - n.mutex.RLock() - defer n.mutex.RUnlock() - - if ch, exists := n.clients[sessionID]; exists { - select { - case ch <- notification: - default: - // Channel full, drop notification - } - } -} +This works for all MCP request types including: +- `CallToolRequest` +- `ReadResourceRequest` +- `ListToolsRequest` +- `ListResourcesRequest` +- `InitializeRequest` +- And other MCP request types -func (n *SSENotifier) Broadcast(eventType string, data interface{}) { - notification := mcp.Notification{ - Type: eventType, - Data: data, - } - - n.mutex.RLock() - defer n.mutex.RUnlock() - - for _, ch := range n.clients { - select { - case ch <- notification: - default: - // Channel full, skip this client - } - } -} +The headers are automatically populated by the SSE transport layer from the initial SSE connection and are available in your handlers without any additional configuration. -func (n *SSENotifier) BroadcastExcept(excludeSessionID, eventType string, data interface{}) { - notification := mcp.Notification{ - Type: eventType, - Data: data, - } - - n.mutex.RLock() - defer n.mutex.RUnlock() - - for sessionID, ch := range n.clients { - if sessionID == excludeSessionID { - continue - } - - select { - case ch <- notification: - default: - // Channel full, skip this client - } - } -} -``` +Note: Since SSE maintains a persistent connection, the headers are captured when the connection is established and remain the same for all requests during that connection's lifetime. ## Next Steps diff --git a/www/docs/pages/transports/stdio.mdx b/www/docs/pages/transports/stdio.mdx index 6dea7adcf..3f609690f 100644 --- a/www/docs/pages/transports/stdio.mdx +++ b/www/docs/pages/transports/stdio.mdx @@ -40,7 +40,7 @@ import ( func main() { s := server.NewMCPServer("File Tools", "1.0.0", server.WithToolCapabilities(true), - server.WithResourceCapabilities(true), + server.WithResourceCapabilities(true, true), ) // Add file listing tool @@ -52,7 +52,7 @@ func main() { mcp.Description("Directory path to list"), ), mcp.WithBoolean("recursive", - mcp.Default(false), + mcp.DefaultBool(false), mcp.Description("List files recursively"), ), ), @@ -98,15 +98,11 @@ func handleListFiles(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToo return mcp.NewToolResultError(fmt.Sprintf("failed to list files: %v", err)), nil } - return mcp.NewToolResultJSON(map[string]interface{}{ - "path": path, - "files": files, - "count": len(files), - "recursive": recursive, - }), nil + return mcp.NewToolResultText(fmt.Sprintf(`{"path":"%s","files":%v,"count":%d,"recursive":%t}`, + path, files, len(files), recursive)), nil } -func handleFileContent(ctx context.Context, req mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { +func handleFileContent(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) { // Extract path from URI: "file:///path/to/file" -> "/path/to/file" path := extractPathFromURI(req.Params.URI) @@ -119,13 +115,11 @@ func handleFileContent(ctx context.Context, req mcp.ReadResourceRequest) (*mcp.R return nil, fmt.Errorf("failed to read file: %w", err) } - return &mcp.ReadResourceResult{ - Contents: []mcp.ResourceContent{ - { - URI: req.Params.URI, - MIMEType: detectMIMEType(path), - Text: string(content), - }, + return []mcp.ResourceContents{ + mcp.TextResourceContents{ + URI: req.Params.URI, + MIMEType: detectMIMEType(path), + Text: string(content), }, }, nil } @@ -211,16 +205,10 @@ import ( func main() { s := server.NewMCPServer("Advanced CLI Tool", "1.0.0", - server.WithAllCapabilities(), - server.WithRecovery(), - server.WithHooks(&server.Hooks{ - OnSessionStart: func(sessionID string) { - logToFile(fmt.Sprintf("Session started: %s", sessionID)) - }, - OnSessionEnd: func(sessionID string) { - logToFile(fmt.Sprintf("Session ended: %s", sessionID)) - }, - }), + server.WithResourceCapabilities(true, true), + server.WithPromptCapabilities(true), + server.WithToolCapabilities(true), + server.WithLogging(), ) // Add comprehensive tools @@ -325,29 +313,42 @@ package main import ( "context" "log" + "time" + "github.com/mark3labs/mcp-go/client" + "github.com/mark3labs/mcp-go/mcp" ) func main() { // Create STDIO client c, err := client.NewStdioClient( - "go", "run", "/path/to/server/main.go", + "go", nil /* inherit env */, "run", "/path/to/server/main.go", ) if err != nil { log.Fatal(err) } defer c.Close() - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() // Initialize connection - if err := c.Initialize(ctx); err != nil { + _, err = c.Initialize(ctx, mcp.InitializeRequest{ + Params: mcp.InitializeRequestParams{ + ProtocolVersion: mcp.LATEST_PROTOCOL_VERSION, + ClientInfo: mcp.Implementation{ + Name: "test-client", + Version: "1.0.0", + }, + }, + }) + if err != nil { log.Fatal(err) } // List available tools - tools, err := c.ListTools(ctx) + tools, err := c.ListTools(ctx, mcp.ListToolsRequest{}) if err != nil { log.Fatal(err) } @@ -359,7 +360,7 @@ func main() { // Call a tool result, err := c.CallTool(ctx, mcp.CallToolRequest{ - Params: mcp.CallToolRequestParams{ + Params: mcp.CallToolParams{ Name: "list_files", Arguments: map[string]interface{}{ "path": ".", @@ -375,6 +376,58 @@ func main() { } ``` +#### Customizing Subprocess Execution + +If you need more control over how a sub-process is spawned when creating a new STDIO client, you can use +`NewStdioMCPClientWithOptions` instead of `NewStdioMCPClient`. + +By passing the `WithCommandFunc` option, you can supply a custom factory function to create the `exec.Cmd` that launches +the server. This allows configuration of environment variables, working directories, and system-level process attributes. + +Referring to the previous example, we can replace the line that creates the client: + +```go +c, err := client.NewStdioClient( + "go", nil, "run", "/path/to/server/main.go", +) +``` + +With the options-aware version: + +```go +c, err := client.NewStdioMCPClientWithOptions( + "go", + nil, + []string {"run", "/path/to/server/main.go"}, + transport.WithCommandFunc(func(ctx context.Context, command string, args []string, env []string) (*exec.Cmd, error) { + cmd := exec.CommandContext(ctx, command, args...) + cmd.Env = env // Explicit environment for the subprocess. + cmd.Dir = "/var/sandbox/mcp-server" // Working directory (not isolated unless paired with chroot or namespace). + + // Apply low-level process isolation and privilege dropping. + cmd.SysProcAttr = &syscall.SysProcAttr{ + // Drop to non-root user (e.g., user/group ID 1001) + Credential: &syscall.Credential{ + Uid: 1001, + Gid: 1001, + }, + // File system isolation: only works if running as root. + Chroot: "/var/sandbox/mcp-server", + + // Linux namespace isolation (Linux only): + // Prevents access to other processes, mounts, IPC, networks, etc. + Cloneflags: syscall.CLONE_NEWIPC | // Isolate inter-process comms + syscall.CLONE_NEWNS | // Isolate filesystem mounts + syscall.CLONE_NEWPID | // Isolate PID namespace (child sees itself as PID 1) + syscall.CLONE_NEWUTS | // Isolate hostname + syscall.CLONE_NEWNET, // Isolate networking (optional) + } + + return cmd, nil + }), +) +``` + ## Debugging ### Command Line Testing @@ -445,18 +498,7 @@ func main() { s := server.NewMCPServer("Debug Server", "1.0.0", server.WithToolCapabilities(true), - server.WithHooks(&server.Hooks{ - OnSessionStart: func(sessionID string) { - logger.Printf("Session started: %s", sessionID) - }, - OnToolCall: func(sessionID, toolName string, duration time.Duration, err error) { - if err != nil { - logger.Printf("Tool %s failed: %v", toolName, err) - } else { - logger.Printf("Tool %s completed in %v", toolName, duration) - } - }, - }), + server.WithLogging(), ) // Add tools with debug logging @@ -466,7 +508,7 @@ func main() { mcp.WithString("message", mcp.Required()), ), func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - message := req.Params.Arguments["message"].(string) + message := req.GetString("message", "") logger.Printf("Echo tool called with message: %s", message) return mcp.NewToolResultText(fmt.Sprintf("Echo: %s", message)), nil }, @@ -504,8 +546,8 @@ This opens a web interface where you can: ```go func handleToolWithErrors(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { // Validate required parameters - path, ok := req.Params.Arguments["path"].(string) - if !ok { + path, err := req.RequireString("path") + if err != nil { return nil, fmt.Errorf("path parameter is required and must be a string") } @@ -543,7 +585,7 @@ func handleToolWithErrors(ctx context.Context, req mcp.CallToolRequest) (*mcp.Ca return nil, fmt.Errorf("operation failed: %w", err) } - return mcp.NewToolResultJSON(result), nil + return mcp.NewToolResultText(fmt.Sprintf("%v", result)), nil } ``` @@ -638,7 +680,7 @@ func getCachedFile(path string) (string, bool) { ```go func handleLargeFile(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { - path := req.Params.Arguments["path"].(string) + path := req.GetString("path", "") // Stream large files instead of loading into memory file, err := os.Open(path) diff --git a/www/vocs.config.ts b/www/vocs.config.ts index 0706755d7..dc745b385 100644 --- a/www/vocs.config.ts +++ b/www/vocs.config.ts @@ -100,6 +100,20 @@ export default defineConfig({ }, ], }, + { + text: 'Advanced', + collapsed: true, + items: [ + { + text: 'Server Sampling', + link: '/servers/advanced-sampling', + }, + { + text: 'Client Sampling', + link: '/clients/advanced-sampling', + }, + ], + }, ], socials: [ {