484 lines
13 KiB
Go
484 lines
13 KiB
Go
package server
|
|
|
|
import (
|
|
"fmt"
|
|
"log/slog"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
)
|
|
|
|
// MCPManager manages multiple MCP server connections and provides tool execution services
|
|
type MCPManager struct {
|
|
mu sync.RWMutex
|
|
clients map[string]*MCPClient
|
|
toolRouting map[string]string // tool name -> client name mapping
|
|
maxClients int
|
|
}
|
|
|
|
// MCPServerConfig is imported from api package
|
|
|
|
// ToolResult represents the result of a tool execution
|
|
type ToolResult struct {
|
|
Content string
|
|
Error error
|
|
}
|
|
|
|
// ExecutionPlan represents the execution strategy for a set of tool calls
|
|
type ExecutionPlan struct {
|
|
RequiresSequential bool
|
|
Groups [][]int // Groups of tool indices that can run in parallel
|
|
Reason string // Explanation of why this plan was chosen
|
|
}
|
|
|
|
// NewMCPManager creates a new MCP manager
|
|
func NewMCPManager(maxClients int) *MCPManager {
|
|
return &MCPManager{
|
|
clients: make(map[string]*MCPClient),
|
|
toolRouting: make(map[string]string),
|
|
maxClients: maxClients,
|
|
}
|
|
}
|
|
|
|
// AddServer adds a new MCP server to the manager
|
|
func (m *MCPManager) AddServer(config api.MCPServerConfig) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if len(m.clients) >= m.maxClients {
|
|
return fmt.Errorf("maximum number of MCP servers reached (%d)", m.maxClients)
|
|
}
|
|
|
|
if _, exists := m.clients[config.Name]; exists {
|
|
return fmt.Errorf("MCP server '%s' already exists", config.Name)
|
|
}
|
|
|
|
// Validate server configuration for security
|
|
if err := m.validateServerConfig(config); err != nil {
|
|
return fmt.Errorf("invalid MCP server configuration: %w", err)
|
|
}
|
|
|
|
// Create and initialize the MCP client
|
|
client := NewMCPClient(config.Name, config.Command, config.Args, config.Env)
|
|
|
|
if err := client.Initialize(); err != nil {
|
|
client.Close()
|
|
return fmt.Errorf("failed to initialize MCP server '%s': %w", config.Name, err)
|
|
}
|
|
|
|
// Discover tools
|
|
tools, err := client.ListTools()
|
|
if err != nil {
|
|
client.Close()
|
|
return fmt.Errorf("failed to list tools from MCP server '%s': %w", config.Name, err)
|
|
}
|
|
|
|
// Update tool routing
|
|
for _, tool := range tools {
|
|
m.toolRouting[tool.Function.Name] = config.Name
|
|
}
|
|
|
|
m.clients[config.Name] = client
|
|
|
|
slog.Info("MCP server added", "name", config.Name, "tools", len(tools))
|
|
return nil
|
|
}
|
|
|
|
// RemoveServer removes an MCP server from the manager
|
|
func (m *MCPManager) RemoveServer(name string) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
client, exists := m.clients[name]
|
|
if !exists {
|
|
return fmt.Errorf("MCP server '%s' not found", name)
|
|
}
|
|
|
|
// Remove tool routing entries
|
|
for toolName, clientName := range m.toolRouting {
|
|
if clientName == name {
|
|
delete(m.toolRouting, toolName)
|
|
}
|
|
}
|
|
|
|
// Close the client
|
|
if err := client.Close(); err != nil {
|
|
slog.Warn("Error closing MCP client", "name", name, "error", err)
|
|
}
|
|
|
|
delete(m.clients, name)
|
|
|
|
slog.Info("MCP server removed", "name", name)
|
|
return nil
|
|
}
|
|
|
|
// GetAllTools returns all available tools from all MCP servers
|
|
func (m *MCPManager) GetAllTools() []api.Tool {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
var allTools []api.Tool
|
|
|
|
for _, client := range m.clients {
|
|
tools, err := client.ListTools()
|
|
if err != nil {
|
|
slog.Warn("Failed to get tools from MCP client", "name", client.name, "error", err)
|
|
continue
|
|
}
|
|
allTools = append(allTools, tools...)
|
|
}
|
|
|
|
return allTools
|
|
}
|
|
|
|
// ExecuteTool executes a single tool call
|
|
func (m *MCPManager) ExecuteTool(toolCall api.ToolCall) ToolResult {
|
|
toolName := toolCall.Function.Name
|
|
|
|
m.mu.RLock()
|
|
clientName, exists := m.toolRouting[toolName]
|
|
if !exists {
|
|
m.mu.RUnlock()
|
|
return ToolResult{Error: fmt.Errorf("tool '%s' not found", toolName)}
|
|
}
|
|
|
|
client, exists := m.clients[clientName]
|
|
if !exists {
|
|
m.mu.RUnlock()
|
|
return ToolResult{Error: fmt.Errorf("MCP client '%s' not found", clientName)}
|
|
}
|
|
m.mu.RUnlock()
|
|
|
|
// Convert arguments to map[string]interface{}
|
|
args := make(map[string]interface{})
|
|
for k, v := range toolCall.Function.Arguments {
|
|
args[k] = v
|
|
}
|
|
|
|
// Execute the tool
|
|
content, err := client.CallTool(toolName, args)
|
|
if err != nil {
|
|
slog.Debug("MCP tool execution failed", "tool", toolName, "client", clientName)
|
|
} else {
|
|
slog.Debug("MCP tool executed", "tool", toolName, "client", clientName, "result_length", len(content))
|
|
}
|
|
return ToolResult{
|
|
Content: content,
|
|
Error: err,
|
|
}
|
|
}
|
|
|
|
// AnalyzeExecutionPlan analyzes tool calls to determine optimal execution strategy
|
|
func (m *MCPManager) AnalyzeExecutionPlan(toolCalls []api.ToolCall) ExecutionPlan {
|
|
if len(toolCalls) <= 1 {
|
|
return ExecutionPlan{
|
|
RequiresSequential: false,
|
|
Groups: [][]int{{0}},
|
|
Reason: "Single tool call",
|
|
}
|
|
}
|
|
|
|
// Analyze tool patterns for dependencies
|
|
hasWriteOperations := false
|
|
hasReadOperations := false
|
|
fileTargets := make(map[string][]int) // Track which tools operate on which files
|
|
|
|
for i, toolCall := range toolCalls {
|
|
toolName := toolCall.Function.Name
|
|
args := toolCall.Function.Arguments
|
|
|
|
// Check for file operations
|
|
if strings.Contains(toolName, "write") || strings.Contains(toolName, "create") ||
|
|
strings.Contains(toolName, "edit") || strings.Contains(toolName, "append") {
|
|
hasWriteOperations = true
|
|
|
|
// Try to extract file path from arguments
|
|
if pathArg, exists := args["path"]; exists {
|
|
if path, ok := pathArg.(string); ok {
|
|
fileTargets[path] = append(fileTargets[path], i)
|
|
}
|
|
} else if fileArg, exists := args["file"]; exists {
|
|
if file, ok := fileArg.(string); ok {
|
|
fileTargets[file] = append(fileTargets[file], i)
|
|
}
|
|
}
|
|
}
|
|
|
|
if strings.Contains(toolName, "read") || strings.Contains(toolName, "list") ||
|
|
strings.Contains(toolName, "get") {
|
|
hasReadOperations = true
|
|
|
|
// Try to extract file path from arguments
|
|
if pathArg, exists := args["path"]; exists {
|
|
if path, ok := pathArg.(string); ok {
|
|
fileTargets[path] = append(fileTargets[path], i)
|
|
}
|
|
} else if fileArg, exists := args["file"]; exists {
|
|
if file, ok := fileArg.(string); ok {
|
|
fileTargets[file] = append(fileTargets[file], i)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Determine if sequential execution is needed
|
|
requiresSequential := false
|
|
reason := "Can execute in parallel"
|
|
|
|
// Check for file operation dependencies
|
|
if hasWriteOperations && hasReadOperations {
|
|
requiresSequential = true
|
|
reason = "Mixed read and write operations detected"
|
|
}
|
|
|
|
// Check for operations on the same file
|
|
for file, indices := range fileTargets {
|
|
if len(indices) > 1 {
|
|
requiresSequential = true
|
|
reason = fmt.Sprintf("Multiple operations on the same file: %s", file)
|
|
break
|
|
}
|
|
}
|
|
|
|
// Check for explicit ordering patterns in tool names
|
|
for i := 0; i < len(toolCalls)-1; i++ {
|
|
curr := toolCalls[i].Function.Name
|
|
next := toolCalls[i+1].Function.Name
|
|
|
|
// Common patterns that suggest ordering
|
|
if (strings.Contains(curr, "create") && strings.Contains(next, "read")) ||
|
|
(strings.Contains(curr, "write") && strings.Contains(next, "read")) ||
|
|
(strings.Contains(curr, "1") && strings.Contains(next, "2")) ||
|
|
(strings.Contains(curr, "first") && strings.Contains(next, "second")) ||
|
|
(strings.Contains(curr, "init") && strings.Contains(next, "use")) {
|
|
requiresSequential = true
|
|
reason = "Tool names suggest sequential dependency"
|
|
break
|
|
}
|
|
}
|
|
|
|
// Build execution groups
|
|
var groups [][]int
|
|
if requiresSequential {
|
|
// Each tool in its own group for sequential execution
|
|
for i := range toolCalls {
|
|
groups = append(groups, []int{i})
|
|
}
|
|
} else {
|
|
// All tools in one group for parallel execution
|
|
group := make([]int, len(toolCalls))
|
|
for i := range toolCalls {
|
|
group[i] = i
|
|
}
|
|
groups = [][]int{group}
|
|
}
|
|
|
|
plan := ExecutionPlan{
|
|
RequiresSequential: requiresSequential,
|
|
Groups: groups,
|
|
Reason: reason,
|
|
}
|
|
|
|
slog.Debug("Execution plan analyzed",
|
|
"sequential", requiresSequential,
|
|
"reason", reason,
|
|
"tool_count", len(toolCalls))
|
|
|
|
return plan
|
|
}
|
|
|
|
// ExecuteWithPlan executes tool calls according to the execution plan
|
|
func (m *MCPManager) ExecuteWithPlan(toolCalls []api.ToolCall, plan ExecutionPlan) []ToolResult {
|
|
results := make([]ToolResult, len(toolCalls))
|
|
|
|
for _, group := range plan.Groups {
|
|
if len(group) == 1 {
|
|
// Single tool, execute directly
|
|
idx := group[0]
|
|
results[idx] = m.ExecuteTool(toolCalls[idx])
|
|
} else {
|
|
// Multiple tools in group, execute in parallel
|
|
var wg sync.WaitGroup
|
|
for _, idx := range group {
|
|
wg.Add(1)
|
|
go func(i int) {
|
|
defer wg.Done()
|
|
results[i] = m.ExecuteTool(toolCalls[i])
|
|
}(idx)
|
|
}
|
|
wg.Wait()
|
|
}
|
|
}
|
|
|
|
return results
|
|
}
|
|
|
|
// ExecuteToolsParallel executes multiple tool calls in parallel
|
|
func (m *MCPManager) ExecuteToolsParallel(toolCalls []api.ToolCall) []ToolResult {
|
|
if len(toolCalls) == 0 {
|
|
return nil
|
|
}
|
|
|
|
results := make([]ToolResult, len(toolCalls))
|
|
|
|
// For single tool call, execute directly
|
|
if len(toolCalls) == 1 {
|
|
results[0] = m.ExecuteTool(toolCalls[0])
|
|
return results
|
|
}
|
|
|
|
// Execute multiple tools in parallel
|
|
var wg sync.WaitGroup
|
|
for i, toolCall := range toolCalls {
|
|
wg.Add(1)
|
|
go func(index int, tc api.ToolCall) {
|
|
defer wg.Done()
|
|
results[index] = m.ExecuteTool(tc)
|
|
}(i, toolCall)
|
|
}
|
|
|
|
wg.Wait()
|
|
return results
|
|
}
|
|
|
|
// ExecuteToolsSequential executes multiple tool calls sequentially
|
|
func (m *MCPManager) ExecuteToolsSequential(toolCalls []api.ToolCall) []ToolResult {
|
|
results := make([]ToolResult, len(toolCalls))
|
|
|
|
for i, toolCall := range toolCalls {
|
|
results[i] = m.ExecuteTool(toolCall)
|
|
|
|
// Stop on first error if desired
|
|
if results[i].Error != nil {
|
|
slog.Warn("Tool execution failed", "tool", toolCall.Function.Name, "error", results[i].Error)
|
|
}
|
|
}
|
|
|
|
return results
|
|
}
|
|
|
|
// GetToolClient returns the client name for a given tool
|
|
func (m *MCPManager) GetToolClient(toolName string) (string, bool) {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
clientName, exists := m.toolRouting[toolName]
|
|
return clientName, exists
|
|
}
|
|
|
|
// GetServerNames returns a list of all registered MCP server names
|
|
func (m *MCPManager) GetServerNames() []string {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
names := make([]string, 0, len(m.clients))
|
|
for name := range m.clients {
|
|
names = append(names, name)
|
|
}
|
|
|
|
return names
|
|
}
|
|
|
|
// GetToolDefinition returns the definition for a specific tool
|
|
func (m *MCPManager) GetToolDefinition(serverName, toolName string) *api.Tool {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
client, exists := m.clients[serverName]
|
|
if !exists {
|
|
return nil
|
|
}
|
|
|
|
// Get tools from the client
|
|
tools := client.GetTools()
|
|
for _, tool := range tools {
|
|
if tool.Function.Name == toolName {
|
|
return &tool
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Close shuts down all MCP clients
|
|
func (m *MCPManager) Close() error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
var errs []string
|
|
|
|
for name, client := range m.clients {
|
|
if err := client.Close(); err != nil {
|
|
errs = append(errs, fmt.Sprintf("%s: %v", name, err))
|
|
}
|
|
}
|
|
|
|
// Clear all data
|
|
m.clients = make(map[string]*MCPClient)
|
|
m.toolRouting = make(map[string]string)
|
|
|
|
if len(errs) > 0 {
|
|
return fmt.Errorf("errors closing MCP clients: %s", strings.Join(errs, "; "))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Shutdown is an alias for Close for consistency with registry
|
|
func (m *MCPManager) Shutdown() error {
|
|
slog.Info("Shutting down MCP manager", "clients", len(m.clients))
|
|
return m.Close()
|
|
}
|
|
|
|
// validateServerConfig validates MCP server configuration for security
|
|
func (m *MCPManager) validateServerConfig(config api.MCPServerConfig) error {
|
|
// Validate name
|
|
if config.Name == "" {
|
|
return fmt.Errorf("server name cannot be empty")
|
|
}
|
|
if len(config.Name) > 100 {
|
|
return fmt.Errorf("server name too long (max 100 characters)")
|
|
}
|
|
if strings.ContainsAny(config.Name, "/\\:*?\"<>|") {
|
|
return fmt.Errorf("server name contains invalid characters")
|
|
}
|
|
|
|
// Validate command
|
|
if config.Command == "" {
|
|
return fmt.Errorf("command cannot be empty")
|
|
}
|
|
|
|
// Get security configuration
|
|
securityConfig := GetSecurityConfig()
|
|
|
|
// Check if command is allowed by security policy
|
|
if !securityConfig.IsCommandAllowed(config.Command) {
|
|
return fmt.Errorf("command '%s' is not allowed for security reasons", config.Command)
|
|
}
|
|
|
|
// Validate command path (must be absolute or in PATH)
|
|
if strings.Contains(config.Command, "..") {
|
|
return fmt.Errorf("command path cannot contain '..'")
|
|
}
|
|
|
|
// Validate arguments
|
|
for _, arg := range config.Args {
|
|
if strings.Contains(arg, "..") || strings.HasPrefix(arg, "-") && len(arg) > 50 {
|
|
return fmt.Errorf("suspicious argument detected: %s", arg)
|
|
}
|
|
// Check for shell injection attempts using security config
|
|
if securityConfig.HasShellMetacharacters(arg) {
|
|
return fmt.Errorf("argument contains shell metacharacters: %s", arg)
|
|
}
|
|
}
|
|
|
|
// Validate environment variables
|
|
for key := range config.Env {
|
|
if securityConfig.HasShellMetacharacters(key) {
|
|
return fmt.Errorf("environment variable name contains invalid characters: %s", key)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
} |