ollama/server/mcp_client.go

811 lines
21 KiB
Go

package server
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"os"
"os/exec"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/ollama/ollama/api"
)
// MCPClient manages communication with a single MCP server via JSON-RPC over stdio
type MCPClient struct {
name string
command string
args []string
env map[string]string
// Process management
cmd *exec.Cmd
stdin io.WriteCloser
stdout *bufio.Reader
stderr *bufio.Reader
// State
mu sync.RWMutex
initialized bool
tools []api.Tool
requestID int64
responses map[int64]chan *jsonRPCResponse
// Lifecycle
ctx context.Context
cancel context.CancelFunc
done chan struct{}
// Pipe handles (for clean shutdown)
stdoutPipe io.ReadCloser
stderrPipe io.ReadCloser
// Dependencies (injectable for testing)
commandResolver CommandResolverInterface
}
// =============================================================================
// MCPClient Options (Functional Options Pattern)
// =============================================================================
// MCPClientOption configures an MCPClient during creation.
type MCPClientOption func(*MCPClient)
// WithCommandResolver sets a custom command resolver for the client.
// This is primarily useful for testing to avoid system dependencies.
//
// Example:
//
// mockResolver := &MockCommandResolver{...}
// client := NewMCPClient("test", "npx", args, env, WithCommandResolver(mockResolver))
func WithCommandResolver(resolver CommandResolverInterface) MCPClientOption {
return func(c *MCPClient) {
c.commandResolver = resolver
}
}
// JSON-RPC 2.0 message types
type jsonRPCRequest struct {
JSONRPC string `json:"jsonrpc"`
ID *int64 `json:"id,omitempty"`
Method string `json:"method"`
Params interface{} `json:"params,omitempty"`
}
type jsonRPCResponse struct {
JSONRPC string `json:"jsonrpc"`
ID *int64 `json:"id"`
Result json.RawMessage `json:"result,omitempty"`
Error *jsonRPCError `json:"error,omitempty"`
}
type jsonRPCError struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
}
// MCP protocol message types
type mcpInitializeRequest struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities map[string]interface{} `json:"capabilities"`
ClientInfo mcpClientInfo `json:"clientInfo"`
}
type mcpClientInfo struct {
Name string `json:"name"`
Version string `json:"version"`
}
type mcpInitializeResponse struct {
ProtocolVersion string `json:"protocolVersion"`
Capabilities map[string]interface{} `json:"capabilities"`
ServerInfo mcpServerInfo `json:"serverInfo"`
}
type mcpServerInfo struct {
Name string `json:"name"`
Version string `json:"version"`
}
type mcpListToolsRequest struct{}
type mcpListToolsResponse struct {
Tools []mcpTool `json:"tools"`
}
type mcpTool struct {
Name string `json:"name"`
Description string `json:"description,omitempty"`
InputSchema map[string]interface{} `json:"inputSchema"`
}
type mcpCallToolRequest struct {
Name string `json:"name"`
Arguments map[string]interface{} `json:"arguments"`
}
type mcpCallToolResponse struct {
Content []mcpContent `json:"content"`
IsError bool `json:"isError,omitempty"`
}
type mcpContent struct {
Type string `json:"type"`
Text string `json:"text"`
}
// NewMCPClient creates a new MCP client for the specified server configuration.
// Optional MCPClientOption arguments can be used to customize behavior (e.g., for testing).
func NewMCPClient(name, command string, args []string, env map[string]string, opts ...MCPClientOption) *MCPClient {
ctx, cancel := context.WithCancel(context.Background())
client := &MCPClient{
name: name,
command: command, // Will be resolved after options are applied
args: args,
env: env,
responses: make(map[int64]chan *jsonRPCResponse),
ctx: ctx,
cancel: cancel,
done: make(chan struct{}),
commandResolver: DefaultCommandResolver, // Default, can be overridden
}
// Apply options
for _, opt := range opts {
opt(client)
}
// Guard against nil resolver (e.g., if WithCommandResolver(nil) was called)
if client.commandResolver == nil {
client.commandResolver = DefaultCommandResolver
}
// Resolve the command using the configured resolver
client.command = client.commandResolver.ResolveForEnvironment(command)
return client
}
// Start spawns the MCP server process and initializes communication.
//
// SECURITY REVIEW: This function executes external processes. Security controls:
// - Command must pass validation in MCPManager.validateServerConfig()
// - Environment is filtered via buildSecureEnvironment()
// - Process runs in isolated process group (Setpgid)
// - Context-based cancellation for cleanup
func (c *MCPClient) Start() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.cmd != nil {
return errors.New("MCP client already started")
}
// Handle commands that might have spaces (like "pnpm dlx")
cmdParts := strings.Fields(c.command)
var cmdName string
var cmdArgs []string
if len(cmdParts) > 1 {
cmdName = cmdParts[0]
cmdArgs = append(cmdParts[1:], c.args...)
} else {
cmdName = c.command
cmdArgs = c.args
}
// SECURITY: Create command with context for cancellation control
c.cmd = exec.CommandContext(c.ctx, cmdName, cmdArgs...)
// SECURITY: Apply filtered environment (see buildSecureEnvironment)
c.cmd.Env = c.buildSecureEnvironment()
// SECURITY: Process isolation via process group
c.cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true, // Isolate in own process group
Pgid: 0, // New process group
// Future: Consider adding privilege dropping for root users
// Credential: &syscall.Credential{Uid: 65534, Gid: 65534}
}
// Set up pipes for communication.
// Each error path explicitly closes previously opened pipes to prevent leaks.
// On success, pipes remain open for the lifetime of the MCP client.
stdin, err := c.cmd.StdinPipe()
if err != nil {
return fmt.Errorf("failed to create stdin pipe: %w", err)
}
c.stdin = stdin
stdout, err := c.cmd.StdoutPipe()
if err != nil {
c.stdin.Close()
return fmt.Errorf("failed to create stdout pipe: %w", err)
}
c.stdoutPipe = stdout
c.stdout = bufio.NewReader(stdout)
stderr, err := c.cmd.StderrPipe()
if err != nil {
c.stdin.Close()
stdout.Close()
return fmt.Errorf("failed to create stderr pipe: %w", err)
}
c.stderrPipe = stderr
c.stderr = bufio.NewReader(stderr)
// Start the process
if err := c.cmd.Start(); err != nil {
c.stdin.Close()
stdout.Close()
stderr.Close()
return fmt.Errorf("failed to start MCP server: %w", err)
}
slog.Info("MCP server started", "name", c.name, "pid", c.cmd.Process.Pid)
// Start message handling goroutines
go c.handleResponses()
go c.handleErrors()
// Check if the process is still running after a brief delay
// This catches immediate failures like command not found
processCheckDone := make(chan bool, 1)
go func() {
time.Sleep(100 * time.Millisecond)
// Non-blocking check if process has exited
if c.cmd.ProcessState != nil {
processCheckDone <- false
return
}
// Try to check process existence without waiting
if c.cmd.Process != nil {
// On Unix systems, signal 0 can be used to check process existence
if err := c.cmd.Process.Signal(syscall.Signal(0)); err != nil {
processCheckDone <- false
return
}
}
processCheckDone <- true
}()
select {
case alive := <-processCheckDone:
if !alive {
// Process died immediately - collect the error
waitErr := c.cmd.Wait()
c.stdin.Close()
stdout.Close()
stderr.Close()
return fmt.Errorf("MCP server exited immediately: %w", waitErr)
}
case <-time.After(200 * time.Millisecond):
// Process seems to be running, continue
}
return nil
}
// Initialize performs the MCP handshake sequence
func (c *MCPClient) Initialize() error {
if err := c.Start(); err != nil {
return err
}
// Add timeout to initialization to prevent hanging
initCtx, cancel := context.WithTimeout(c.ctx, 10*time.Second)
defer cancel()
// Send initialize request
req := mcpInitializeRequest{
ProtocolVersion: "2024-11-05",
Capabilities: map[string]interface{}{
"tools": map[string]interface{}{},
},
ClientInfo: mcpClientInfo{
Name: "ollama",
Version: "0.1.0",
},
}
var resp mcpInitializeResponse
if err := c.callWithContext(initCtx, "initialize", req, &resp); err != nil {
return fmt.Errorf("MCP initialize failed: %w", err)
}
// Send initialized notification
if err := c.notify("notifications/initialized", nil); err != nil {
return fmt.Errorf("MCP initialized notification failed: %w", err)
}
c.mu.Lock()
c.initialized = true
c.mu.Unlock()
slog.Info("MCP client initialized", "name", c.name, "server", resp.ServerInfo.Name)
return nil
}
// ListTools discovers available tools from the MCP server
func (c *MCPClient) ListTools() ([]api.Tool, error) {
c.mu.RLock()
if !c.initialized {
c.mu.RUnlock()
return nil, errors.New("MCP client not initialized")
}
// Return cached tools if available
if len(c.tools) > 0 {
tools := make([]api.Tool, len(c.tools))
copy(tools, c.tools)
c.mu.RUnlock()
return tools, nil
}
c.mu.RUnlock()
var resp mcpListToolsResponse
if err := c.call("tools/list", mcpListToolsRequest{}, &resp); err != nil {
return nil, fmt.Errorf("failed to list MCP tools: %w", err)
}
// Convert MCP tools to Ollama API format
tools := make([]api.Tool, 0, len(resp.Tools))
for _, mcpTool := range resp.Tools {
tool := api.Tool{
Type: "function",
Function: api.ToolFunction{
Name: fmt.Sprintf("%s:%s", c.name, mcpTool.Name), // Namespace with server name
Description: mcpTool.Description,
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: make(map[string]api.ToolProperty),
Required: []string{},
},
},
}
// Convert input schema to tool parameters
if props, ok := mcpTool.InputSchema["properties"].(map[string]interface{}); ok {
for propName, propDef := range props {
propDefMap, ok := propDef.(map[string]interface{})
if !ok {
slog.Debug("MCP schema: property definition not a map", "tool", mcpTool.Name, "property", propName)
continue
}
toolProp := api.ToolProperty{
Description: getStringFromMap(propDefMap, "description"),
}
if propType, ok := propDefMap["type"].(string); ok {
toolProp.Type = api.PropertyType{propType}
} else {
slog.Debug("MCP schema: property type not a string", "tool", mcpTool.Name, "property", propName)
}
// Preserve items schema for array types (needed for context injection)
if items, ok := propDefMap["items"]; ok {
toolProp.Items = items
}
tool.Function.Parameters.Properties[propName] = toolProp
}
} else if mcpTool.InputSchema["properties"] != nil {
slog.Debug("MCP schema: properties not a map", "tool", mcpTool.Name)
}
if required, ok := mcpTool.InputSchema["required"].([]interface{}); ok {
for _, req := range required {
if reqStr, ok := req.(string); ok {
tool.Function.Parameters.Required = append(tool.Function.Parameters.Required, reqStr)
} else {
slog.Debug("MCP schema: required item not a string", "tool", mcpTool.Name)
}
}
} else if mcpTool.InputSchema["required"] != nil {
slog.Debug("MCP schema: required not an array", "tool", mcpTool.Name)
}
tools = append(tools, tool)
}
// Cache the tools
c.mu.Lock()
c.tools = tools
c.mu.Unlock()
slog.Debug("MCP tools discovered", "name", c.name, "count", len(tools))
return tools, nil
}
// CallTool executes a tool call via the MCP server
func (c *MCPClient) CallTool(name string, args map[string]interface{}) (string, error) {
c.mu.RLock()
if !c.initialized {
c.mu.RUnlock()
return "", errors.New("MCP client not initialized")
}
c.mu.RUnlock()
// Remove namespace prefix if present
toolName := name
if prefix := c.name + ":"; len(name) > len(prefix) && name[:len(prefix)] == prefix {
toolName = name[len(prefix):]
}
// Ensure arguments is never nil (MCP protocol requires an object, not undefined)
if args == nil {
args = make(map[string]interface{})
}
req := mcpCallToolRequest{
Name: toolName,
Arguments: args,
}
// Debug logging removed
var resp mcpCallToolResponse
// Set timeout for tool execution
ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second)
defer cancel()
if err := c.callWithContext(ctx, "tools/call", req, &resp); err != nil {
return "", fmt.Errorf("MCP tool call failed: %w", err)
}
slog.Debug("MCP tool response", "name", name, "is_error", resp.IsError, "content_count", len(resp.Content))
if resp.IsError {
// Log error without full response to avoid exposing sensitive data
slog.Error("MCP tool execution error", "name", name, "content_count", len(resp.Content))
return "", fmt.Errorf("MCP tool returned error")
}
// Concatenate all text content
var result string
for _, content := range resp.Content {
if content.Type == "text" {
result += content.Text
}
}
// Debug logging removed
return result, nil
}
// GetTools returns the list of tools available from this MCP server
func (c *MCPClient) GetTools() []api.Tool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.tools
}
// Close shuts down the MCP client and terminates the server process
func (c *MCPClient) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.cmd == nil {
return nil
}
slog.Info("Shutting down MCP client", "name", c.name)
// Cancel context to stop goroutines
c.cancel()
// Close stdin to signal shutdown
if c.stdin != nil {
c.stdin.Close()
}
// Close stdout/stderr pipes to unblock handleResponses/handleErrors goroutines
if c.stdoutPipe != nil {
c.stdoutPipe.Close()
}
if c.stderrPipe != nil {
c.stderrPipe.Close()
}
// Wait for process to exit gracefully
done := make(chan error, 1)
go func() {
done <- c.cmd.Wait()
}()
select {
case err := <-done:
if err != nil {
slog.Warn("MCP server exited with error", "name", c.name, "error", err)
}
case <-time.After(5 * time.Second):
// Force kill if not responding
slog.Warn("Force killing unresponsive MCP server", "name", c.name)
c.cmd.Process.Kill()
<-done
}
c.cmd = nil
c.initialized = false
close(c.done)
return nil
}
// call sends a JSON-RPC request and waits for the response
func (c *MCPClient) call(method string, params interface{}, result interface{}) error {
return c.callWithContext(c.ctx, method, params, result)
}
// callWithContext sends a JSON-RPC request with a custom context
func (c *MCPClient) callWithContext(ctx context.Context, method string, params interface{}, result interface{}) error {
id := atomic.AddInt64(&c.requestID, 1)
req := jsonRPCRequest{
JSONRPC: "2.0",
ID: &id,
Method: method,
Params: params,
}
// Create response channel
respChan := make(chan *jsonRPCResponse, 1)
c.mu.Lock()
c.responses[id] = respChan
c.mu.Unlock()
defer func() {
c.mu.Lock()
delete(c.responses, id)
c.mu.Unlock()
close(respChan)
}()
// Send request
if err := c.sendRequest(req); err != nil {
return err
}
// Wait for response
select {
case resp := <-respChan:
if resp.Error != nil {
return fmt.Errorf("JSON-RPC error %d: %s", resp.Error.Code, resp.Error.Message)
}
if result != nil && resp.Result != nil {
if err := json.Unmarshal(resp.Result, result); err != nil {
return fmt.Errorf("failed to unmarshal response: %w", err)
}
}
return nil
case <-ctx.Done():
return ctx.Err()
case <-c.done:
return errors.New("MCP client closed")
}
}
// notify sends a JSON-RPC notification (no response expected)
func (c *MCPClient) notify(method string, params interface{}) error {
req := jsonRPCRequest{
JSONRPC: "2.0",
Method: method,
Params: params,
}
return c.sendRequest(req)
}
// sendRequest sends a JSON-RPC request over stdin
func (c *MCPClient) sendRequest(req jsonRPCRequest) error {
if c.stdin == nil {
return fmt.Errorf("client not started")
}
data, err := json.Marshal(req)
if err != nil {
return fmt.Errorf("failed to marshal request: %w", err)
}
if _, err := c.stdin.Write(append(data, '\n')); err != nil {
return fmt.Errorf("failed to write request: %w", err)
}
return nil
}
// handleResponses processes incoming JSON-RPC responses from stdout
func (c *MCPClient) handleResponses() {
defer func() {
if r := recover(); r != nil {
slog.Error("MCP response handler panic", "name", c.name, "error", r)
}
}()
scanner := bufio.NewScanner(c.stdout)
// Set a larger buffer to handle long JSON responses
scanner.Buffer(make([]byte, 64*1024), 1024*1024) // 64KB initial, 1MB max
for {
select {
case <-c.done:
return
default:
if !scanner.Scan() {
if err := scanner.Err(); err != nil {
slog.Error("Error reading MCP response", "name", c.name, "error", err)
}
return
}
line := scanner.Bytes()
var resp jsonRPCResponse
if err := json.Unmarshal(line, &resp); err != nil {
// Don't log raw line content - may contain sensitive data
slog.Warn("Invalid JSON-RPC response", "name", c.name, "error", err, "length", len(line))
continue
}
// Route response to waiting caller
if resp.ID != nil {
c.mu.RLock()
if respChan, exists := c.responses[*resp.ID]; exists {
select {
case respChan <- &resp:
default:
slog.Warn("Response channel full", "name", c.name, "id", *resp.ID)
}
}
c.mu.RUnlock()
}
}
}
}
// handleErrors processes stderr output from the MCP server
func (c *MCPClient) handleErrors() {
defer func() {
if r := recover(); r != nil {
slog.Error("MCP error handler panic", "name", c.name, "error", r)
}
}()
for {
select {
case <-c.done:
return
default:
line, isPrefix, err := c.stderr.ReadLine()
if err != nil {
if err != io.EOF {
slog.Error("Error reading MCP stderr", "name", c.name, "error", err)
}
return
}
if isPrefix {
slog.Warn("MCP stderr line too long, truncated", "name", c.name)
}
// Truncate stderr to avoid logging excessive/sensitive output
msg := string(line)
if len(msg) > 200 {
msg = msg[:200] + "...(truncated)"
}
slog.Debug("MCP server stderr", "name", c.name, "message", msg)
}
}
}
// buildSecureEnvironment creates a filtered environment for the MCP server.
//
// SECURITY REVIEW: This is a critical security function. It controls what
// environment variables are passed to MCP server processes.
//
// Defense strategy (defense in depth):
// 1. Start with empty environment (not inherited)
// 2. Allowlist only known-safe variables
// 3. Apply MCPSecurityConfig filtering (blocks credentials)
// 4. Sanitize PATH to remove dangerous directories
// 5. Add custom env vars only after security checks
func (c *MCPClient) buildSecureEnvironment() []string {
// SECURITY: Start with empty env, not os.Environ()
env := []string{}
// Get security configuration
securityConfig := GetSecurityConfig()
// SECURITY: Allowlist of safe environment variables.
// Only these variables can be passed through from the parent process.
allowedVars := map[string]bool{
"PATH": true,
"HOME": true,
"USER": true,
"LANG": true,
"LC_ALL": true,
"LC_CTYPE": true,
"TZ": true,
"TMPDIR": true,
"TEMP": true,
"TMP": true,
"TERM": true,
"PYTHONPATH": true,
"NODE_PATH": true,
"DISPLAY": true, // For GUI applications
"EDITOR": true, // For text editing
"SHELL": false, // Explicitly blocked - could enable shell escapes
}
// Filter existing environment variables
for _, e := range os.Environ() {
parts := strings.SplitN(e, "=", 2)
if len(parts) != 2 {
continue
}
key := parts[0]
value := parts[1]
// Check if variable should be filtered based on security config
if securityConfig.ShouldFilterEnvironmentVar(key) {
slog.Debug("MCP: filtered credential env var", "name", c.name, "key", key)
continue
}
// Only include explicitly allowed variables
if allowed, exists := allowedVars[key]; exists && allowed {
env = append(env, fmt.Sprintf("%s=%s", key, value))
}
}
// Add custom environment variables from server config.
// NOTE: Custom vars only check the blocklist, not the allowlist. This is intentional:
// inherited vars use strict allowlist (defense in depth), but custom vars are explicitly
// configured by the user/admin, so they're trusted if not in the blocklist.
for key, value := range c.env {
if securityConfig.ShouldFilterEnvironmentVar(key) {
slog.Debug("MCP: blocked custom env var", "name", c.name, "key", key)
continue
}
env = append(env, fmt.Sprintf("%s=%s", key, value))
}
// Set a restricted PATH if not already set
if !hasEnvVar(env, "PATH") {
env = append(env, "PATH=/usr/local/bin:/usr/bin:/bin")
}
return env
}
// getStringFromMap safely extracts a string value from a map
func getStringFromMap(m map[string]interface{}, key string) string {
if val, ok := m[key].(string); ok {
return val
}
return ""
}
// hasEnvVar checks if an environment variable exists in the env slice
func hasEnvVar(env []string, key string) bool {
prefix := key + "="
for _, e := range env {
if strings.HasPrefix(e, prefix) {
return true
}
}
return false
}