server: add MCP client for JSON-RPC communication
Add core MCP client implementation for communicating with Model Context Protocol servers over stdio using JSON-RPC 2.0. MCPClient (server/mcp_client.go): - JSON-RPC 2.0 client with request/response correlation - Handles initialize, tools/list, tools/call lifecycle - Secure environment variable filtering (blocks secrets, sanitizes PATH) - Request timeouts and graceful shutdown - Parses MCP tool definitions into Ollama's api.Tool format - MCPClientOption pattern for dependency injection in tests API Types (api/types.go): - MCPServerConfig: Runtime server configuration - MCPTool: Tool definition with JSON Schema inputSchema - MCPToolResult: Execution result container Relates to #7865
This commit is contained in:
parent
18fdcc94e5
commit
a8027c6386
158
api/types.go
158
api/types.go
|
|
@ -7,6 +7,7 @@ import (
|
|||
"math"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
|
@ -125,6 +126,12 @@ type GenerateRequest struct {
|
|||
// each with an associated log probability. Only applies when Logprobs is true.
|
||||
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
|
||||
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||
|
||||
// Tools is a list of tools the model may call.
|
||||
Tools []Tool `json:"tools,omitempty"`
|
||||
|
||||
// MCPServers specifies MCP servers to use for tool functionality
|
||||
MCPServers []MCPServerConfig `json:"mcp_servers,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
|
|
@ -175,6 +182,26 @@ type ChatRequest struct {
|
|||
// each with an associated log probability. Only applies when Logprobs is true.
|
||||
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
|
||||
TopLogprobs int `json:"top_logprobs,omitempty"`
|
||||
|
||||
// MCPServers is an optional list of MCP (Model Context Protocol) servers
|
||||
// that provide tools for autonomous execution during the chat.
|
||||
MCPServers []MCPServerConfig `json:"mcp_servers,omitempty"`
|
||||
|
||||
// MaxToolRounds limits the number of tool execution rounds to prevent
|
||||
// infinite loops. Defaults to 15 if not specified.
|
||||
MaxToolRounds int `json:"max_tool_rounds,omitempty"`
|
||||
|
||||
// ToolTimeout sets the timeout for individual tool executions.
|
||||
// Defaults to 30 seconds if not specified.
|
||||
ToolTimeout *Duration `json:"tool_timeout,omitempty"`
|
||||
|
||||
// SessionID is an optional session identifier for maintaining MCP state
|
||||
// across multiple API calls. If not provided, a new session is created.
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
|
||||
// ToolsPath is the file path passed via --tools flag in interactive mode.
|
||||
// Used to generate consistent session IDs for tool continuity.
|
||||
ToolsPath string `json:"tools_path,omitempty"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
|
|
@ -197,11 +224,12 @@ type Message struct {
|
|||
Content string `json:"content"`
|
||||
// Thinking contains the text that was inside thinking tags in the
|
||||
// original model output when ChatRequest.Think is enabled.
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolName string `json:"tool_name,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
Thinking string `json:"thinking,omitempty"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolResults []ToolResult `json:"tool_results,omitempty"` // MCP tool results
|
||||
ToolName string `json:"tool_name,omitempty"`
|
||||
ToolCallID string `json:"tool_call_id,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||
|
|
@ -221,6 +249,12 @@ type ToolCall struct {
|
|||
Function ToolCallFunction `json:"function"`
|
||||
}
|
||||
|
||||
type ToolResult struct {
|
||||
ToolName string `json:"tool_name"`
|
||||
Content string `json:"content"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type ToolCallFunction struct {
|
||||
Index int `json:"index"`
|
||||
Name string `json:"name"`
|
||||
|
|
@ -306,18 +340,106 @@ func (tp ToolProperty) ToTypeScriptType() string {
|
|||
}
|
||||
|
||||
if len(tp.Type) == 1 {
|
||||
return mapToTypeScriptType(tp.Type[0])
|
||||
return tp.mapSingleType(tp.Type[0])
|
||||
}
|
||||
|
||||
var types []string
|
||||
for _, t := range tp.Type {
|
||||
types = append(types, mapToTypeScriptType(t))
|
||||
types = append(types, tp.mapSingleType(t))
|
||||
}
|
||||
return strings.Join(types, " | ")
|
||||
}
|
||||
|
||||
// mapToTypeScriptType maps JSON Schema types to TypeScript types
|
||||
func mapToTypeScriptType(jsonType string) string {
|
||||
// mapSingleType maps a single JSON Schema type to TypeScript, using Items for arrays
|
||||
func (tp ToolProperty) mapSingleType(jsonType string) string {
|
||||
switch jsonType {
|
||||
case "string":
|
||||
return "string"
|
||||
case "number", "integer":
|
||||
return "number"
|
||||
case "boolean":
|
||||
return "boolean"
|
||||
case "array":
|
||||
return tp.arrayTypeString()
|
||||
case "object":
|
||||
return "Record<string, any>"
|
||||
case "null":
|
||||
return "null"
|
||||
default:
|
||||
return "any"
|
||||
}
|
||||
}
|
||||
|
||||
// arrayTypeString generates TypeScript type for arrays, using Items schema
|
||||
func (tp ToolProperty) arrayTypeString() string {
|
||||
if tp.Items == nil {
|
||||
return "any[]"
|
||||
}
|
||||
|
||||
itemsMap, ok := tp.Items.(map[string]interface{})
|
||||
if !ok {
|
||||
return "any[]"
|
||||
}
|
||||
|
||||
// Check if items has a type
|
||||
itemType, ok := itemsMap["type"].(string)
|
||||
if !ok {
|
||||
return "any[]"
|
||||
}
|
||||
|
||||
// For object types with properties, generate inline type
|
||||
if itemType == "object" {
|
||||
if props, ok := itemsMap["properties"].(map[string]interface{}); ok {
|
||||
return objectTypeFromProperties(props) + "[]"
|
||||
}
|
||||
}
|
||||
|
||||
// Simple type array
|
||||
return mapSimpleType(itemType) + "[]"
|
||||
}
|
||||
|
||||
// objectTypeFromProperties generates a TypeScript inline object type from properties
|
||||
func objectTypeFromProperties(props map[string]interface{}) string {
|
||||
if len(props) == 0 {
|
||||
return "Record<string, any>"
|
||||
}
|
||||
|
||||
// Collect and sort field names for consistent output
|
||||
names := make([]string, 0, len(props))
|
||||
for name := range props {
|
||||
names = append(names, name)
|
||||
}
|
||||
sort.Strings(names)
|
||||
|
||||
var fields []string
|
||||
for _, name := range names {
|
||||
propDef := props[name]
|
||||
propMap, ok := propDef.(map[string]interface{})
|
||||
if !ok {
|
||||
fields = append(fields, name+": any")
|
||||
continue
|
||||
}
|
||||
|
||||
propType := "any"
|
||||
if t, ok := propMap["type"].(string); ok {
|
||||
propType = mapSimpleType(t)
|
||||
// Handle nested arrays
|
||||
if t == "array" {
|
||||
if items, ok := propMap["items"].(map[string]interface{}); ok {
|
||||
if itemType, ok := items["type"].(string); ok {
|
||||
propType = mapSimpleType(itemType) + "[]"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fields = append(fields, name+": "+propType)
|
||||
}
|
||||
|
||||
return "{" + strings.Join(fields, ", ") + "}"
|
||||
}
|
||||
|
||||
// mapSimpleType maps JSON Schema types to TypeScript (without recursion)
|
||||
func mapSimpleType(jsonType string) string {
|
||||
switch jsonType {
|
||||
case "string":
|
||||
return "string"
|
||||
|
|
@ -381,6 +503,21 @@ type Logprob struct {
|
|||
TopLogprobs []TokenLogprob `json:"top_logprobs,omitempty"`
|
||||
}
|
||||
|
||||
// MCPServerConfig represents configuration for an MCP (Model Context Protocol) server
|
||||
type MCPServerConfig struct {
|
||||
// Name is a unique identifier for the MCP server
|
||||
Name string `json:"name"`
|
||||
|
||||
// Command is the executable command to start the MCP server
|
||||
Command string `json:"command"`
|
||||
|
||||
// Args are optional command-line arguments for the MCP server
|
||||
Args []string `json:"args,omitempty"`
|
||||
|
||||
// Env are optional environment variables for the MCP server
|
||||
Env map[string]string `json:"env,omitempty"`
|
||||
}
|
||||
|
||||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||
// similar to [GenerateResponse].
|
||||
type ChatResponse struct {
|
||||
|
|
@ -720,7 +857,8 @@ type GenerateResponse struct {
|
|||
|
||||
Metrics
|
||||
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
ToolResults []ToolResult `json:"tool_results,omitempty"`
|
||||
|
||||
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,810 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// MCPClient manages communication with a single MCP server via JSON-RPC over stdio
|
||||
type MCPClient struct {
|
||||
name string
|
||||
command string
|
||||
args []string
|
||||
env map[string]string
|
||||
|
||||
// Process management
|
||||
cmd *exec.Cmd
|
||||
stdin io.WriteCloser
|
||||
stdout *bufio.Reader
|
||||
stderr *bufio.Reader
|
||||
|
||||
// State
|
||||
mu sync.RWMutex
|
||||
initialized bool
|
||||
tools []api.Tool
|
||||
requestID int64
|
||||
responses map[int64]chan *jsonRPCResponse
|
||||
|
||||
// Lifecycle
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
done chan struct{}
|
||||
|
||||
// Pipe handles (for clean shutdown)
|
||||
stdoutPipe io.ReadCloser
|
||||
stderrPipe io.ReadCloser
|
||||
|
||||
// Dependencies (injectable for testing)
|
||||
commandResolver CommandResolverInterface
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// MCPClient Options (Functional Options Pattern)
|
||||
// =============================================================================
|
||||
|
||||
// MCPClientOption configures an MCPClient during creation.
|
||||
type MCPClientOption func(*MCPClient)
|
||||
|
||||
// WithCommandResolver sets a custom command resolver for the client.
|
||||
// This is primarily useful for testing to avoid system dependencies.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// mockResolver := &MockCommandResolver{...}
|
||||
// client := NewMCPClient("test", "npx", args, env, WithCommandResolver(mockResolver))
|
||||
func WithCommandResolver(resolver CommandResolverInterface) MCPClientOption {
|
||||
return func(c *MCPClient) {
|
||||
c.commandResolver = resolver
|
||||
}
|
||||
}
|
||||
|
||||
// JSON-RPC 2.0 message types
|
||||
type jsonRPCRequest struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID *int64 `json:"id,omitempty"`
|
||||
Method string `json:"method"`
|
||||
Params interface{} `json:"params,omitempty"`
|
||||
}
|
||||
|
||||
type jsonRPCResponse struct {
|
||||
JSONRPC string `json:"jsonrpc"`
|
||||
ID *int64 `json:"id"`
|
||||
Result json.RawMessage `json:"result,omitempty"`
|
||||
Error *jsonRPCError `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
type jsonRPCError struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data interface{} `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// MCP protocol message types
|
||||
type mcpInitializeRequest struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities map[string]interface{} `json:"capabilities"`
|
||||
ClientInfo mcpClientInfo `json:"clientInfo"`
|
||||
}
|
||||
|
||||
type mcpClientInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
type mcpInitializeResponse struct {
|
||||
ProtocolVersion string `json:"protocolVersion"`
|
||||
Capabilities map[string]interface{} `json:"capabilities"`
|
||||
ServerInfo mcpServerInfo `json:"serverInfo"`
|
||||
}
|
||||
|
||||
type mcpServerInfo struct {
|
||||
Name string `json:"name"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
type mcpListToolsRequest struct{}
|
||||
|
||||
type mcpListToolsResponse struct {
|
||||
Tools []mcpTool `json:"tools"`
|
||||
}
|
||||
|
||||
type mcpTool struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description,omitempty"`
|
||||
InputSchema map[string]interface{} `json:"inputSchema"`
|
||||
}
|
||||
|
||||
type mcpCallToolRequest struct {
|
||||
Name string `json:"name"`
|
||||
Arguments map[string]interface{} `json:"arguments"`
|
||||
}
|
||||
|
||||
type mcpCallToolResponse struct {
|
||||
Content []mcpContent `json:"content"`
|
||||
IsError bool `json:"isError,omitempty"`
|
||||
}
|
||||
|
||||
type mcpContent struct {
|
||||
Type string `json:"type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// NewMCPClient creates a new MCP client for the specified server configuration.
|
||||
// Optional MCPClientOption arguments can be used to customize behavior (e.g., for testing).
|
||||
func NewMCPClient(name, command string, args []string, env map[string]string, opts ...MCPClientOption) *MCPClient {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
client := &MCPClient{
|
||||
name: name,
|
||||
command: command, // Will be resolved after options are applied
|
||||
args: args,
|
||||
env: env,
|
||||
responses: make(map[int64]chan *jsonRPCResponse),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
done: make(chan struct{}),
|
||||
commandResolver: DefaultCommandResolver, // Default, can be overridden
|
||||
}
|
||||
|
||||
// Apply options
|
||||
for _, opt := range opts {
|
||||
opt(client)
|
||||
}
|
||||
|
||||
// Guard against nil resolver (e.g., if WithCommandResolver(nil) was called)
|
||||
if client.commandResolver == nil {
|
||||
client.commandResolver = DefaultCommandResolver
|
||||
}
|
||||
|
||||
// Resolve the command using the configured resolver
|
||||
client.command = client.commandResolver.ResolveForEnvironment(command)
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// Start spawns the MCP server process and initializes communication.
|
||||
//
|
||||
// SECURITY REVIEW: This function executes external processes. Security controls:
|
||||
// - Command must pass validation in MCPManager.validateServerConfig()
|
||||
// - Environment is filtered via buildSecureEnvironment()
|
||||
// - Process runs in isolated process group (Setpgid)
|
||||
// - Context-based cancellation for cleanup
|
||||
func (c *MCPClient) Start() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.cmd != nil {
|
||||
return errors.New("MCP client already started")
|
||||
}
|
||||
|
||||
// Handle commands that might have spaces (like "pnpm dlx")
|
||||
cmdParts := strings.Fields(c.command)
|
||||
var cmdName string
|
||||
var cmdArgs []string
|
||||
|
||||
if len(cmdParts) > 1 {
|
||||
cmdName = cmdParts[0]
|
||||
cmdArgs = append(cmdParts[1:], c.args...)
|
||||
} else {
|
||||
cmdName = c.command
|
||||
cmdArgs = c.args
|
||||
}
|
||||
|
||||
// SECURITY: Create command with context for cancellation control
|
||||
c.cmd = exec.CommandContext(c.ctx, cmdName, cmdArgs...)
|
||||
|
||||
// SECURITY: Apply filtered environment (see buildSecureEnvironment)
|
||||
c.cmd.Env = c.buildSecureEnvironment()
|
||||
|
||||
// SECURITY: Process isolation via process group
|
||||
c.cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Setpgid: true, // Isolate in own process group
|
||||
Pgid: 0, // New process group
|
||||
// Future: Consider adding privilege dropping for root users
|
||||
// Credential: &syscall.Credential{Uid: 65534, Gid: 65534}
|
||||
}
|
||||
|
||||
// Set up pipes for communication.
|
||||
// Each error path explicitly closes previously opened pipes to prevent leaks.
|
||||
// On success, pipes remain open for the lifetime of the MCP client.
|
||||
stdin, err := c.cmd.StdinPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create stdin pipe: %w", err)
|
||||
}
|
||||
c.stdin = stdin
|
||||
|
||||
stdout, err := c.cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
c.stdin.Close()
|
||||
return fmt.Errorf("failed to create stdout pipe: %w", err)
|
||||
}
|
||||
c.stdoutPipe = stdout
|
||||
c.stdout = bufio.NewReader(stdout)
|
||||
|
||||
stderr, err := c.cmd.StderrPipe()
|
||||
if err != nil {
|
||||
c.stdin.Close()
|
||||
stdout.Close()
|
||||
return fmt.Errorf("failed to create stderr pipe: %w", err)
|
||||
}
|
||||
c.stderrPipe = stderr
|
||||
c.stderr = bufio.NewReader(stderr)
|
||||
|
||||
// Start the process
|
||||
if err := c.cmd.Start(); err != nil {
|
||||
c.stdin.Close()
|
||||
stdout.Close()
|
||||
stderr.Close()
|
||||
return fmt.Errorf("failed to start MCP server: %w", err)
|
||||
}
|
||||
|
||||
slog.Info("MCP server started", "name", c.name, "pid", c.cmd.Process.Pid)
|
||||
|
||||
// Start message handling goroutines
|
||||
go c.handleResponses()
|
||||
go c.handleErrors()
|
||||
|
||||
// Check if the process is still running after a brief delay
|
||||
// This catches immediate failures like command not found
|
||||
processCheckDone := make(chan bool, 1)
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// Non-blocking check if process has exited
|
||||
if c.cmd.ProcessState != nil {
|
||||
processCheckDone <- false
|
||||
return
|
||||
}
|
||||
// Try to check process existence without waiting
|
||||
if c.cmd.Process != nil {
|
||||
// On Unix systems, signal 0 can be used to check process existence
|
||||
if err := c.cmd.Process.Signal(syscall.Signal(0)); err != nil {
|
||||
processCheckDone <- false
|
||||
return
|
||||
}
|
||||
}
|
||||
processCheckDone <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case alive := <-processCheckDone:
|
||||
if !alive {
|
||||
// Process died immediately - collect the error
|
||||
waitErr := c.cmd.Wait()
|
||||
c.stdin.Close()
|
||||
stdout.Close()
|
||||
stderr.Close()
|
||||
return fmt.Errorf("MCP server exited immediately: %w", waitErr)
|
||||
}
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
// Process seems to be running, continue
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Initialize performs the MCP handshake sequence
|
||||
func (c *MCPClient) Initialize() error {
|
||||
if err := c.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Add timeout to initialization to prevent hanging
|
||||
initCtx, cancel := context.WithTimeout(c.ctx, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Send initialize request
|
||||
req := mcpInitializeRequest{
|
||||
ProtocolVersion: "2024-11-05",
|
||||
Capabilities: map[string]interface{}{
|
||||
"tools": map[string]interface{}{},
|
||||
},
|
||||
ClientInfo: mcpClientInfo{
|
||||
Name: "ollama",
|
||||
Version: "0.1.0",
|
||||
},
|
||||
}
|
||||
|
||||
var resp mcpInitializeResponse
|
||||
if err := c.callWithContext(initCtx, "initialize", req, &resp); err != nil {
|
||||
return fmt.Errorf("MCP initialize failed: %w", err)
|
||||
}
|
||||
|
||||
// Send initialized notification
|
||||
if err := c.notify("notifications/initialized", nil); err != nil {
|
||||
return fmt.Errorf("MCP initialized notification failed: %w", err)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.initialized = true
|
||||
c.mu.Unlock()
|
||||
|
||||
slog.Info("MCP client initialized", "name", c.name, "server", resp.ServerInfo.Name)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListTools discovers available tools from the MCP server
|
||||
func (c *MCPClient) ListTools() ([]api.Tool, error) {
|
||||
c.mu.RLock()
|
||||
if !c.initialized {
|
||||
c.mu.RUnlock()
|
||||
return nil, errors.New("MCP client not initialized")
|
||||
}
|
||||
|
||||
// Return cached tools if available
|
||||
if len(c.tools) > 0 {
|
||||
tools := make([]api.Tool, len(c.tools))
|
||||
copy(tools, c.tools)
|
||||
c.mu.RUnlock()
|
||||
return tools, nil
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
var resp mcpListToolsResponse
|
||||
if err := c.call("tools/list", mcpListToolsRequest{}, &resp); err != nil {
|
||||
return nil, fmt.Errorf("failed to list MCP tools: %w", err)
|
||||
}
|
||||
|
||||
// Convert MCP tools to Ollama API format
|
||||
tools := make([]api.Tool, 0, len(resp.Tools))
|
||||
for _, mcpTool := range resp.Tools {
|
||||
tool := api.Tool{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: fmt.Sprintf("%s:%s", c.name, mcpTool.Name), // Namespace with server name
|
||||
Description: mcpTool.Description,
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: make(map[string]api.ToolProperty),
|
||||
Required: []string{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Convert input schema to tool parameters
|
||||
if props, ok := mcpTool.InputSchema["properties"].(map[string]interface{}); ok {
|
||||
for propName, propDef := range props {
|
||||
propDefMap, ok := propDef.(map[string]interface{})
|
||||
if !ok {
|
||||
slog.Debug("MCP schema: property definition not a map", "tool", mcpTool.Name, "property", propName)
|
||||
continue
|
||||
}
|
||||
toolProp := api.ToolProperty{
|
||||
Description: getStringFromMap(propDefMap, "description"),
|
||||
}
|
||||
|
||||
if propType, ok := propDefMap["type"].(string); ok {
|
||||
toolProp.Type = api.PropertyType{propType}
|
||||
} else {
|
||||
slog.Debug("MCP schema: property type not a string", "tool", mcpTool.Name, "property", propName)
|
||||
}
|
||||
|
||||
// Preserve items schema for array types (needed for context injection)
|
||||
if items, ok := propDefMap["items"]; ok {
|
||||
toolProp.Items = items
|
||||
}
|
||||
|
||||
tool.Function.Parameters.Properties[propName] = toolProp
|
||||
}
|
||||
} else if mcpTool.InputSchema["properties"] != nil {
|
||||
slog.Debug("MCP schema: properties not a map", "tool", mcpTool.Name)
|
||||
}
|
||||
|
||||
if required, ok := mcpTool.InputSchema["required"].([]interface{}); ok {
|
||||
for _, req := range required {
|
||||
if reqStr, ok := req.(string); ok {
|
||||
tool.Function.Parameters.Required = append(tool.Function.Parameters.Required, reqStr)
|
||||
} else {
|
||||
slog.Debug("MCP schema: required item not a string", "tool", mcpTool.Name)
|
||||
}
|
||||
}
|
||||
} else if mcpTool.InputSchema["required"] != nil {
|
||||
slog.Debug("MCP schema: required not an array", "tool", mcpTool.Name)
|
||||
}
|
||||
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
|
||||
// Cache the tools
|
||||
c.mu.Lock()
|
||||
c.tools = tools
|
||||
c.mu.Unlock()
|
||||
|
||||
slog.Debug("MCP tools discovered", "name", c.name, "count", len(tools))
|
||||
return tools, nil
|
||||
}
|
||||
|
||||
// CallTool executes a tool call via the MCP server
|
||||
func (c *MCPClient) CallTool(name string, args map[string]interface{}) (string, error) {
|
||||
c.mu.RLock()
|
||||
if !c.initialized {
|
||||
c.mu.RUnlock()
|
||||
return "", errors.New("MCP client not initialized")
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
// Remove namespace prefix if present
|
||||
toolName := name
|
||||
if prefix := c.name + ":"; len(name) > len(prefix) && name[:len(prefix)] == prefix {
|
||||
toolName = name[len(prefix):]
|
||||
}
|
||||
|
||||
// Ensure arguments is never nil (MCP protocol requires an object, not undefined)
|
||||
if args == nil {
|
||||
args = make(map[string]interface{})
|
||||
}
|
||||
|
||||
req := mcpCallToolRequest{
|
||||
Name: toolName,
|
||||
Arguments: args,
|
||||
}
|
||||
|
||||
// Debug logging removed
|
||||
|
||||
var resp mcpCallToolResponse
|
||||
|
||||
// Set timeout for tool execution
|
||||
ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := c.callWithContext(ctx, "tools/call", req, &resp); err != nil {
|
||||
return "", fmt.Errorf("MCP tool call failed: %w", err)
|
||||
}
|
||||
|
||||
slog.Debug("MCP tool response", "name", name, "is_error", resp.IsError, "content_count", len(resp.Content))
|
||||
|
||||
if resp.IsError {
|
||||
// Log error without full response to avoid exposing sensitive data
|
||||
slog.Error("MCP tool execution error", "name", name, "content_count", len(resp.Content))
|
||||
return "", fmt.Errorf("MCP tool returned error")
|
||||
}
|
||||
|
||||
// Concatenate all text content
|
||||
var result string
|
||||
for _, content := range resp.Content {
|
||||
if content.Type == "text" {
|
||||
result += content.Text
|
||||
}
|
||||
}
|
||||
|
||||
// Debug logging removed
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetTools returns the list of tools available from this MCP server
|
||||
func (c *MCPClient) GetTools() []api.Tool {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.tools
|
||||
}
|
||||
|
||||
// Close shuts down the MCP client and terminates the server process
|
||||
func (c *MCPClient) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.cmd == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
slog.Info("Shutting down MCP client", "name", c.name)
|
||||
|
||||
// Cancel context to stop goroutines
|
||||
c.cancel()
|
||||
|
||||
// Close stdin to signal shutdown
|
||||
if c.stdin != nil {
|
||||
c.stdin.Close()
|
||||
}
|
||||
|
||||
// Close stdout/stderr pipes to unblock handleResponses/handleErrors goroutines
|
||||
if c.stdoutPipe != nil {
|
||||
c.stdoutPipe.Close()
|
||||
}
|
||||
if c.stderrPipe != nil {
|
||||
c.stderrPipe.Close()
|
||||
}
|
||||
|
||||
// Wait for process to exit gracefully
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- c.cmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
slog.Warn("MCP server exited with error", "name", c.name, "error", err)
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
// Force kill if not responding
|
||||
slog.Warn("Force killing unresponsive MCP server", "name", c.name)
|
||||
c.cmd.Process.Kill()
|
||||
<-done
|
||||
}
|
||||
|
||||
c.cmd = nil
|
||||
c.initialized = false
|
||||
close(c.done)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// call sends a JSON-RPC request and waits for the response
|
||||
func (c *MCPClient) call(method string, params interface{}, result interface{}) error {
|
||||
return c.callWithContext(c.ctx, method, params, result)
|
||||
}
|
||||
|
||||
// callWithContext sends a JSON-RPC request with a custom context
|
||||
func (c *MCPClient) callWithContext(ctx context.Context, method string, params interface{}, result interface{}) error {
|
||||
id := atomic.AddInt64(&c.requestID, 1)
|
||||
|
||||
req := jsonRPCRequest{
|
||||
JSONRPC: "2.0",
|
||||
ID: &id,
|
||||
Method: method,
|
||||
Params: params,
|
||||
}
|
||||
|
||||
// Create response channel
|
||||
respChan := make(chan *jsonRPCResponse, 1)
|
||||
c.mu.Lock()
|
||||
c.responses[id] = respChan
|
||||
c.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
c.mu.Lock()
|
||||
delete(c.responses, id)
|
||||
c.mu.Unlock()
|
||||
close(respChan)
|
||||
}()
|
||||
|
||||
// Send request
|
||||
if err := c.sendRequest(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for response
|
||||
select {
|
||||
case resp := <-respChan:
|
||||
if resp.Error != nil {
|
||||
return fmt.Errorf("JSON-RPC error %d: %s", resp.Error.Code, resp.Error.Message)
|
||||
}
|
||||
|
||||
if result != nil && resp.Result != nil {
|
||||
if err := json.Unmarshal(resp.Result, result); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal response: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-c.done:
|
||||
return errors.New("MCP client closed")
|
||||
}
|
||||
}
|
||||
|
||||
// notify sends a JSON-RPC notification (no response expected)
|
||||
func (c *MCPClient) notify(method string, params interface{}) error {
|
||||
req := jsonRPCRequest{
|
||||
JSONRPC: "2.0",
|
||||
Method: method,
|
||||
Params: params,
|
||||
}
|
||||
|
||||
return c.sendRequest(req)
|
||||
}
|
||||
|
||||
// sendRequest sends a JSON-RPC request over stdin
|
||||
func (c *MCPClient) sendRequest(req jsonRPCRequest) error {
|
||||
if c.stdin == nil {
|
||||
return fmt.Errorf("client not started")
|
||||
}
|
||||
|
||||
data, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
if _, err := c.stdin.Write(append(data, '\n')); err != nil {
|
||||
return fmt.Errorf("failed to write request: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleResponses processes incoming JSON-RPC responses from stdout
|
||||
func (c *MCPClient) handleResponses() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("MCP response handler panic", "name", c.name, "error", r)
|
||||
}
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(c.stdout)
|
||||
// Set a larger buffer to handle long JSON responses
|
||||
scanner.Buffer(make([]byte, 64*1024), 1024*1024) // 64KB initial, 1MB max
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
default:
|
||||
if !scanner.Scan() {
|
||||
if err := scanner.Err(); err != nil {
|
||||
slog.Error("Error reading MCP response", "name", c.name, "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
line := scanner.Bytes()
|
||||
var resp jsonRPCResponse
|
||||
if err := json.Unmarshal(line, &resp); err != nil {
|
||||
// Don't log raw line content - may contain sensitive data
|
||||
slog.Warn("Invalid JSON-RPC response", "name", c.name, "error", err, "length", len(line))
|
||||
continue
|
||||
}
|
||||
|
||||
// Route response to waiting caller
|
||||
if resp.ID != nil {
|
||||
c.mu.RLock()
|
||||
if respChan, exists := c.responses[*resp.ID]; exists {
|
||||
select {
|
||||
case respChan <- &resp:
|
||||
default:
|
||||
slog.Warn("Response channel full", "name", c.name, "id", *resp.ID)
|
||||
}
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleErrors processes stderr output from the MCP server
|
||||
func (c *MCPClient) handleErrors() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
slog.Error("MCP error handler panic", "name", c.name, "error", r)
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.done:
|
||||
return
|
||||
default:
|
||||
line, isPrefix, err := c.stderr.ReadLine()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
slog.Error("Error reading MCP stderr", "name", c.name, "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if isPrefix {
|
||||
slog.Warn("MCP stderr line too long, truncated", "name", c.name)
|
||||
}
|
||||
|
||||
// Truncate stderr to avoid logging excessive/sensitive output
|
||||
msg := string(line)
|
||||
if len(msg) > 200 {
|
||||
msg = msg[:200] + "...(truncated)"
|
||||
}
|
||||
slog.Debug("MCP server stderr", "name", c.name, "message", msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// buildSecureEnvironment creates a filtered environment for the MCP server.
|
||||
//
|
||||
// SECURITY REVIEW: This is a critical security function. It controls what
|
||||
// environment variables are passed to MCP server processes.
|
||||
//
|
||||
// Defense strategy (defense in depth):
|
||||
// 1. Start with empty environment (not inherited)
|
||||
// 2. Allowlist only known-safe variables
|
||||
// 3. Apply MCPSecurityConfig filtering (blocks credentials)
|
||||
// 4. Sanitize PATH to remove dangerous directories
|
||||
// 5. Add custom env vars only after security checks
|
||||
func (c *MCPClient) buildSecureEnvironment() []string {
|
||||
// SECURITY: Start with empty env, not os.Environ()
|
||||
env := []string{}
|
||||
|
||||
// Get security configuration
|
||||
securityConfig := GetSecurityConfig()
|
||||
|
||||
// SECURITY: Allowlist of safe environment variables.
|
||||
// Only these variables can be passed through from the parent process.
|
||||
allowedVars := map[string]bool{
|
||||
"PATH": true,
|
||||
"HOME": true,
|
||||
"USER": true,
|
||||
"LANG": true,
|
||||
"LC_ALL": true,
|
||||
"LC_CTYPE": true,
|
||||
"TZ": true,
|
||||
"TMPDIR": true,
|
||||
"TEMP": true,
|
||||
"TMP": true,
|
||||
"TERM": true,
|
||||
"PYTHONPATH": true,
|
||||
"NODE_PATH": true,
|
||||
"DISPLAY": true, // For GUI applications
|
||||
"EDITOR": true, // For text editing
|
||||
"SHELL": false, // Explicitly blocked - could enable shell escapes
|
||||
}
|
||||
|
||||
// Filter existing environment variables
|
||||
for _, e := range os.Environ() {
|
||||
parts := strings.SplitN(e, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
key := parts[0]
|
||||
value := parts[1]
|
||||
|
||||
// Check if variable should be filtered based on security config
|
||||
if securityConfig.ShouldFilterEnvironmentVar(key) {
|
||||
slog.Debug("MCP: filtered credential env var", "name", c.name, "key", key)
|
||||
continue
|
||||
}
|
||||
|
||||
// Only include explicitly allowed variables
|
||||
if allowed, exists := allowedVars[key]; exists && allowed {
|
||||
env = append(env, fmt.Sprintf("%s=%s", key, value))
|
||||
}
|
||||
}
|
||||
|
||||
// Add custom environment variables from server config.
|
||||
// NOTE: Custom vars only check the blocklist, not the allowlist. This is intentional:
|
||||
// inherited vars use strict allowlist (defense in depth), but custom vars are explicitly
|
||||
// configured by the user/admin, so they're trusted if not in the blocklist.
|
||||
for key, value := range c.env {
|
||||
if securityConfig.ShouldFilterEnvironmentVar(key) {
|
||||
slog.Debug("MCP: blocked custom env var", "name", c.name, "key", key)
|
||||
continue
|
||||
}
|
||||
env = append(env, fmt.Sprintf("%s=%s", key, value))
|
||||
}
|
||||
|
||||
// Set a restricted PATH if not already set
|
||||
if !hasEnvVar(env, "PATH") {
|
||||
env = append(env, "PATH=/usr/local/bin:/usr/bin:/bin")
|
||||
}
|
||||
|
||||
return env
|
||||
}
|
||||
|
||||
// getStringFromMap safely extracts a string value from a map
|
||||
func getStringFromMap(m map[string]interface{}, key string) string {
|
||||
if val, ok := m[key].(string); ok {
|
||||
return val
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// hasEnvVar checks if an environment variable exists in the env slice
|
||||
func hasEnvVar(env []string, key string) bool {
|
||||
prefix := key + "="
|
||||
for _, e := range env {
|
||||
if strings.HasPrefix(e, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Loading…
Reference in New Issue