From 89f74a8b057420dccae2695922859bee72b9c4b9 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Tue, 30 Dec 2025 14:59:31 -0500 Subject: [PATCH] agents: add MCP server support and ENTRYPOINT command MCP (Model Context Protocol) support: - Add MCPRef type for agent MCP server references - Parse MCP command in Agentfiles (MCP name command [args...]) - Load and manage MCP servers with mcpManager - Implement agentic loop for multi-turn tool execution - Add /mcp REPL commands (add, remove, disable, enable) - Add 'ollama mcp' CLI commands for global config management - Support both model-bundled and global (~/.ollama/mcp.json) MCPs - Display MCPs in 'ollama show' output ENTRYPOINT support: - Add ENTRYPOINT command to Agentfiles for custom runtimes - Allow agents without FROM when ENTRYPOINT is specified - Execute entrypoint as subprocess with stdin/stdout connected - Support $PROMPT placeholder for prompt insertion control - Hide Model section in 'ollama show' for entrypoint-only agents - Pass user prompt as argument to entrypoint command --- api/types.go | 11 + cmd/cmd.go | 356 +++++++++++++---- cmd/interactive.go | 243 ++++++++++++ cmd/mcp.go | 545 +++++++++++++++++++++++++ cmd/mcp_cmd.go | 898 ++++++++++++++++++++++++++++++++++++++++++ go.mod | 2 +- parser/parser.go | 109 ++++- server/create.go | 47 +++ server/images.go | 7 + server/manifest.go | 23 +- server/mcp.go | 315 +++++++++++++++ server/routes.go | 12 +- types/model/config.go | 22 +- types/model/name.go | 1 + 14 files changed, 2510 insertions(+), 81 deletions(-) create mode 100644 cmd/mcp.go create mode 100644 cmd/mcp_cmd.go create mode 100644 server/mcp.go diff --git a/api/types.go b/api/types.go index cf8177e54..b5aa4b67c 100644 --- a/api/types.go +++ b/api/types.go @@ -20,6 +20,9 @@ import ( // SkillRef is an alias for model.SkillRef representing a skill reference. type SkillRef = model.SkillRef +// MCPRef is an alias for model.MCPRef representing an MCP server reference. +type MCPRef = model.MCPRef + // StatusError is an error with an HTTP status code and message. type StatusError struct { StatusCode int @@ -563,9 +566,15 @@ type CreateRequest struct { // Skills is a list of skill references for the agent (local paths or registry refs) Skills []SkillRef `json:"skills,omitempty"` + // MCPs is a list of MCP server references for the agent + MCPs []MCPRef `json:"mcps,omitempty"` + // AgentType defines the type of agent (e.g., "conversational", "task-based") AgentType string `json:"agent_type,omitempty"` + // Entrypoint specifies an external command to run instead of the built-in chat loop + Entrypoint string `json:"entrypoint,omitempty"` + // Info is a map of additional information for the model Info map[string]any `json:"info,omitempty"` @@ -618,7 +627,9 @@ type ShowResponse struct { ModifiedAt time.Time `json:"modified_at,omitempty"` Requires string `json:"requires,omitempty"` Skills []SkillRef `json:"skills,omitempty"` + MCPs []MCPRef `json:"mcps,omitempty"` AgentType string `json:"agent_type,omitempty"` + Entrypoint string `json:"entrypoint,omitempty"` } // CopyRequest is the request passed to [Client.Copy]. diff --git a/cmd/cmd.go b/cmd/cmd.go index 4b1d2cf48..95461e69b 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -15,6 +15,7 @@ import ( "net" "net/http" "os" + "os/exec" "os/signal" "path/filepath" "runtime" @@ -495,11 +496,13 @@ func RunHandler(cmd *cobra.Command, args []string) error { opts.ParentModel = info.Details.ParentModel // Check if this is an agent - isAgent := info.AgentType != "" || len(info.Skills) > 0 + isAgent := info.AgentType != "" || len(info.Skills) > 0 || len(info.MCPs) > 0 || info.Entrypoint != "" if isAgent { opts.IsAgent = true opts.AgentType = info.AgentType opts.Skills = info.Skills + opts.MCPs = info.MCPs + opts.Entrypoint = info.Entrypoint } // Check if this is an embedding model @@ -525,6 +528,11 @@ func RunHandler(cmd *cobra.Command, args []string) error { return generateEmbedding(cmd, name, opts.Prompt, opts.KeepAlive, truncate, dimensions) } + // If agent has entrypoint, run it instead of chat loop + if opts.Entrypoint != "" { + return runEntrypoint(cmd, opts) + } + if interactive { if err := loadOrUnloadModel(cmd, &opts); err != nil { var sErr api.AuthorizationError @@ -564,6 +572,51 @@ func RunHandler(cmd *cobra.Command, args []string) error { return generate(cmd, opts) } +// runEntrypoint executes the agent's entrypoint command instead of the built-in chat loop. +func runEntrypoint(cmd *cobra.Command, opts runOptions) error { + entrypoint := opts.Entrypoint + + // Check if entrypoint contains $PROMPT placeholder + hasPlaceholder := strings.Contains(entrypoint, "$PROMPT") + + if hasPlaceholder && opts.Prompt != "" { + // Replace $PROMPT with the actual prompt + entrypoint = strings.ReplaceAll(entrypoint, "$PROMPT", opts.Prompt) + } else if hasPlaceholder { + // No prompt provided but placeholder exists - remove placeholder + entrypoint = strings.ReplaceAll(entrypoint, "$PROMPT", "") + } + + // Parse entrypoint into command and args + parts := strings.Fields(entrypoint) + if len(parts) == 0 { + return fmt.Errorf("empty entrypoint") + } + + command := parts[0] + args := parts[1:] + + // If user provided a prompt and no placeholder was used, append it as argument + if opts.Prompt != "" && !hasPlaceholder { + args = append(args, opts.Prompt) + } + + // Look up command in PATH + execPath, err := exec.LookPath(command) + if err != nil { + return fmt.Errorf("entrypoint command not found: %s", command) + } + + // Create subprocess + proc := exec.Command(execPath, args...) + proc.Stdin = os.Stdin + proc.Stdout = os.Stdout + proc.Stderr = os.Stderr + + // Run and wait + return proc.Run() +} + func SigninHandler(cmd *cobra.Command, args []string) error { client, err := api.ClientFromEnvironment() if err != nil { @@ -923,47 +976,96 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error { fmt.Fprintln(w) } - tableRender("Model", func() (rows [][]string) { - if resp.RemoteHost != "" { - rows = append(rows, []string{"", "Remote model", resp.RemoteModel}) - rows = append(rows, []string{"", "Remote URL", resp.RemoteHost}) - } - - if resp.ModelInfo != nil { - arch := resp.ModelInfo["general.architecture"].(string) - rows = append(rows, []string{"", "architecture", arch}) - - var paramStr string - if resp.Details.ParameterSize != "" { - paramStr = resp.Details.ParameterSize - } else if v, ok := resp.ModelInfo["general.parameter_count"]; ok { - if f, ok := v.(float64); ok { - paramStr = format.HumanNumber(uint64(f)) - } - } - rows = append(rows, []string{"", "parameters", paramStr}) - - if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok { - if f, ok := v.(float64); ok { - rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)}) - } + // Only show Model section if there's actual model info (not for entrypoint-only agents) + hasModelInfo := resp.RemoteHost != "" || resp.ModelInfo != nil || resp.Details.Family != "" || resp.Details.ParameterSize != "" || resp.Details.QuantizationLevel != "" + if hasModelInfo { + tableRender("Model", func() (rows [][]string) { + if resp.RemoteHost != "" { + rows = append(rows, []string{"", "Remote model", resp.RemoteModel}) + rows = append(rows, []string{"", "Remote URL", resp.RemoteHost}) } - if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok { - if f, ok := v.(float64); ok { - rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)}) + if resp.ModelInfo != nil { + arch := resp.ModelInfo["general.architecture"].(string) + rows = append(rows, []string{"", "architecture", arch}) + + var paramStr string + if resp.Details.ParameterSize != "" { + paramStr = resp.Details.ParameterSize + } else if v, ok := resp.ModelInfo["general.parameter_count"]; ok { + if f, ok := v.(float64); ok { + paramStr = format.HumanNumber(uint64(f)) + } + } + rows = append(rows, []string{"", "parameters", paramStr}) + + if v, ok := resp.ModelInfo[fmt.Sprintf("%s.context_length", arch)]; ok { + if f, ok := v.(float64); ok { + rows = append(rows, []string{"", "context length", strconv.FormatFloat(f, 'f', -1, 64)}) + } + } + + if v, ok := resp.ModelInfo[fmt.Sprintf("%s.embedding_length", arch)]; ok { + if f, ok := v.(float64); ok { + rows = append(rows, []string{"", "embedding length", strconv.FormatFloat(f, 'f', -1, 64)}) + } + } + } else { + rows = append(rows, []string{"", "architecture", resp.Details.Family}) + rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize}) + } + rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel}) + if resp.Requires != "" { + rows = append(rows, []string{"", "requires", resp.Requires}) + } + return + }) + } + + // Display agent information if this is an agent + if resp.AgentType != "" || len(resp.Skills) > 0 || len(resp.MCPs) > 0 || resp.Entrypoint != "" { + tableRender("Agent", func() (rows [][]string) { + if resp.AgentType != "" { + rows = append(rows, []string{"", "type", resp.AgentType}) + } + if resp.Entrypoint != "" { + rows = append(rows, []string{"", "entrypoint", resp.Entrypoint}) + } + if len(resp.Skills) > 0 { + for i, skill := range resp.Skills { + label := "skill" + if i > 0 { + label = "" + } + // Show skill name or digest + skillDisplay := skill.Name + if skillDisplay == "" && skill.Digest != "" { + skillDisplay = skill.Digest[:12] + "..." + } + rows = append(rows, []string{"", label, skillDisplay}) } } - } else { - rows = append(rows, []string{"", "architecture", resp.Details.Family}) - rows = append(rows, []string{"", "parameters", resp.Details.ParameterSize}) - } - rows = append(rows, []string{"", "quantization", resp.Details.QuantizationLevel}) - if resp.Requires != "" { - rows = append(rows, []string{"", "requires", resp.Requires}) - } - return - }) + if len(resp.MCPs) > 0 { + for i, mcp := range resp.MCPs { + label := "mcp" + if i > 0 { + label = "" + } + // Show MCP name and command + mcpDisplay := mcp.Name + if mcp.Command != "" { + cmdLine := mcp.Command + if len(mcp.Args) > 0 { + cmdLine += " " + strings.Join(mcp.Args, " ") + } + mcpDisplay += " (" + cmdLine + ")" + } + rows = append(rows, []string{"", label, mcpDisplay}) + } + } + return + }) + } if len(resp.Capabilities) > 0 { tableRender("Capabilities", func() (rows [][]string) { @@ -1208,6 +1310,8 @@ type runOptions struct { IsAgent bool AgentType string Skills []api.SkillRef + MCPs []api.MCPRef + Entrypoint string } func (r runOptions) Copy() runOptions { @@ -1360,6 +1464,49 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { } } + // Load MCP servers for agents (from opts and global config) + var mcpMgr *mcpManager + allMCPs := opts.MCPs + + // Load global MCPs from ~/.ollama/mcp.json + if globalConfig, err := loadMCPConfig(); err == nil && len(globalConfig.MCPServers) > 0 { + for name, srv := range globalConfig.MCPServers { + // Skip disabled MCPs + if srv.Disabled { + continue + } + // Check if already in opts.MCPs (model takes precedence) + found := false + for _, m := range opts.MCPs { + if m.Name == name { + found = true + break + } + } + if !found { + allMCPs = append(allMCPs, api.MCPRef{ + Name: name, + Command: srv.Command, + Args: srv.Args, + Env: srv.Env, + Type: srv.Type, + }) + } + } + } + + if len(allMCPs) > 0 { + mcpMgr = newMCPManager() + if err := mcpMgr.loadMCPsFromRefs(allMCPs); err != nil { + return nil, fmt.Errorf("failed to load MCP servers: %w", err) + } + if mcpMgr.ToolCount() > 0 { + fmt.Fprintf(os.Stderr, "Loaded MCP servers: %s (%d tools)\n", + strings.Join(mcpMgr.ServerNames(), ", "), mcpMgr.ToolCount()) + } + defer mcpMgr.Shutdown() + } + p := progress.NewProgress(os.Stderr) defer p.StopAndClear() @@ -1424,11 +1571,11 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { if response.Message.ToolCalls != nil { toolCalls := response.Message.ToolCalls if len(toolCalls) > 0 { - if skillsCatalog != nil { + if skillsCatalog != nil || mcpMgr != nil { // Store tool calls for execution after response is complete pendingToolCalls = append(pendingToolCalls, toolCalls...) } else { - // No skills catalog, just display tool calls + // No skills catalog or MCP, just display tool calls fmt.Print(renderToolCalls(toolCalls, false)) } } @@ -1461,43 +1608,65 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { } } - req := &api.ChatRequest{ - Model: opts.Model, - Messages: messages, - Format: json.RawMessage(opts.Format), - Options: opts.Options, - Think: opts.Think, - } - - // Add tools for agents - if skillsCatalog != nil { - req.Tools = skillsCatalog.Tools() - } - - if opts.KeepAlive != nil { - req.KeepAlive = opts.KeepAlive - } - - if err := client.Chat(cancelCtx, req, fn); err != nil { - if errors.Is(err, context.Canceled) { - return nil, nil + // Agentic loop: continue until no more tool calls + for { + req := &api.ChatRequest{ + Model: opts.Model, + Messages: messages, + Format: json.RawMessage(opts.Format), + Options: opts.Options, + Think: opts.Think, } - // this error should ideally be wrapped properly by the client - if strings.Contains(err.Error(), "upstream error") { - p.StopAndClear() - fmt.Println("An error occurred while processing your message. Please try again.") - fmt.Println() - return nil, nil + // Add tools for agents (combine skills and MCP tools) + var allTools api.Tools + if skillsCatalog != nil { + allTools = append(allTools, skillsCatalog.Tools()...) + } + if mcpMgr != nil { + allTools = append(allTools, mcpMgr.Tools()...) + } + if len(allTools) > 0 { + req.Tools = allTools } - return nil, err - } - // Execute tool calls for agents - if len(pendingToolCalls) > 0 && skillsCatalog != nil { + if opts.KeepAlive != nil { + req.KeepAlive = opts.KeepAlive + } + + if err := client.Chat(cancelCtx, req, fn); err != nil { + if errors.Is(err, context.Canceled) { + return nil, nil + } + + // this error should ideally be wrapped properly by the client + if strings.Contains(err.Error(), "upstream error") { + p.StopAndClear() + fmt.Println("An error occurred while processing your message. Please try again.") + fmt.Println() + return nil, nil + } + return nil, err + } + + // If no tool calls, we're done + if len(pendingToolCalls) == 0 || (skillsCatalog == nil && mcpMgr == nil) { + break + } + + // Execute tool calls and continue the conversation fmt.Fprintf(os.Stderr, "\n") + // Add assistant's tool call message to history + assistantMsg := api.Message{ + Role: "assistant", + Content: fullResponse.String(), + ToolCalls: pendingToolCalls, + } + messages = append(messages, assistantMsg) + // Execute each tool call and collect results + var toolResults []api.Message for _, call := range pendingToolCalls { // Show what's being executed switch call.Function.Name { @@ -1513,13 +1682,35 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { fmt.Fprintf(os.Stderr, "Executing: %s\n", call.Function.Name) } - result, handled, err := skillsCatalog.RunToolCall(call) + var result api.Message + var handled bool + var err error + + // Try skill catalog first + if skillsCatalog != nil { + result, handled, err = skillsCatalog.RunToolCall(call) + } + + // If not handled by skills, try MCP + if !handled && mcpMgr != nil { + result, handled, err = mcpMgr.RunToolCall(call) + } + if err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) + // Add error result + toolResults = append(toolResults, api.Message{ + Role: "tool", + Content: fmt.Sprintf("Error: %v", err), + }) continue } if !handled { fmt.Fprintf(os.Stderr, "Warning: Unknown tool %s\n", call.Function.Name) + toolResults = append(toolResults, api.Message{ + Role: "tool", + Content: fmt.Sprintf("Unknown tool: %s", call.Function.Name), + }) continue } @@ -1527,9 +1718,31 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) { if result.Content != "" { fmt.Fprintf(os.Stderr, "Output:\n%s\n", result.Content) } + + // Add tool result to messages + toolResults = append(toolResults, api.Message{ + Role: "tool", + Content: result.Content, + }) } + // Add tool results to message history + messages = append(messages, toolResults...) + fmt.Fprintf(os.Stderr, "\n") + + // Reset state for next iteration + fullResponse.Reset() + thinkingContent.Reset() + thinkTagOpened = false + thinkTagClosed = false + pendingToolCalls = nil + state = &displayResponseState{} + + // Start new progress spinner for next API call + p = progress.NewProgress(os.Stderr) + spinner = progress.NewSpinner("") + p.Add("", spinner) } if len(opts.Messages) > 0 { @@ -2022,6 +2235,7 @@ func NewCLI() *cobra.Command { deleteCmd, runnerCmd, NewSkillCommand(), + NewMCPCommand(), ) return rootCmd diff --git a/cmd/interactive.go b/cmd/interactive.go index c82adeff7..7b774c8a5 100644 --- a/cmd/interactive.go +++ b/cmd/interactive.go @@ -36,6 +36,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { fmt.Fprintln(os.Stderr, " /show Show model information") fmt.Fprintln(os.Stderr, " /skills Show available skills") fmt.Fprintln(os.Stderr, " /skill Add or remove skills dynamically") + fmt.Fprintln(os.Stderr, " /mcp Show/add/remove MCP servers") fmt.Fprintln(os.Stderr, " /load Load a session or model") fmt.Fprintln(os.Stderr, " /save Save your current session") fmt.Fprintln(os.Stderr, " /clear Clear session context") @@ -616,6 +617,240 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { } fmt.Println() continue + + case strings.HasPrefix(line, "/mcp"): + args := strings.Fields(line) + + // If just "/mcp" with no args, show all MCP servers + if len(args) == 1 { + // Show MCPs from model (bundled) + global config + client, err := api.ClientFromEnvironment() + if err != nil { + fmt.Println("error: couldn't connect to ollama server") + return err + } + req := &api.ShowRequest{ + Name: opts.Model, + } + resp, err := client.Show(cmd.Context(), req) + if err != nil { + fmt.Println("error: couldn't get model info") + return err + } + + // Combine model MCPs with global config MCPs + allMCPs := make([]api.MCPRef, 0) + allMCPs = append(allMCPs, resp.MCPs...) + + // Load global config + globalConfig, _ := loadMCPConfig() + globalMCPNames := make(map[string]bool) + + if globalConfig != nil { + for name, srv := range globalConfig.MCPServers { + // Check if already in model MCPs + found := false + for _, modelMCP := range resp.MCPs { + if modelMCP.Name == name { + found = true + break + } + } + if !found { + allMCPs = append(allMCPs, api.MCPRef{ + Name: name, + Command: srv.Command, + Args: srv.Args, + Env: srv.Env, + Type: srv.Type, + }) + } + globalMCPNames[name] = true + } + } + + if len(allMCPs) == 0 { + fmt.Println("No MCP servers available.") + fmt.Println("Use '/mcp add [args...]' to add one.") + } else { + fmt.Println("Available MCP Servers:") + for _, mcp := range allMCPs { + cmdLine := mcp.Command + if len(mcp.Args) > 0 { + cmdLine += " " + strings.Join(mcp.Args, " ") + } + source := "" + disabled := "" + // Check if it's from model or global config + isFromModel := false + for _, modelMCP := range resp.MCPs { + if modelMCP.Name == mcp.Name { + isFromModel = true + break + } + } + if isFromModel { + source = " (model)" + } else if globalMCPNames[mcp.Name] { + source = " (global)" + // Check if disabled + if srv, ok := globalConfig.MCPServers[mcp.Name]; ok && srv.Disabled { + disabled = " [disabled]" + } + } + fmt.Printf(" %s: %s%s%s\n", mcp.Name, cmdLine, source, disabled) + } + } + fmt.Println() + continue + } + + switch args[1] { + case "add": + if len(args) < 4 { + fmt.Println("Usage: /mcp add [args...]") + continue + } + mcpName := args[2] + mcpCommand := args[3] + mcpArgs := args[4:] + + // Load global config + config, err := loadMCPConfig() + if err != nil { + fmt.Printf("Error loading MCP config: %v\n", err) + continue + } + + // Check if already exists + if _, exists := config.MCPServers[mcpName]; exists { + fmt.Printf("Warning: overwriting existing MCP server '%s'\n", mcpName) + } + + // Add to global config + config.MCPServers[mcpName] = MCPServerConfig{ + Type: "stdio", + Command: mcpCommand, + Args: mcpArgs, + } + + // Save config + if err := saveMCPConfig(config); err != nil { + fmt.Printf("Error saving MCP config: %v\n", err) + continue + } + + cmdLine := mcpCommand + if len(mcpArgs) > 0 { + cmdLine += " " + strings.Join(mcpArgs, " ") + } + fmt.Printf("Added MCP server '%s' (%s) to %s\n", mcpName, cmdLine, getMCPConfigPath()) + fmt.Println("Note: MCP server will be started on next message.") + + case "remove", "rm": + if len(args) < 3 { + fmt.Println("Usage: /mcp remove ") + continue + } + mcpName := args[2] + + // Load global config + config, err := loadMCPConfig() + if err != nil { + fmt.Printf("Error loading MCP config: %v\n", err) + continue + } + + if _, exists := config.MCPServers[mcpName]; !exists { + fmt.Printf("MCP server '%s' not found in global config\n", mcpName) + continue + } + + delete(config.MCPServers, mcpName) + + if err := saveMCPConfig(config); err != nil { + fmt.Printf("Error saving MCP config: %v\n", err) + continue + } + + fmt.Printf("Removed MCP server '%s' from %s\n", mcpName, getMCPConfigPath()) + fmt.Println("Note: Changes will take effect on next message.") + + case "disable": + if len(args) < 3 { + fmt.Println("Usage: /mcp disable ") + continue + } + mcpName := args[2] + + config, err := loadMCPConfig() + if err != nil { + fmt.Printf("Error loading MCP config: %v\n", err) + continue + } + + srv, exists := config.MCPServers[mcpName] + if !exists { + fmt.Printf("MCP server '%s' not found in global config\n", mcpName) + continue + } + + if srv.Disabled { + fmt.Printf("MCP server '%s' is already disabled\n", mcpName) + continue + } + + srv.Disabled = true + config.MCPServers[mcpName] = srv + + if err := saveMCPConfig(config); err != nil { + fmt.Printf("Error saving MCP config: %v\n", err) + continue + } + + fmt.Printf("Disabled MCP server '%s'\n", mcpName) + fmt.Println("Note: Changes will take effect on next message.") + + case "enable": + if len(args) < 3 { + fmt.Println("Usage: /mcp enable ") + continue + } + mcpName := args[2] + + config, err := loadMCPConfig() + if err != nil { + fmt.Printf("Error loading MCP config: %v\n", err) + continue + } + + srv, exists := config.MCPServers[mcpName] + if !exists { + fmt.Printf("MCP server '%s' not found in global config\n", mcpName) + continue + } + + if !srv.Disabled { + fmt.Printf("MCP server '%s' is already enabled\n", mcpName) + continue + } + + srv.Disabled = false + config.MCPServers[mcpName] = srv + + if err := saveMCPConfig(config); err != nil { + fmt.Printf("Error saving MCP config: %v\n", err) + continue + } + + fmt.Printf("Enabled MCP server '%s'\n", mcpName) + fmt.Println("Note: Changes will take effect on next message.") + + default: + fmt.Printf("Unknown mcp command '%s'. Use /mcp, /mcp add, /mcp remove, /mcp disable, or /mcp enable\n", args[1]) + } + continue + case strings.HasPrefix(line, "/help"), strings.HasPrefix(line, "/?"): args := strings.Fields(line) if len(args) > 1 { @@ -630,6 +865,14 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error { fmt.Fprintln(os.Stderr, " /skill remove Remove a skill by name") fmt.Fprintln(os.Stderr, " /skill list List current session skills") fmt.Fprintln(os.Stderr, "") + case "mcp", "/mcp": + fmt.Fprintln(os.Stderr, "Available Commands:") + fmt.Fprintln(os.Stderr, " /mcp Show all MCP servers") + fmt.Fprintln(os.Stderr, " /mcp add [args...] Add an MCP server to global config") + fmt.Fprintln(os.Stderr, " /mcp remove Remove an MCP server from global config") + fmt.Fprintln(os.Stderr, " /mcp disable Disable an MCP server (keep in config)") + fmt.Fprintln(os.Stderr, " /mcp enable Re-enable a disabled MCP server") + fmt.Fprintln(os.Stderr, "") case "shortcut", "shortcuts": usageShortcuts() } diff --git a/cmd/mcp.go b/cmd/mcp.go new file mode 100644 index 000000000..f96cd2ee9 --- /dev/null +++ b/cmd/mcp.go @@ -0,0 +1,545 @@ +package cmd + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "os" + "os/exec" + "strings" + "sync" + "time" + + "github.com/ollama/ollama/api" +) + +const ( + mcpInitTimeout = 30 * time.Second + mcpCallTimeout = 60 * time.Second + mcpShutdownTimeout = 5 * time.Second +) + +// JSON-RPC types +type jsonrpcRequest struct { + JSONRPC string `json:"jsonrpc"` + ID int `json:"id,omitempty"` + Method string `json:"method"` + Params any `json:"params,omitempty"` +} + +type jsonrpcResponse struct { + JSONRPC string `json:"jsonrpc"` + ID int `json:"id"` + Result json.RawMessage `json:"result,omitempty"` + Error *jsonrpcError `json:"error,omitempty"` +} + +type jsonrpcError struct { + Code int `json:"code"` + Message string `json:"message"` + Data any `json:"data,omitempty"` +} + +// MCP protocol types +type mcpInitializeParams struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities map[string]any `json:"capabilities"` + ClientInfo mcpClientInfo `json:"clientInfo"` +} + +type mcpClientInfo struct { + Name string `json:"name"` + Version string `json:"version"` +} + +type mcpInitializeResult struct { + ProtocolVersion string `json:"protocolVersion"` + Capabilities mcpCapabilities `json:"capabilities"` + ServerInfo mcpServerInfo `json:"serverInfo"` +} + +type mcpCapabilities struct { + Tools *mcpToolsCapability `json:"tools,omitempty"` +} + +type mcpToolsCapability struct { + ListChanged bool `json:"listChanged,omitempty"` +} + +type mcpServerInfo struct { + Name string `json:"name"` + Version string `json:"version"` +} + +type mcpTool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema mcpToolInputSchema `json:"inputSchema"` +} + +type mcpToolInputSchema struct { + Type string `json:"type"` + Properties map[string]any `json:"properties,omitempty"` + Required []string `json:"required,omitempty"` +} + +type mcpToolsListResult struct { + Tools []mcpTool `json:"tools"` +} + +type mcpToolCallParams struct { + Name string `json:"name"` + Arguments map[string]any `json:"arguments,omitempty"` +} + +type mcpToolCallResult struct { + Content []mcpContent `json:"content"` + IsError bool `json:"isError,omitempty"` +} + +type mcpContent struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` +} + +// mcpServer represents a running MCP server process +type mcpServer struct { + ref api.MCPRef + cmd *exec.Cmd + stdin io.WriteCloser + stdout *bufio.Reader + stderr io.ReadCloser + tools []mcpTool + mu sync.Mutex + nextID int + started bool +} + +// mcpManager manages multiple MCP servers for an agent session +type mcpManager struct { + servers map[string]*mcpServer + mu sync.RWMutex +} + +// newMCPManager creates a new MCP manager +func newMCPManager() *mcpManager { + return &mcpManager{ + servers: make(map[string]*mcpServer), + } +} + +// loadMCPsFromRefs initializes MCP servers from refs +func (m *mcpManager) loadMCPsFromRefs(refs []api.MCPRef) error { + if len(refs) == 0 { + return nil + } + + for _, ref := range refs { + if err := m.addServer(ref); err != nil { + fmt.Fprintf(os.Stderr, "Warning: failed to initialize MCP server %q: %v\n", ref.Name, err) + } + } + + return nil +} + +// addServer adds and starts an MCP server +func (m *mcpManager) addServer(ref api.MCPRef) error { + m.mu.Lock() + defer m.mu.Unlock() + + if _, exists := m.servers[ref.Name]; exists { + return fmt.Errorf("MCP server %q already exists", ref.Name) + } + + srv := &mcpServer{ + ref: ref, + nextID: 1, + } + + if err := srv.start(); err != nil { + return fmt.Errorf("starting MCP server: %w", err) + } + + m.servers[ref.Name] = srv + return nil +} + +// start starts the MCP server process +func (s *mcpServer) start() error { + s.mu.Lock() + + if s.started { + s.mu.Unlock() + return nil + } + + s.cmd = exec.Command(s.ref.Command, s.ref.Args...) + + // Set environment + s.cmd.Env = os.Environ() + for k, v := range s.ref.Env { + s.cmd.Env = append(s.cmd.Env, fmt.Sprintf("%s=%s", k, v)) + } + + var err error + s.stdin, err = s.cmd.StdinPipe() + if err != nil { + s.mu.Unlock() + return fmt.Errorf("creating stdin pipe: %w", err) + } + + stdout, err := s.cmd.StdoutPipe() + if err != nil { + s.mu.Unlock() + return fmt.Errorf("creating stdout pipe: %w", err) + } + s.stdout = bufio.NewReader(stdout) + + s.stderr, err = s.cmd.StderrPipe() + if err != nil { + s.mu.Unlock() + return fmt.Errorf("creating stderr pipe: %w", err) + } + + // Start stderr reader goroutine (discard stderr for now) + go func() { + scanner := bufio.NewScanner(s.stderr) + for scanner.Scan() { + _ = scanner.Text() + } + }() + + if err := s.cmd.Start(); err != nil { + s.mu.Unlock() + return fmt.Errorf("starting process: %w", err) + } + + s.started = true + s.mu.Unlock() // Release lock before calling initialize/listTools which use the mutex + + // Initialize the server + if err := s.initialize(); err != nil { + s.stop() + return fmt.Errorf("initializing MCP server: %w", err) + } + + // Get available tools + if err := s.listTools(); err != nil { + s.stop() + return fmt.Errorf("listing tools: %w", err) + } + + return nil +} + +// initialize sends the MCP initialize request +func (s *mcpServer) initialize() error { + ctx, cancel := context.WithTimeout(context.Background(), mcpInitTimeout) + defer cancel() + + params := mcpInitializeParams{ + ProtocolVersion: "2024-11-05", + Capabilities: map[string]any{}, + ClientInfo: mcpClientInfo{ + Name: "ollama", + Version: "0.1.0", + }, + } + + var result mcpInitializeResult + if err := s.call(ctx, "initialize", params, &result); err != nil { + return err + } + + // Send initialized notification + return s.notify("notifications/initialized", nil) +} + +// listTools fetches the available tools from the MCP server +func (s *mcpServer) listTools() error { + ctx, cancel := context.WithTimeout(context.Background(), mcpInitTimeout) + defer cancel() + + var result mcpToolsListResult + if err := s.call(ctx, "tools/list", nil, &result); err != nil { + return err + } + + s.tools = result.Tools + return nil +} + +// call sends a JSON-RPC request and waits for the response +func (s *mcpServer) call(ctx context.Context, method string, params any, result any) error { + s.mu.Lock() + id := s.nextID + s.nextID++ + s.mu.Unlock() + + req := jsonrpcRequest{ + JSONRPC: "2.0", + ID: id, + Method: method, + Params: params, + } + + reqBytes, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("marshaling request: %w", err) + } + + // Send request + s.mu.Lock() + _, err = s.stdin.Write(append(reqBytes, '\n')) + s.mu.Unlock() + if err != nil { + return fmt.Errorf("writing request: %w", err) + } + + // Read response with timeout + respCh := make(chan []byte, 1) + errCh := make(chan error, 1) + + go func() { + s.mu.Lock() + line, err := s.stdout.ReadBytes('\n') + s.mu.Unlock() + if err != nil { + errCh <- err + return + } + respCh <- line + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: + return fmt.Errorf("reading response: %w", err) + case line := <-respCh: + var resp jsonrpcResponse + if err := json.Unmarshal(line, &resp); err != nil { + return fmt.Errorf("unmarshaling response: %w", err) + } + + if resp.Error != nil { + return fmt.Errorf("MCP error %d: %s", resp.Error.Code, resp.Error.Message) + } + + if result != nil && len(resp.Result) > 0 { + if err := json.Unmarshal(resp.Result, result); err != nil { + return fmt.Errorf("unmarshaling result: %w", err) + } + } + + return nil + } +} + +// notify sends a JSON-RPC notification (no response expected) +func (s *mcpServer) notify(method string, params any) error { + req := jsonrpcRequest{ + JSONRPC: "2.0", + Method: method, + Params: params, + } + + reqBytes, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("marshaling notification: %w", err) + } + + s.mu.Lock() + defer s.mu.Unlock() + if _, err := s.stdin.Write(append(reqBytes, '\n')); err != nil { + return fmt.Errorf("writing notification: %w", err) + } + + return nil +} + +// callTool executes a tool call on the MCP server +func (s *mcpServer) callTool(ctx context.Context, name string, arguments map[string]any) (string, error) { + params := mcpToolCallParams{ + Name: name, + Arguments: arguments, + } + + var result mcpToolCallResult + if err := s.call(ctx, "tools/call", params, &result); err != nil { + return "", err + } + + // Concatenate text content + var sb strings.Builder + for _, content := range result.Content { + if content.Type == "text" { + sb.WriteString(content.Text) + } + } + + if result.IsError { + return sb.String(), errors.New(sb.String()) + } + + return sb.String(), nil +} + +// stop shuts down the MCP server +func (s *mcpServer) stop() error { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.started { + return nil + } + + // Close stdin to signal shutdown + if s.stdin != nil { + s.stdin.Close() + } + + // Wait for process with timeout + done := make(chan error, 1) + go func() { + done <- s.cmd.Wait() + }() + + select { + case <-time.After(mcpShutdownTimeout): + s.cmd.Process.Kill() + case <-done: + } + + s.started = false + return nil +} + +// Tools returns all tools from all MCP servers as api.Tools +func (m *mcpManager) Tools() api.Tools { + m.mu.RLock() + defer m.mu.RUnlock() + + var tools api.Tools + + for serverName, srv := range m.servers { + for _, t := range srv.tools { + // Namespace tool names: mcp_{servername}_{toolname} + namespacedName := fmt.Sprintf("mcp_%s_%s", serverName, t.Name) + + tool := api.Tool{ + Type: "function", + Function: api.ToolFunction{ + Name: namespacedName, + Description: t.Description, + Parameters: convertMCPSchema(t.InputSchema), + }, + } + tools = append(tools, tool) + } + } + + return tools +} + +// convertMCPSchema converts MCP input schema to api.ToolFunctionParameters +func convertMCPSchema(schema mcpToolInputSchema) api.ToolFunctionParameters { + params := api.ToolFunctionParameters{ + Type: schema.Type, + Required: schema.Required, + Properties: make(map[string]api.ToolProperty), + } + + for name, prop := range schema.Properties { + if propMap, ok := prop.(map[string]any); ok { + tp := api.ToolProperty{} + if t, ok := propMap["type"].(string); ok { + tp.Type = api.PropertyType{t} + } + if d, ok := propMap["description"].(string); ok { + tp.Description = d + } + params.Properties[name] = tp + } + } + + return params +} + +// RunToolCall routes a tool call to the appropriate MCP server +func (m *mcpManager) RunToolCall(call api.ToolCall) (api.Message, bool, error) { + name := call.Function.Name + + // Check if this is an MCP tool (mcp_servername_toolname) + if !strings.HasPrefix(name, "mcp_") { + return api.Message{}, false, nil + } + + // Parse server name and tool name + rest := strings.TrimPrefix(name, "mcp_") + idx := strings.Index(rest, "_") + if idx == -1 { + return toolMessage(call, fmt.Sprintf("invalid MCP tool name: %s", name)), true, nil + } + + serverName := rest[:idx] + toolName := rest[idx+1:] + + m.mu.RLock() + srv, ok := m.servers[serverName] + m.mu.RUnlock() + + if !ok { + return toolMessage(call, fmt.Sprintf("MCP server %q not found", serverName)), true, nil + } + + ctx, cancel := context.WithTimeout(context.Background(), mcpCallTimeout) + defer cancel() + + result, err := srv.callTool(ctx, toolName, call.Function.Arguments) + if err != nil { + return toolMessage(call, fmt.Sprintf("error: %v", err)), true, nil + } + + return toolMessage(call, result), true, nil +} + +// Shutdown stops all MCP servers +func (m *mcpManager) Shutdown() { + m.mu.Lock() + defer m.mu.Unlock() + + for _, srv := range m.servers { + srv.stop() + } + + m.servers = make(map[string]*mcpServer) +} + +// ServerNames returns the names of all running MCP servers +func (m *mcpManager) ServerNames() []string { + m.mu.RLock() + defer m.mu.RUnlock() + + names := make([]string, 0, len(m.servers)) + for name := range m.servers { + names = append(names, name) + } + return names +} + +// ToolCount returns the total number of tools across all servers +func (m *mcpManager) ToolCount() int { + m.mu.RLock() + defer m.mu.RUnlock() + + count := 0 + for _, srv := range m.servers { + count += len(srv.tools) + } + return count +} diff --git a/cmd/mcp_cmd.go b/cmd/mcp_cmd.go new file mode 100644 index 000000000..8890af3a9 --- /dev/null +++ b/cmd/mcp_cmd.go @@ -0,0 +1,898 @@ +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + "text/tabwriter" + "time" + + "github.com/spf13/cobra" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/format" + "github.com/ollama/ollama/progress" + "github.com/ollama/ollama/server" + "github.com/ollama/ollama/types/model" +) + +// MCPConfigFile represents the global MCP configuration file structure. +type MCPConfigFile struct { + MCPServers map[string]MCPServerConfig `json:"mcpServers"` +} + +// MCPServerConfig represents a single MCP server configuration. +type MCPServerConfig struct { + Type string `json:"type,omitempty"` + Command string `json:"command"` + Args []string `json:"args,omitempty"` + Env map[string]string `json:"env,omitempty"` + Disabled bool `json:"disabled,omitempty"` +} + +// getMCPConfigPath returns the path to the global MCP config file. +func getMCPConfigPath() string { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + return filepath.Join(home, ".ollama", "mcp.json") +} + +// loadMCPConfig loads the global MCP configuration file. +func loadMCPConfig() (*MCPConfigFile, error) { + configPath := getMCPConfigPath() + if configPath == "" { + return nil, fmt.Errorf("could not determine home directory") + } + + data, err := os.ReadFile(configPath) + if err != nil { + if os.IsNotExist(err) { + // Return empty config if file doesn't exist + return &MCPConfigFile{ + MCPServers: make(map[string]MCPServerConfig), + }, nil + } + return nil, fmt.Errorf("reading config: %w", err) + } + + var config MCPConfigFile + if err := json.Unmarshal(data, &config); err != nil { + return nil, fmt.Errorf("parsing config: %w", err) + } + + if config.MCPServers == nil { + config.MCPServers = make(map[string]MCPServerConfig) + } + + return &config, nil +} + +// saveMCPConfig saves the global MCP configuration file. +func saveMCPConfig(config *MCPConfigFile) error { + configPath := getMCPConfigPath() + if configPath == "" { + return fmt.Errorf("could not determine home directory") + } + + // Ensure directory exists + if err := os.MkdirAll(filepath.Dir(configPath), 0o755); err != nil { + return fmt.Errorf("creating config directory: %w", err) + } + + data, err := json.MarshalIndent(config, "", " ") + if err != nil { + return fmt.Errorf("marshaling config: %w", err) + } + + if err := os.WriteFile(configPath, data, 0o644); err != nil { + return fmt.Errorf("writing config: %w", err) + } + + return nil +} + +// MCPAddHandler handles the mcp add command. +func MCPAddHandler(cmd *cobra.Command, args []string) error { + if len(args) < 2 { + return fmt.Errorf("usage: ollama mcp add NAME COMMAND [ARGS...]") + } + + name := args[0] + command := args[1] + cmdArgs := args[2:] + + // Load existing config + config, err := loadMCPConfig() + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + + // Check if already exists + if _, exists := config.MCPServers[name]; exists { + fmt.Fprintf(os.Stderr, "Warning: overwriting existing MCP server '%s'\n", name) + } + + // Add the new server + config.MCPServers[name] = MCPServerConfig{ + Type: "stdio", + Command: command, + Args: cmdArgs, + } + + // Save config + if err := saveMCPConfig(config); err != nil { + return fmt.Errorf("saving config: %w", err) + } + + configPath := getMCPConfigPath() + fmt.Fprintf(os.Stderr, "Added MCP server '%s' to %s\n", name, configPath) + fmt.Fprintf(os.Stderr, " Command: %s %s\n", command, strings.Join(cmdArgs, " ")) + + return nil +} + +// MCPRemoveGlobalHandler handles removing an MCP from global config. +func MCPRemoveGlobalHandler(cmd *cobra.Command, args []string) error { + if len(args) == 0 { + return fmt.Errorf("usage: ollama mcp remove-global NAME [NAME...]") + } + + config, err := loadMCPConfig() + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + + for _, name := range args { + if _, exists := config.MCPServers[name]; !exists { + fmt.Fprintf(os.Stderr, "MCP server '%s' not found in global config\n", name) + continue + } + + delete(config.MCPServers, name) + fmt.Fprintf(os.Stderr, "Removed MCP server '%s' from global config\n", name) + } + + if err := saveMCPConfig(config); err != nil { + return fmt.Errorf("saving config: %w", err) + } + + return nil +} + +// MCPListGlobalHandler handles listing global MCP servers. +func MCPListGlobalHandler(cmd *cobra.Command, args []string) error { + config, err := loadMCPConfig() + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + + if len(config.MCPServers) == 0 { + fmt.Println("No global MCP servers configured") + fmt.Printf("Add one with: ollama mcp add NAME COMMAND [ARGS...]\n") + return nil + } + + fmt.Printf("Global MCP servers (%s):\n\n", getMCPConfigPath()) + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) + fmt.Fprintln(w, "NAME\tCOMMAND\tSTATUS") + + for name, srv := range config.MCPServers { + cmdLine := srv.Command + if len(srv.Args) > 0 { + cmdLine += " " + strings.Join(srv.Args, " ") + } + status := "enabled" + if srv.Disabled { + status = "disabled" + } + fmt.Fprintf(w, "%s\t%s\t%s\n", name, cmdLine, status) + } + + return w.Flush() +} + +// MCPDisableHandler handles disabling an MCP server in global config. +func MCPDisableHandler(cmd *cobra.Command, args []string) error { + if len(args) == 0 { + return fmt.Errorf("usage: ollama mcp disable NAME [NAME...]") + } + + config, err := loadMCPConfig() + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + + for _, name := range args { + srv, exists := config.MCPServers[name] + if !exists { + fmt.Fprintf(os.Stderr, "MCP server '%s' not found in global config\n", name) + continue + } + + if srv.Disabled { + fmt.Fprintf(os.Stderr, "MCP server '%s' is already disabled\n", name) + continue + } + + srv.Disabled = true + config.MCPServers[name] = srv + fmt.Fprintf(os.Stderr, "Disabled MCP server '%s'\n", name) + } + + if err := saveMCPConfig(config); err != nil { + return fmt.Errorf("saving config: %w", err) + } + + return nil +} + +// MCPEnableHandler handles enabling an MCP server in global config. +func MCPEnableHandler(cmd *cobra.Command, args []string) error { + if len(args) == 0 { + return fmt.Errorf("usage: ollama mcp enable NAME [NAME...]") + } + + config, err := loadMCPConfig() + if err != nil { + return fmt.Errorf("loading config: %w", err) + } + + for _, name := range args { + srv, exists := config.MCPServers[name] + if !exists { + fmt.Fprintf(os.Stderr, "MCP server '%s' not found in global config\n", name) + continue + } + + if !srv.Disabled { + fmt.Fprintf(os.Stderr, "MCP server '%s' is already enabled\n", name) + continue + } + + srv.Disabled = false + config.MCPServers[name] = srv + fmt.Fprintf(os.Stderr, "Enabled MCP server '%s'\n", name) + } + + if err := saveMCPConfig(config); err != nil { + return fmt.Errorf("saving config: %w", err) + } + + return nil +} + +// MCPPushHandler handles the mcp push command. +func MCPPushHandler(cmd *cobra.Command, args []string) error { + if len(args) != 2 { + return fmt.Errorf("usage: ollama mcp push NAME[:TAG] PATH") + } + + name := args[0] + path := args[1] + + // Expand path + if strings.HasPrefix(path, "~") { + home, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("expanding home directory: %w", err) + } + path = filepath.Join(home, path[1:]) + } + + absPath, err := filepath.Abs(path) + if err != nil { + return fmt.Errorf("resolving path: %w", err) + } + + // Validate MCP directory - check for mcp.json, package.json, or any config file + validFiles := []string{"mcp.json", "package.json", "server.py", "server.js", "main.py", "index.js"} + found := false + for _, vf := range validFiles { + if _, err := os.Stat(filepath.Join(absPath, vf)); err == nil { + found = true + break + } + } + if !found { + return fmt.Errorf("MCP directory should contain one of: %s", strings.Join(validFiles, ", ")) + } + + // Parse MCP name (will set Kind="mcp") + n := server.ParseMCPName(name) + if n.Model == "" { + return fmt.Errorf("invalid MCP name: %s", name) + } + + p := progress.NewProgress(os.Stderr) + defer p.Stop() + + // Create MCP layer + displayName := n.DisplayShortest() + status := fmt.Sprintf("Creating MCP layer for %s", displayName) + spinner := progress.NewSpinner(status) + p.Add(status, spinner) + + layer, err := server.CreateMCPLayer(absPath) + if err != nil { + return fmt.Errorf("creating MCP layer: %w", err) + } + + spinner.Stop() + + // Create MCP manifest + manifest, configLayer, err := createMCPManifest(absPath, layer) + if err != nil { + return fmt.Errorf("creating MCP manifest: %w", err) + } + + // Write manifest locally + manifestPath, err := server.GetMCPManifestPath(n) + if err != nil { + return fmt.Errorf("getting manifest path: %w", err) + } + + if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil { + return fmt.Errorf("creating manifest directory: %w", err) + } + + manifestJSON, err := json.Marshal(manifest) + if err != nil { + return fmt.Errorf("marshaling manifest: %w", err) + } + + if err := os.WriteFile(manifestPath, manifestJSON, 0o644); err != nil { + return fmt.Errorf("writing manifest: %w", err) + } + + fmt.Fprintf(os.Stderr, "MCP %s created locally\n", displayName) + fmt.Fprintf(os.Stderr, " Config: %s (%s)\n", configLayer.Digest, format.HumanBytes(configLayer.Size)) + fmt.Fprintf(os.Stderr, " Layer: %s (%s)\n", layer.Digest, format.HumanBytes(layer.Size)) + + // Push to registry + client, err := api.ClientFromEnvironment() + if err != nil { + return fmt.Errorf("creating client: %w", err) + } + + insecure, _ := cmd.Flags().GetBool("insecure") + + fmt.Fprintf(os.Stderr, "\nPushing to registry...\n") + + fn := func(resp api.ProgressResponse) error { + if resp.Digest != "" { + bar := progress.NewBar(resp.Status, resp.Total, resp.Completed) + p.Add(resp.Digest, bar) + } else if resp.Status != "" { + spinner := progress.NewSpinner(resp.Status) + p.Add(resp.Status, spinner) + } + return nil + } + + req := &api.PushRequest{ + Model: displayName, + Insecure: insecure, + } + + if err := client.Push(context.Background(), req, fn); err != nil { + // If push fails, still show success for local creation + fmt.Fprintf(os.Stderr, "\nNote: Local MCP created but push failed: %v\n", err) + fmt.Fprintf(os.Stderr, "You can try pushing later with: ollama mcp push %s\n", name) + return nil + } + + fmt.Fprintf(os.Stderr, "Successfully pushed %s\n", displayName) + return nil +} + +// MCPPullHandler handles the mcp pull command. +func MCPPullHandler(cmd *cobra.Command, args []string) error { + if len(args) != 1 { + return fmt.Errorf("usage: ollama mcp pull NAME[:TAG]") + } + + name := args[0] + n := server.ParseMCPName(name) + if n.Model == "" { + return fmt.Errorf("invalid MCP name: %s", name) + } + + client, err := api.ClientFromEnvironment() + if err != nil { + return fmt.Errorf("creating client: %w", err) + } + + insecure, _ := cmd.Flags().GetBool("insecure") + + p := progress.NewProgress(os.Stderr) + defer p.Stop() + + fn := func(resp api.ProgressResponse) error { + if resp.Digest != "" { + bar := progress.NewBar(resp.Status, resp.Total, resp.Completed) + p.Add(resp.Digest, bar) + } else if resp.Status != "" { + spinner := progress.NewSpinner(resp.Status) + p.Add(resp.Status, spinner) + } + return nil + } + + displayName := n.DisplayShortest() + req := &api.PullRequest{ + Model: displayName, + Insecure: insecure, + } + + if err := client.Pull(context.Background(), req, fn); err != nil { + return fmt.Errorf("pulling MCP: %w", err) + } + + fmt.Fprintf(os.Stderr, "Successfully pulled %s\n", displayName) + return nil +} + +// MCPListHandler handles the mcp list command. +func MCPListHandler(cmd *cobra.Command, args []string) error { + mcps, err := listLocalMCPs() + if err != nil { + return fmt.Errorf("listing MCPs: %w", err) + } + + if len(mcps) == 0 { + fmt.Println("No MCPs installed") + return nil + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0) + fmt.Fprintln(w, "NAME\tTAG\tSIZE\tMODIFIED") + + for _, mcp := range mcps { + fmt.Fprintf(w, "%s/%s\t%s\t%s\t%s\n", + mcp.Namespace, + mcp.Name, + mcp.Tag, + format.HumanBytes(mcp.Size), + format.HumanTime(mcp.ModifiedAt, "Never"), + ) + } + + return w.Flush() +} + +// MCPRemoveHandler handles the mcp rm command. +func MCPRemoveHandler(cmd *cobra.Command, args []string) error { + if len(args) == 0 { + return fmt.Errorf("usage: ollama mcp rm NAME[:TAG] [NAME[:TAG]...]") + } + + for _, name := range args { + n := server.ParseMCPName(name) + if n.Model == "" { + fmt.Fprintf(os.Stderr, "Invalid MCP name: %s\n", name) + continue + } + + displayName := n.DisplayShortest() + manifestPath, err := server.GetMCPManifestPath(n) + if err != nil { + fmt.Fprintf(os.Stderr, "Error getting manifest path for %s: %v\n", name, err) + continue + } + + if _, err := os.Stat(manifestPath); os.IsNotExist(err) { + fmt.Fprintf(os.Stderr, "MCP not found: %s\n", displayName) + continue + } + + if err := os.Remove(manifestPath); err != nil { + fmt.Fprintf(os.Stderr, "Error removing %s: %v\n", displayName, err) + continue + } + + // Clean up empty parent directories + dir := filepath.Dir(manifestPath) + for dir != filepath.Join(os.Getenv("HOME"), ".ollama", "models", "manifests") { + entries, _ := os.ReadDir(dir) + if len(entries) == 0 { + os.Remove(dir) + dir = filepath.Dir(dir) + } else { + break + } + } + + fmt.Fprintf(os.Stderr, "Deleted '%s'\n", displayName) + } + + return nil +} + +// MCPShowHandler handles the mcp show command. +func MCPShowHandler(cmd *cobra.Command, args []string) error { + if len(args) != 1 { + return fmt.Errorf("usage: ollama mcp show NAME[:TAG]") + } + + name := args[0] + n := server.ParseMCPName(name) + if n.Model == "" { + return fmt.Errorf("invalid MCP name: %s", name) + } + + displayName := n.DisplayShortest() + manifestPath, err := server.GetMCPManifestPath(n) + if err != nil { + return fmt.Errorf("getting manifest path: %w", err) + } + + data, err := os.ReadFile(manifestPath) + if err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("MCP not found: %s", displayName) + } + return fmt.Errorf("reading manifest: %w", err) + } + + var manifest server.Manifest + if err := json.Unmarshal(data, &manifest); err != nil { + return fmt.Errorf("parsing manifest: %w", err) + } + + fmt.Printf("MCP: %s\n\n", displayName) + + fmt.Println("Layers:") + for _, layer := range manifest.Layers { + fmt.Printf(" %s %s %s\n", layer.MediaType, layer.Digest[:19], format.HumanBytes(layer.Size)) + } + + // Try to read and display mcp.json or package.json content + if len(manifest.Layers) > 0 { + for _, layer := range manifest.Layers { + if layer.MediaType == server.MediaTypeMCP { + mcpPath, err := server.GetMCPsPath(layer.Digest) + if err == nil { + // Try mcp.json first + mcpJSONPath := filepath.Join(mcpPath, "mcp.json") + if content, err := os.ReadFile(mcpJSONPath); err == nil { + fmt.Println("\nConfig (mcp.json):") + fmt.Println(string(content)) + } else { + // Try package.json + pkgJSONPath := filepath.Join(mcpPath, "package.json") + if content, err := os.ReadFile(pkgJSONPath); err == nil { + fmt.Println("\nConfig (package.json):") + fmt.Println(string(content)) + } + } + + // List files in the MCP + fmt.Println("\nFiles:") + filepath.Walk(mcpPath, func(path string, info os.FileInfo, err error) error { + if err != nil { + return nil + } + relPath, _ := filepath.Rel(mcpPath, path) + if relPath == "." { + return nil + } + if info.IsDir() { + fmt.Printf(" %s/\n", relPath) + } else { + fmt.Printf(" %s (%s)\n", relPath, format.HumanBytes(info.Size())) + } + return nil + }) + } + } + } + } + + return nil +} + +// MCPInfo represents information about an installed MCP. +type MCPInfo struct { + Namespace string + Name string + Tag string + Size int64 + ModifiedAt time.Time +} + +// listLocalMCPs returns a list of locally installed MCPs. +// MCPs are stored with 5-part paths: host/namespace/kind/model/tag +// where kind is "mcp". +func listLocalMCPs() ([]MCPInfo, error) { + manifestsPath := filepath.Join(os.Getenv("HOME"), ".ollama", "models", "manifests") + + var mcps []MCPInfo + + // Walk through all registries + registries, err := os.ReadDir(manifestsPath) + if err != nil { + if os.IsNotExist(err) { + return mcps, nil + } + return nil, err + } + + for _, registry := range registries { + if !registry.IsDir() { + continue + } + + // Walk namespaces + namespaces, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name())) + if err != nil { + continue + } + + for _, namespace := range namespaces { + if !namespace.IsDir() { + continue + } + + // Walk kinds looking for "mcp" + kinds, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name())) + if err != nil { + continue + } + + for _, kind := range kinds { + if !kind.IsDir() { + continue + } + + // Only process mcp kind + if kind.Name() != server.MCPNamespace { + continue + } + + // Walk MCP names (model names) + mcpNames, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name())) + if err != nil { + continue + } + + for _, mcpName := range mcpNames { + if !mcpName.IsDir() { + continue + } + + // Walk tags + tags, err := os.ReadDir(filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name(), mcpName.Name())) + if err != nil { + continue + } + + for _, tag := range tags { + manifestPath := filepath.Join(manifestsPath, registry.Name(), namespace.Name(), kind.Name(), mcpName.Name(), tag.Name()) + fi, err := os.Stat(manifestPath) + if err != nil || fi.IsDir() { + continue + } + + // Read manifest to get size + data, err := os.ReadFile(manifestPath) + if err != nil { + continue + } + + var manifest server.Manifest + if err := json.Unmarshal(data, &manifest); err != nil { + continue + } + + var totalSize int64 + for _, layer := range manifest.Layers { + totalSize += layer.Size + } + + // Build display name using model.Name + n := model.Name{ + Host: registry.Name(), + Namespace: namespace.Name(), + Kind: kind.Name(), + Model: mcpName.Name(), + Tag: tag.Name(), + } + + mcps = append(mcps, MCPInfo{ + Namespace: n.Namespace + "/" + n.Kind, + Name: n.Model, + Tag: n.Tag, + Size: totalSize, + ModifiedAt: fi.ModTime(), + }) + } + } + } + } + } + + return mcps, nil +} + +// createMCPManifest creates a manifest for a standalone MCP. +func createMCPManifest(mcpDir string, layer server.Layer) (*server.Manifest, *server.Layer, error) { + // Try to read mcp.json or package.json to extract metadata + name, description := extractMCPMetadata(mcpDir) + if name == "" { + // Use directory name as fallback + name = filepath.Base(mcpDir) + } + + // Create config + config := map[string]any{ + "name": name, + "description": description, + "architecture": "amd64", + "os": "linux", + } + + configJSON, err := json.Marshal(config) + if err != nil { + return nil, nil, fmt.Errorf("marshaling config: %w", err) + } + + // Create config layer + configLayer, err := server.NewLayer(strings.NewReader(string(configJSON)), "application/vnd.docker.container.image.v1+json") + if err != nil { + return nil, nil, fmt.Errorf("creating config layer: %w", err) + } + + manifest := &server.Manifest{ + SchemaVersion: 2, + MediaType: "application/vnd.docker.distribution.manifest.v2+json", + Config: configLayer, + Layers: []server.Layer{layer}, + } + + return manifest, &configLayer, nil +} + +// extractMCPMetadata extracts name and description from mcp.json or package.json. +func extractMCPMetadata(mcpDir string) (name, description string) { + // Try mcp.json first + mcpJSONPath := filepath.Join(mcpDir, "mcp.json") + if data, err := os.ReadFile(mcpJSONPath); err == nil { + var config map[string]any + if err := json.Unmarshal(data, &config); err == nil { + if n, ok := config["name"].(string); ok { + name = n + } + if d, ok := config["description"].(string); ok { + description = d + } + return name, description + } + } + + // Try package.json + pkgJSONPath := filepath.Join(mcpDir, "package.json") + if data, err := os.ReadFile(pkgJSONPath); err == nil { + var config map[string]any + if err := json.Unmarshal(data, &config); err == nil { + if n, ok := config["name"].(string); ok { + name = n + } + if d, ok := config["description"].(string); ok { + description = d + } + return name, description + } + } + + return "", "" +} + +// NewMCPCommand creates the mcp parent command with subcommands. +func NewMCPCommand() *cobra.Command { + mcpCmd := &cobra.Command{ + Use: "mcp", + Short: "Manage MCP servers", + Long: "Commands for managing MCP (Model Context Protocol) servers (add, push, pull, list, rm, show)", + } + + // Global config commands + addCmd := &cobra.Command{ + Use: "add NAME COMMAND [ARGS...]", + Short: "Add an MCP server to global config", + Long: `Add an MCP server to the global config (~/.ollama/mcp.json). +Global MCP servers are available to all agents. + +Examples: + ollama mcp add web-search uv run ./mcp-server.py + ollama mcp add calculator python3 /path/to/calc.py`, + Args: cobra.MinimumNArgs(2), + RunE: MCPAddHandler, + DisableFlagParsing: true, // Allow args with dashes + } + + removeGlobalCmd := &cobra.Command{ + Use: "remove-global NAME [NAME...]", + Aliases: []string{"rm-global"}, + Short: "Remove an MCP server from global config", + Args: cobra.MinimumNArgs(1), + RunE: MCPRemoveGlobalHandler, + } + + listGlobalCmd := &cobra.Command{ + Use: "list-global", + Short: "List global MCP servers", + Args: cobra.NoArgs, + RunE: MCPListGlobalHandler, + } + + // Registry commands + pushCmd := &cobra.Command{ + Use: "push NAME[:TAG] PATH", + Short: "Push an MCP server to a registry", + Long: "Package a local MCP server directory and push it to a registry", + Args: cobra.ExactArgs(2), + PreRunE: checkServerHeartbeat, + RunE: MCPPushHandler, + } + pushCmd.Flags().Bool("insecure", false, "Use an insecure registry") + + pullCmd := &cobra.Command{ + Use: "pull NAME[:TAG]", + Short: "Pull an MCP server from a registry", + Args: cobra.ExactArgs(1), + PreRunE: checkServerHeartbeat, + RunE: MCPPullHandler, + } + pullCmd.Flags().Bool("insecure", false, "Use an insecure registry") + + listCmd := &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List installed MCP servers (from registry)", + Args: cobra.NoArgs, + RunE: MCPListHandler, + } + + rmCmd := &cobra.Command{ + Use: "rm NAME[:TAG] [NAME[:TAG]...]", + Aliases: []string{"remove", "delete"}, + Short: "Remove an MCP server (from registry)", + Args: cobra.MinimumNArgs(1), + RunE: MCPRemoveHandler, + } + + showCmd := &cobra.Command{ + Use: "show NAME[:TAG]", + Short: "Show MCP server details", + Args: cobra.ExactArgs(1), + RunE: MCPShowHandler, + } + + disableCmd := &cobra.Command{ + Use: "disable NAME [NAME...]", + Short: "Disable an MCP server (keep in config)", + Long: `Disable an MCP server without removing it from config. +Disabled servers will not be started when running agents. +Use 'ollama mcp enable' to re-enable.`, + Args: cobra.MinimumNArgs(1), + RunE: MCPDisableHandler, + } + + enableCmd := &cobra.Command{ + Use: "enable NAME [NAME...]", + Short: "Enable a disabled MCP server", + Long: `Re-enable a previously disabled MCP server.`, + Args: cobra.MinimumNArgs(1), + RunE: MCPEnableHandler, + } + + mcpCmd.AddCommand(addCmd, removeGlobalCmd, listGlobalCmd, disableCmd, enableCmd, pushCmd, pullCmd, listCmd, rmCmd, showCmd) + + return mcpCmd +} diff --git a/go.mod b/go.mod index f7c9ff295..126980412 100644 --- a/go.mod +++ b/go.mod @@ -83,5 +83,5 @@ require ( golang.org/x/term v0.36.0 golang.org/x/text v0.30.0 google.golang.org/protobuf v1.34.1 - gopkg.in/yaml.v3 v3.0.1 // indirect + gopkg.in/yaml.v3 v3.0.1 ) diff --git a/parser/parser.go b/parser/parser.go index 364eed53a..4bdcdfb03 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "crypto/sha256" + "encoding/json" "errors" "fmt" "io" @@ -59,6 +60,7 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) var messages []api.Message var licenses []string var skills []api.SkillRef + var mcps []api.MCPRef params := make(map[string]any) for _, c := range f.Commands { @@ -121,8 +123,21 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) messages = append(messages, api.Message{Role: role, Content: msg}) case "skill": skills = append(skills, api.SkillRef{Name: c.Args}) + case "mcp": + mcpRef, err := parseMCPArg(c.Args, relativeDir) + if err != nil { + return nil, fmt.Errorf("invalid MCP: %w", err) + } + mcps = append(mcps, mcpRef) case "agent_type": - req.AgentType = c.Args + // Handle "AGENT TYPE conversational" -> strip "TYPE " prefix + args := c.Args + if strings.HasPrefix(strings.ToLower(args), "type ") { + args = strings.TrimSpace(args[5:]) + } + req.AgentType = args + case "entrypoint": + req.Entrypoint = c.Args default: if slices.Contains(deprecatedParameters, c.Name) { fmt.Printf("warning: parameter %s is deprecated\n", c.Name) @@ -158,6 +173,9 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error) if len(skills) > 0 { req.Skills = skills } + if len(mcps) > 0 { + req.MCPs = mcps + } return req, nil } @@ -341,7 +359,7 @@ func (c Command) String() string { switch c.Name { case "model": fmt.Fprintf(&sb, "FROM %s", c.Args) - case "license", "template", "system", "adapter", "renderer", "parser", "requires", "skill", "agent_type": + case "license", "template", "system", "adapter", "renderer", "parser", "requires", "skill", "agent_type", "entrypoint": fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args)) case "message": role, message, _ := strings.Cut(c.Args, ": ") @@ -367,7 +385,7 @@ const ( var ( errMissingFrom = errors.New("no FROM line") errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"") - errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", \"message\", \"requires\", \"skill\", or \"agent_type\"") + errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", \"message\", \"requires\", \"skill\", \"agent_type\", \"mcp\", or \"entrypoint\"") ) type ParserError struct { @@ -431,6 +449,9 @@ func ParseFile(r io.Reader) (*Modelfile, error) { switch s := strings.ToLower(b.String()); s { case "from": cmd.Name = "model" + case "agent": + // "AGENT TYPE" -> "agent_type", consume next word + cmd.Name = "agent_type" case "parameter": // transition to stateParameter which sets command name next = stateParameter @@ -508,6 +529,10 @@ func ParseFile(r io.Reader) (*Modelfile, error) { if cmd.Name == "model" { return &f, nil } + // Allow entrypoint-only agents without FROM + if cmd.Name == "entrypoint" { + return &f, nil + } } return nil, errMissingFrom @@ -627,7 +652,7 @@ func isValidMessageRole(role string) bool { func isValidCommand(cmd string) bool { switch strings.ToLower(cmd) { - case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message", "requires", "skill", "agent_type": + case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message", "requires", "skill", "agent_type", "agent", "mcp", "entrypoint": return true default: return false @@ -674,3 +699,79 @@ func expandPathImpl(path, relativeDir string, currentUserFunc func() (*user.User func expandPath(path, relativeDir string) (string, error) { return expandPathImpl(path, relativeDir, user.Current, user.Lookup) } + +// parseMCPArg parses MCP command arguments. +// Supports two formats: +// +// JSON: {"name": "web-search", "command": "uv", "args": ["run", "./script.py"]} +// Simple: web-search uv run ./script.py (name, command, args...) +func parseMCPArg(args string, relativeDir string) (api.MCPRef, error) { + args = strings.TrimSpace(args) + if args == "" { + return api.MCPRef{}, errors.New("MCP requires arguments") + } + + // Try JSON format first + if strings.HasPrefix(args, "{") { + var ref api.MCPRef + if err := json.Unmarshal([]byte(args), &ref); err != nil { + return api.MCPRef{}, fmt.Errorf("invalid JSON: %w", err) + } + if ref.Name == "" { + return api.MCPRef{}, errors.New("MCP name is required") + } + if ref.Command == "" { + return api.MCPRef{}, errors.New("MCP command is required") + } + if ref.Type == "" { + ref.Type = "stdio" + } + // Expand relative paths in args + for i, arg := range ref.Args { + if isLocalPath(arg) { + expanded, err := expandPath(arg, relativeDir) + if err != nil { + return api.MCPRef{}, fmt.Errorf("expanding path %q: %w", arg, err) + } + ref.Args[i] = expanded + } + } + return ref, nil + } + + // Simple format: name command args... + parts := strings.Fields(args) + if len(parts) < 2 { + return api.MCPRef{}, errors.New("MCP requires at least name and command") + } + + ref := api.MCPRef{ + Name: parts[0], + Command: parts[1], + Type: "stdio", + } + if len(parts) > 2 { + ref.Args = parts[2:] + } + + // Expand relative paths in args + for i, arg := range ref.Args { + if isLocalPath(arg) { + expanded, err := expandPath(arg, relativeDir) + if err != nil { + return api.MCPRef{}, fmt.Errorf("expanding path %q: %w", arg, err) + } + ref.Args[i] = expanded + } + } + + return ref, nil +} + +// isLocalPath checks if a string looks like a local filesystem path. +func isLocalPath(s string) bool { + return strings.HasPrefix(s, "/") || + strings.HasPrefix(s, "./") || + strings.HasPrefix(s, "../") || + strings.HasPrefix(s, "~") +} diff --git a/server/create.go b/server/create.go index 5ebc06d07..f1ad2eeda 100644 --- a/server/create.go +++ b/server/create.go @@ -63,7 +63,9 @@ func (s *Server) CreateHandler(c *gin.Context) { config.Parser = r.Parser config.Requires = r.Requires config.Skills = r.Skills + config.MCPs = r.MCPs config.AgentType = r.AgentType + config.Entrypoint = r.Entrypoint for v := range r.Files { if !fs.ValidPath(v) { @@ -159,6 +161,9 @@ func (s *Server) CreateHandler(c *gin.Context) { ch <- gin.H{"error": err.Error()} return } + } else if r.Entrypoint != "" { + // Entrypoint-only agent: no base model needed + slog.Debug("create entrypoint-only agent", "entrypoint", r.Entrypoint) } else { ch <- gin.H{"error": errNeitherFromOrFiles.Error(), "status": http.StatusBadRequest} return @@ -551,6 +556,12 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML, return err } + // Handle MCP layers for agents + layers, config.MCPs, err = setMCPLayers(layers, config.MCPs, fn) + if err != nil { + return err + } + configLayer, err := createConfigLayer(layers, *config) if err != nil { return err @@ -894,6 +905,42 @@ func setSkillLayers(layers []Layer, skills []model.SkillRef, fn func(resp api.Pr return layers, updatedSkills, nil } +// setMCPLayers handles MCP server references. +// Currently, MCPs are stored as config data (command/args). +// Future: support bundling MCP server directories as layers. +func setMCPLayers(layers []Layer, mcps []model.MCPRef, fn func(resp api.ProgressResponse)) ([]Layer, []model.MCPRef, error) { + if len(mcps) == 0 { + return layers, mcps, nil + } + + // Remove any existing MCP layers + layers = removeLayer(layers, MediaTypeMCP) + + var updatedMCPs []model.MCPRef + + for _, mcp := range mcps { + // Validate MCP has required fields + if mcp.Name == "" { + return nil, nil, fmt.Errorf("MCP server requires a name") + } + if mcp.Command == "" { + return nil, nil, fmt.Errorf("MCP server %q requires a command", mcp.Name) + } + + // Set default type if not specified + if mcp.Type == "" { + mcp.Type = "stdio" + } + + // For now, just keep MCPs as config data + // Future: detect local paths in args and bundle them + updatedMCPs = append(updatedMCPs, mcp) + fn(api.ProgressResponse{Status: fmt.Sprintf("configuring MCP: %s", mcp.Name)}) + } + + return layers, updatedMCPs, nil +} + func createConfigLayer(layers []Layer, config model.ConfigV2) (*Layer, error) { digests := make([]string, len(layers)) for i, layer := range layers { diff --git a/server/images.go b/server/images.go index 9e9114d72..01df68be4 100644 --- a/server/images.go +++ b/server/images.go @@ -232,6 +232,13 @@ func (m *Model) String() string { }) } + if m.Config.Entrypoint != "" { + modelfile.Commands = append(modelfile.Commands, parser.Command{ + Name: "entrypoint", + Args: m.Config.Entrypoint, + }) + } + for k, v := range m.Options { switch v := v.(type) { case []any: diff --git a/server/manifest.go b/server/manifest.go index 0d348dd06..e42f80197 100644 --- a/server/manifest.go +++ b/server/manifest.go @@ -129,11 +129,30 @@ func Manifests(continueOnError bool) (map[model.Name]*Manifest, error) { return nil, err } - // TODO(mxyng): use something less brittle - matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*")) + // Find both 4-part (models) and 5-part (skills/agents) manifest paths + matches4, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*")) if err != nil { return nil, err } + matches5, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*", "*")) + if err != nil { + return nil, err + } + + // Combine matches, filtering to only include files + var matches []string + for _, match := range matches4 { + fi, err := os.Stat(match) + if err == nil && !fi.IsDir() { + matches = append(matches, match) + } + } + for _, match := range matches5 { + fi, err := os.Stat(match) + if err == nil && !fi.IsDir() { + matches = append(matches, match) + } + } ms := make(map[model.Name]*Manifest) for _, match := range matches { diff --git a/server/mcp.go b/server/mcp.go new file mode 100644 index 000000000..f1bb2e576 --- /dev/null +++ b/server/mcp.go @@ -0,0 +1,315 @@ +package server + +import ( + "archive/tar" + "compress/gzip" + "fmt" + "io" + "os" + "path/filepath" + "regexp" + "strings" + + "github.com/ollama/ollama/envconfig" + "github.com/ollama/ollama/types/model" +) + +// MediaTypeMCP is the media type for MCP server layers in manifests. +const MediaTypeMCP = "application/vnd.ollama.image.mcp" + +// GetMCPsPath returns the path to the extracted MCPs cache directory. +// If digest is empty, returns the mcps directory itself. +// If digest is provided, returns the path to the extracted MCP for that digest. +func GetMCPsPath(digest string) (string, error) { + // only accept actual sha256 digests + pattern := "^sha256[:-][0-9a-fA-F]{64}$" + re := regexp.MustCompile(pattern) + + if digest != "" && !re.MatchString(digest) { + return "", ErrInvalidDigestFormat + } + + digest = strings.ReplaceAll(digest, ":", "-") + path := filepath.Join(envconfig.Models(), "mcps", digest) + dirPath := filepath.Dir(path) + if digest == "" { + dirPath = path + } + + if err := os.MkdirAll(dirPath, 0o755); err != nil { + return "", fmt.Errorf("%w: ensure path elements are traversable", err) + } + + return path, nil +} + +// ExtractMCPBlob extracts an MCP tar.gz blob to the mcps cache. +// The blob is expected to be at the blobs path for the given digest. +// Returns the path to the extracted MCP directory. +func ExtractMCPBlob(digest string) (string, error) { + // Get the blob path + blobPath, err := GetBlobsPath(digest) + if err != nil { + return "", fmt.Errorf("getting blob path: %w", err) + } + + // Get the extraction path + mcpPath, err := GetMCPsPath(digest) + if err != nil { + return "", fmt.Errorf("getting mcp path: %w", err) + } + + // Check if already extracted (look for any file) + entries, err := os.ReadDir(mcpPath) + if err == nil && len(entries) > 0 { + return mcpPath, nil + } + + // Open the blob + f, err := os.Open(blobPath) + if err != nil { + return "", fmt.Errorf("opening blob: %w", err) + } + defer f.Close() + + // Create gzip reader + gzr, err := gzip.NewReader(f) + if err != nil { + return "", fmt.Errorf("creating gzip reader: %w", err) + } + defer gzr.Close() + + // Create tar reader + tr := tar.NewReader(gzr) + + // Create the mcp directory + if err := os.MkdirAll(mcpPath, 0o755); err != nil { + return "", fmt.Errorf("creating mcp directory: %w", err) + } + + // Extract files + for { + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return "", fmt.Errorf("reading tar: %w", err) + } + + // Clean the name and ensure it doesn't escape the target directory + name := filepath.Clean(header.Name) + if strings.HasPrefix(name, "..") { + return "", fmt.Errorf("invalid path in archive: %s", header.Name) + } + + target := filepath.Join(mcpPath, name) + + // Verify the target is within mcpPath + if !strings.HasPrefix(target, filepath.Clean(mcpPath)+string(os.PathSeparator)) && target != filepath.Clean(mcpPath) { + return "", fmt.Errorf("path escapes mcp directory: %s", header.Name) + } + + switch header.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(target, 0o755); err != nil { + return "", fmt.Errorf("creating directory: %w", err) + } + case tar.TypeReg: + // Ensure parent directory exists + if err := os.MkdirAll(filepath.Dir(target), 0o755); err != nil { + return "", fmt.Errorf("creating parent directory: %w", err) + } + + outFile, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, os.FileMode(header.Mode)) + if err != nil { + return "", fmt.Errorf("creating file: %w", err) + } + + if _, err := io.Copy(outFile, tr); err != nil { + outFile.Close() + return "", fmt.Errorf("writing file: %w", err) + } + outFile.Close() + } + } + + return mcpPath, nil +} + +// CreateMCPLayer creates an MCP layer from a local directory. +// The directory can optionally contain an mcp.json or package.json file. +// Returns the created layer. +func CreateMCPLayer(mcpDir string) (Layer, error) { + // Verify directory exists + info, err := os.Stat(mcpDir) + if err != nil { + return Layer{}, fmt.Errorf("mcp directory not found: %w", err) + } + if !info.IsDir() { + return Layer{}, fmt.Errorf("mcp path is not a directory: %s", mcpDir) + } + + // Create a temporary file for the tar.gz + blobsPath, err := GetBlobsPath("") + if err != nil { + return Layer{}, fmt.Errorf("getting blobs path: %w", err) + } + + tmpFile, err := os.CreateTemp(blobsPath, "mcp-*.tar.gz") + if err != nil { + return Layer{}, fmt.Errorf("creating temp file: %w", err) + } + tmpPath := tmpFile.Name() + defer func() { + tmpFile.Close() + os.Remove(tmpPath) + }() + + // Create gzip writer + gzw := gzip.NewWriter(tmpFile) + defer gzw.Close() + + // Create tar writer + tw := tar.NewWriter(gzw) + defer tw.Close() + + // Walk the mcp directory and add files to tar + err = filepath.Walk(mcpDir, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + // Get relative path + relPath, err := filepath.Rel(mcpDir, path) + if err != nil { + return err + } + + // Skip the root directory itself + if relPath == "." { + return nil + } + + // Create tar header + header, err := tar.FileInfoHeader(info, "") + if err != nil { + return err + } + header.Name = relPath + + if err := tw.WriteHeader(header); err != nil { + return err + } + + // Write file contents if it's a regular file + if !info.IsDir() { + f, err := os.Open(path) + if err != nil { + return err + } + defer f.Close() + + if _, err := io.Copy(tw, f); err != nil { + return err + } + } + + return nil + }) + if err != nil { + return Layer{}, fmt.Errorf("creating tar archive: %w", err) + } + + // Close writers to flush + if err := tw.Close(); err != nil { + return Layer{}, fmt.Errorf("closing tar writer: %w", err) + } + if err := gzw.Close(); err != nil { + return Layer{}, fmt.Errorf("closing gzip writer: %w", err) + } + if err := tmpFile.Close(); err != nil { + return Layer{}, fmt.Errorf("closing temp file: %w", err) + } + + // Open the temp file for reading + tmpFile, err = os.Open(tmpPath) + if err != nil { + return Layer{}, fmt.Errorf("reopening temp file: %w", err) + } + defer tmpFile.Close() + + // Create the layer (this will compute the digest and move to blobs) + layer, err := NewLayer(tmpFile, MediaTypeMCP) + if err != nil { + return Layer{}, fmt.Errorf("creating layer: %w", err) + } + + // Extract the mcp to the cache so it's ready to use + if _, err := ExtractMCPBlob(layer.Digest); err != nil { + return Layer{}, fmt.Errorf("extracting mcp: %w", err) + } + + return layer, nil +} + +// IsLocalMCPPath checks if an MCP reference looks like a local path. +// Local paths are explicitly prefixed with /, ./, ../, or ~. +func IsLocalMCPPath(name string) bool { + return strings.HasPrefix(name, "/") || + strings.HasPrefix(name, "./") || + strings.HasPrefix(name, "../") || + strings.HasPrefix(name, "~") +} + +// MCPNamespace is the namespace used for standalone MCPs in the registry. +const MCPNamespace = "mcp" + +// IsMCPReference checks if a name refers to an MCP (has mcp/ prefix). +func IsMCPReference(name string) bool { + name = strings.ReplaceAll(name, string(os.PathSeparator), "/") + parts := strings.Split(name, "/") + + // mcp/name or mcp/name:tag + if len(parts) >= 1 && parts[0] == MCPNamespace { + return true + } + // namespace/mcp/name (e.g., myuser/mcp/websearch) + if len(parts) >= 2 && parts[1] == MCPNamespace { + return true + } + return false +} + +// ParseMCPName parses an MCP reference string into a model.Name. +// The Kind field is set to "mcp". +func ParseMCPName(name string) model.Name { + n := model.ParseName(name) + + // If Kind wasn't set (old format without mcp/), set it + if n.Kind == "" { + n.Kind = MCPNamespace + } + + return n +} + +// GetMCPManifestPath returns the path to the MCP manifest file. +func GetMCPManifestPath(n model.Name) (string, error) { + if n.Model == "" { + return "", fmt.Errorf("mcp name is required") + } + + // Ensure Kind is set + if n.Kind == "" { + n.Kind = MCPNamespace + } + + path := filepath.Join( + envconfig.Models(), + "manifests", + n.Filepath(), + ) + + return path, nil +} diff --git a/server/routes.go b/server/routes.go index 2bac04a05..beeb0e50e 100644 --- a/server/routes.go +++ b/server/routes.go @@ -969,6 +969,9 @@ func getExistingName(n model.Name) (model.Name, error) { if set.Namespace == "" && strings.EqualFold(e.Namespace, n.Namespace) { n.Namespace = e.Namespace } + if set.Kind == "" && strings.EqualFold(e.Kind, n.Kind) { + n.Kind = e.Kind + } if set.Model == "" && strings.EqualFold(e.Model, n.Model) { n.Model = e.Model } @@ -1108,7 +1111,9 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { ModifiedAt: manifest.fi.ModTime(), Requires: m.Config.Requires, Skills: m.Config.Skills, + MCPs: m.Config.MCPs, AgentType: m.Config.AgentType, + Entrypoint: m.Config.Entrypoint, } if m.Config.RemoteHost != "" { @@ -1163,11 +1168,16 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { fmt.Fprint(&sb, m.String()) resp.Modelfile = sb.String() - // skip loading tensor information if this is a remote model + // skip loading tensor information if this is a remote model or a skill if m.Config.RemoteHost != "" && m.Config.RemoteModel != "" { return resp, nil } + // Skills don't have model weights, skip tensor loading + if m.ModelPath == "" { + return resp, nil + } + kvData, tensors, err := getModelData(m.ModelPath, req.Verbose) if err != nil { return nil, err diff --git a/types/model/config.go b/types/model/config.go index e64eb1b79..61c1f4645 100644 --- a/types/model/config.go +++ b/types/model/config.go @@ -8,6 +8,22 @@ type SkillRef struct { Digest string `json:"digest,omitempty"` } +// MCPRef represents a reference to an MCP (Model Context Protocol) server. +type MCPRef struct { + // Name is the identifier for the MCP server (used for tool namespacing) + Name string `json:"name,omitempty"` + // Digest is the content-addressable digest of the bundled MCP server blob + Digest string `json:"digest,omitempty"` + // Command is the executable to run (e.g., "uv", "node", "python3") + Command string `json:"command,omitempty"` + // Args are the arguments to pass to the command + Args []string `json:"args,omitempty"` + // Env is optional environment variables for the MCP server + Env map[string]string `json:"env,omitempty"` + // Type is the transport type (currently only "stdio" is supported) + Type string `json:"type,omitempty"` +} + // ConfigV2 represents the configuration metadata for a model. type ConfigV2 struct { ModelFormat string `json:"model_format"` @@ -29,8 +45,10 @@ type ConfigV2 struct { BaseName string `json:"base_name,omitempty"` // agent-specific fields - Skills []SkillRef `json:"skills,omitempty"` - AgentType string `json:"agent_type,omitempty"` + Skills []SkillRef `json:"skills,omitempty"` + MCPs []MCPRef `json:"mcps,omitempty"` + AgentType string `json:"agent_type,omitempty"` + Entrypoint string `json:"entrypoint,omitempty"` // required by spec Architecture string `json:"architecture"` diff --git a/types/model/name.go b/types/model/name.go index 25c29bde3..313564ed9 100644 --- a/types/model/name.go +++ b/types/model/name.go @@ -137,6 +137,7 @@ func ParseName(s string) Name { var ValidKinds = map[string]bool{ "skill": true, "agent": true, + "mcp": true, } // ParseNameBare parses s as a name string and returns a Name. No merge with