546 lines
12 KiB
Go
546 lines
12 KiB
Go
package cmd
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"os/exec"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
)
|
|
|
|
const (
|
|
mcpInitTimeout = 30 * time.Second
|
|
mcpCallTimeout = 60 * time.Second
|
|
mcpShutdownTimeout = 5 * time.Second
|
|
)
|
|
|
|
// JSON-RPC types
|
|
type jsonrpcRequest struct {
|
|
JSONRPC string `json:"jsonrpc"`
|
|
ID int `json:"id,omitempty"`
|
|
Method string `json:"method"`
|
|
Params any `json:"params,omitempty"`
|
|
}
|
|
|
|
type jsonrpcResponse struct {
|
|
JSONRPC string `json:"jsonrpc"`
|
|
ID int `json:"id"`
|
|
Result json.RawMessage `json:"result,omitempty"`
|
|
Error *jsonrpcError `json:"error,omitempty"`
|
|
}
|
|
|
|
type jsonrpcError struct {
|
|
Code int `json:"code"`
|
|
Message string `json:"message"`
|
|
Data any `json:"data,omitempty"`
|
|
}
|
|
|
|
// MCP protocol types
|
|
type mcpInitializeParams struct {
|
|
ProtocolVersion string `json:"protocolVersion"`
|
|
Capabilities map[string]any `json:"capabilities"`
|
|
ClientInfo mcpClientInfo `json:"clientInfo"`
|
|
}
|
|
|
|
type mcpClientInfo struct {
|
|
Name string `json:"name"`
|
|
Version string `json:"version"`
|
|
}
|
|
|
|
type mcpInitializeResult struct {
|
|
ProtocolVersion string `json:"protocolVersion"`
|
|
Capabilities mcpCapabilities `json:"capabilities"`
|
|
ServerInfo mcpServerInfo `json:"serverInfo"`
|
|
}
|
|
|
|
type mcpCapabilities struct {
|
|
Tools *mcpToolsCapability `json:"tools,omitempty"`
|
|
}
|
|
|
|
type mcpToolsCapability struct {
|
|
ListChanged bool `json:"listChanged,omitempty"`
|
|
}
|
|
|
|
type mcpServerInfo struct {
|
|
Name string `json:"name"`
|
|
Version string `json:"version"`
|
|
}
|
|
|
|
type mcpTool struct {
|
|
Name string `json:"name"`
|
|
Description string `json:"description,omitempty"`
|
|
InputSchema mcpToolInputSchema `json:"inputSchema"`
|
|
}
|
|
|
|
type mcpToolInputSchema struct {
|
|
Type string `json:"type"`
|
|
Properties map[string]any `json:"properties,omitempty"`
|
|
Required []string `json:"required,omitempty"`
|
|
}
|
|
|
|
type mcpToolsListResult struct {
|
|
Tools []mcpTool `json:"tools"`
|
|
}
|
|
|
|
type mcpToolCallParams struct {
|
|
Name string `json:"name"`
|
|
Arguments map[string]any `json:"arguments,omitempty"`
|
|
}
|
|
|
|
type mcpToolCallResult struct {
|
|
Content []mcpContent `json:"content"`
|
|
IsError bool `json:"isError,omitempty"`
|
|
}
|
|
|
|
type mcpContent struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text,omitempty"`
|
|
}
|
|
|
|
// mcpServer represents a running MCP server process
|
|
type mcpServer struct {
|
|
ref api.MCPRef
|
|
cmd *exec.Cmd
|
|
stdin io.WriteCloser
|
|
stdout *bufio.Reader
|
|
stderr io.ReadCloser
|
|
tools []mcpTool
|
|
mu sync.Mutex
|
|
nextID int
|
|
started bool
|
|
}
|
|
|
|
// mcpManager manages multiple MCP servers for an agent session
|
|
type mcpManager struct {
|
|
servers map[string]*mcpServer
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// newMCPManager creates a new MCP manager
|
|
func newMCPManager() *mcpManager {
|
|
return &mcpManager{
|
|
servers: make(map[string]*mcpServer),
|
|
}
|
|
}
|
|
|
|
// loadMCPsFromRefs initializes MCP servers from refs
|
|
func (m *mcpManager) loadMCPsFromRefs(refs []api.MCPRef) error {
|
|
if len(refs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
for _, ref := range refs {
|
|
if err := m.addServer(ref); err != nil {
|
|
fmt.Fprintf(os.Stderr, "Warning: failed to initialize MCP server %q: %v\n", ref.Name, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// addServer adds and starts an MCP server
|
|
func (m *mcpManager) addServer(ref api.MCPRef) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
if _, exists := m.servers[ref.Name]; exists {
|
|
return fmt.Errorf("MCP server %q already exists", ref.Name)
|
|
}
|
|
|
|
srv := &mcpServer{
|
|
ref: ref,
|
|
nextID: 1,
|
|
}
|
|
|
|
if err := srv.start(); err != nil {
|
|
return fmt.Errorf("starting MCP server: %w", err)
|
|
}
|
|
|
|
m.servers[ref.Name] = srv
|
|
return nil
|
|
}
|
|
|
|
// start starts the MCP server process
|
|
func (s *mcpServer) start() error {
|
|
s.mu.Lock()
|
|
|
|
if s.started {
|
|
s.mu.Unlock()
|
|
return nil
|
|
}
|
|
|
|
s.cmd = exec.Command(s.ref.Command, s.ref.Args...)
|
|
|
|
// Set environment
|
|
s.cmd.Env = os.Environ()
|
|
for k, v := range s.ref.Env {
|
|
s.cmd.Env = append(s.cmd.Env, fmt.Sprintf("%s=%s", k, v))
|
|
}
|
|
|
|
var err error
|
|
s.stdin, err = s.cmd.StdinPipe()
|
|
if err != nil {
|
|
s.mu.Unlock()
|
|
return fmt.Errorf("creating stdin pipe: %w", err)
|
|
}
|
|
|
|
stdout, err := s.cmd.StdoutPipe()
|
|
if err != nil {
|
|
s.mu.Unlock()
|
|
return fmt.Errorf("creating stdout pipe: %w", err)
|
|
}
|
|
s.stdout = bufio.NewReader(stdout)
|
|
|
|
s.stderr, err = s.cmd.StderrPipe()
|
|
if err != nil {
|
|
s.mu.Unlock()
|
|
return fmt.Errorf("creating stderr pipe: %w", err)
|
|
}
|
|
|
|
// Start stderr reader goroutine (discard stderr for now)
|
|
go func() {
|
|
scanner := bufio.NewScanner(s.stderr)
|
|
for scanner.Scan() {
|
|
_ = scanner.Text()
|
|
}
|
|
}()
|
|
|
|
if err := s.cmd.Start(); err != nil {
|
|
s.mu.Unlock()
|
|
return fmt.Errorf("starting process: %w", err)
|
|
}
|
|
|
|
s.started = true
|
|
s.mu.Unlock() // Release lock before calling initialize/listTools which use the mutex
|
|
|
|
// Initialize the server
|
|
if err := s.initialize(); err != nil {
|
|
s.stop()
|
|
return fmt.Errorf("initializing MCP server: %w", err)
|
|
}
|
|
|
|
// Get available tools
|
|
if err := s.listTools(); err != nil {
|
|
s.stop()
|
|
return fmt.Errorf("listing tools: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// initialize sends the MCP initialize request
|
|
func (s *mcpServer) initialize() error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), mcpInitTimeout)
|
|
defer cancel()
|
|
|
|
params := mcpInitializeParams{
|
|
ProtocolVersion: "2024-11-05",
|
|
Capabilities: map[string]any{},
|
|
ClientInfo: mcpClientInfo{
|
|
Name: "ollama",
|
|
Version: "0.1.0",
|
|
},
|
|
}
|
|
|
|
var result mcpInitializeResult
|
|
if err := s.call(ctx, "initialize", params, &result); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Send initialized notification
|
|
return s.notify("notifications/initialized", nil)
|
|
}
|
|
|
|
// listTools fetches the available tools from the MCP server
|
|
func (s *mcpServer) listTools() error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), mcpInitTimeout)
|
|
defer cancel()
|
|
|
|
var result mcpToolsListResult
|
|
if err := s.call(ctx, "tools/list", nil, &result); err != nil {
|
|
return err
|
|
}
|
|
|
|
s.tools = result.Tools
|
|
return nil
|
|
}
|
|
|
|
// call sends a JSON-RPC request and waits for the response
|
|
func (s *mcpServer) call(ctx context.Context, method string, params any, result any) error {
|
|
s.mu.Lock()
|
|
id := s.nextID
|
|
s.nextID++
|
|
s.mu.Unlock()
|
|
|
|
req := jsonrpcRequest{
|
|
JSONRPC: "2.0",
|
|
ID: id,
|
|
Method: method,
|
|
Params: params,
|
|
}
|
|
|
|
reqBytes, err := json.Marshal(req)
|
|
if err != nil {
|
|
return fmt.Errorf("marshaling request: %w", err)
|
|
}
|
|
|
|
// Send request
|
|
s.mu.Lock()
|
|
_, err = s.stdin.Write(append(reqBytes, '\n'))
|
|
s.mu.Unlock()
|
|
if err != nil {
|
|
return fmt.Errorf("writing request: %w", err)
|
|
}
|
|
|
|
// Read response with timeout
|
|
respCh := make(chan []byte, 1)
|
|
errCh := make(chan error, 1)
|
|
|
|
go func() {
|
|
s.mu.Lock()
|
|
line, err := s.stdout.ReadBytes('\n')
|
|
s.mu.Unlock()
|
|
if err != nil {
|
|
errCh <- err
|
|
return
|
|
}
|
|
respCh <- line
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case err := <-errCh:
|
|
return fmt.Errorf("reading response: %w", err)
|
|
case line := <-respCh:
|
|
var resp jsonrpcResponse
|
|
if err := json.Unmarshal(line, &resp); err != nil {
|
|
return fmt.Errorf("unmarshaling response: %w", err)
|
|
}
|
|
|
|
if resp.Error != nil {
|
|
return fmt.Errorf("MCP error %d: %s", resp.Error.Code, resp.Error.Message)
|
|
}
|
|
|
|
if result != nil && len(resp.Result) > 0 {
|
|
if err := json.Unmarshal(resp.Result, result); err != nil {
|
|
return fmt.Errorf("unmarshaling result: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// notify sends a JSON-RPC notification (no response expected)
|
|
func (s *mcpServer) notify(method string, params any) error {
|
|
req := jsonrpcRequest{
|
|
JSONRPC: "2.0",
|
|
Method: method,
|
|
Params: params,
|
|
}
|
|
|
|
reqBytes, err := json.Marshal(req)
|
|
if err != nil {
|
|
return fmt.Errorf("marshaling notification: %w", err)
|
|
}
|
|
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
if _, err := s.stdin.Write(append(reqBytes, '\n')); err != nil {
|
|
return fmt.Errorf("writing notification: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// callTool executes a tool call on the MCP server
|
|
func (s *mcpServer) callTool(ctx context.Context, name string, arguments map[string]any) (string, error) {
|
|
params := mcpToolCallParams{
|
|
Name: name,
|
|
Arguments: arguments,
|
|
}
|
|
|
|
var result mcpToolCallResult
|
|
if err := s.call(ctx, "tools/call", params, &result); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Concatenate text content
|
|
var sb strings.Builder
|
|
for _, content := range result.Content {
|
|
if content.Type == "text" {
|
|
sb.WriteString(content.Text)
|
|
}
|
|
}
|
|
|
|
if result.IsError {
|
|
return sb.String(), errors.New(sb.String())
|
|
}
|
|
|
|
return sb.String(), nil
|
|
}
|
|
|
|
// stop shuts down the MCP server
|
|
func (s *mcpServer) stop() error {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
if !s.started {
|
|
return nil
|
|
}
|
|
|
|
// Close stdin to signal shutdown
|
|
if s.stdin != nil {
|
|
s.stdin.Close()
|
|
}
|
|
|
|
// Wait for process with timeout
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
done <- s.cmd.Wait()
|
|
}()
|
|
|
|
select {
|
|
case <-time.After(mcpShutdownTimeout):
|
|
s.cmd.Process.Kill()
|
|
case <-done:
|
|
}
|
|
|
|
s.started = false
|
|
return nil
|
|
}
|
|
|
|
// Tools returns all tools from all MCP servers as api.Tools
|
|
func (m *mcpManager) Tools() api.Tools {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
var tools api.Tools
|
|
|
|
for serverName, srv := range m.servers {
|
|
for _, t := range srv.tools {
|
|
// Namespace tool names: mcp_{servername}_{toolname}
|
|
namespacedName := fmt.Sprintf("mcp_%s_%s", serverName, t.Name)
|
|
|
|
tool := api.Tool{
|
|
Type: "function",
|
|
Function: api.ToolFunction{
|
|
Name: namespacedName,
|
|
Description: t.Description,
|
|
Parameters: convertMCPSchema(t.InputSchema),
|
|
},
|
|
}
|
|
tools = append(tools, tool)
|
|
}
|
|
}
|
|
|
|
return tools
|
|
}
|
|
|
|
// convertMCPSchema converts MCP input schema to api.ToolFunctionParameters
|
|
func convertMCPSchema(schema mcpToolInputSchema) api.ToolFunctionParameters {
|
|
params := api.ToolFunctionParameters{
|
|
Type: schema.Type,
|
|
Required: schema.Required,
|
|
Properties: make(map[string]api.ToolProperty),
|
|
}
|
|
|
|
for name, prop := range schema.Properties {
|
|
if propMap, ok := prop.(map[string]any); ok {
|
|
tp := api.ToolProperty{}
|
|
if t, ok := propMap["type"].(string); ok {
|
|
tp.Type = api.PropertyType{t}
|
|
}
|
|
if d, ok := propMap["description"].(string); ok {
|
|
tp.Description = d
|
|
}
|
|
params.Properties[name] = tp
|
|
}
|
|
}
|
|
|
|
return params
|
|
}
|
|
|
|
// RunToolCall routes a tool call to the appropriate MCP server
|
|
func (m *mcpManager) RunToolCall(call api.ToolCall) (api.Message, bool, error) {
|
|
name := call.Function.Name
|
|
|
|
// Check if this is an MCP tool (mcp_servername_toolname)
|
|
if !strings.HasPrefix(name, "mcp_") {
|
|
return api.Message{}, false, nil
|
|
}
|
|
|
|
// Parse server name and tool name
|
|
rest := strings.TrimPrefix(name, "mcp_")
|
|
idx := strings.Index(rest, "_")
|
|
if idx == -1 {
|
|
return toolMessage(call, fmt.Sprintf("invalid MCP tool name: %s", name)), true, nil
|
|
}
|
|
|
|
serverName := rest[:idx]
|
|
toolName := rest[idx+1:]
|
|
|
|
m.mu.RLock()
|
|
srv, ok := m.servers[serverName]
|
|
m.mu.RUnlock()
|
|
|
|
if !ok {
|
|
return toolMessage(call, fmt.Sprintf("MCP server %q not found", serverName)), true, nil
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), mcpCallTimeout)
|
|
defer cancel()
|
|
|
|
result, err := srv.callTool(ctx, toolName, call.Function.Arguments)
|
|
if err != nil {
|
|
return toolMessage(call, fmt.Sprintf("error: %v", err)), true, nil
|
|
}
|
|
|
|
return toolMessage(call, result), true, nil
|
|
}
|
|
|
|
// Shutdown stops all MCP servers
|
|
func (m *mcpManager) Shutdown() {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
for _, srv := range m.servers {
|
|
srv.stop()
|
|
}
|
|
|
|
m.servers = make(map[string]*mcpServer)
|
|
}
|
|
|
|
// ServerNames returns the names of all running MCP servers
|
|
func (m *mcpManager) ServerNames() []string {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
names := make([]string, 0, len(m.servers))
|
|
for name := range m.servers {
|
|
names = append(names, name)
|
|
}
|
|
return names
|
|
}
|
|
|
|
// ToolCount returns the total number of tools across all servers
|
|
func (m *mcpManager) ToolCount() int {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
|
|
count := 0
|
|
for _, srv := range m.servers {
|
|
count += len(srv.tools)
|
|
}
|
|
return count
|
|
}
|