diff --git a/api/types.go b/api/types.go index 2434fe478..d6105d84f 100644 --- a/api/types.go +++ b/api/types.go @@ -8,6 +8,7 @@ import ( "math" "os" "reflect" + "sort" "strconv" "strings" "time" @@ -177,6 +178,26 @@ type ChatRequest struct { // each with an associated log probability. Only applies when Logprobs is true. // Valid values are 0-20. Default is 0 (only return the selected token's logprob). TopLogprobs int `json:"top_logprobs,omitempty"` + + // MCPServers is an optional list of MCP (Model Context Protocol) servers + // that provide tools for autonomous execution during the chat. + MCPServers []MCPServerConfig `json:"mcp_servers,omitempty"` + + // MaxToolRounds limits the number of tool execution rounds to prevent + // infinite loops. Defaults to 15 if not specified. + MaxToolRounds int `json:"max_tool_rounds,omitempty"` + + // ToolTimeout sets the timeout for individual tool executions. + // Defaults to 30 seconds if not specified. + ToolTimeout *Duration `json:"tool_timeout,omitempty"` + + // SessionID is an optional session identifier for maintaining MCP state + // across multiple API calls. If not provided, a new session is created. + SessionID string `json:"session_id,omitempty"` + + // ToolsPath is the file path passed via --tools flag in interactive mode. + // Used to generate consistent session IDs for tool continuity. + ToolsPath string `json:"tools_path,omitempty"` } type Tools []Tool @@ -199,11 +220,12 @@ type Message struct { Content string `json:"content"` // Thinking contains the text that was inside thinking tags in the // original model output when ChatRequest.Think is enabled. - Thinking string `json:"thinking,omitempty"` - Images []ImageData `json:"images,omitempty"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - ToolName string `json:"tool_name,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + Thinking string `json:"thinking,omitempty"` + Images []ImageData `json:"images,omitempty"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + ToolResults []ToolResult `json:"tool_results,omitempty"` // MCP tool results + ToolName string `json:"tool_name,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` } func (m *Message) UnmarshalJSON(b []byte) error { @@ -223,6 +245,13 @@ type ToolCall struct { Function ToolCallFunction `json:"function"` } +type ToolResult struct { + ToolName string `json:"tool_name"` + Arguments ToolCallFunctionArguments `json:"arguments,omitempty"` + Content string `json:"content"` + Error string `json:"error,omitempty"` +} + type ToolCallFunction struct { Index int `json:"index"` Name string `json:"name"` @@ -439,18 +468,106 @@ func (tp ToolProperty) ToTypeScriptType() string { } if len(tp.Type) == 1 { - return mapToTypeScriptType(tp.Type[0]) + return tp.mapSingleType(tp.Type[0]) } var types []string for _, t := range tp.Type { - types = append(types, mapToTypeScriptType(t)) + types = append(types, tp.mapSingleType(t)) } return strings.Join(types, " | ") } -// mapToTypeScriptType maps JSON Schema types to TypeScript types -func mapToTypeScriptType(jsonType string) string { +// mapSingleType maps a single JSON Schema type to TypeScript, using Items for arrays +func (tp ToolProperty) mapSingleType(jsonType string) string { + switch jsonType { + case "string": + return "string" + case "number", "integer": + return "number" + case "boolean": + return "boolean" + case "array": + return tp.arrayTypeString() + case "object": + return "Record" + case "null": + return "null" + default: + return "any" + } +} + +// arrayTypeString generates TypeScript type for arrays, using Items schema +func (tp ToolProperty) arrayTypeString() string { + if tp.Items == nil { + return "any[]" + } + + itemsMap, ok := tp.Items.(map[string]interface{}) + if !ok { + return "any[]" + } + + // Check if items has a type + itemType, ok := itemsMap["type"].(string) + if !ok { + return "any[]" + } + + // For object types with properties, generate inline type + if itemType == "object" { + if props, ok := itemsMap["properties"].(map[string]interface{}); ok { + return objectTypeFromProperties(props) + "[]" + } + } + + // Simple type array + return mapSimpleType(itemType) + "[]" +} + +// objectTypeFromProperties generates a TypeScript inline object type from properties +func objectTypeFromProperties(props map[string]interface{}) string { + if len(props) == 0 { + return "Record" + } + + // Collect and sort field names for consistent output + names := make([]string, 0, len(props)) + for name := range props { + names = append(names, name) + } + sort.Strings(names) + + var fields []string + for _, name := range names { + propDef := props[name] + propMap, ok := propDef.(map[string]interface{}) + if !ok { + fields = append(fields, name+": any") + continue + } + + propType := "any" + if t, ok := propMap["type"].(string); ok { + propType = mapSimpleType(t) + // Handle nested arrays + if t == "array" { + if items, ok := propMap["items"].(map[string]interface{}); ok { + if itemType, ok := items["type"].(string); ok { + propType = mapSimpleType(itemType) + "[]" + } + } + } + } + fields = append(fields, name+": "+propType) + } + + return "{" + strings.Join(fields, ", ") + "}" +} + +// mapSimpleType maps JSON Schema types to TypeScript (without recursion) +func mapSimpleType(jsonType string) string { switch jsonType { case "string": return "string" @@ -514,6 +631,21 @@ type Logprob struct { TopLogprobs []TokenLogprob `json:"top_logprobs,omitempty"` } +// MCPServerConfig represents configuration for an MCP (Model Context Protocol) server +type MCPServerConfig struct { + // Name is a unique identifier for the MCP server + Name string `json:"name"` + + // Command is the executable command to start the MCP server + Command string `json:"command"` + + // Args are optional command-line arguments for the MCP server + Args []string `json:"args,omitempty"` + + // Env are optional environment variables for the MCP server + Env map[string]string `json:"env,omitempty"` +} + // ChatResponse is the response returned by [Client.Chat]. Its fields are // similar to [GenerateResponse]. type ChatResponse struct { @@ -853,8 +985,6 @@ type GenerateResponse struct { Metrics - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - DebugInfo *DebugInfo `json:"_debug_info,omitempty"` // Logprobs contains log probability information for the generated tokens, diff --git a/cmd/cmd.go b/cmd/cmd.go index 35074ad2b..48ec95aa9 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -22,6 +22,7 @@ import ( "sort" "strconv" "strings" + "sync" "sync/atomic" "syscall" "time" @@ -49,6 +50,13 @@ import ( const ConnectInstructions = "To sign in, navigate to:\n %s\n\n" +// Tool detection and buffering configuration +const ( + DefaultToolBufferDelay = 500 * time.Millisecond + MinToolBufferDelay = 100 * time.Millisecond + MaxToolBufferDelay = 2 * time.Second +) + // ensureThinkingSupport emits a warning if the model does not advertise thinking support func ensureThinkingSupport(ctx context.Context, client *api.Client, name string) { if name == "" { @@ -416,6 +424,41 @@ func RunHandler(cmd *cobra.Command, args []string) error { opts.KeepAlive = &api.Duration{Duration: d} } + toolsSpec, err := cmd.Flags().GetString("tools") + if err != nil { + return err + } + if toolsSpec != "" { + mcpServers, toolsPath, err := server.GetMCPServersForTools(toolsSpec) + if err != nil { + // If definitions fail to load, fall back to basic filesystem support + fmt.Fprintf(os.Stderr, "Warning: Failed to load MCP definitions: %v\n", err) + mcpServers = []api.MCPServerConfig{ + { + Name: "filesystem", + Command: "npx", + Args: []string{"-y", "@modelcontextprotocol/server-filesystem", toolsPath}, + }, + } + } + + if len(mcpServers) == 0 { + fmt.Fprintf(os.Stderr, "Warning: No MCP servers matched for --tools context\n") + } else { + // Log what servers are being enabled + serverNames := make([]string, 0, len(mcpServers)) + for _, srv := range mcpServers { + serverNames = append(serverNames, srv.Name) + } + fmt.Fprintf(os.Stderr, "Enabling MCP servers: %s\n", strings.Join(serverNames, ", ")) + if toolsPath != "" { + fmt.Fprintf(os.Stderr, "Tools path: %s\n", toolsPath) + } + } + + opts.MCPServers = mcpServers + } + prompts := args[1:] // prepend stdin to the prompt if provided if !term.IsTerminal(int(os.Stdin.Fd())) { @@ -786,6 +829,62 @@ func ListRunningHandler(cmd *cobra.Command, args []string) error { return nil } +func ListToolsHandler(cmd *cobra.Command, args []string) { + servers, err := server.ListMCPServers() + if err != nil { + fmt.Fprintf(os.Stderr, "Error loading MCP servers: %v\n", err) + return + } + + if len(servers) == 0 { + fmt.Println("No MCP servers configured.") + fmt.Println("\nTo add MCP servers, create ~/.ollama/mcp-servers.json") + fmt.Println("See https://ollama.com/docs/mcp for configuration details.") + return + } + + fmt.Println("Available MCP Servers:") + fmt.Println() + + var data [][]string + for _, s := range servers { + autoEnable := string(s.AutoEnable) + if autoEnable == "" { + autoEnable = "never" + } + + capabilities := "-" + if len(s.Capabilities) > 0 { + capabilities = strings.Join(s.Capabilities, ", ") + } + + requiresPath := "no" + if s.RequiresPath { + requiresPath = "yes" + } + + data = append(data, []string{s.Name, s.Description, autoEnable, requiresPath, capabilities}) + } + + table := tablewriter.NewWriter(os.Stdout) + table.SetHeader([]string{"NAME", "DESCRIPTION", "AUTO-ENABLE", "REQUIRES PATH", "CAPABILITIES"}) + table.SetHeaderAlignment(tablewriter.ALIGN_LEFT) + table.SetAlignment(tablewriter.ALIGN_LEFT) + table.SetHeaderLine(false) + table.SetBorder(false) + table.SetNoWhiteSpace(true) + table.SetTablePadding(" ") + table.AppendBulk(data) + table.Render() + + fmt.Println() + fmt.Println("Auto-enable modes:") + fmt.Println(" never - Must be explicitly configured via API") + fmt.Println(" always - Enables whenever --tools is used") + fmt.Println(" with_path - Enables when --tools has a path argument") + fmt.Println(" if_match - Enables when conditions match (e.g., file exists)") +} + func DeleteHandler(cmd *cobra.Command, args []string) error { client, err := api.ClientFromEnvironment() if err != nil { @@ -1189,6 +1288,7 @@ type runOptions struct { Think *api.ThinkValue HideThinking bool ShowConnect bool + MCPServers []api.MCPServerConfig } func (r runOptions) Copy() runOptions { @@ -1218,6 +1318,12 @@ func (r runOptions) Copy() runOptions { think = &cThink } + var mcpServers []api.MCPServerConfig + if r.MCPServers != nil { + mcpServers = make([]api.MCPServerConfig, len(r.MCPServers)) + copy(mcpServers, r.MCPServers) + } + return runOptions{ Model: r.Model, ParentModel: r.ParentModel, @@ -1233,6 +1339,7 @@ func (r runOptions) Copy() runOptions { Think: think, HideThinking: r.HideThinking, ShowConnect: r.ShowConnect, + MCPServers: mcpServers, } } @@ -1241,6 +1348,237 @@ type displayResponseState struct { wordBuffer string } +// StreamingToolDetector maintains state for detecting tool calls across streaming chunks +type StreamingToolDetector struct { + inXMLToolCall bool + xmlStartBuffer strings.Builder + inJSONToolCall bool + jsonBuffer strings.Builder + jsonDepth int + inString bool + escapeNext bool + // tailBuffer holds potential partial tag matches from end of previous chunk + tailBuffer string +} + +// NewStreamingToolDetector creates a new stateful tool detector +func NewStreamingToolDetector() *StreamingToolDetector { + return &StreamingToolDetector{} +} + +// maxTagLength is the longest tag we need to detect across chunk boundaries +const maxTagLength = 12 // len("") + +// Process handles a chunk of streaming content and separates tool calls from regular content +func (s *StreamingToolDetector) Process(chunk string) (displayContent string, hasIncompleteToolCall bool) { + // Prepend any buffered tail from previous chunk + if s.tailBuffer != "" { + chunk = s.tailBuffer + chunk + s.tailBuffer = "" + } + + var result strings.Builder + + for i := 0; i < len(chunk); i++ { + ch := chunk[i] + + // Check if we're near the end and might have a partial tag + // Buffer potential partial matches for next chunk + remainingLen := len(chunk) - i + if !s.inXMLToolCall && !s.inJSONToolCall && remainingLen < maxTagLength { + // Check if remaining content could be start of a tag + remaining := chunk[i:] + if couldBePartialTag(remaining) { + s.tailBuffer = remaining + break // Stop processing, buffer the rest + } + } + + // Handle XML tool calls + if !s.inXMLToolCall && i+11 <= len(chunk) && chunk[i:i+11] == "" { + s.inXMLToolCall = true + s.xmlStartBuffer.Reset() + s.xmlStartBuffer.WriteString("") + i += 10 // Skip past "" + continue + } + + if s.inXMLToolCall { + s.xmlStartBuffer.WriteByte(ch) + if i+12 <= len(chunk) && chunk[i:i+12] == "" { + // Complete XML tool call - skip it entirely + s.inXMLToolCall = false + s.xmlStartBuffer.Reset() + i += 11 // Skip past "" + continue + } + continue + } + + // Handle JSON tool calls + if !s.inJSONToolCall && !s.inXMLToolCall { + // Look for start of JSON tool call pattern + if i+8 <= len(chunk) && chunk[i:i+8] == `{"name":` { + // Check if "arguments" appears nearby (tool call signature) + lookahead := chunk[i:] + if len(lookahead) > 200 { + lookahead = lookahead[:200] + } + if strings.Contains(lookahead, `"arguments":`) { + s.inJSONToolCall = true + s.jsonBuffer.Reset() + s.jsonBuffer.WriteByte(ch) + s.jsonDepth = 1 + s.inString = false + s.escapeNext = false + continue + } + } + } + + if s.inJSONToolCall { + s.jsonBuffer.WriteByte(ch) + + // Track JSON structure to find the end + if s.escapeNext { + s.escapeNext = false + continue + } + + if ch == '\\' && s.inString { + s.escapeNext = true + continue + } + + if ch == '"' && !s.escapeNext { + s.inString = !s.inString + continue + } + + if !s.inString { + if ch == '{' { + s.jsonDepth++ + } else if ch == '}' { + s.jsonDepth-- + if s.jsonDepth == 0 { + // Complete JSON tool call - skip it + s.inJSONToolCall = false + s.jsonBuffer.Reset() + continue + } + } + } + continue + } + + // Regular content + result.WriteByte(ch) + } + + // Check if we have incomplete tool calls or buffered tail that need buffering + hasIncompleteToolCall = s.inXMLToolCall || s.inJSONToolCall || s.tailBuffer != "" + + return result.String(), hasIncompleteToolCall +} + +// couldBePartialTag checks if a string could be the start of a tool call tag +// Only returns true for patterns that are specific enough to likely be tool calls +func couldBePartialTag(s string) bool { + // Require at least 2 chars to avoid false positives on common single chars like < or { + if len(s) < 2 { + return false + } + + // Check for partial XML tags - must start with " MaxToolBufferDelay { + return MaxToolBufferDelay + } + return delay + } + } + return DefaultToolBufferDelay +} + func displayResponse(content string, wordWrap bool, state *displayResponseState) { termWidth, _, _ := term.GetSize(int(os.Stdout.Fd())) if wordWrap && termWidth >= 10 { @@ -1327,6 +1665,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT) + defer signal.Stop(sigChan) go func() { <-sigChan @@ -1339,6 +1678,18 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { var fullResponse strings.Builder var thinkTagOpened bool = false var thinkTagClosed bool = false + var toolCallsDisplayed bool = false + + // Streaming tool detector for better chunk handling + toolDetector := NewStreamingToolDetector() + + // Buffer for accumulating content before display + var contentBuffer strings.Builder + var bufferTimer *time.Timer + var bufferMutex sync.Mutex + + // Get configurable buffer delay + bufferDelay := getToolBufferDelay() role := "assistant" @@ -1370,20 +1721,84 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { thinkTagClosed = true state = &displayResponseState{} } - // purposefully not putting thinking blocks in the response, which would - // only be needed if we later added tool calling to the cli (they get - // filtered out anyway since current models don't expect them unless you're - // about to finish some tool calls) + + // Use stateful tool detector for better streaming chunk handling + displayContent, hasIncompleteToolCall := toolDetector.Process(content) + + // Store full response for context fullResponse.WriteString(content) + // Buffer management based on tool detection + if hasIncompleteToolCall { + // We have an incomplete tool call - buffer the content + bufferMutex.Lock() + contentBuffer.WriteString(displayContent) + + // Cancel any existing timer + if bufferTimer != nil { + bufferTimer.Stop() + } + + // Set a new timer to flush the buffer after a delay + bufferTimer = time.AfterFunc(bufferDelay, func() { + bufferMutex.Lock() + defer bufferMutex.Unlock() + + bufferedContent := contentBuffer.String() + contentBuffer.Reset() + + // Reset tool detector state when flushing + toolDetector.Reset() + + // Only display if there's actual content after filtering + if strings.TrimSpace(bufferedContent) != "" { + displayResponse(bufferedContent, opts.WordWrap, state) + } + }) + bufferMutex.Unlock() + } else { + // No incomplete tool call - display immediately + if strings.TrimSpace(displayContent) != "" { + displayResponse(displayContent, opts.WordWrap, state) + } + } + + // Display tool calls cleanly if detected if response.Message.ToolCalls != nil { toolCalls := response.Message.ToolCalls - if len(toolCalls) > 0 { + if len(toolCalls) > 0 && !toolCallsDisplayed { + // Flush any buffered content before showing tool calls + bufferMutex.Lock() + if contentBuffer.Len() > 0 { + bufferedContent := contentBuffer.String() + contentBuffer.Reset() + if strings.TrimSpace(bufferedContent) != "" { + displayResponse(bufferedContent, opts.WordWrap, state) + } + } + if bufferTimer != nil { + bufferTimer.Stop() + bufferTimer = nil + } + bufferMutex.Unlock() + + // Add newline for clean separation + fmt.Println() fmt.Print(renderToolCalls(toolCalls, false)) + toolCallsDisplayed = true } } - displayResponse(content, opts.WordWrap, state) + // Display tool results if available + if response.Message.ToolResults != nil { + toolResults := response.Message.ToolResults + if len(toolResults) > 0 { + fmt.Print(renderToolResults(toolResults, false)) + fmt.Println() // New line after results + // Reset flag to allow next round's tool calls to be displayed + toolCallsDisplayed = false + } + } return nil } @@ -1393,11 +1808,12 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { } req := &api.ChatRequest{ - Model: opts.Model, - Messages: opts.Messages, - Format: json.RawMessage(opts.Format), - Options: opts.Options, - Think: opts.Think, + Model: opts.Model, + Messages: opts.Messages, + Format: json.RawMessage(opts.Format), + Options: opts.Options, + Think: opts.Think, + MCPServers: opts.MCPServers, } if opts.KeepAlive != nil { @@ -1418,6 +1834,20 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { } return nil, err } + + // Flush any remaining buffered content + bufferMutex.Lock() + if bufferTimer != nil { + bufferTimer.Stop() + } + if contentBuffer.Len() > 0 { + bufferedContent := contentBuffer.String() + contentBuffer.Reset() + if strings.TrimSpace(bufferedContent) != "" && !strings.Contains(bufferedContent, `{"name":`) { + displayResponse(bufferedContent, opts.WordWrap, state) + } + } + bufferMutex.Unlock() if len(opts.Messages) > 0 { fmt.Println() @@ -1437,6 +1867,11 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { } func generate(cmd *cobra.Command, opts runOptions) error { + // Tools/MCP servers require interactive mode (Chat API) + if len(opts.MCPServers) > 0 { + return errors.New("--tools flag requires interactive mode; use 'ollama run --tools ' without piped input") + } + client, err := api.ClientFromEnvironment() if err != nil { return err @@ -1460,6 +1895,7 @@ func generate(cmd *cobra.Command, opts runOptions) error { sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT) + defer signal.Stop(sigChan) go func() { <-sigChan @@ -1491,7 +1927,7 @@ func generate(cmd *cobra.Command, opts runOptions) error { displayResponse(response.Thinking, opts.WordWrap, state) } - if thinkTagOpened && !thinkTagClosed && (content != "" || len(response.ToolCalls) > 0) { + if thinkTagOpened && !thinkTagClosed && content != "" { if !strings.HasSuffix(thinkingContent.String(), "\n") { fmt.Println() } @@ -1503,13 +1939,6 @@ func generate(cmd *cobra.Command, opts runOptions) error { displayResponse(content, opts.WordWrap, state) - if response.ToolCalls != nil { - toolCalls := response.ToolCalls - if len(toolCalls) > 0 { - fmt.Print(renderToolCalls(toolCalls, plainText)) - } - } - return nil } @@ -1704,11 +2133,17 @@ func NewCLI() *cobra.Command { return } + if listTools, _ := cmd.Flags().GetBool("list-tools"); listTools { + ListToolsHandler(cmd, args) + return + } + cmd.Print(cmd.UsageString()) }, } rootCmd.Flags().BoolP("version", "v", false, "Show version information") + rootCmd.Flags().Bool("list-tools", false, "List available MCP servers and their tools") createCmd := &cobra.Command{ Use: "create MODEL", @@ -1754,6 +2189,7 @@ func NewCLI() *cobra.Command { runCmd.Flags().Bool("hidethinking", false, "Hide thinking output (if provided)") runCmd.Flags().Bool("truncate", false, "For embedding models: truncate inputs exceeding context length (default: true). Set --truncate=false to error instead") runCmd.Flags().Int("dimensions", 0, "Truncate output embeddings to specified dimension (embedding models only)") + runCmd.Flags().String("tools", "", "Enable MCP tools (default: all registered servers with current dir, or specify path for filesystem)") stopCmd := &cobra.Command{ Use: "stop MODEL", @@ -1964,15 +2400,101 @@ func renderToolCalls(toolCalls []api.ToolCall, plainText bool) string { out += formatExplanation } for i, toolCall := range toolCalls { - argsAsJSON, err := json.Marshal(toolCall.Function.Arguments) - if err != nil { - return "" - } if i > 0 { out += "\n" } - // all tool calls are unexpected since we don't currently support registering any in the CLI - out += fmt.Sprintf(" Model called a non-existent function '%s()' with arguments: %s", formatValues+toolCall.Function.Name+formatExplanation, formatValues+string(argsAsJSON)+formatExplanation) + // Format arguments in a more readable way + var argsDisplay string + // Arguments is already a map[string]any + // Sort keys for deterministic display order + keys := make([]string, 0, len(toolCall.Function.Arguments)) + for k := range toolCall.Function.Arguments { + keys = append(keys, k) + } + sort.Strings(keys) + var pairs []string + for _, k := range keys { + pairs = append(pairs, fmt.Sprintf("%s: %v", k, toolCall.Function.Arguments[k])) + } + if len(pairs) > 0 { + argsDisplay = strings.Join(pairs, ", ") + } else { + argsDisplay = "(no arguments)" + } + + // Show tool execution in progress with cleaner format + out += fmt.Sprintf("🔧 Executing tool '%s'%s\n", + formatValues+toolCall.Function.Name+formatExplanation, formatExplanation) + out += fmt.Sprintf(" Arguments: %s%s%s\n", + formatValues, argsDisplay, formatExplanation) + } + if !plainText { + out += readline.ColorDefault + } + return out +} + +func renderToolResults(toolResults []api.ToolResult, plainText bool) string { + out := "" + formatExplanation := "" + formatValues := "" + formatError := "" + if !plainText { + formatExplanation = readline.ColorGrey + readline.ColorBold + formatValues = readline.ColorDefault + // Use bold for errors since ColorRed doesn't exist + formatError = readline.ColorBold + out += formatExplanation + } + for i, toolResult := range toolResults { + if i > 0 { + out += "\n" + } + + // Tool name and arguments already shown in renderToolCalls + // Just show the result or error here + if toolResult.Error != "" { + // Parse error for better context + errorMsg := toolResult.Error + // Try to extract meaningful error from MCP errors + if strings.Contains(errorMsg, "MCP tool returned error") { + errorMsg = "Tool execution failed" + } + // Look for specific error patterns + if strings.Contains(toolResult.Error, "Parent directory does not exist") { + errorMsg = "Parent directory does not exist - check path" + } else if strings.Contains(toolResult.Error, "permission denied") { + errorMsg = "Permission denied - insufficient privileges" + } else if strings.Contains(toolResult.Error, "Invalid arguments") { + errorMsg = "Invalid tool arguments provided" + } else if strings.Contains(toolResult.Error, "file not found") { + errorMsg = "File or directory not found" + } + + // Truncate long error messages (rune-safe for UTF-8) + errorRunes := []rune(errorMsg) + if len(errorRunes) > 200 { + errorMsg = string(errorRunes[:197]) + "..." + } + + out += fmt.Sprintf("❌ Error: %s%s%s\n", + formatError, errorMsg, formatExplanation) + } else { + content := toolResult.Content + if strings.TrimSpace(content) == "" { + // Empty result - show a clear indicator + out += fmt.Sprintf("✅ Result: %s(empty)%s\n", + formatValues, formatExplanation) + } else { + // Truncate very long results for display (rune-safe for UTF-8) + runes := []rune(content) + if len(runes) > 200 { + content = string(runes[:197]) + "..." + } + out += fmt.Sprintf("✅ Result:\n%s%s%s\n", + formatValues, content, formatExplanation) + } + } } if !plainText { out += readline.ColorDefault diff --git a/cmd/cmd_test.go b/cmd/cmd_test.go index 7dc3d0093..42ff92118 100644 --- a/cmd/cmd_test.go +++ b/cmd/cmd_test.go @@ -425,6 +425,7 @@ func TestRunEmbeddingModel(t *testing.T) { cmd.Flags().String("format", "", "") cmd.Flags().String("think", "", "") cmd.Flags().Bool("hidethinking", false, "") + cmd.Flags().String("tools", "", "") oldStdout := os.Stdout r, w, _ := os.Pipe() @@ -517,6 +518,7 @@ func TestRunEmbeddingModelWithFlags(t *testing.T) { cmd.Flags().String("format", "", "") cmd.Flags().String("think", "", "") cmd.Flags().Bool("hidethinking", false, "") + cmd.Flags().String("tools", "", "") if err := cmd.Flags().Set("truncate", "true"); err != nil { t.Fatalf("failed to set truncate flag: %v", err) @@ -618,6 +620,7 @@ func TestRunEmbeddingModelPipedInput(t *testing.T) { cmd.Flags().String("format", "", "") cmd.Flags().String("think", "", "") cmd.Flags().Bool("hidethinking", false, "") + cmd.Flags().String("tools", "", "") // Capture stdin oldStdin := os.Stdin @@ -693,6 +696,7 @@ func TestRunEmbeddingModelNoInput(t *testing.T) { cmd.Flags().String("format", "", "") cmd.Flags().String("think", "", "") cmd.Flags().Bool("hidethinking", false, "") + cmd.Flags().String("tools", "", "") cmd.SetOut(io.Discard) cmd.SetErr(io.Discard) diff --git a/docs/api.md b/docs/api.md index 7c32c9597..08dfdcfdb 100644 --- a/docs/api.md +++ b/docs/api.md @@ -510,6 +510,7 @@ Advanced parameters (optional): - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.mdx#valid-parameters-and-values) such as `temperature` - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) +- `mcp_servers`: (experimental) list of MCP server configurations for autonomous tool execution. See [MCP documentation](./mcp.md) ### Tool calling diff --git a/docs/cli.mdx b/docs/cli.mdx index 97810e64a..f936d7d93 100644 --- a/docs/cli.mdx +++ b/docs/cli.mdx @@ -25,6 +25,14 @@ I'm a basic program that prints the famous "Hello, world!" message to the consol ollama run gemma3 "What's in this image? /Users/jmorgan/Desktop/smile.png" ``` +#### MCP tools (experimental) + +``` +ollama run qwen2.5:7b --tools /path/to/directory +``` + +Enables MCP (Model Context Protocol) servers for autonomous tool execution. See [MCP documentation](./mcp.md). + ### Generate embeddings ``` diff --git a/docs/mcp.md b/docs/mcp.md new file mode 100644 index 000000000..6ce3831eb --- /dev/null +++ b/docs/mcp.md @@ -0,0 +1,360 @@ +# MCP (Model Context Protocol) Integration + +Ollama supports the Model Context Protocol (MCP), enabling language models to execute tools and interact with external systems autonomously. + +> **Status**: Experimental + +## Quick Start + +### CLI Usage + +```bash +# Run with filesystem tools restricted to a directory +ollama run qwen2.5:7b --tools /path/to/directory + +# In a git repository, both filesystem AND git tools auto-enable +ollama run qwen2.5:7b --tools /path/to/git-repo + +# Example interaction +>>> List all files in the directory +# Model will automatically execute filesystem:list_directory tool + +>>> Show the git status +# Model will automatically execute git:status tool (if in a git repo) +``` + +### API Usage + +```bash +curl -X POST http://localhost:11434/api/chat \ + -d '{ + "model": "qwen2.5:7b", + "messages": [{"role": "user", "content": "List the files"}], + "mcp_servers": [{ + "name": "filesystem", + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/safe/path"] + }] + }' +``` + +## How It Works + +1. **Model generates tool call** in JSON format +2. **Parser detects** the tool call during streaming +3. **MCP server executes** the tool via JSON-RPC over stdio +4. **Results returned** to model context +5. **Model continues** generation with tool results +6. **Loop repeats** for multi-step tasks (up to 15 rounds) + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ Public API (mcp.go) │ +│ GetMCPServersForTools() - Get servers for --tools flag │ +│ GetMCPManager() - Get manager for explicit configs │ +│ GetMCPManagerForPath() - Get manager for tools path │ +│ ListMCPServers() - List available server definitions │ +└─────────────────────────────────────────────────────────────────┘ + │ + ┌───────────────────┴───────────────────┐ + ▼ ▼ +┌─────────────────────┐ ┌─────────────────────┐ +│ MCPDefinitions │ │ MCPSessionManager │ +│ (mcp_definitions) │ │ (mcp_sessions) │ +│ │ │ │ +│ Static config of │ │ Runtime sessions │ +│ available servers │ │ with connections │ +└─────────────────────┘ └─────────────────────┘ + │ + ▼ + ┌─────────────────────┐ + │ MCPManager │ + │ (mcp_manager) │ + │ │ + │ Multi-client mgmt │ + │ Tool execution │ + └─────────────────────┘ + │ + ▼ + ┌─────────────────────┐ + │ MCPClient │ + │ (mcp_client) │ + │ │ + │ Single JSON-RPC │ + │ connection │ + └─────────────────────┘ +``` + +## Auto-Enable Configuration + +MCP servers can declare when they should automatically enable with the `--tools` flag. + +### Auto-Enable Modes + +| Mode | Description | +|------|-------------| +| `never` | Server must be explicitly configured via API (default) | +| `always` | Server enables whenever `--tools` is used | +| `with_path` | Server enables when `--tools` has a path argument | +| `if_match` | Server enables if conditions in `enable_if` match | + +### Conditional Enabling (if_match) + +The `enable_if` object supports these conditions: + +| Condition | Description | +|-----------|-------------| +| `file_exists` | Check if a file/directory exists in the tools path | +| `env_set` | Check if an environment variable is set (non-empty) | + +### Example Configuration + +Create `~/.ollama/mcp-servers.json`: + +```json +{ + "servers": [ + { + "name": "filesystem", + "description": "File system operations", + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem"], + "requires_path": true, + "auto_enable": "with_path" + }, + { + "name": "git", + "description": "Git repository operations", + "command": "npx", + "args": ["-y", "@cyanheads/git-mcp-server"], + "requires_path": true, + "auto_enable": "if_match", + "enable_if": { + "file_exists": ".git" + } + }, + { + "name": "postgres", + "description": "PostgreSQL database operations", + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-postgres"], + "auto_enable": "if_match", + "enable_if": { + "env_set": "POSTGRES_CONNECTION" + } + }, + { + "name": "python", + "description": "Python code execution", + "command": "python", + "args": ["-m", "mcp_server_python"], + "auto_enable": "never" + } + ] +} +``` + +With this configuration: +- **filesystem** enables for any `--tools /path` +- **git** enables only if `/path/.git` exists +- **postgres** enables only if `POSTGRES_CONNECTION` env var is set +- **python** never auto-enables (must use API explicitly) + +## API Reference + +### Chat Endpoint with MCP + +**Endpoint:** `POST /api/chat` + +**Request:** +```json +{ + "model": "qwen2.5:7b", + "messages": [{"role": "user", "content": "Your prompt"}], + "mcp_servers": [ + { + "name": "server-name", + "command": "executable", + "args": ["arg1", "arg2"], + "env": {"KEY": "value"} + } + ], + "max_tool_rounds": 10, + "tool_timeout": 30000 +} +``` + +**MCP Server Configuration:** + +| Field | Type | Description | +|-------|------|-------------| +| `name` | string | Unique identifier for the server | +| `command` | string | Executable to run | +| `args` | []string | Command-line arguments | +| `env` | map | Environment variables | + +### Server Definition Fields + +| Field | Type | Description | +|-------|------|-------------| +| `name` | string | Unique server identifier | +| `description` | string | Human-readable description | +| `command` | string | Executable to run (npx, python, etc.) | +| `args` | []string | Command-line arguments | +| `env` | map | Environment variables | +| `requires_path` | bool | Whether server needs a path argument | +| `path_arg_index` | int | Where to insert path in args (-1 = append) | +| `capabilities` | []string | List of capability tags | +| `auto_enable` | string | Auto-enable mode (never/always/with_path/if_match) | +| `enable_if` | object | Conditions for if_match mode | + +## Security + +### Implemented Safeguards + +- **Process isolation**: MCP servers run in separate process groups +- **Path restrictions**: Filesystem access limited to specified directories +- **Environment filtering**: Allowlist-based, sensitive variables removed +- **Command validation**: Dangerous commands (shells, sudo, rm) blocked +- **Argument sanitization**: Shell injection prevention +- **Timeouts**: 30-second default with graceful shutdown + +### Blocked Commands + +Shells (`bash`, `sh`, `zsh`), privilege escalation (`sudo`, `su`), destructive commands (`rm`, `dd`), and network tools (`curl`, `wget`, `nc`) are blocked by default. + +## Creating MCP Servers + +MCP servers communicate via JSON-RPC 2.0 over stdin/stdout and must implement three methods: + +### Required Methods + +1. **`initialize`** - Returns server capabilities +2. **`tools/list`** - Returns available tools with schemas +3. **`tools/call`** - Executes a tool and returns results + +### Minimal Python Example + +```python +#!/usr/bin/env python3 +import json +import sys + +def handle_request(request): + method = request.get("method") + request_id = request.get("id") + + if method == "initialize": + return { + "jsonrpc": "2.0", "id": request_id, + "result": { + "protocolVersion": "0.1.0", + "capabilities": {"tools": {}}, + "serverInfo": {"name": "my-server", "version": "1.0.0"} + } + } + + elif method == "tools/list": + return { + "jsonrpc": "2.0", "id": request_id, + "result": { + "tools": [{ + "name": "hello", + "description": "Say hello", + "inputSchema": { + "type": "object", + "properties": { + "name": {"type": "string", "description": "Name to greet"} + }, + "required": ["name"] + } + }] + } + } + + elif method == "tools/call": + name = request["params"]["arguments"].get("name", "World") + return { + "jsonrpc": "2.0", "id": request_id, + "result": { + "content": [{"type": "text", "text": f"Hello, {name}!"}] + } + } + +if __name__ == "__main__": + while True: + line = sys.stdin.readline() + if not line: + break + request = json.loads(line) + response = handle_request(request) + sys.stdout.write(json.dumps(response) + "\n") + sys.stdout.flush() +``` + +### Testing Your Server + +```bash +# Test initialize +echo '{"jsonrpc":"2.0","method":"initialize","params":{},"id":1}' | python3 my_server.py + +# Test tools/list +echo '{"jsonrpc":"2.0","method":"tools/list","params":{},"id":2}' | python3 my_server.py + +# Test tools/call +echo '{"jsonrpc":"2.0","method":"tools/call","params":{"name":"hello","arguments":{"name":"Alice"}},"id":3}' | python3 my_server.py +``` + +## Environment Variables + +| Variable | Description | +|----------|-------------| +| `OLLAMA_DEBUG=INFO` | Enable debug logging | +| `OLLAMA_MCP_TIMEOUT` | Tool execution timeout (ms) | +| `OLLAMA_MCP_SERVERS` | JSON config for MCP servers (overrides file) | +| `OLLAMA_MCP_DISABLE=1` | Disable MCP validation on startup | + +## Supported Models + +MCP works best with models that support tool calling: +- Qwen 2.5 / Qwen 3 series +- Llama 3.1+ with tool support +- Other models with JSON tool call output + +## Limitations + +- **Transport**: stdio only (no HTTP/WebSocket) +- **Protocol**: MCP 1.0 +- **Concurrency**: Max 10 parallel MCP servers +- **Platform**: Linux/macOS (Windows untested) + +## Troubleshooting + +**"Tool not found"** +- Verify MCP server initialized correctly +- Check tool name includes namespace prefix + +**"MCP server failed to initialize"** +- Check command path exists +- Verify Python/Node environment +- Test server manually with JSON input + +**"No MCP servers matched for --tools context"** +- Check auto_enable settings in config +- Verify path exists and conditions match + +**"Access denied"** +- Path outside allowed directories +- Security policy violation + +**Debug mode:** +```bash +OLLAMA_DEBUG=INFO ollama serve +``` + +## Resources + +- [MCP Specification](https://modelcontextprotocol.io/docs) +- [Official MCP Servers](https://github.com/modelcontextprotocol/servers) diff --git a/examples/mcp-servers.json b/examples/mcp-servers.json new file mode 100644 index 000000000..3fde76d23 --- /dev/null +++ b/examples/mcp-servers.json @@ -0,0 +1,59 @@ +{ + "servers": [ + { + "name": "filesystem", + "description": "File system operations with path-based access control", + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem"], + "requires_path": true, + "path_arg_index": -1, + "capabilities": ["read", "write", "list", "search"], + "auto_enable": "with_path" + }, + { + "name": "git", + "description": "Git repository operations", + "command": "npx", + "args": ["-y", "@cyanheads/git-mcp-server"], + "requires_path": true, + "path_arg_index": -1, + "capabilities": ["diff", "log", "status", "branch"], + "auto_enable": "if_match", + "enable_if": { + "file_exists": ".git" + } + }, + { + "name": "postgres", + "description": "PostgreSQL database operations", + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-postgres"], + "env": { + "POSTGRES_CONNECTION": "" + }, + "capabilities": ["query", "schema"], + "auto_enable": "if_match", + "enable_if": { + "env_set": "POSTGRES_CONNECTION" + } + }, + { + "name": "python", + "description": "Python code execution in sandboxed environment", + "command": "python", + "args": ["-m", "mcp_server_python"], + "requires_path": false, + "capabilities": ["execute", "eval"], + "auto_enable": "never" + }, + { + "name": "web", + "description": "Web browsing and content extraction", + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-puppeteer"], + "requires_path": false, + "capabilities": ["browse", "screenshot", "extract"], + "auto_enable": "never" + } + ] +} diff --git a/server/mcp.go b/server/mcp.go new file mode 100644 index 000000000..40c9e97f0 --- /dev/null +++ b/server/mcp.go @@ -0,0 +1,114 @@ +// Package server provides MCP (Model Context Protocol) integration for Ollama. +// +// MCP Architecture: +// +// ┌─────────────────────────────────────────────────────────────────┐ +// │ Public API (this file) │ +// │ GetMCPServersForTools() - Get servers for --tools flag │ +// │ GetMCPManager() - Get manager for explicit configs │ +// │ GetMCPManagerForPath() - Get manager for tools path │ +// │ ListMCPServers() - List available server definitions │ +// └─────────────────────────────────────────────────────────────────┘ +// │ +// ┌───────────────────┴───────────────────┐ +// ▼ ▼ +// ┌─────────────────────┐ ┌─────────────────────┐ +// │ MCPDefinitions │ │ MCPSessionManager │ +// │ (mcp_definitions) │ │ (mcp_sessions) │ +// │ │ │ │ +// │ Static config of │ │ Runtime sessions │ +// │ available servers │ │ with connections │ +// └─────────────────────┘ └─────────────────────┘ +// │ +// ▼ +// ┌─────────────────────┐ +// │ MCPManager │ +// │ (mcp_manager) │ +// │ │ +// │ Multi-client mgmt │ +// │ Tool execution │ +// └─────────────────────┘ +// │ +// ▼ +// ┌─────────────────────┐ +// │ MCPClient │ +// │ (mcp_client) │ +// │ │ +// │ Single JSON-RPC │ +// │ connection │ +// └─────────────────────┘ + +package server + +import ( + "os" + "path/filepath" + "strings" + + "github.com/ollama/ollama/api" +) + +// ============================================================================ +// Public API - Clean interface for external code +// ============================================================================ + +// GetMCPServersForTools returns the MCP server configs that should be enabled +// for the given tools spec. It handles path normalization: +// - "." or "true" → current working directory +// - "~/path" → expands to home directory +// - relative paths → resolved to absolute paths +// +// Returns the server configs and the resolved absolute path. +// On error, still returns the resolved path so callers can implement fallback. +// This is used by the --tools CLI flag. +func GetMCPServersForTools(toolsSpec string) ([]api.MCPServerConfig, string, error) { + // Normalize the tools path first (needed even for fallback on error) + toolsPath := toolsSpec + if toolsSpec == "." || toolsSpec == "true" { + if cwd, err := os.Getwd(); err == nil { + toolsPath = cwd + } + } + + // Expand tilde to home directory + if strings.HasPrefix(toolsPath, "~") { + if home := os.Getenv("HOME"); home != "" { + toolsPath = filepath.Join(home, toolsPath[1:]) + } + } + + // Resolve to absolute path + if absPath, err := filepath.Abs(toolsPath); err == nil { + toolsPath = absPath + } + + // Load definitions + defs, err := LoadMCPDefinitions() + if err != nil { + return nil, toolsPath, err + } + + ctx := AutoEnableContext{ToolsPath: toolsPath} + return defs.GetAutoEnableServers(ctx), toolsPath, nil +} + +// GetMCPManager returns an MCP manager for the given session and configs. +// If a session with matching configs already exists, it will be reused. +func GetMCPManager(sessionID string, configs []api.MCPServerConfig) (*MCPManager, error) { + return GetMCPSessionManager().GetOrCreateManager(sessionID, configs) +} + +// GetMCPManagerForPath returns an MCP manager for servers that auto-enable +// for the given tools path. Used by CLI: `ollama run model --tools /path` +func GetMCPManagerForPath(model string, toolsPath string) (*MCPManager, error) { + return GetMCPSessionManager().GetManagerForToolsPath(model, toolsPath) +} + +// ListMCPServers returns information about all available MCP server definitions. +func ListMCPServers() ([]MCPServerInfo, error) { + defs, err := LoadMCPDefinitions() + if err != nil { + return nil, err + } + return defs.ListServers(), nil +} diff --git a/server/mcp_client.go b/server/mcp_client.go new file mode 100644 index 000000000..5a45e5c52 --- /dev/null +++ b/server/mcp_client.go @@ -0,0 +1,810 @@ +package server + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "os" + "os/exec" + "strings" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/ollama/ollama/api" +) + +// MCPClient manages communication with a single MCP server via JSON-RPC over stdio +type MCPClient struct { + name string + command string + args []string + env map[string]string + + // Process management + cmd *exec.Cmd + stdin io.WriteCloser + stdout *bufio.Reader + stderr *bufio.Reader + + // State + mu sync.RWMutex + initialized bool + tools []api.Tool + requestID int64 + responses map[int64]chan *jsonRPCResponse + + // Lifecycle + ctx context.Context + cancel context.CancelFunc + done chan struct{} + + // Pipe handles (for clean shutdown) + stdoutPipe io.ReadCloser + stderrPipe io.ReadCloser + + // Dependencies (injectable for testing) + commandResolver CommandResolverInterface +} + +// ============================================================================= +// MCPClient Options (Functional Options Pattern) +// ============================================================================= + +// MCPClientOption configures an MCPClient during creation. +type MCPClientOption func(*MCPClient) + +// WithCommandResolver sets a custom command resolver for the client. +// This is primarily useful for testing to avoid system dependencies. +// +// Example: +// +// mockResolver := &MockCommandResolver{...} +// client := NewMCPClient("test", "npx", args, env, WithCommandResolver(mockResolver)) +func WithCommandResolver(resolver CommandResolverInterface) MCPClientOption { + return func(c *MCPClient) { + c.commandResolver = resolver + } +} + +// JSON-RPC 2.0 message types +type jsonRPCRequest struct { + JSONRPC string `json:"jsonrpc"` + ID *int64 `json:"id,omitempty"` + Method string `json:"method"` + Params interface{} `json:"params,omitempty"` +} + +type jsonRPCResponse struct { + JSONRPC string `json:"jsonrpc"` + ID *int64 `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *jsonRPCError `json:"error,omitempty"` +} + +type jsonRPCError struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data,omitempty"` +} + +// MCP protocol message types +type mcpInitializeRequest struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities map[string]interface{} `json:"capabilities"` + ClientInfo mcpClientInfo `json:"clientInfo"` +} + +type mcpClientInfo struct { + Name string `json:"name"` + Version string `json:"version"` +} + +type mcpInitializeResponse struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities map[string]interface{} `json:"capabilities"` + ServerInfo mcpServerInfo `json:"serverInfo"` +} + +type mcpServerInfo struct { + Name string `json:"name"` + Version string `json:"version"` +} + +type mcpListToolsRequest struct{} + +type mcpListToolsResponse struct { + Tools []mcpTool `json:"tools"` +} + +type mcpTool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema map[string]interface{} `json:"inputSchema"` +} + +type mcpCallToolRequest struct { + Name string `json:"name"` + Arguments map[string]interface{} `json:"arguments"` +} + +type mcpCallToolResponse struct { + Content []mcpContent `json:"content"` + IsError bool `json:"isError,omitempty"` +} + +type mcpContent struct { + Type string `json:"type"` + Text string `json:"text"` +} + +// NewMCPClient creates a new MCP client for the specified server configuration. +// Optional MCPClientOption arguments can be used to customize behavior (e.g., for testing). +func NewMCPClient(name, command string, args []string, env map[string]string, opts ...MCPClientOption) *MCPClient { + ctx, cancel := context.WithCancel(context.Background()) + + client := &MCPClient{ + name: name, + command: command, // Will be resolved after options are applied + args: args, + env: env, + responses: make(map[int64]chan *jsonRPCResponse), + ctx: ctx, + cancel: cancel, + done: make(chan struct{}), + commandResolver: DefaultCommandResolver, // Default, can be overridden + } + + // Apply options + for _, opt := range opts { + opt(client) + } + + // Guard against nil resolver (e.g., if WithCommandResolver(nil) was called) + if client.commandResolver == nil { + client.commandResolver = DefaultCommandResolver + } + + // Resolve the command using the configured resolver + client.command = client.commandResolver.ResolveForEnvironment(command) + + return client +} + +// Start spawns the MCP server process and initializes communication. +// +// SECURITY REVIEW: This function executes external processes. Security controls: +// - Command must pass validation in MCPManager.validateServerConfig() +// - Environment is filtered via buildSecureEnvironment() +// - Process runs in isolated process group (Setpgid) +// - Context-based cancellation for cleanup +func (c *MCPClient) Start() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.cmd != nil { + return errors.New("MCP client already started") + } + + // Handle commands that might have spaces (like "pnpm dlx") + cmdParts := strings.Fields(c.command) + var cmdName string + var cmdArgs []string + + if len(cmdParts) > 1 { + cmdName = cmdParts[0] + cmdArgs = append(cmdParts[1:], c.args...) + } else { + cmdName = c.command + cmdArgs = c.args + } + + // SECURITY: Create command with context for cancellation control + c.cmd = exec.CommandContext(c.ctx, cmdName, cmdArgs...) + + // SECURITY: Apply filtered environment (see buildSecureEnvironment) + c.cmd.Env = c.buildSecureEnvironment() + + // SECURITY: Process isolation via process group + c.cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, // Isolate in own process group + Pgid: 0, // New process group + // Future: Consider adding privilege dropping for root users + // Credential: &syscall.Credential{Uid: 65534, Gid: 65534} + } + + // Set up pipes for communication. + // Each error path explicitly closes previously opened pipes to prevent leaks. + // On success, pipes remain open for the lifetime of the MCP client. + stdin, err := c.cmd.StdinPipe() + if err != nil { + return fmt.Errorf("failed to create stdin pipe: %w", err) + } + c.stdin = stdin + + stdout, err := c.cmd.StdoutPipe() + if err != nil { + c.stdin.Close() + return fmt.Errorf("failed to create stdout pipe: %w", err) + } + c.stdoutPipe = stdout + c.stdout = bufio.NewReader(stdout) + + stderr, err := c.cmd.StderrPipe() + if err != nil { + c.stdin.Close() + stdout.Close() + return fmt.Errorf("failed to create stderr pipe: %w", err) + } + c.stderrPipe = stderr + c.stderr = bufio.NewReader(stderr) + + // Start the process + if err := c.cmd.Start(); err != nil { + c.stdin.Close() + stdout.Close() + stderr.Close() + return fmt.Errorf("failed to start MCP server: %w", err) + } + + slog.Info("MCP server started", "name", c.name, "pid", c.cmd.Process.Pid) + + // Start message handling goroutines + go c.handleResponses() + go c.handleErrors() + + // Check if the process is still running after a brief delay + // This catches immediate failures like command not found + processCheckDone := make(chan bool, 1) + go func() { + time.Sleep(100 * time.Millisecond) + // Non-blocking check if process has exited + if c.cmd.ProcessState != nil { + processCheckDone <- false + return + } + // Try to check process existence without waiting + if c.cmd.Process != nil { + // On Unix systems, signal 0 can be used to check process existence + if err := c.cmd.Process.Signal(syscall.Signal(0)); err != nil { + processCheckDone <- false + return + } + } + processCheckDone <- true + }() + + select { + case alive := <-processCheckDone: + if !alive { + // Process died immediately - collect the error + waitErr := c.cmd.Wait() + c.stdin.Close() + stdout.Close() + stderr.Close() + return fmt.Errorf("MCP server exited immediately: %w", waitErr) + } + case <-time.After(200 * time.Millisecond): + // Process seems to be running, continue + } + + return nil +} + +// Initialize performs the MCP handshake sequence +func (c *MCPClient) Initialize() error { + if err := c.Start(); err != nil { + return err + } + + // Add timeout to initialization to prevent hanging + initCtx, cancel := context.WithTimeout(c.ctx, 10*time.Second) + defer cancel() + + // Send initialize request + req := mcpInitializeRequest{ + ProtocolVersion: "2024-11-05", + Capabilities: map[string]interface{}{ + "tools": map[string]interface{}{}, + }, + ClientInfo: mcpClientInfo{ + Name: "ollama", + Version: "0.1.0", + }, + } + + var resp mcpInitializeResponse + if err := c.callWithContext(initCtx, "initialize", req, &resp); err != nil { + return fmt.Errorf("MCP initialize failed: %w", err) + } + + // Send initialized notification + if err := c.notify("notifications/initialized", nil); err != nil { + return fmt.Errorf("MCP initialized notification failed: %w", err) + } + + c.mu.Lock() + c.initialized = true + c.mu.Unlock() + + slog.Info("MCP client initialized", "name", c.name, "server", resp.ServerInfo.Name) + return nil +} + +// ListTools discovers available tools from the MCP server +func (c *MCPClient) ListTools() ([]api.Tool, error) { + c.mu.RLock() + if !c.initialized { + c.mu.RUnlock() + return nil, errors.New("MCP client not initialized") + } + + // Return cached tools if available + if len(c.tools) > 0 { + tools := make([]api.Tool, len(c.tools)) + copy(tools, c.tools) + c.mu.RUnlock() + return tools, nil + } + c.mu.RUnlock() + + var resp mcpListToolsResponse + if err := c.call("tools/list", mcpListToolsRequest{}, &resp); err != nil { + return nil, fmt.Errorf("failed to list MCP tools: %w", err) + } + + // Convert MCP tools to Ollama API format + tools := make([]api.Tool, 0, len(resp.Tools)) + for _, mcpTool := range resp.Tools { + tool := api.Tool{ + Type: "function", + Function: api.ToolFunction{ + Name: fmt.Sprintf("%s:%s", c.name, mcpTool.Name), // Namespace with server name + Description: mcpTool.Description, + Parameters: api.ToolFunctionParameters{ + Type: "object", + Properties: make(map[string]api.ToolProperty), + Required: []string{}, + }, + }, + } + + // Convert input schema to tool parameters + if props, ok := mcpTool.InputSchema["properties"].(map[string]interface{}); ok { + for propName, propDef := range props { + propDefMap, ok := propDef.(map[string]interface{}) + if !ok { + slog.Debug("MCP schema: property definition not a map", "tool", mcpTool.Name, "property", propName) + continue + } + toolProp := api.ToolProperty{ + Description: getStringFromMap(propDefMap, "description"), + } + + if propType, ok := propDefMap["type"].(string); ok { + toolProp.Type = api.PropertyType{propType} + } else { + slog.Debug("MCP schema: property type not a string", "tool", mcpTool.Name, "property", propName) + } + + // Preserve items schema for array types (needed for context injection) + if items, ok := propDefMap["items"]; ok { + toolProp.Items = items + } + + tool.Function.Parameters.Properties[propName] = toolProp + } + } else if mcpTool.InputSchema["properties"] != nil { + slog.Debug("MCP schema: properties not a map", "tool", mcpTool.Name) + } + + if required, ok := mcpTool.InputSchema["required"].([]interface{}); ok { + for _, req := range required { + if reqStr, ok := req.(string); ok { + tool.Function.Parameters.Required = append(tool.Function.Parameters.Required, reqStr) + } else { + slog.Debug("MCP schema: required item not a string", "tool", mcpTool.Name) + } + } + } else if mcpTool.InputSchema["required"] != nil { + slog.Debug("MCP schema: required not an array", "tool", mcpTool.Name) + } + + tools = append(tools, tool) + } + + // Cache the tools + c.mu.Lock() + c.tools = tools + c.mu.Unlock() + + slog.Debug("MCP tools discovered", "name", c.name, "count", len(tools)) + return tools, nil +} + +// CallTool executes a tool call via the MCP server +func (c *MCPClient) CallTool(name string, args map[string]interface{}) (string, error) { + c.mu.RLock() + if !c.initialized { + c.mu.RUnlock() + return "", errors.New("MCP client not initialized") + } + c.mu.RUnlock() + + // Remove namespace prefix if present + toolName := name + if prefix := c.name + ":"; len(name) > len(prefix) && name[:len(prefix)] == prefix { + toolName = name[len(prefix):] + } + + // Ensure arguments is never nil (MCP protocol requires an object, not undefined) + if args == nil { + args = make(map[string]interface{}) + } + + req := mcpCallToolRequest{ + Name: toolName, + Arguments: args, + } + + // Debug logging removed + + var resp mcpCallToolResponse + + // Set timeout for tool execution + ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second) + defer cancel() + + if err := c.callWithContext(ctx, "tools/call", req, &resp); err != nil { + return "", fmt.Errorf("MCP tool call failed: %w", err) + } + + slog.Debug("MCP tool response", "name", name, "is_error", resp.IsError, "content_count", len(resp.Content)) + + if resp.IsError { + // Log error without full response to avoid exposing sensitive data + slog.Error("MCP tool execution error", "name", name, "content_count", len(resp.Content)) + return "", fmt.Errorf("MCP tool returned error") + } + + // Concatenate all text content + var result string + for _, content := range resp.Content { + if content.Type == "text" { + result += content.Text + } + } + + // Debug logging removed + return result, nil +} + +// GetTools returns the list of tools available from this MCP server +func (c *MCPClient) GetTools() []api.Tool { + c.mu.RLock() + defer c.mu.RUnlock() + return c.tools +} + +// Close shuts down the MCP client and terminates the server process +func (c *MCPClient) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.cmd == nil { + return nil + } + + slog.Info("Shutting down MCP client", "name", c.name) + + // Cancel context to stop goroutines + c.cancel() + + // Close stdin to signal shutdown + if c.stdin != nil { + c.stdin.Close() + } + + // Close stdout/stderr pipes to unblock handleResponses/handleErrors goroutines + if c.stdoutPipe != nil { + c.stdoutPipe.Close() + } + if c.stderrPipe != nil { + c.stderrPipe.Close() + } + + // Wait for process to exit gracefully + done := make(chan error, 1) + go func() { + done <- c.cmd.Wait() + }() + + select { + case err := <-done: + if err != nil { + slog.Warn("MCP server exited with error", "name", c.name, "error", err) + } + case <-time.After(5 * time.Second): + // Force kill if not responding + slog.Warn("Force killing unresponsive MCP server", "name", c.name) + c.cmd.Process.Kill() + <-done + } + + c.cmd = nil + c.initialized = false + close(c.done) + + return nil +} + +// call sends a JSON-RPC request and waits for the response +func (c *MCPClient) call(method string, params interface{}, result interface{}) error { + return c.callWithContext(c.ctx, method, params, result) +} + +// callWithContext sends a JSON-RPC request with a custom context +func (c *MCPClient) callWithContext(ctx context.Context, method string, params interface{}, result interface{}) error { + id := atomic.AddInt64(&c.requestID, 1) + + req := jsonRPCRequest{ + JSONRPC: "2.0", + ID: &id, + Method: method, + Params: params, + } + + // Create response channel + respChan := make(chan *jsonRPCResponse, 1) + c.mu.Lock() + c.responses[id] = respChan + c.mu.Unlock() + + defer func() { + c.mu.Lock() + delete(c.responses, id) + c.mu.Unlock() + close(respChan) + }() + + // Send request + if err := c.sendRequest(req); err != nil { + return err + } + + // Wait for response + select { + case resp := <-respChan: + if resp.Error != nil { + return fmt.Errorf("JSON-RPC error %d: %s", resp.Error.Code, resp.Error.Message) + } + + if result != nil && resp.Result != nil { + if err := json.Unmarshal(resp.Result, result); err != nil { + return fmt.Errorf("failed to unmarshal response: %w", err) + } + } + + return nil + case <-ctx.Done(): + return ctx.Err() + case <-c.done: + return errors.New("MCP client closed") + } +} + +// notify sends a JSON-RPC notification (no response expected) +func (c *MCPClient) notify(method string, params interface{}) error { + req := jsonRPCRequest{ + JSONRPC: "2.0", + Method: method, + Params: params, + } + + return c.sendRequest(req) +} + +// sendRequest sends a JSON-RPC request over stdin +func (c *MCPClient) sendRequest(req jsonRPCRequest) error { + if c.stdin == nil { + return fmt.Errorf("client not started") + } + + data, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + if _, err := c.stdin.Write(append(data, '\n')); err != nil { + return fmt.Errorf("failed to write request: %w", err) + } + + return nil +} + +// handleResponses processes incoming JSON-RPC responses from stdout +func (c *MCPClient) handleResponses() { + defer func() { + if r := recover(); r != nil { + slog.Error("MCP response handler panic", "name", c.name, "error", r) + } + }() + + scanner := bufio.NewScanner(c.stdout) + // Set a larger buffer to handle long JSON responses + scanner.Buffer(make([]byte, 64*1024), 1024*1024) // 64KB initial, 1MB max + + for { + select { + case <-c.done: + return + default: + if !scanner.Scan() { + if err := scanner.Err(); err != nil { + slog.Error("Error reading MCP response", "name", c.name, "error", err) + } + return + } + + line := scanner.Bytes() + var resp jsonRPCResponse + if err := json.Unmarshal(line, &resp); err != nil { + // Don't log raw line content - may contain sensitive data + slog.Warn("Invalid JSON-RPC response", "name", c.name, "error", err, "length", len(line)) + continue + } + + // Route response to waiting caller + if resp.ID != nil { + c.mu.RLock() + if respChan, exists := c.responses[*resp.ID]; exists { + select { + case respChan <- &resp: + default: + slog.Warn("Response channel full", "name", c.name, "id", *resp.ID) + } + } + c.mu.RUnlock() + } + } + } +} + +// handleErrors processes stderr output from the MCP server +func (c *MCPClient) handleErrors() { + defer func() { + if r := recover(); r != nil { + slog.Error("MCP error handler panic", "name", c.name, "error", r) + } + }() + + for { + select { + case <-c.done: + return + default: + line, isPrefix, err := c.stderr.ReadLine() + if err != nil { + if err != io.EOF { + slog.Error("Error reading MCP stderr", "name", c.name, "error", err) + } + return + } + if isPrefix { + slog.Warn("MCP stderr line too long, truncated", "name", c.name) + } + + // Truncate stderr to avoid logging excessive/sensitive output + msg := string(line) + if len(msg) > 200 { + msg = msg[:200] + "...(truncated)" + } + slog.Debug("MCP server stderr", "name", c.name, "message", msg) + } + } +} + +// buildSecureEnvironment creates a filtered environment for the MCP server. +// +// SECURITY REVIEW: This is a critical security function. It controls what +// environment variables are passed to MCP server processes. +// +// Defense strategy (defense in depth): +// 1. Start with empty environment (not inherited) +// 2. Allowlist only known-safe variables +// 3. Apply MCPSecurityConfig filtering (blocks credentials) +// 4. Sanitize PATH to remove dangerous directories +// 5. Add custom env vars only after security checks +func (c *MCPClient) buildSecureEnvironment() []string { + // SECURITY: Start with empty env, not os.Environ() + env := []string{} + + // Get security configuration + securityConfig := GetSecurityConfig() + + // SECURITY: Allowlist of safe environment variables. + // Only these variables can be passed through from the parent process. + allowedVars := map[string]bool{ + "PATH": true, + "HOME": true, + "USER": true, + "LANG": true, + "LC_ALL": true, + "LC_CTYPE": true, + "TZ": true, + "TMPDIR": true, + "TEMP": true, + "TMP": true, + "TERM": true, + "PYTHONPATH": true, + "NODE_PATH": true, + "DISPLAY": true, // For GUI applications + "EDITOR": true, // For text editing + "SHELL": false, // Explicitly blocked - could enable shell escapes + } + + // Filter existing environment variables + for _, e := range os.Environ() { + parts := strings.SplitN(e, "=", 2) + if len(parts) != 2 { + continue + } + + key := parts[0] + value := parts[1] + + // Check if variable should be filtered based on security config + if securityConfig.ShouldFilterEnvironmentVar(key) { + slog.Debug("MCP: filtered credential env var", "name", c.name, "key", key) + continue + } + + // Only include explicitly allowed variables + if allowed, exists := allowedVars[key]; exists && allowed { + env = append(env, fmt.Sprintf("%s=%s", key, value)) + } + } + + // Add custom environment variables from server config. + // NOTE: Custom vars only check the blocklist, not the allowlist. This is intentional: + // inherited vars use strict allowlist (defense in depth), but custom vars are explicitly + // configured by the user/admin, so they're trusted if not in the blocklist. + for key, value := range c.env { + if securityConfig.ShouldFilterEnvironmentVar(key) { + slog.Debug("MCP: blocked custom env var", "name", c.name, "key", key) + continue + } + env = append(env, fmt.Sprintf("%s=%s", key, value)) + } + + // Set a restricted PATH if not already set + if !hasEnvVar(env, "PATH") { + env = append(env, "PATH=/usr/local/bin:/usr/bin:/bin") + } + + return env +} + +// getStringFromMap safely extracts a string value from a map +func getStringFromMap(m map[string]interface{}, key string) string { + if val, ok := m[key].(string); ok { + return val + } + return "" +} + +// hasEnvVar checks if an environment variable exists in the env slice +func hasEnvVar(env []string, key string) bool { + prefix := key + "=" + for _, e := range env { + if strings.HasPrefix(e, prefix) { + return true + } + } + return false +} diff --git a/server/mcp_code_api.go b/server/mcp_code_api.go new file mode 100644 index 000000000..0000e173b --- /dev/null +++ b/server/mcp_code_api.go @@ -0,0 +1,81 @@ +package server + +import ( + "fmt" + "log/slog" + "strings" + + "github.com/ollama/ollama/api" +) + +// MCPCodeAPI provides context injection for MCP tools +type MCPCodeAPI struct { + manager *MCPManager +} + +// NewMCPCodeAPI creates a new MCP code API +func NewMCPCodeAPI(manager *MCPManager) *MCPCodeAPI { + return &MCPCodeAPI{ + manager: manager, + } +} + +// GenerateMinimalContext returns essential runtime context for tool usage. +// Tool schemas are already provided via the template's TypeScript rendering, +// so we only need to add runtime-specific info like working directories. +func (m *MCPCodeAPI) GenerateMinimalContext(configs []api.MCPServerConfig) string { + slog.Debug("GenerateMinimalContext called", "configs_count", len(configs)) + + var context strings.Builder + + // Add filesystem working directory if applicable + for _, config := range configs { + if workingDir := m.extractFilesystemPath(config); workingDir != "" { + context.WriteString(fmt.Sprintf(` +Filesystem working directory: %s +All filesystem tool paths must be within this directory. +`, workingDir)) + } + } + + result := context.String() + if result != "" { + slog.Debug("Generated MCP context", "length", len(result)) + } + return result +} + +// extractFilesystemPath extracts the working directory from filesystem server config +func (m *MCPCodeAPI) extractFilesystemPath(config api.MCPServerConfig) string { + isFilesystem := strings.Contains(config.Command, "filesystem") || + (len(config.Args) > 0 && strings.Contains(strings.Join(config.Args, " "), "filesystem")) + + if isFilesystem && len(config.Args) > 0 { + // Path is typically the last argument + return config.Args[len(config.Args)-1] + } + return "" +} + +// InjectContextIntoMessages adds runtime context to the message stream +func (m *MCPCodeAPI) InjectContextIntoMessages(messages []api.Message, configs []api.MCPServerConfig) []api.Message { + context := m.GenerateMinimalContext(configs) + if context == "" { + return messages + } + + // Check if there's already a system message + if len(messages) > 0 && messages[0].Role == "system" { + // Append to existing system message + messages[0].Content += context + } else { + // Create new system message + systemMsg := api.Message{ + Role: "system", + Content: context, + } + messages = append([]api.Message{systemMsg}, messages...) + } + + return messages +} diff --git a/server/mcp_command_resolver.go b/server/mcp_command_resolver.go new file mode 100644 index 000000000..054282bef --- /dev/null +++ b/server/mcp_command_resolver.go @@ -0,0 +1,214 @@ +package server + +import ( + "fmt" + "log/slog" + "os" + "os/exec" + "sync" +) + +// ============================================================================= +// Command Resolver Interface & Default Implementation +// ============================================================================= +// +// SECURITY REVIEW: This component determines which executables are launched +// for MCP servers. Changes here should be reviewed carefully. + +// CommandResolverInterface defines the contract for command resolution. +// Implementations resolve command names (like "npx", "python") to actual +// executable paths, with support for fallbacks and environment overrides. +// +// This interface enables dependency injection for testing MCPClient without +// requiring actual executables to be present on the system. +type CommandResolverInterface interface { + // ResolveCommand finds the actual executable for a command name. + // Returns the resolved path/command or an error if not found. + ResolveCommand(command string) (string, error) + + // ResolveForEnvironment resolves a command, checking environment + // variable overrides first (e.g., OLLAMA_NPX_COMMAND for "npx"). + // Returns the original command as fallback if resolution fails. + ResolveForEnvironment(command string) string +} + +// CommandResolver handles resolving commands to their actual executables +// with fallback detection for different system configurations. +type CommandResolver struct { + mu sync.RWMutex + resolved map[string]string +} + +// Ensure CommandResolver implements CommandResolverInterface +var _ CommandResolverInterface = (*CommandResolver)(nil) + +// NewCommandResolver creates a new command resolver +func NewCommandResolver() *CommandResolver { + return &CommandResolver{ + resolved: make(map[string]string), + } +} + +// DefaultCommandResolver is the shared resolver instance for production use. +// Tests should use WithCommandResolver option instead of modifying this. +var DefaultCommandResolver = NewCommandResolver() + +// ResolveCommand finds the actual executable for a given command +func (cr *CommandResolver) ResolveCommand(command string) (string, error) { + cr.mu.RLock() + if resolved, ok := cr.resolved[command]; ok { + cr.mu.RUnlock() + return resolved, nil + } + cr.mu.RUnlock() + + // Try to resolve the command + var resolved string + var err error + + switch command { + case "npx": + resolved, err = cr.resolveNodePackageManager() + case "python": + resolved, err = cr.resolvePython() + case "node": + resolved, err = cr.resolveNode() + default: + // For other commands, check if they exist as-is + resolved, err = cr.checkCommand(command) + } + + if err != nil { + return "", err + } + + // Cache the resolution + cr.mu.Lock() + cr.resolved[command] = resolved + cr.mu.Unlock() + + return resolved, nil +} + +// resolveNodePackageManager finds an available Node.js package manager +func (cr *CommandResolver) resolveNodePackageManager() (string, error) { + // Priority order for package managers + managers := []struct { + cmd string + args []string + }{ + {"npx", []string{"--version"}}, + {"pnpm", []string{"dlx", "--version"}}, // pnpm equivalent of npx + {"yarn", []string{"dlx", "--version"}}, // yarn 2+ equivalent + {"bunx", []string{"--version"}}, // bun equivalent + } + + for _, mgr := range managers { + if path, err := exec.LookPath(mgr.cmd); err == nil { + // Verify it actually works + cmd := exec.Command(path, mgr.args...) + if err := cmd.Run(); err == nil { + // For pnpm/yarn, we need to return the dlx subcommand + if mgr.cmd == "pnpm" { + return "pnpm dlx", nil + } else if mgr.cmd == "yarn" { + return "yarn dlx", nil + } + return mgr.cmd, nil + } + } + } + + // Check if npm is available and suggest installing npx + if _, err := exec.LookPath("npm"); err == nil { + return "", fmt.Errorf("npx not found but npm is available - install with: npm install -g npx") + } + + return "", fmt.Errorf("no Node.js package manager found (tried npx, pnpm, yarn, bunx)") +} + +// resolvePython finds an available Python interpreter +func (cr *CommandResolver) resolvePython() (string, error) { + // Priority order for Python interpreters + interpreters := []string{ + "python3", // Most Unix systems + "python", // Windows or virtualenv + "python3.12", // Specific versions + "python3.11", + "python3.10", + "python3.9", + "python3.8", + } + + for _, interp := range interpreters { + if path, err := exec.LookPath(interp); err == nil { + // Verify it's Python 3.8+ by checking version + cmd := exec.Command(path, "--version") + output, err := cmd.Output() + if err == nil && len(output) > 0 { + // Basic check that it's Python 3 + if string(output[:7]) == "Python " && output[7] >= '3' { + return interp, nil + } + } + } + } + + return "", fmt.Errorf("no Python 3 interpreter found (tried python3, python, and versioned variants)") +} + +// resolveNode finds the Node.js executable +func (cr *CommandResolver) resolveNode() (string, error) { + // Try different Node.js executable names + nodes := []string{"node", "nodejs"} + + for _, node := range nodes { + if path, err := exec.LookPath(node); err == nil { + // Verify it works + cmd := exec.Command(path, "--version") + if err := cmd.Run(); err == nil { + return node, nil + } + } + } + + return "", fmt.Errorf("Node.js not found (tried node, nodejs)") +} + +// checkCommand checks if a command exists as-is +func (cr *CommandResolver) checkCommand(command string) (string, error) { + if _, err := exec.LookPath(command); err == nil { + return command, nil + } + return "", fmt.Errorf("command not found: %s", command) +} + +// ResolveForEnvironment checks environment variables for command overrides +func (cr *CommandResolver) ResolveForEnvironment(command string) string { + // Allow environment variable overrides + envMap := map[string]string{ + "npx": "OLLAMA_NPX_COMMAND", + "python": "OLLAMA_PYTHON_COMMAND", + "node": "OLLAMA_NODE_COMMAND", + } + + if envVar, ok := envMap[command]; ok { + if override := os.Getenv(envVar); override != "" { + // Validate override against security blocklist + if GetSecurityConfig().IsCommandAllowed(override) { + return override + } + slog.Warn("Environment override blocked by security policy", "var", envVar, "command", override) + } + } + + // Try standard resolution + if resolved, err := cr.ResolveCommand(command); err == nil { + return resolved + } + + // Return original command as fallback + return command +} + +// NOTE: A GetSystemRequirements() method could be added here for diagnostics/status endpoints \ No newline at end of file diff --git a/server/mcp_definitions.go b/server/mcp_definitions.go new file mode 100644 index 000000000..0b2160374 --- /dev/null +++ b/server/mcp_definitions.go @@ -0,0 +1,280 @@ +package server + +import ( + "encoding/json" + "fmt" + "log/slog" + "os" + "path/filepath" + + "github.com/ollama/ollama/api" +) + +// AutoEnableMode determines when a server auto-enables with --tools +type AutoEnableMode string + +const ( + // AutoEnableNever means the server must be explicitly configured (default) + AutoEnableNever AutoEnableMode = "never" + // AutoEnableAlways means the server enables whenever --tools is used + AutoEnableAlways AutoEnableMode = "always" + // AutoEnableWithPath means the server enables when --tools has a path + AutoEnableWithPath AutoEnableMode = "with_path" + // AutoEnableIfMatch means the server enables if EnableIf condition matches + AutoEnableIfMatch AutoEnableMode = "if_match" +) + +// EnableCondition specifies conditions for AutoEnableIfMatch mode +type EnableCondition struct { + // FileExists checks if a specific file exists in the tools path + FileExists string `json:"file_exists,omitempty"` + // EnvSet checks if an environment variable is set (non-empty) + EnvSet string `json:"env_set,omitempty"` +} + +// AutoEnableContext provides context for auto-enable decisions +type AutoEnableContext struct { + // ToolsPath is the path from --tools flag (may be empty) + ToolsPath string + // Env contains environment variables (optional, falls back to os.Getenv) + Env map[string]string +} + +// MCPDefinitions holds available MCP server definitions loaded from configuration. +// This is the static configuration of what servers CAN be used. +type MCPDefinitions struct { + Servers map[string]MCPServerDefinition `json:"servers"` +} + +// MCPServerDefinition defines an available MCP server type +type MCPServerDefinition struct { + Name string `json:"name"` + Description string `json:"description"` + Command string `json:"command"` + Args []string `json:"args,omitempty"` + RequiresPath bool `json:"requires_path,omitempty"` + PathArgIndex int `json:"path_arg_index,omitempty"` + Env map[string]string `json:"env,omitempty"` + Capabilities []string `json:"capabilities,omitempty"` + + // AutoEnable determines when this server auto-enables with --tools + // Default is "never" (must be explicitly configured via API) + AutoEnable AutoEnableMode `json:"auto_enable,omitempty"` + + // EnableIf specifies conditions for AutoEnableIfMatch mode + EnableIf EnableCondition `json:"enable_if,omitempty"` +} + +// MCPServerInfo provides information about an available MCP server +type MCPServerInfo struct { + Name string `json:"name"` + Description string `json:"description"` + RequiresPath bool `json:"requires_path"` + Capabilities []string `json:"capabilities,omitempty"` + AutoEnable AutoEnableMode `json:"auto_enable,omitempty"` +} + +// DefaultMCPServers returns minimal built-in MCP server definitions +// Full examples are provided in examples/mcp-servers.json +func DefaultMCPServers() map[string]MCPServerDefinition { + // Only include filesystem by default - it requires only npx which is commonly available + // Users can add more servers via ~/.ollama/mcp-servers.json + return map[string]MCPServerDefinition{ + "filesystem": { + Name: "filesystem", + Description: "File system operations with path-based access control", + Command: "npx", + Args: []string{"-y", "@modelcontextprotocol/server-filesystem"}, + RequiresPath: true, + PathArgIndex: -1, + Capabilities: []string{"read", "write", "list", "search"}, + AutoEnable: AutoEnableWithPath, // Enable when --tools has a path + }, + } +} + +// LoadMCPDefinitions loads MCP server definitions from configuration files. +// Priority order: user config (~/.ollama) > system config (/etc/ollama) > defaults +func LoadMCPDefinitions() (*MCPDefinitions, error) { + defs := &MCPDefinitions{ + Servers: DefaultMCPServers(), + } + + // Load from user config if exists + configPaths := []string{ + filepath.Join(os.Getenv("HOME"), ".ollama", "mcp-servers.json"), + "/etc/ollama/mcp-servers.json", + "./mcp-servers.json", + } + + for _, path := range configPaths { + if err := defs.LoadFromFile(path); err == nil { + slog.Debug("Loaded MCP definitions", "path", path) + break + } + } + + // Load from environment variable if set + if mcpConfig := os.Getenv("OLLAMA_MCP_SERVERS"); mcpConfig != "" { + if err := defs.LoadFromJSON([]byte(mcpConfig)); err != nil { + return nil, fmt.Errorf("failed to parse OLLAMA_MCP_SERVERS: %w", err) + } + } + + return defs, nil +} + +// LoadFromFile loads MCP server definitions from a JSON file +func (d *MCPDefinitions) LoadFromFile(path string) error { + data, err := os.ReadFile(path) + if err != nil { + return err + } + return d.LoadFromJSON(data) +} + +// LoadFromJSON loads MCP server definitions from JSON data +func (d *MCPDefinitions) LoadFromJSON(data []byte) error { + var config struct { + Servers []MCPServerDefinition `json:"servers"` + } + + if err := json.Unmarshal(data, &config); err != nil { + return err + } + + for _, server := range config.Servers { + d.Servers[server.Name] = server + } + + return nil +} + +// ListServers returns information about all available MCP servers +func (d *MCPDefinitions) ListServers() []MCPServerInfo { + var servers []MCPServerInfo + for _, def := range d.Servers { + servers = append(servers, MCPServerInfo{ + Name: def.Name, + Description: def.Description, + RequiresPath: def.RequiresPath, + Capabilities: def.Capabilities, + AutoEnable: def.AutoEnable, + }) + } + return servers +} + +// GetAutoEnableServers returns servers that should auto-enable for the given context. +// This method checks each server's AutoEnable mode and EnableIf conditions. +func (d *MCPDefinitions) GetAutoEnableServers(ctx AutoEnableContext) []api.MCPServerConfig { + var configs []api.MCPServerConfig + + for _, def := range d.Servers { + if !d.shouldAutoEnable(def, ctx) { + continue + } + + config, err := d.buildConfigForAutoEnable(def, ctx) + if err != nil { + slog.Warn("Failed to build config for auto-enable server", + "name", def.Name, "error", err) + continue + } + + configs = append(configs, config) + } + + return configs +} + +// shouldAutoEnable checks if a server should auto-enable for the given context +func (d *MCPDefinitions) shouldAutoEnable(def MCPServerDefinition, ctx AutoEnableContext) bool { + switch def.AutoEnable { + case AutoEnableNever, "": + return false + + case AutoEnableAlways: + return true + + case AutoEnableWithPath: + return ctx.ToolsPath != "" + + case AutoEnableIfMatch: + return d.checkEnableCondition(def.EnableIf, ctx) + + default: + return false + } +} + +// checkEnableCondition evaluates an EnableCondition against the context +func (d *MCPDefinitions) checkEnableCondition(cond EnableCondition, ctx AutoEnableContext) bool { + // All specified conditions must match (AND logic) + + if cond.FileExists != "" { + checkPath := filepath.Join(ctx.ToolsPath, cond.FileExists) + if _, err := os.Stat(checkPath); err != nil { + return false + } + } + + if cond.EnvSet != "" { + // Check context env first, fall back to os.Getenv + val := "" + if ctx.Env != nil { + val = ctx.Env[cond.EnvSet] + } + if val == "" { + val = os.Getenv(cond.EnvSet) + } + if val == "" { + return false + } + } + + return true +} + +// buildConfigForAutoEnable creates an MCPServerConfig for auto-enabled servers +func (d *MCPDefinitions) buildConfigForAutoEnable(def MCPServerDefinition, ctx AutoEnableContext) (api.MCPServerConfig, error) { + // Resolve the command using the command resolver + resolvedCommand := DefaultCommandResolver.ResolveForEnvironment(def.Command) + + config := api.MCPServerConfig{ + Name: def.Name, + Command: resolvedCommand, + Args: append([]string{}, def.Args...), // Copy args + Env: make(map[string]string), + } + + // Copy environment variables + for k, v := range def.Env { + config.Env[k] = v + } + + // Add path if required + if def.RequiresPath { + if ctx.ToolsPath == "" { + return config, fmt.Errorf("server '%s' requires a path but none provided", def.Name) + } + + // Validate path exists + if _, err := os.Stat(ctx.ToolsPath); err != nil { + return config, fmt.Errorf("invalid path for server '%s': %w", def.Name, err) + } + + // Add path to args at specified position + if def.PathArgIndex < 0 { + config.Args = append(config.Args, ctx.ToolsPath) + } else if def.PathArgIndex <= len(config.Args) { + config.Args = append(config.Args[:def.PathArgIndex], + append([]string{ctx.ToolsPath}, config.Args[def.PathArgIndex:]...)...) + } else { + // PathArgIndex out of bounds, append to end + config.Args = append(config.Args, ctx.ToolsPath) + } + } + + return config, nil +} diff --git a/server/mcp_manager.go b/server/mcp_manager.go new file mode 100644 index 000000000..d7d36c3f1 --- /dev/null +++ b/server/mcp_manager.go @@ -0,0 +1,484 @@ +package server + +import ( + "fmt" + "log/slog" + "strings" + "sync" + + "github.com/ollama/ollama/api" +) + +// MCPManager manages multiple MCP server connections and provides tool execution services +type MCPManager struct { + mu sync.RWMutex + clients map[string]*MCPClient + toolRouting map[string]string // tool name -> client name mapping + maxClients int +} + +// MCPServerConfig is imported from api package + +// ToolResult represents the result of a tool execution +type ToolResult struct { + Content string + Error error +} + +// ExecutionPlan represents the execution strategy for a set of tool calls +type ExecutionPlan struct { + RequiresSequential bool + Groups [][]int // Groups of tool indices that can run in parallel + Reason string // Explanation of why this plan was chosen +} + +// NewMCPManager creates a new MCP manager +func NewMCPManager(maxClients int) *MCPManager { + return &MCPManager{ + clients: make(map[string]*MCPClient), + toolRouting: make(map[string]string), + maxClients: maxClients, + } +} + +// AddServer adds a new MCP server to the manager +func (m *MCPManager) AddServer(config api.MCPServerConfig) error { + m.mu.Lock() + defer m.mu.Unlock() + + if len(m.clients) >= m.maxClients { + return fmt.Errorf("maximum number of MCP servers reached (%d)", m.maxClients) + } + + if _, exists := m.clients[config.Name]; exists { + return fmt.Errorf("MCP server '%s' already exists", config.Name) + } + + // Validate server configuration for security + if err := m.validateServerConfig(config); err != nil { + return fmt.Errorf("invalid MCP server configuration: %w", err) + } + + // Create and initialize the MCP client + client := NewMCPClient(config.Name, config.Command, config.Args, config.Env) + + if err := client.Initialize(); err != nil { + client.Close() + return fmt.Errorf("failed to initialize MCP server '%s': %w", config.Name, err) + } + + // Discover tools + tools, err := client.ListTools() + if err != nil { + client.Close() + return fmt.Errorf("failed to list tools from MCP server '%s': %w", config.Name, err) + } + + // Update tool routing + for _, tool := range tools { + m.toolRouting[tool.Function.Name] = config.Name + } + + m.clients[config.Name] = client + + slog.Info("MCP server added", "name", config.Name, "tools", len(tools)) + return nil +} + +// RemoveServer removes an MCP server from the manager +func (m *MCPManager) RemoveServer(name string) error { + m.mu.Lock() + defer m.mu.Unlock() + + client, exists := m.clients[name] + if !exists { + return fmt.Errorf("MCP server '%s' not found", name) + } + + // Remove tool routing entries + for toolName, clientName := range m.toolRouting { + if clientName == name { + delete(m.toolRouting, toolName) + } + } + + // Close the client + if err := client.Close(); err != nil { + slog.Warn("Error closing MCP client", "name", name, "error", err) + } + + delete(m.clients, name) + + slog.Info("MCP server removed", "name", name) + return nil +} + +// GetAllTools returns all available tools from all MCP servers +func (m *MCPManager) GetAllTools() []api.Tool { + m.mu.RLock() + defer m.mu.RUnlock() + + var allTools []api.Tool + + for _, client := range m.clients { + tools, err := client.ListTools() + if err != nil { + slog.Warn("Failed to get tools from MCP client", "name", client.name, "error", err) + continue + } + allTools = append(allTools, tools...) + } + + return allTools +} + +// ExecuteTool executes a single tool call +func (m *MCPManager) ExecuteTool(toolCall api.ToolCall) ToolResult { + toolName := toolCall.Function.Name + + m.mu.RLock() + clientName, exists := m.toolRouting[toolName] + if !exists { + m.mu.RUnlock() + return ToolResult{Error: fmt.Errorf("tool '%s' not found", toolName)} + } + + client, exists := m.clients[clientName] + if !exists { + m.mu.RUnlock() + return ToolResult{Error: fmt.Errorf("MCP client '%s' not found", clientName)} + } + m.mu.RUnlock() + + // Convert arguments to map[string]interface{} + args := make(map[string]interface{}) + for k, v := range toolCall.Function.Arguments { + args[k] = v + } + + // Execute the tool + content, err := client.CallTool(toolName, args) + if err != nil { + slog.Debug("MCP tool execution failed", "tool", toolName, "client", clientName) + } else { + slog.Debug("MCP tool executed", "tool", toolName, "client", clientName, "result_length", len(content)) + } + return ToolResult{ + Content: content, + Error: err, + } +} + +// AnalyzeExecutionPlan analyzes tool calls to determine optimal execution strategy +func (m *MCPManager) AnalyzeExecutionPlan(toolCalls []api.ToolCall) ExecutionPlan { + if len(toolCalls) <= 1 { + return ExecutionPlan{ + RequiresSequential: false, + Groups: [][]int{{0}}, + Reason: "Single tool call", + } + } + + // Analyze tool patterns for dependencies + hasWriteOperations := false + hasReadOperations := false + fileTargets := make(map[string][]int) // Track which tools operate on which files + + for i, toolCall := range toolCalls { + toolName := toolCall.Function.Name + args := toolCall.Function.Arguments + + // Check for file operations + if strings.Contains(toolName, "write") || strings.Contains(toolName, "create") || + strings.Contains(toolName, "edit") || strings.Contains(toolName, "append") { + hasWriteOperations = true + + // Try to extract file path from arguments + if pathArg, exists := args["path"]; exists { + if path, ok := pathArg.(string); ok { + fileTargets[path] = append(fileTargets[path], i) + } + } else if fileArg, exists := args["file"]; exists { + if file, ok := fileArg.(string); ok { + fileTargets[file] = append(fileTargets[file], i) + } + } + } + + if strings.Contains(toolName, "read") || strings.Contains(toolName, "list") || + strings.Contains(toolName, "get") { + hasReadOperations = true + + // Try to extract file path from arguments + if pathArg, exists := args["path"]; exists { + if path, ok := pathArg.(string); ok { + fileTargets[path] = append(fileTargets[path], i) + } + } else if fileArg, exists := args["file"]; exists { + if file, ok := fileArg.(string); ok { + fileTargets[file] = append(fileTargets[file], i) + } + } + } + } + + // Determine if sequential execution is needed + requiresSequential := false + reason := "Can execute in parallel" + + // Check for file operation dependencies + if hasWriteOperations && hasReadOperations { + requiresSequential = true + reason = "Mixed read and write operations detected" + } + + // Check for operations on the same file + for file, indices := range fileTargets { + if len(indices) > 1 { + requiresSequential = true + reason = fmt.Sprintf("Multiple operations on the same file: %s", file) + break + } + } + + // Check for explicit ordering patterns in tool names + for i := 0; i < len(toolCalls)-1; i++ { + curr := toolCalls[i].Function.Name + next := toolCalls[i+1].Function.Name + + // Common patterns that suggest ordering + if (strings.Contains(curr, "create") && strings.Contains(next, "read")) || + (strings.Contains(curr, "write") && strings.Contains(next, "read")) || + (strings.Contains(curr, "1") && strings.Contains(next, "2")) || + (strings.Contains(curr, "first") && strings.Contains(next, "second")) || + (strings.Contains(curr, "init") && strings.Contains(next, "use")) { + requiresSequential = true + reason = "Tool names suggest sequential dependency" + break + } + } + + // Build execution groups + var groups [][]int + if requiresSequential { + // Each tool in its own group for sequential execution + for i := range toolCalls { + groups = append(groups, []int{i}) + } + } else { + // All tools in one group for parallel execution + group := make([]int, len(toolCalls)) + for i := range toolCalls { + group[i] = i + } + groups = [][]int{group} + } + + plan := ExecutionPlan{ + RequiresSequential: requiresSequential, + Groups: groups, + Reason: reason, + } + + slog.Debug("Execution plan analyzed", + "sequential", requiresSequential, + "reason", reason, + "tool_count", len(toolCalls)) + + return plan +} + +// ExecuteWithPlan executes tool calls according to the execution plan +func (m *MCPManager) ExecuteWithPlan(toolCalls []api.ToolCall, plan ExecutionPlan) []ToolResult { + results := make([]ToolResult, len(toolCalls)) + + for _, group := range plan.Groups { + if len(group) == 1 { + // Single tool, execute directly + idx := group[0] + results[idx] = m.ExecuteTool(toolCalls[idx]) + } else { + // Multiple tools in group, execute in parallel + var wg sync.WaitGroup + for _, idx := range group { + wg.Add(1) + go func(i int) { + defer wg.Done() + results[i] = m.ExecuteTool(toolCalls[i]) + }(idx) + } + wg.Wait() + } + } + + return results +} + +// ExecuteToolsParallel executes multiple tool calls in parallel +func (m *MCPManager) ExecuteToolsParallel(toolCalls []api.ToolCall) []ToolResult { + if len(toolCalls) == 0 { + return nil + } + + results := make([]ToolResult, len(toolCalls)) + + // For single tool call, execute directly + if len(toolCalls) == 1 { + results[0] = m.ExecuteTool(toolCalls[0]) + return results + } + + // Execute multiple tools in parallel + var wg sync.WaitGroup + for i, toolCall := range toolCalls { + wg.Add(1) + go func(index int, tc api.ToolCall) { + defer wg.Done() + results[index] = m.ExecuteTool(tc) + }(i, toolCall) + } + + wg.Wait() + return results +} + +// ExecuteToolsSequential executes multiple tool calls sequentially +func (m *MCPManager) ExecuteToolsSequential(toolCalls []api.ToolCall) []ToolResult { + results := make([]ToolResult, len(toolCalls)) + + for i, toolCall := range toolCalls { + results[i] = m.ExecuteTool(toolCall) + + // Stop on first error if desired + if results[i].Error != nil { + slog.Warn("Tool execution failed", "tool", toolCall.Function.Name, "error", results[i].Error) + } + } + + return results +} + +// GetToolClient returns the client name for a given tool +func (m *MCPManager) GetToolClient(toolName string) (string, bool) { + m.mu.RLock() + defer m.mu.RUnlock() + + clientName, exists := m.toolRouting[toolName] + return clientName, exists +} + +// GetServerNames returns a list of all registered MCP server names +func (m *MCPManager) GetServerNames() []string { + m.mu.RLock() + defer m.mu.RUnlock() + + names := make([]string, 0, len(m.clients)) + for name := range m.clients { + names = append(names, name) + } + + return names +} + +// GetToolDefinition returns the definition for a specific tool +func (m *MCPManager) GetToolDefinition(serverName, toolName string) *api.Tool { + m.mu.RLock() + defer m.mu.RUnlock() + + client, exists := m.clients[serverName] + if !exists { + return nil + } + + // Get tools from the client + tools := client.GetTools() + for _, tool := range tools { + if tool.Function.Name == toolName { + return &tool + } + } + + return nil +} + +// Close shuts down all MCP clients +func (m *MCPManager) Close() error { + m.mu.Lock() + defer m.mu.Unlock() + + var errs []string + + for name, client := range m.clients { + if err := client.Close(); err != nil { + errs = append(errs, fmt.Sprintf("%s: %v", name, err)) + } + } + + // Clear all data + m.clients = make(map[string]*MCPClient) + m.toolRouting = make(map[string]string) + + if len(errs) > 0 { + return fmt.Errorf("errors closing MCP clients: %s", strings.Join(errs, "; ")) + } + + return nil +} + +// Shutdown is an alias for Close for consistency with registry +func (m *MCPManager) Shutdown() error { + slog.Info("Shutting down MCP manager", "clients", len(m.clients)) + return m.Close() +} + +// validateServerConfig validates MCP server configuration for security +func (m *MCPManager) validateServerConfig(config api.MCPServerConfig) error { + // Validate name + if config.Name == "" { + return fmt.Errorf("server name cannot be empty") + } + if len(config.Name) > 100 { + return fmt.Errorf("server name too long (max 100 characters)") + } + if strings.ContainsAny(config.Name, "/\\:*?\"<>|") { + return fmt.Errorf("server name contains invalid characters") + } + + // Validate command + if config.Command == "" { + return fmt.Errorf("command cannot be empty") + } + + // Get security configuration + securityConfig := GetSecurityConfig() + + // Check if command is allowed by security policy + if !securityConfig.IsCommandAllowed(config.Command) { + return fmt.Errorf("command '%s' is not allowed for security reasons", config.Command) + } + + // Validate command path (must be absolute or in PATH) + if strings.Contains(config.Command, "..") { + return fmt.Errorf("command path cannot contain '..'") + } + + // Validate arguments + for _, arg := range config.Args { + if strings.Contains(arg, "..") || strings.HasPrefix(arg, "-") && len(arg) > 50 { + return fmt.Errorf("suspicious argument detected: %s", arg) + } + // Check for shell injection attempts using security config + if securityConfig.HasShellMetacharacters(arg) { + return fmt.Errorf("argument contains shell metacharacters: %s", arg) + } + } + + // Validate environment variables + for key := range config.Env { + if securityConfig.HasShellMetacharacters(key) { + return fmt.Errorf("environment variable name contains invalid characters: %s", key) + } + } + + return nil +} \ No newline at end of file diff --git a/server/mcp_security_config.go b/server/mcp_security_config.go new file mode 100644 index 000000000..467f45836 --- /dev/null +++ b/server/mcp_security_config.go @@ -0,0 +1,152 @@ +package server + +import ( + "path/filepath" + "strings" +) + +// ============================================================================= +// MCP Security Configuration +// ============================================================================= +// +// SECURITY REVIEW: This file defines the security policies that control which +// commands can be executed as MCP servers and what arguments/environment they +// can receive. Changes to this file should be reviewed by security-aware +// maintainers. +// +// Key security surfaces: +// - BlockedCommands: Prevents execution of dangerous system commands +// - BlockedMetacharacters: Prevents shell injection attacks +// - FilteredEnvironmentVars: Prevents credential leakage to MCP servers +// +// Threat model: +// - Malicious MCP server configs attempting to execute system commands +// - Shell injection through tool arguments +// - Credential theft through environment variable access +// +// ============================================================================= + +// MCPSecurityConfig defines security policies for MCP servers +type MCPSecurityConfig struct { + // Commands that are never allowed as MCP servers + BlockedCommands []string + + // Shell metacharacters that are not allowed in arguments + BlockedMetacharacters []string + + // Environment variables that should be filtered + FilteredEnvironmentVars []string +} + +// DefaultSecurityConfig returns the default security configuration. +// +// SECURITY REVIEW: This function defines the default blocklists. Adding or +// removing entries has direct security implications. Consider: +// - Why is a command being added/removed? +// - What attack vectors does it enable/prevent? +// - Are there bypass possibilities (symlinks, PATH manipulation)? +func DefaultSecurityConfig() *MCPSecurityConfig { + return &MCPSecurityConfig{ + // SECURITY: Blocked commands - these can never be used as MCP server commands. + // Rationale: These commands could be used for privilege escalation, + // arbitrary file manipulation, or establishing network connections. + BlockedCommands: []string{ + // Shells - prevent arbitrary command execution + "sh", "bash", "zsh", "fish", "csh", "ksh", "dash", "tcsh", + "cmd", "cmd.exe", "powershell", "powershell.exe", "pwsh", "pwsh.exe", + + // System commands - prevent privilege escalation and system damage + "sudo", "su", "doas", "runas", "pkexec", + "rm", "del", "rmdir", "format", "dd", "shred", + "kill", "killall", "pkill", "shutdown", "reboot", + "systemctl", "service", "init", + + // Network tools - prevent data exfiltration and network attacks + "curl", "wget", "nc", "netcat", "telnet", "ssh", "scp", "sftp", + "nmap", "ping", "traceroute", "dig", "nslookup", + + // Script interpreters - prevent arbitrary code execution + "eval", "exec", "source", ".", + "perl", "ruby", "php", "lua", "tcl", + + // File manipulation - prevent permission/ownership changes + "chmod", "chown", "chgrp", "mount", "umount", + "ln", "mkfifo", "mknod", + + // Package managers - prevent system modification + "apt", "apt-get", "yum", "dnf", "pacman", "zypper", + "brew", "port", "snap", "flatpak", + }, + + // SECURITY: Shell metacharacters - block these in arguments to prevent injection. + // These characters could be used to chain commands or redirect I/O. + BlockedMetacharacters: []string{ + ";", "|", "&", "$(", "`", ">", "<", ">>", "<<", + "||", "&&", "\n", "\r", "$", "!", "*", "?", "=", + }, + + // SECURITY: Environment variables to filter - prevent credential leakage. + // MCP servers should not have access to credentials in the parent environment. + FilteredEnvironmentVars: []string{ + // AWS + "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN", + + // Cloud providers + "GOOGLE_APPLICATION_CREDENTIALS", "AZURE_CLIENT_SECRET", + + // API Keys + "GITHUB_TOKEN", "GITLAB_TOKEN", "OPENAI_API_KEY", "ANTHROPIC_API_KEY", + + // Database + "DATABASE_URL", "DB_PASSWORD", "MYSQL_ROOT_PASSWORD", "POSTGRES_PASSWORD", + + // Authentication + "JWT_SECRET", "SESSION_SECRET", "AUTH_TOKEN", "API_KEY", "API_SECRET", + + // SSH + "SSH_AUTH_SOCK", "SSH_AGENT_PID", + }, + } +} + +// IsCommandAllowed checks if a command is allowed by security policy +func (c *MCPSecurityConfig) IsCommandAllowed(command string) bool { + baseName := filepath.Base(command) + + for _, blocked := range c.BlockedCommands { + if baseName == blocked || strings.HasSuffix(command, "/"+blocked) { + return false + } + } + + return true +} + +// HasShellMetacharacters checks if a string contains shell metacharacters +func (c *MCPSecurityConfig) HasShellMetacharacters(s string) bool { + for _, meta := range c.BlockedMetacharacters { + if strings.Contains(s, meta) { + return true + } + } + return false +} + +// ShouldFilterEnvironmentVar checks if an environment variable should be filtered +func (c *MCPSecurityConfig) ShouldFilterEnvironmentVar(key string) bool { + for _, filtered := range c.FilteredEnvironmentVars { + if key == filtered { + return true + } + } + return false +} + +// Global security config instance +// NOTE: To add user customization, load from ~/.ollama/mcp-security.json and append to these defaults +var globalSecurityConfig = DefaultSecurityConfig() + +// GetSecurityConfig returns the global security configuration +func GetSecurityConfig() *MCPSecurityConfig { + return globalSecurityConfig +} \ No newline at end of file diff --git a/server/mcp_sessions.go b/server/mcp_sessions.go new file mode 100644 index 000000000..3302d94a9 --- /dev/null +++ b/server/mcp_sessions.go @@ -0,0 +1,178 @@ +package server + +import ( + "crypto/sha256" + "encoding/hex" + "log/slog" + "reflect" + "sync" + "time" + + "github.com/ollama/ollama/api" +) + +// MCPSessionManager manages active MCP sessions with automatic cleanup. +// This is the runtime component that tracks active connections. +type MCPSessionManager struct { + mu sync.RWMutex + sessions map[string]*MCPSession // session ID -> session + ttl time.Duration // session timeout + stopCleanup chan struct{} // signals cleanup goroutine to stop +} + +// MCPSession wraps an MCPManager with session metadata +type MCPSession struct { + *MCPManager + lastAccess time.Time + sessionID string + configs []api.MCPServerConfig +} + +var ( + globalSessionManager *MCPSessionManager + sessionManagerOnce sync.Once +) + +// GetMCPSessionManager returns the singleton MCP session manager +func GetMCPSessionManager() *MCPSessionManager { + sessionManagerOnce.Do(func() { + globalSessionManager = &MCPSessionManager{ + sessions: make(map[string]*MCPSession), + ttl: 30 * time.Minute, // Sessions expire after 30 min + stopCleanup: make(chan struct{}), + } + // Start cleanup goroutine + go globalSessionManager.cleanupExpired() + }) + return globalSessionManager +} + +// GetOrCreateManager gets existing or creates new MCP manager for session +func (sm *MCPSessionManager) GetOrCreateManager(sessionID string, configs []api.MCPServerConfig) (*MCPManager, error) { + sm.mu.Lock() + defer sm.mu.Unlock() + + // Check if session exists and configs match + if session, exists := sm.sessions[sessionID]; exists { + if configsMatch(session.configs, configs) { + session.lastAccess = time.Now() + slog.Debug("Reusing existing MCP session", "session", sessionID, "clients", len(session.clients)) + return session.MCPManager, nil + } + // Configs changed, shutdown old session + slog.Info("MCP configs changed, recreating session", "session", sessionID) + session.Shutdown() + delete(sm.sessions, sessionID) + } + + // Create new session + slog.Info("Creating new MCP session", "session", sessionID, "configs", len(configs)) + manager := NewMCPManager(10) + for _, config := range configs { + if err := manager.AddServer(config); err != nil { + slog.Warn("Failed to add MCP server", "name", config.Name, "error", err) + } + } + + sm.sessions[sessionID] = &MCPSession{ + MCPManager: manager, + lastAccess: time.Now(), + sessionID: sessionID, + configs: configs, + } + + return manager, nil +} + +// GetManagerForToolsPath creates a manager for a tools directory path. +// It uses the definitions system to get auto-enabled servers for the path. +func (sm *MCPSessionManager) GetManagerForToolsPath(model string, toolsPath string) (*MCPManager, error) { + // Generate consistent session ID for model + tools path + sessionID := generateToolsSessionID(model, toolsPath) + + // Use definitions to get auto-enabled servers (single source of truth) + defs, err := LoadMCPDefinitions() + if err != nil { + return nil, err + } + + ctx := AutoEnableContext{ToolsPath: toolsPath} + configs := defs.GetAutoEnableServers(ctx) + + return sm.GetOrCreateManager(sessionID, configs) +} + +// cleanupExpired removes expired sessions +func (sm *MCPSessionManager) cleanupExpired() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-sm.stopCleanup: + return + case <-ticker.C: + sm.mu.Lock() + now := time.Now() + for sessionID, session := range sm.sessions { + if now.Sub(session.lastAccess) > sm.ttl { + slog.Info("Cleaning up expired MCP session", "session", sessionID) + session.Shutdown() + delete(sm.sessions, sessionID) + } + } + sm.mu.Unlock() + } + } +} + +// Shutdown closes all sessions and stops the cleanup goroutine +func (sm *MCPSessionManager) Shutdown() { + // Signal cleanup goroutine to stop + close(sm.stopCleanup) + + sm.mu.Lock() + defer sm.mu.Unlock() + + slog.Info("Shutting down MCP session manager", "sessions", len(sm.sessions)) + for sessionID, session := range sm.sessions { + slog.Debug("Shutting down session", "session", sessionID) + session.Shutdown() + } + sm.sessions = make(map[string]*MCPSession) +} + +// configsMatch checks if two sets of MCP configs are equivalent +func configsMatch(a, b []api.MCPServerConfig) bool { + if len(a) != len(b) { + return false + } + return reflect.DeepEqual(a, b) +} + +// generateToolsSessionID creates a consistent session ID for model + tools path +func generateToolsSessionID(model, toolsPath string) string { + h := sha256.New() + h.Write([]byte(model)) + h.Write([]byte(toolsPath)) + return "tools-" + hex.EncodeToString(h.Sum(nil))[:16] +} + +// GenerateSessionID creates a session ID based on the request +func GenerateSessionID(req api.ChatRequest) string { + // If explicit session ID provided + if req.SessionID != "" { + return req.SessionID + } + + // For interactive mode with tools path + if req.ToolsPath != "" { + return generateToolsSessionID(req.Model, req.ToolsPath) + } + + // Default: use request-specific ID (no persistence) + h := sha256.New() + h.Write([]byte(time.Now().Format(time.RFC3339Nano))) + h.Write([]byte(req.Model)) + return "req-" + hex.EncodeToString(h.Sum(nil))[:16] +} diff --git a/server/mcp_test.go b/server/mcp_test.go new file mode 100644 index 000000000..f045c793e --- /dev/null +++ b/server/mcp_test.go @@ -0,0 +1,731 @@ +package server + +// ============================================================================= +// MCP Integration Tests +// ============================================================================= +// +// This file contains tests for the MCP (Model Context Protocol) implementation. +// +// Test Categories: +// +// 1. Client Tests (TestMCPClient*) +// - Client initialization and lifecycle +// - Environment variable filtering +// - Timeout handling +// +// 2. Security Tests (TestDangerous*, TestShellInjection*, TestSecure*) +// - Command blocklist validation +// - Shell metacharacter detection +// - Credential filtering +// +// 3. Manager Tests (TestMCPManager*, TestToolResult*, TestParallel*) +// - Server registration +// - Tool caching +// - Parallel execution +// +// 4. Auto-Enable Tests (TestAutoEnable*) +// - Mode: never, always, with_path, if_match +// - Conditions: file_exists, env_set +// +// Run all MCP tests: +// go test -v ./server/... -run "TestMCP|TestSecure|TestShell|TestTool|TestDanger|TestParallel|TestAutoEnable" +// +// ============================================================================= + +import ( + "context" + "fmt" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/ollama/ollama/api" +) + +// TestMCPClientInitialization tests the MCP client initialization +func TestMCPClientInitialization(t *testing.T) { + client := NewMCPClient("test", "echo", []string{"test"}, nil) + + require.Equal(t, "test", client.name) + require.Equal(t, "echo", client.command) + require.False(t, client.initialized, "Client should not be initialized on creation") +} + +// TestSecureEnvironmentFiltering tests environment variable filtering +func TestSecureEnvironmentFiltering(t *testing.T) { + // Set some test environment variables + os.Setenv("TEST_SAFE_VAR", "safe_value") + os.Setenv("AWS_SECRET_ACCESS_KEY", "secret_key") + os.Setenv("PATH", "/usr/local/bin:/usr/bin:/bin:/root/bin") + defer os.Unsetenv("TEST_SAFE_VAR") + defer os.Unsetenv("AWS_SECRET_ACCESS_KEY") + + client := NewMCPClient("test", "echo", []string{}, nil) + env := client.buildSecureEnvironment() + + // Check that sensitive variables are filtered out + for _, e := range env { + require.False(t, strings.HasPrefix(e, "AWS_SECRET_ACCESS_KEY="), + "Sensitive AWS_SECRET_ACCESS_KEY should be filtered out") + require.False(t, strings.Contains(e, "/root/bin"), + "Dangerous PATH component /root/bin should be filtered out") + } + + // Check that PATH is present but sanitized + hasPath := false + for _, e := range env { + if strings.HasPrefix(e, "PATH=") { + hasPath = true + require.NotContains(t, e, "/root", + "PATH should not contain /root directories") + } + } + require.True(t, hasPath, "PATH should be present in environment") +} + +// TestMCPManagerAddServer tests adding MCP servers to the manager +func TestMCPManagerAddServer(t *testing.T) { + manager := NewMCPManager(5) + + // Test adding a valid server config + config := api.MCPServerConfig{ + Name: "test_server", + Command: "python", + Args: []string{"-m", "test_module"}, + Env: map[string]string{"TEST": "value"}, + } + + // This will fail in test environment but validates the validation logic + err := manager.AddServer(config) + if err != nil { + require.Contains(t, err.Error(), "failed to initialize", + "Expected initialization failure in test environment") + } + + // Test invalid server names + invalidConfigs := []api.MCPServerConfig{ + {Name: "", Command: "python"}, // Empty name + {Name: strings.Repeat("a", 101), Command: "python"}, // Too long + {Name: "test/server", Command: "python"}, // Invalid characters + } + + for _, cfg := range invalidConfigs { + err := manager.validateServerConfig(cfg) + require.Error(t, err, "Should reject invalid config: %+v", cfg) + } +} + +// TestDangerousCommandValidation tests rejection of dangerous commands +func TestDangerousCommandValidation(t *testing.T) { + manager := NewMCPManager(5) + + dangerousConfigs := []api.MCPServerConfig{ + {Name: "test1", Command: "bash"}, + {Name: "test2", Command: "/bin/sh"}, + {Name: "test3", Command: "sudo"}, + {Name: "test4", Command: "rm"}, + {Name: "test5", Command: "curl"}, + {Name: "test6", Command: "eval"}, + } + + for _, cfg := range dangerousConfigs { + err := manager.validateServerConfig(cfg) + require.Error(t, err, "Should reject dangerous command: %s", cfg.Command) + require.Contains(t, err.Error(), "not allowed for security", + "Expected security error for command %s", cfg.Command) + } + + // Test that safe commands are allowed + safeConfigs := []api.MCPServerConfig{ + {Name: "test1", Command: "python"}, + {Name: "test2", Command: "node"}, + {Name: "test3", Command: "/usr/bin/python3"}, + } + + for _, cfg := range safeConfigs { + err := manager.validateServerConfig(cfg) + require.NoError(t, err, "Should allow safe command %s", cfg.Command) + } +} + +// TestShellInjectionPrevention tests prevention of shell injection +func TestShellInjectionPrevention(t *testing.T) { + manager := NewMCPManager(5) + + // Test arguments with shell metacharacters + injectionConfigs := []api.MCPServerConfig{ + { + Name: "test1", + Command: "python", + Args: []string{"; rm -rf /"}, + }, + { + Name: "test2", + Command: "python", + Args: []string{"test", "| cat /etc/passwd"}, + }, + { + Name: "test3", + Command: "python", + Args: []string{"$(whoami)"}, + }, + { + Name: "test4", + Command: "python", + Args: []string{"`id`"}, + }, + } + + for _, cfg := range injectionConfigs { + err := manager.validateServerConfig(cfg) + require.Error(t, err, "Should reject shell injection attempt in args: %v", cfg.Args) + require.Contains(t, err.Error(), "shell metacharacters", + "Expected shell metacharacter error") + } +} + + +// TestParallelToolExecution tests parallel execution of tools +func TestParallelToolExecution(t *testing.T) { + manager := NewMCPManager(5) + + // Create test tool calls + toolCalls := []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "tool1", + Arguments: map[string]interface{}{"test": "1"}, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "tool2", + Arguments: map[string]interface{}{"test": "2"}, + }, + }, + { + Function: api.ToolCallFunction{ + Name: "tool3", + Arguments: map[string]interface{}{"test": "3"}, + }, + }, + } + + // Execute in parallel (will fail but tests the mechanism) + results := manager.ExecuteToolsParallel(toolCalls) + + require.Len(t, results, len(toolCalls)) + + // All should have errors since no MCP servers are connected + for i, result := range results { + require.Error(t, result.Error, "Expected error for tool call %d", i) + } +} + + +// TestMCPClientTimeout tests timeout handling for tool execution +func TestMCPClientTimeout(t *testing.T) { + client := NewMCPClient("test", "sleep", []string{"60"}, nil) + + // Create a context with very short timeout + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + // Try to call with timeout (will fail but tests the mechanism) + req := mcpCallToolRequest{ + Name: "test_tool", + Arguments: map[string]interface{}{}, + } + + var resp mcpCallToolResponse + err := client.callWithContext(ctx, "tools/call", req, &resp) + + // Should timeout or fail + require.Error(t, err, "Expected timeout or error") +} + +// TestEnvironmentVariableValidation tests validation of environment variables +func TestEnvironmentVariableValidation(t *testing.T) { + manager := NewMCPManager(5) + + // Test invalid environment variable names + invalidEnvConfigs := []api.MCPServerConfig{ + { + Name: "test1", + Command: "python", + Env: map[string]string{"VAR=BAD": "value"}, + }, + { + Name: "test2", + Command: "python", + Env: map[string]string{"VAR;CMD": "value"}, + }, + { + Name: "test3", + Command: "python", + Env: map[string]string{"VAR|PIPE": "value"}, + }, + } + + for _, cfg := range invalidEnvConfigs { + err := manager.validateServerConfig(cfg) + require.Error(t, err, "Should reject invalid environment variable names: %v", cfg.Env) + } + + // Test valid environment variables + validConfig := api.MCPServerConfig{ + Name: "test", + Command: "python", + Env: map[string]string{ + "PYTHONPATH": "/usr/lib/python3", + "MY_VAR": "value", + "TEST_123": "test", + }, + } + + err := manager.validateServerConfig(validConfig) + require.NoError(t, err, "Should allow valid environment variables") +} + +// BenchmarkToolExecution benchmarks tool execution performance +func BenchmarkToolExecution(b *testing.B) { + manager := NewMCPManager(10) + + toolCall := api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "test_tool", + Arguments: map[string]interface{}{"param": "value"}, + }, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = manager.ExecuteTool(toolCall) + } +} + +// BenchmarkParallelToolExecution benchmarks parallel tool execution +func BenchmarkParallelToolExecution(b *testing.B) { + manager := NewMCPManager(10) + + toolCalls := make([]api.ToolCall, 10) + for i := range toolCalls { + toolCalls[i] = api.ToolCall{ + Function: api.ToolCallFunction{ + Name: fmt.Sprintf("tool_%d", i), + Arguments: map[string]interface{}{"param": i}, + }, + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = manager.ExecuteToolsParallel(toolCalls) + } +} + +// ============================================================================= +// Auto-Enable Unit Tests +// ============================================================================= + +// TestAutoEnableMode_Never verifies servers with auto_enable:"never" don't auto-enable +func TestAutoEnableMode_Never(t *testing.T) { + defs := &MCPDefinitions{ + Servers: map[string]MCPServerDefinition{ + "never_server": { + Name: "never_server", + Command: "python", + AutoEnable: AutoEnableNever, + }, + "empty_mode": { + Name: "empty_mode", + Command: "python", + // AutoEnable not set - defaults to never + }, + }, + } + + ctx := AutoEnableContext{ToolsPath: "/some/path"} + servers := defs.GetAutoEnableServers(ctx) + + require.Empty(t, servers, "Expected 0 auto-enabled servers for 'never' mode") +} + +// TestAutoEnableMode_Always verifies servers with auto_enable:"always" always enable +func TestAutoEnableMode_Always(t *testing.T) { + defs := &MCPDefinitions{ + Servers: map[string]MCPServerDefinition{ + "always_server": { + Name: "always_server", + Command: "python", + Args: []string{"-m", "server"}, + AutoEnable: AutoEnableAlways, + }, + }, + } + + // Should enable even with empty path + ctx := AutoEnableContext{ToolsPath: ""} + servers := defs.GetAutoEnableServers(ctx) + + require.Len(t, servers, 1) + require.Equal(t, "always_server", servers[0].Name) + + // Should also enable with path + ctx = AutoEnableContext{ToolsPath: "/tmp"} + servers = defs.GetAutoEnableServers(ctx) + + require.Len(t, servers, 1) +} + +// TestAutoEnableMode_WithPath verifies servers with auto_enable:"with_path" enable only when path is provided +func TestAutoEnableMode_WithPath(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "mcp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + defs := &MCPDefinitions{ + Servers: map[string]MCPServerDefinition{ + "path_server": { + Name: "path_server", + Command: "python", + Args: []string{"-m", "server"}, + RequiresPath: true, + PathArgIndex: -1, + AutoEnable: AutoEnableWithPath, + }, + }, + } + + // Should NOT enable without path + ctx := AutoEnableContext{ToolsPath: ""} + servers := defs.GetAutoEnableServers(ctx) + require.Empty(t, servers, "Expected 0 servers without path") + + // Should enable with valid path + ctx = AutoEnableContext{ToolsPath: tmpDir} + servers = defs.GetAutoEnableServers(ctx) + require.Len(t, servers, 1) + + // Verify path was appended to args + expectedArgs := []string{"-m", "server", tmpDir} + require.Equal(t, expectedArgs, servers[0].Args) +} + +// TestAutoEnableMode_IfMatch_FileExists verifies file_exists condition +func TestAutoEnableMode_IfMatch_FileExists(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "mcp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + // Create .git directory to simulate git repo + gitDir := filepath.Join(tmpDir, ".git") + require.NoError(t, os.Mkdir(gitDir, 0755)) + + defs := &MCPDefinitions{ + Servers: map[string]MCPServerDefinition{ + "git_server": { + Name: "git_server", + Command: "python", + Args: []string{"-m", "git_server"}, + RequiresPath: true, + PathArgIndex: -1, + AutoEnable: AutoEnableIfMatch, + EnableIf: EnableCondition{FileExists: ".git"}, + }, + }, + } + + // Should enable when .git exists + ctx := AutoEnableContext{ToolsPath: tmpDir} + servers := defs.GetAutoEnableServers(ctx) + require.Len(t, servers, 1, "Expected 1 server when .git exists") + + // Should NOT enable in directory without .git + noGitDir, err := os.MkdirTemp("", "mcp-test-nogit-*") + require.NoError(t, err) + defer os.RemoveAll(noGitDir) + + ctx = AutoEnableContext{ToolsPath: noGitDir} + servers = defs.GetAutoEnableServers(ctx) + require.Empty(t, servers, "Expected 0 servers without .git") +} + +// TestAutoEnableMode_IfMatch_EnvSet verifies env_set condition +func TestAutoEnableMode_IfMatch_EnvSet(t *testing.T) { + defs := &MCPDefinitions{ + Servers: map[string]MCPServerDefinition{ + "env_server": { + Name: "env_server", + Command: "python", + AutoEnable: AutoEnableIfMatch, + EnableIf: EnableCondition{EnvSet: "MCP_TEST_VAR"}, + }, + }, + } + + // Test with env in context + ctx := AutoEnableContext{ + ToolsPath: "", + Env: map[string]string{"MCP_TEST_VAR": "some_value"}, + } + servers := defs.GetAutoEnableServers(ctx) + require.Len(t, servers, 1, "Expected 1 server when env is set in context") + + // Test with env NOT set + ctx = AutoEnableContext{ + ToolsPath: "", + Env: map[string]string{}, + } + servers = defs.GetAutoEnableServers(ctx) + require.Empty(t, servers, "Expected 0 servers when env not set") + + // Test with os.Getenv fallback + os.Setenv("MCP_TEST_VAR_FALLBACK", "fallback_value") + defer os.Unsetenv("MCP_TEST_VAR_FALLBACK") + + defsFallback := &MCPDefinitions{ + Servers: map[string]MCPServerDefinition{ + "env_server": { + Name: "env_server", + Command: "python", + AutoEnable: AutoEnableIfMatch, + EnableIf: EnableCondition{EnvSet: "MCP_TEST_VAR_FALLBACK"}, + }, + }, + } + ctx = AutoEnableContext{ToolsPath: "", Env: nil} + servers = defsFallback.GetAutoEnableServers(ctx) + require.Len(t, servers, 1, "Expected 1 server with os.Getenv fallback") +} + +// TestAutoEnableMode_IfMatch_CombinedConditions verifies AND logic for conditions +func TestAutoEnableMode_IfMatch_CombinedConditions(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "mcp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + markerFile := filepath.Join(tmpDir, ".marker") + require.NoError(t, os.WriteFile(markerFile, []byte("test"), 0644)) + + defs := &MCPDefinitions{ + Servers: map[string]MCPServerDefinition{ + "combined_server": { + Name: "combined_server", + Command: "python", + RequiresPath: true, + PathArgIndex: -1, + AutoEnable: AutoEnableIfMatch, + EnableIf: EnableCondition{ + FileExists: ".marker", + EnvSet: "MCP_COMBINED_TEST", + }, + }, + }, + } + + // Should NOT enable when only file exists + ctx := AutoEnableContext{ + ToolsPath: tmpDir, + Env: map[string]string{}, + } + servers := defs.GetAutoEnableServers(ctx) + require.Empty(t, servers, "Expected 0 servers when only file condition matches") + + // Should NOT enable when only env is set + ctx = AutoEnableContext{ + ToolsPath: "/nonexistent", + Env: map[string]string{"MCP_COMBINED_TEST": "value"}, + } + servers = defs.GetAutoEnableServers(ctx) + require.Empty(t, servers, "Expected 0 servers when only env condition matches") + + // Should enable when BOTH conditions match + ctx = AutoEnableContext{ + ToolsPath: tmpDir, + Env: map[string]string{"MCP_COMBINED_TEST": "value"}, + } + servers = defs.GetAutoEnableServers(ctx) + require.Len(t, servers, 1, "Expected 1 server when both conditions match") +} + +// TestGetAutoEnableServers_MultipleServers verifies multiple servers can auto-enable +func TestGetAutoEnableServers_MultipleServers(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "mcp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + // Create .git directory + require.NoError(t, os.Mkdir(filepath.Join(tmpDir, ".git"), 0755)) + + defs := &MCPDefinitions{ + Servers: map[string]MCPServerDefinition{ + "filesystem": { + Name: "filesystem", + Command: "npx", + Args: []string{"-y", "@mcp/server-filesystem"}, + RequiresPath: true, + PathArgIndex: -1, + AutoEnable: AutoEnableWithPath, + }, + "git": { + Name: "git", + Command: "python", + Args: []string{"-m", "mcp_git"}, + RequiresPath: true, + PathArgIndex: -1, + AutoEnable: AutoEnableIfMatch, + EnableIf: EnableCondition{FileExists: ".git"}, + }, + "never_server": { + Name: "never_server", + Command: "python", + AutoEnable: AutoEnableNever, + }, + }, + } + + ctx := AutoEnableContext{ToolsPath: tmpDir} + servers := defs.GetAutoEnableServers(ctx) + + require.Len(t, servers, 2, "Expected 2 auto-enabled servers") + + // Verify both filesystem and git are enabled + names := make(map[string]bool) + for _, s := range servers { + names[s.Name] = true + } + + require.True(t, names["filesystem"], "Expected 'filesystem' server to be auto-enabled") + require.True(t, names["git"], "Expected 'git' server to be auto-enabled") + require.False(t, names["never_server"], "'never_server' should NOT be auto-enabled") +} + +// TestBuildConfigForAutoEnable_PathArgIndex verifies path insertion at different positions +func TestBuildConfigForAutoEnable_PathArgIndex(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "mcp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + testCases := []struct { + name string + args []string + pathArgIndex int + expected []string + }{ + { + name: "append at end (index -1)", + args: []string{"arg1", "arg2"}, + pathArgIndex: -1, + expected: []string{"arg1", "arg2", tmpDir}, + }, + { + name: "insert at beginning (index 0)", + args: []string{"arg1", "arg2"}, + pathArgIndex: 0, + expected: []string{tmpDir, "arg1", "arg2"}, + }, + { + name: "insert in middle (index 1)", + args: []string{"arg1", "arg2"}, + pathArgIndex: 1, + expected: []string{"arg1", tmpDir, "arg2"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + defs := &MCPDefinitions{ + Servers: map[string]MCPServerDefinition{ + "test": { + Name: "test", + Command: "python", + Args: tc.args, + RequiresPath: true, + PathArgIndex: tc.pathArgIndex, + AutoEnable: AutoEnableWithPath, + }, + }, + } + + ctx := AutoEnableContext{ToolsPath: tmpDir} + servers := defs.GetAutoEnableServers(ctx) + + require.Len(t, servers, 1) + require.Equal(t, tc.expected, servers[0].Args) + }) + } +} + +// TestBuildConfigForAutoEnable_InvalidPath verifies error handling for invalid paths +func TestBuildConfigForAutoEnable_InvalidPath(t *testing.T) { + defs := &MCPDefinitions{ + Servers: map[string]MCPServerDefinition{ + "path_server": { + Name: "path_server", + Command: "python", + RequiresPath: true, + AutoEnable: AutoEnableWithPath, + }, + }, + } + + // Should fail with non-existent path + ctx := AutoEnableContext{ToolsPath: "/definitely/not/a/real/path/12345"} + servers := defs.GetAutoEnableServers(ctx) + + // Server should be skipped due to invalid path + require.Empty(t, servers, "Expected 0 servers with invalid path") +} + +// TestBuildConfigForAutoEnable_EnvCopy verifies environment variables are copied +func TestBuildConfigForAutoEnable_EnvCopy(t *testing.T) { + defs := &MCPDefinitions{ + Servers: map[string]MCPServerDefinition{ + "env_server": { + Name: "env_server", + Command: "python", + AutoEnable: AutoEnableAlways, + Env: map[string]string{ + "VAR1": "value1", + "VAR2": "value2", + }, + }, + }, + } + + ctx := AutoEnableContext{} + servers := defs.GetAutoEnableServers(ctx) + + require.Len(t, servers, 1) + require.Equal(t, "value1", servers[0].Env["VAR1"]) + require.Equal(t, "value2", servers[0].Env["VAR2"]) + + // Verify original wasn't mutated by modifying copy + servers[0].Env["VAR1"] = "modified" + original := defs.Servers["env_server"] + require.Equal(t, "value1", original.Env["VAR1"], "Original server definition was mutated") +} + +// TestEnableCondition_EmptyConditions verifies empty conditions always match +func TestEnableCondition_EmptyConditions(t *testing.T) { + defs := &MCPDefinitions{ + Servers: map[string]MCPServerDefinition{ + "empty_cond": { + Name: "empty_cond", + Command: "python", + AutoEnable: AutoEnableIfMatch, + EnableIf: EnableCondition{}, // Empty - should always match + }, + }, + } + + ctx := AutoEnableContext{ToolsPath: ""} + servers := defs.GetAutoEnableServers(ctx) + + require.Len(t, servers, 1, "Expected 1 server with empty conditions") +} diff --git a/server/routes.go b/server/routes.go index 977a13ff2..28bcfd2ee 100644 --- a/server/routes.go +++ b/server/routes.go @@ -52,6 +52,17 @@ import ( "github.com/ollama/ollama/version" ) +// CompletionResult holds the result of a completion request +type CompletionResult struct { + Content string + Thinking string + ToolCalls []api.ToolCall + Done bool + DoneReason string + Metrics api.Metrics + Error error +} + const signinURLStr = "https://ollama.com/connect?name=%s&key=%s" func shouldUseHarmony(model *Model) bool { @@ -337,10 +348,11 @@ func (s *Server) GenerateHandler(c *gin.Context) { m.Config.Parser = "harmony" } + if !req.Raw && m.Config.Parser != "" { builtinParser = parsers.ParserForName(m.Config.Parser) if builtinParser != nil { - // no tools or last message for generate endpoint + // Initialize parser for thinking extraction only (tools not supported in Generate API) builtinParser.Init(nil, nil, req.Think) } } @@ -459,7 +471,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { // the real chat handler, but doing this as a stopgap to get renderer // support for generate if values.Messages != nil && values.Suffix == "" && req.Template == "" { - prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate) + prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, nil, req.Think, req.Truncate == nil || *req.Truncate) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -510,8 +522,8 @@ func (s *Server) GenerateHandler(c *gin.Context) { ch := make(chan any) go func() { // TODO (jmorganca): avoid building the response twice both here and below - var sb strings.Builder defer close(ch) + var sb strings.Builder if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ Prompt: prompt, Images: images, @@ -537,16 +549,13 @@ func (s *Server) GenerateHandler(c *gin.Context) { } if builtinParser != nil { - content, thinking, toolCalls, err := builtinParser.Add(cr.Content, cr.Done) + content, thinking, _, err := builtinParser.Add(cr.Content, cr.Done) if err != nil { ch <- gin.H{"error": err.Error()} return } res.Response = content res.Thinking = thinking - if cr.Done && len(toolCalls) > 0 { - res.ToolCalls = toolCalls - } } else if thinkingState != nil { thinking, content := thinkingState.AddContent(cr.Content) res.Thinking = thinking @@ -574,7 +583,7 @@ func (s *Server) GenerateHandler(c *gin.Context) { if builtinParser != nil { // only send messages with meaningful content (empty messages confuse clients) - if res.Response != "" || res.Thinking != "" || res.Done || len(res.ToolCalls) > 0 { + if res.Response != "" || res.Thinking != "" || res.Done { ch <- res } @@ -1517,6 +1526,10 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.POST("/api/show", s.ShowHandler) r.DELETE("/api/delete", s.DeleteHandler) + // MCP Tools discovery + r.GET("/api/tools", s.ToolsHandler) + r.POST("/api/tools", s.ToolsHandler) + r.POST("/api/me", s.WhoamiHandler) r.POST("/api/signout", s.SignoutHandler) @@ -1861,6 +1874,211 @@ func toolCallId() string { return "call_" + strings.ToLower(string(b)) } +// executeCompletionWithTools executes a completion and collects the full response +// This is a synchronous wrapper around the async completion callback +// When suppressDone is true, the Done flag is not sent to the client channel +// (used for intermediate rounds in multi-round tool execution) +func (s *Server) executeCompletionWithTools( + ctx context.Context, + r llm.LlamaServer, + prompt string, + images []llm.ImageData, + opts *api.Options, + req api.ChatRequest, + m *Model, + builtinParser parsers.Parser, + thinkingState *thinking.Parser, + ch chan any, + checkpointStart time.Time, + checkpointLoaded time.Time, + truncate bool, + suppressDone bool, +) (*CompletionResult, error) { + result := &CompletionResult{} + done := make(chan error, 1) + + // For tracking tool calls when using tools + var toolParser *tools.Parser + if len(req.Tools) > 0 && builtinParser == nil { + toolParser = tools.NewParser(m.Template.Template, req.Tools) + } + + // Track thinking content for structured outputs + var thinkingBuilder strings.Builder + + // Accumulate tool calls across streaming chunks + var accumulatedToolCalls []api.ToolCall + + // Create a new context for this completion + completionCtx, cancel := context.WithCancel(ctx) + defer cancel() + + err := r.Completion(completionCtx, llm.CompletionRequest{ + Prompt: prompt, + Images: images, + Format: req.Format, + Options: opts, + Shift: req.Shift == nil || *req.Shift, + Truncate: truncate, + Logprobs: req.Logprobs, + TopLogprobs: req.TopLogprobs, + }, func(resp llm.CompletionResponse) { + // When suppressDone is true, don't signal Done to client + // (used for intermediate rounds in multi-round tool execution) + clientDone := resp.Done && !suppressDone + + res := api.ChatResponse{ + Model: req.Model, + CreatedAt: time.Now().UTC(), + Message: api.Message{Role: "assistant", Content: resp.Content}, + Done: clientDone, + Metrics: api.Metrics{ + PromptEvalCount: resp.PromptEvalCount, + PromptEvalDuration: resp.PromptEvalDuration, + EvalCount: resp.EvalCount, + EvalDuration: resp.EvalDuration, + }, + Logprobs: toAPILogprobs(resp.Logprobs), + } + + if resp.Done { + res.DoneReason = resp.DoneReason.String() + res.TotalDuration = time.Since(checkpointStart) + res.LoadDuration = checkpointLoaded.Sub(checkpointStart) + result.DoneReason = res.DoneReason + result.Metrics = res.Metrics + } + + // Handle builtin parser (for models with native tool support) + if builtinParser != nil { + content, thinking, toolCalls, err := builtinParser.Add(resp.Content, resp.Done) + if err != nil { + result.Error = err + done <- err + return + } + + res.Message.Content = content + res.Message.Thinking = thinking + res.Message.ToolCalls = toolCalls + + thinkingBuilder.WriteString(thinking) + + // Accumulate results + result.Content += content + result.Thinking += thinking + + // Accumulate tool calls for multi-round MCP execution + if len(toolCalls) > 0 { + accumulatedToolCalls = append(accumulatedToolCalls, toolCalls...) + } + + // On completion, set all accumulated tool calls + if resp.Done { + result.ToolCalls = accumulatedToolCalls + } + + // Stream to client if there's content to stream + if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || resp.Done || len(res.Logprobs) > 0 { + ch <- res + } + + if resp.Done { + result.Done = true + done <- nil + } + return + } + + // Handle thinking state parser + if thinkingState != nil { + thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content) + if thinkingContent == "" && remainingContent == "" && !resp.Done { + // Need more content to decide + return + } + + res.Message.Thinking = thinkingContent + thinkingBuilder.WriteString(thinkingContent) + res.Message.Content = remainingContent + result.Thinking += thinkingContent + } + + // Handle tool parsing (for models without native tool support) + if len(req.Tools) > 0 && builtinParser == nil { + toolCalls, content := toolParser.Add(res.Message.Content) + if len(content) > 0 { + res.Message.Content = content + result.Content += content + } else if len(toolCalls) > 0 { + res.Message.ToolCalls = toolCalls + res.Message.Content = "" + // Keep accumulating tool calls + accumulatedToolCalls = toolCalls + } else if res.Message.Thinking != "" { + // don't return, fall through to send + } else { + // Send logprobs while content is being buffered by the parser for tool calls + if len(res.Logprobs) > 0 && !resp.Done { + logprobRes := res + logprobRes.Message.Content = "" + logprobRes.Message.ToolCalls = nil + ch <- logprobRes + } + + if resp.Done { + res.Message.Content = toolParser.Content() + // Set accumulated tool calls in result before signaling done + if len(accumulatedToolCalls) > 0 { + result.ToolCalls = accumulatedToolCalls + } + // If no tool calls, get final content from parser + if len(result.ToolCalls) == 0 && toolParser != nil { + result.Content = toolParser.Content() + } + result.Done = true + ch <- res + done <- nil + } + return + } + } else { + result.Content += res.Message.Content + } + + // Stream to client + ch <- res + + if resp.Done { + // If we accumulated tool calls, set them in result + if len(accumulatedToolCalls) > 0 { + result.ToolCalls = accumulatedToolCalls + } + // If no tool calls, get final content from parser + if len(result.ToolCalls) == 0 && toolParser != nil { + result.Content = toolParser.Content() + } + result.Done = true + done <- nil + } + }) + + if err != nil { + return nil, err + } + + // Wait for completion or context cancellation + select { + case err := <-done: + if err != nil { + return nil, err + } + return result, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + func (s *Server) ChatHandler(c *gin.Context) { checkpointStart := time.Now() @@ -2027,6 +2245,80 @@ func (s *Server) ChatHandler(c *gin.Context) { } } + // ========================================================================= + // MCP (Model Context Protocol) Integration + // ========================================================================= + // + // MCP allows the model to execute external tools via JSON-RPC servers. + // This section handles: + // 1. Manager initialization (from session cache or new) + // 2. Tool discovery (list available tools from MCP servers) + // 3. Context injection (inform model about available tools) + // 4. Parser configuration (for tool call detection) + // + // Entry points: + // - req.MCPServers: Explicit server configs from API + // - req.ToolsPath: Path-based auto-enable from --tools flag + // + // See: mcp.go, mcp_manager.go for implementation details + // ========================================================================= + + var mcpManager *MCPManager + + if len(req.MCPServers) > 0 || req.ToolsPath != "" { + if req.ToolsPath != "" { + // Path-based mode: auto-enable servers matching the tools path + // Used by CLI: `ollama run model --tools /path` + slog.Debug("Using tools path for MCP manager", "tools_path", req.ToolsPath, "model", req.Model) + mcpManager, err = GetMCPManagerForPath(req.Model, req.ToolsPath) + if err != nil { + slog.Error("Failed to get MCP manager for tools path", "error", err) + // Continue without MCP - graceful degradation + } + } else if len(req.MCPServers) > 0 { + // Explicit mode: use server configs from API request + // Used by API: POST /api/chat with mcp_servers field + sessionID := GenerateSessionID(req) + slog.Debug("Getting MCP manager", "session", sessionID, "servers", len(req.MCPServers)) + mcpManager, err = GetMCPManager(sessionID, req.MCPServers) + if err != nil { + slog.Error("Failed to get MCP manager", "error", err) + // Continue without MCP - graceful degradation + } + } + + if mcpManager != nil { + // Step 1: Discover tools from MCP servers and add to request + mcpTools := mcpManager.GetAllTools() + req.Tools = append(req.Tools, mcpTools...) + + // Step 2: Inject context to help model use tools effectively + // Use programmatic context injection from tool schemas + codeAPI := NewMCPCodeAPI(mcpManager) + req.Messages = codeAPI.InjectContextIntoMessages(req.Messages, req.MCPServers) + + // Step 3: Auto-configure parser for tool call detection + if len(req.Tools) > 0 && m.Config.Parser == "" { + if m.Config.ModelFamily == "qwen2" || m.Config.ModelFamily == "qwen3" { + m.Config.Parser = "qwen3-vl-instruct" + } + } + + // Step 4: Update capabilities now that we have tools + if len(req.Tools) > 0 && !slices.Contains(caps, model.CapabilityTools) { + caps = append(caps, model.CapabilityTools) + } + } + + // Cleanup: Close MCP manager when request completes + // Note: Session manager may cache for reuse within TTL + defer func() { + if err := mcpManager.Close(); err != nil { + slog.Warn("Error closing MCP manager", "error", err) + } + }() + } + r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive) if errors.Is(err, errCapabilityCompletion) { c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) @@ -2115,11 +2407,6 @@ func (s *Server) ChatHandler(c *gin.Context) { } } - var toolParser *tools.Parser - if len(req.Tools) > 0 && (builtinParser == nil || !builtinParser.HasToolSupport()) { - toolParser = tools.NewParser(m.Template.Template, req.Tools) - } - type structuredOutputsState int const ( structuredOutputsState_None structuredOutputsState = iota @@ -2131,181 +2418,223 @@ func (s *Server) ChatHandler(c *gin.Context) { go func() { defer close(ch) - structuredOutputsState := structuredOutputsState_None + // Initialize for multi-round execution + // NOTE: Upstream's structuredOutputsState for thinking models is not yet integrated + // TODO: Add structuredOutputsState support for thinking models with format constraints + currentMsgs := msgs + maxRounds := req.MaxToolRounds + if maxRounds == 0 { + maxRounds = 15 // Default maximum rounds + } - for { - var tb strings.Builder + slog.Debug("Starting multi-round execution", + "mcpManager", mcpManager != nil, + "tools_count", len(req.Tools), + "max_rounds", maxRounds) - currentFormat := req.Format - // structured outputs via double request is enabled when: - // 1. the model supports the thinking capability and - // 2. it uses a built-in parser or our generic thinking parser + // MAIN LOOP - Multi-round execution for tool calling + var round int + for round = 0; round < maxRounds; round++ { + slog.Debug("Starting round", "round", round, "messages", len(currentMsgs)) - // Note that the current approach does not work for (potential future) - // non-thinking models that emit anything before actual content. This - // current approach uses the transition from parsed thinking content to - // parsed non-thinking content as the signal to turn constraining on - - if req.Format != nil && structuredOutputsState == structuredOutputsState_None && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) { - currentFormat = nil - } - - // sets up new context given parent context per request - ctx, cancel := context.WithCancel(c.Request.Context()) - err := r.Completion(ctx, llm.CompletionRequest{ - Prompt: prompt, - Images: images, - Format: currentFormat, - Options: opts, - Shift: req.Shift == nil || *req.Shift, - Truncate: truncate, - Logprobs: req.Logprobs, - TopLogprobs: req.TopLogprobs, - }, func(r llm.CompletionResponse) { - res := api.ChatResponse{ - Model: req.Model, - CreatedAt: time.Now().UTC(), - Message: api.Message{Role: "assistant", Content: r.Content}, - Done: r.Done, - Metrics: api.Metrics{ - PromptEvalCount: r.PromptEvalCount, - PromptEvalDuration: r.PromptEvalDuration, - EvalCount: r.EvalCount, - EvalDuration: r.EvalDuration, - }, - Logprobs: toAPILogprobs(r.Logprobs), - } - - if r.Done { - res.DoneReason = r.DoneReason.String() - res.TotalDuration = time.Since(checkpointStart) - res.LoadDuration = checkpointLoaded.Sub(checkpointStart) - } - - if builtinParser != nil { - slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content) - - content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done) - if err != nil { - ch <- gin.H{"error": err.Error()} - return - } - - res.Message.Content = content - res.Message.Thinking = thinking - for i := range toolCalls { - toolCalls[i].ID = toolCallId() - } - res.Message.ToolCalls = toolCalls - - tb.WriteString(thinking) - // we are now receiving content from the model - we should start applying structured outputs - if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && res.Message.Content != "" { - structuredOutputsState = structuredOutputsState_ReadyToApply - cancel() - return - } - - if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done || len(res.Logprobs) > 0 { - slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done) - ch <- res - } else { - slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser) - } - return - } - - if thinkingState != nil { - thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content) - if thinkingContent == "" && remainingContent == "" && !r.Done { - // need to accumulate more to decide what to send - return - } - res.Message.Thinking = thinkingContent - tb.WriteString(thinkingContent) - // emit the collected thinking text before restarting with structured outputs and clear unstructured content - // to avoid leaking mixed tokens like "Hello" - if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && remainingContent != "" { - structuredOutputsState = structuredOutputsState_ReadyToApply - res.Message.Content = "" - ch <- res - cancel() - return - } - res.Message.Content = remainingContent - } - - if len(req.Tools) > 0 { - toolCalls, content := toolParser.Add(res.Message.Content) - if len(content) > 0 { - res.Message.Content = content - } else if len(toolCalls) > 0 { - for i := range toolCalls { - toolCalls[i].ID = toolCallId() - } - res.Message.ToolCalls = toolCalls - res.Message.Content = "" - } else if res.Message.Thinking != "" { - // don't return, fall through to send - } else { - // Send logprobs while content is being buffered by the parser for tool calls - if len(res.Logprobs) > 0 && !r.Done { - logprobRes := res - logprobRes.Message.Content = "" - logprobRes.Message.ToolCalls = nil - ch <- logprobRes - } - - if r.Done { - res.Message.Content = toolParser.Content() - ch <- res - } - return - } - } - - ch <- res - }) - if err != nil { - if structuredOutputsState == structuredOutputsState_ReadyToApply && strings.Contains(err.Error(), "context canceled") && c.Request.Context().Err() == nil { - // only ignores error if it's a context cancellation due to setting structured outputs - } else { - var serr api.StatusError - if errors.As(err, &serr) { - ch <- gin.H{"error": serr.ErrorMessage, "status": serr.StatusCode} - } else { - ch <- gin.H{"error": err.Error()} - } - return - } - } - - // ignored structured outputs cancellation falls through to here, start a new request with the structured outputs and updated prompt. use the - if structuredOutputsState == structuredOutputsState_ReadyToApply { - structuredOutputsState = structuredOutputsState_Applying - msg := api.Message{ - Role: "assistant", - Thinking: tb.String(), - } - - msgs = append(msgs, msg) - prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate) + // Re-render prompt and reset parser if not first round (tool results were added) + if round > 0 { + var err error + prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, currentMsgs, processedTools, req.Think, truncate) if err != nil { - slog.Error("chat prompt error applying structured outputs", "error", err) + slog.Error("Failed to render prompt in round", "round", round, "error", err) ch <- gin.H{"error": err.Error()} return } - // force constraining by terminating thinking header, the parser is already at this state - // when the last message is thinking, the rendered for gpt-oss cannot disambiguate between having the - // model continue thinking or ending thinking and outputting the final message. - // TODO(parthsareen): consider adding prefill disambiguation logic to the renderer for structured outputs. - if shouldUseHarmony(m) || (builtinParser != nil && m.Config.Parser == "harmony") { - prompt += "<|end|><|start|>assistant<|channel|>final<|message|>" + + // Create fresh parser instance for new round (parser has internal buffer state) + if builtinParser != nil && m.Config.Parser != "" { + builtinParser = parsers.ParserForName(m.Config.Parser) + if builtinParser != nil { + lastMsg := ¤tMsgs[len(currentMsgs)-1] + builtinParser.Init(req.Tools, lastMsg, req.Think) + } } - continue } - break + // Execute completion and collect full response + // When MCP is enabled, suppress Done flag during intermediate rounds + // to prevent client from closing connection prematurely + suppressDone := mcpManager != nil + completionResult, err := s.executeCompletionWithTools( + c.Request.Context(), + r, + prompt, + images, + opts, + req, + m, + builtinParser, + thinkingState, + ch, + checkpointStart, + checkpointLoaded, + truncate, + suppressDone, + ) + + if err != nil { + slog.Error("Completion failed", "round", round, "error", err) + var serr api.StatusError + if errors.As(err, &serr) { + ch <- gin.H{"error": serr.ErrorMessage, "status": serr.StatusCode} + } else { + ch <- gin.H{"error": err.Error()} + } + return + } + + // Check if model called tools + if len(completionResult.ToolCalls) == 0 { + // No tools called - conversation is complete + slog.Debug("No tools called, conversation complete", "round", round) + break // Exit the loop - we're done + } + + // Validate tool calls are not empty or malformed + validToolCalls := 0 + for _, tc := range completionResult.ToolCalls { + if tc.Function.Name != "" { + validToolCalls++ + } else { + slog.Warn("Invalid tool call detected", "round", round, "tool", tc) + } + } + + if validToolCalls == 0 { + slog.Warn("No valid tool calls found, exiting", "round", round) + break + } + + // Model called tools - execute them if we have an MCP manager + if mcpManager != nil { + slog.Debug("MCP tool execution starting", + "tools_in_response", len(completionResult.ToolCalls), + "valid_tools", validToolCalls, + "round", round) + + // Send tool calls to client for display BEFORE executing + // This ensures the client can show "Executing tool..." for all rounds + // Note: Don't include Content here - it was already streamed during completion + ch <- api.ChatResponse{ + Model: req.Model, + Message: api.Message{ + Role: "assistant", + ToolCalls: completionResult.ToolCalls, + }, + } + + // Analyze execution plan + executionPlan := mcpManager.AnalyzeExecutionPlan(completionResult.ToolCalls) + slog.Debug("Execution plan determined", + "sequential", executionPlan.RequiresSequential, + "reason", executionPlan.Reason) + + // Execute tools according to plan + results := mcpManager.ExecuteWithPlan(completionResult.ToolCalls, executionPlan) + + // Log tool calls for debugging + for i, tc := range completionResult.ToolCalls { + slog.Info("Tool call details", + "round", round, + "index", i, + "name", tc.Function.Name, + "arguments", tc.Function.Arguments) + } + + // Add assistant message with tool calls + assistantMsg := api.Message{ + Role: "assistant", + Content: completionResult.Content, // Preserve any content + ToolCalls: completionResult.ToolCalls, + } + currentMsgs = append(currentMsgs, assistantMsg) + + // Add tool result messages and send them to client for display + toolResultsForDisplay := make([]api.ToolResult, 0, len(results)) + for i, result := range results { + toolMsg := api.Message{ + Role: "tool", + ToolName: completionResult.ToolCalls[i].Function.Name, + } + + // Create display result with arguments for context + displayResult := api.ToolResult{ + ToolName: completionResult.ToolCalls[i].Function.Name, + Arguments: completionResult.ToolCalls[i].Function.Arguments, + Content: result.Content, + } + + if result.Error != nil { + // JSON-encode the error for proper template rendering + if encoded, err := json.Marshal(fmt.Sprintf("Error: %v", result.Error)); err == nil { + toolMsg.Content = string(encoded) + } else { + toolMsg.Content = fmt.Sprintf("\"Error: %v\"", result.Error) + } + displayResult.Error = result.Error.Error() + slog.Warn("Tool execution failed", + "tool", completionResult.ToolCalls[i].Function.Name, + "error", result.Error) + } else { + // JSON-encode the content for proper template rendering + // The template expects {"content": {{ .Content }}} where Content should be a JSON string + if encoded, err := json.Marshal(result.Content); err == nil { + toolMsg.Content = string(encoded) + } else { + toolMsg.Content = result.Content + } + } + + currentMsgs = append(currentMsgs, toolMsg) + toolResultsForDisplay = append(toolResultsForDisplay, displayResult) + } + + // Send tool results to client for display + if len(toolResultsForDisplay) > 0 { + ch <- api.ChatResponse{ + Model: req.Model, + Message: api.Message{ + Role: "assistant", + ToolResults: toolResultsForDisplay, + }, + } + } + + // Continue to next round - model will process tool results + slog.Info("Tools executed, continuing to next round", + "round", round, + "messages", len(currentMsgs), + "last_tool", completionResult.ToolCalls[len(completionResult.ToolCalls)-1].Function.Name) + + } else { + // No MCP manager - send tool calls to client for external execution + slog.Debug("No MCP manager, sending tool calls to client", "round", round) + break // Exit - client will handle tool execution + } + } // End of maxRounds loop + + // Check if we exhausted rounds + if round >= maxRounds { + slog.Warn("Maximum tool execution rounds reached", "rounds", maxRounds) + ch <- gin.H{"error": fmt.Sprintf("Maximum tool execution rounds (%d) exceeded", maxRounds)} + } + + // When MCP was enabled, we suppressed Done flags during the loop + // Send a final Done: true to signal the conversation is complete + if mcpManager != nil { + ch <- api.ChatResponse{ + Model: req.Model, + CreatedAt: time.Now().UTC(), + Message: api.Message{Role: "assistant"}, + Done: true, + DoneReason: "stop", + } } }() @@ -2331,22 +2660,15 @@ func (s *Server) ChatHandler(c *gin.Context) { case gin.H: msg, ok := t["error"].(string) if !ok { - msg = "unexpected error format in response" + msg = "unexpected error" } - - status, ok := t["status"].(int) - if !ok { - status = http.StatusInternalServerError - } - - c.JSON(status, gin.H{"error": msg}) + c.JSON(http.StatusBadRequest, gin.H{"error": msg}) return default: c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"}) return } } - resp.Message.Content = sbContent.String() resp.Message.Thinking = sbThinking.String() resp.Logprobs = allLogprobs @@ -2354,12 +2676,10 @@ func (s *Server) ChatHandler(c *gin.Context) { if len(toolCalls) > 0 { resp.Message.ToolCalls = toolCalls } - c.JSON(http.StatusOK, resp) - return + } else { + streamResponse(c, ch) } - - streamResponse(c, ch) } func handleScheduleError(c *gin.Context, name string, err error) { diff --git a/server/routes_tools.go b/server/routes_tools.go new file mode 100644 index 000000000..2ecb1a84f --- /dev/null +++ b/server/routes_tools.go @@ -0,0 +1,91 @@ +package server + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "github.com/ollama/ollama/api" +) + +// ToolsHandler handles requests to list available MCP tools. +// GET: Returns available MCP server definitions from configuration. +// POST with mcp_servers: Returns tools from the specified MCP servers. +func (s *Server) ToolsHandler(c *gin.Context) { + var req struct { + MCPServers []api.MCPServerConfig `json:"mcp_servers,omitempty"` + } + + if c.Request.Method == "POST" { + if err := c.BindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + } + + // If MCP servers provided, list their tools + if len(req.MCPServers) > 0 { + manager := NewMCPManager(10) + defer manager.Close() + + var allTools []ToolInfo + for _, config := range req.MCPServers { + if err := manager.AddServer(config); err != nil { + // Include error in response but continue + allTools = append(allTools, ToolInfo{ + Name: config.Name, + Description: "Failed to initialize: " + err.Error(), + Error: err.Error(), + }) + continue + } + + // Get tools from this server + tools := manager.GetAllTools() + for _, tool := range tools { + allTools = append(allTools, ToolInfo{ + Name: tool.Function.Name, + Description: tool.Function.Description, + Parameters: &tool.Function.Parameters, + ServerName: config.Name, + }) + } + } + + c.JSON(http.StatusOK, ToolsResponse{ + Tools: allTools, + }) + return + } + + // Otherwise, list available MCP server definitions + defs, err := LoadMCPDefinitions() + if err != nil { + // Config parsing errors are client errors (bad config), not server errors + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid MCP configuration: " + err.Error()}) + return + } + + servers := defs.ListServers() + c.JSON(http.StatusOK, MCPServersResponse{ + Servers: servers, + }) +} + +// ToolInfo provides information about a single tool +type ToolInfo struct { + Name string `json:"name"` + Description string `json:"description"` + Parameters *api.ToolFunctionParameters `json:"parameters,omitempty"` + ServerName string `json:"server,omitempty"` + Error string `json:"error,omitempty"` +} + +// ToolsResponse contains the list of available tools +type ToolsResponse struct { + Tools []ToolInfo `json:"tools"` +} + +// MCPServersResponse contains the list of available MCP server types +type MCPServersResponse struct { + Servers []MCPServerInfo `json:"servers"` +} \ No newline at end of file