diff --git a/cmd/eval/README.md b/cmd/eval/README.md new file mode 100644 index 000000000..25be8c274 --- /dev/null +++ b/cmd/eval/README.md @@ -0,0 +1,50 @@ +# eval + +Evaluation tool for testing Ollama models. + +## Usage + +Run all tests: + +```bash +go run . -model llama3.2:latest +``` + +Run specific suite: + +```bash +go run . -model llama3.2:latest -suite tool-calling-basic -v +``` + +List available suites: + +```bash +go run . -list +``` + +## Adding Tests + +Edit `suites.go` to add new test suites. Each test needs: + +- `Name`: test identifier +- `Prompt`: what to send to the model +- `Check`: function to validate the response + +Example: + +```go +{ + Name: "my-test", + Prompt: "What is 2+2?", + Check: Contains("4"), +} +``` + +Available check functions: + +- `HasResponse()` - response is non-empty +- `Contains(s)` - response contains substring +- `CallsTool(name)` - model called specific tool +- `NoTools()` - model called no tools +- `MinTools(n)` - model called at least n tools +- `All(checks...)` - all checks pass diff --git a/cmd/eval/eval.go b/cmd/eval/eval.go new file mode 100644 index 000000000..9eb3e20be --- /dev/null +++ b/cmd/eval/eval.go @@ -0,0 +1,151 @@ +package main + +import ( + "context" + "strings" + "time" + + "github.com/ollama/ollama/api" +) + +// Test is a single evaluation test +type Test struct { + Name string + Prompt string + System string + Tools []api.Tool + Think bool + Options map[string]any + Check func(response string, tools []api.ToolCall) bool +} + +// Suite is a collection of tests +type Suite struct { + Name string + Tests []Test +} + +// Result holds test execution results +type Result struct { + Name string + Passed bool + Error error + Duration time.Duration + Response string + Tools []string + ToolCalls []api.ToolCall + Thinking bool +} + +// Run executes a test against a model +func Run(ctx context.Context, client *api.Client, model string, test Test) Result { + result := Result{Name: test.Name} + + req := &api.ChatRequest{ + Model: model, + Messages: []api.Message{ + {Role: "user", Content: test.Prompt}, + }, + Options: test.Options, + } + + if test.System != "" { + req.Messages = append([]api.Message{ + {Role: "system", Content: test.System}, + }, req.Messages...) + } + + if len(test.Tools) > 0 { + req.Tools = test.Tools + } + + if test.Think { + req.Think = &api.ThinkValue{Value: true} + } + + var resp strings.Builder + var toolCalls []api.ToolCall + + start := time.Now() + err := client.Chat(ctx, req, func(r api.ChatResponse) error { + resp.WriteString(r.Message.Content) + if r.Message.Thinking != "" { + result.Thinking = true + } + toolCalls = append(toolCalls, r.Message.ToolCalls...) + return nil + }) + result.Duration = time.Since(start) + + if err != nil { + result.Error = err + return result + } + + result.Response = resp.String() + result.Tools = uniqueToolNames(toolCalls) + result.ToolCalls = toolCalls + result.Passed = test.Check(result.Response, toolCalls) + + return result +} + +func uniqueToolNames(calls []api.ToolCall) []string { + seen := make(map[string]bool) + var names []string + for _, c := range calls { + if !seen[c.Function.Name] { + seen[c.Function.Name] = true + names = append(names, c.Function.Name) + } + } + return names +} + +// Check functions for common test patterns + +func HasResponse() func(string, []api.ToolCall) bool { + return func(resp string, _ []api.ToolCall) bool { + return strings.TrimSpace(resp) != "" + } +} + +func Contains(s string) func(string, []api.ToolCall) bool { + return func(resp string, _ []api.ToolCall) bool { + return strings.Contains(strings.ToLower(resp), strings.ToLower(s)) + } +} + +func CallsTool(name string) func(string, []api.ToolCall) bool { + return func(_ string, tools []api.ToolCall) bool { + for _, t := range tools { + if t.Function.Name == name { + return true + } + } + return false + } +} + +func NoTools() func(string, []api.ToolCall) bool { + return func(_ string, tools []api.ToolCall) bool { + return len(tools) == 0 + } +} + +func MinTools(n int) func(string, []api.ToolCall) bool { + return func(_ string, tools []api.ToolCall) bool { + return len(tools) >= n + } +} + +func All(checks ...func(string, []api.ToolCall) bool) func(string, []api.ToolCall) bool { + return func(resp string, tools []api.ToolCall) bool { + for _, check := range checks { + if !check(resp, tools) { + return false + } + } + return true + } +} diff --git a/cmd/eval/main.go b/cmd/eval/main.go new file mode 100644 index 000000000..2098023f4 --- /dev/null +++ b/cmd/eval/main.go @@ -0,0 +1,217 @@ +package main + +import ( + "context" + "encoding/json" + "flag" + "fmt" + "os" + "strings" + "time" + + "github.com/ollama/ollama/api" +) + +func main() { + model := flag.String("model", "", "model to evaluate") + suite := flag.String("suite", "", "comma-separated list of suites to run (empty runs all)") + list := flag.Bool("list", false, "list available suites") + verbose := flag.Bool("v", false, "verbose output") + timeout := flag.Int("timeout", 60, "timeout per test in seconds") + export := flag.String("export", "eval-results.json", "export results to file") + flag.Parse() + + if *list { + for _, s := range suites { + fmt.Printf("%s (%d tests)\n", s.Name, len(s.Tests)) + } + return + } + + if *model == "" { + fmt.Fprintf(os.Stderr, "error: -model parameter is required\n") + os.Exit(1) + } + + client, err := api.ClientFromEnvironment() + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + if err := client.Heartbeat(ctx); err != nil { + cancel() + fmt.Fprintf(os.Stderr, "error: cannot connect to ollama\n") + os.Exit(1) + } + cancel() + + selected := suites + if *suite != "" { + suiteNames := strings.Split(*suite, ",") + selected = []Suite{} + var notFound []string + + for _, name := range suiteNames { + name = strings.TrimSpace(name) + if name == "" { + continue + } + + found := false + for _, s := range suites { + if s.Name == name { + selected = append(selected, s) + found = true + break + } + } + if !found { + notFound = append(notFound, name) + } + } + + if len(notFound) > 0 { + fmt.Fprintf(os.Stderr, "error: suite(s) not found: %s\n", strings.Join(notFound, ", ")) + os.Exit(1) + } + } + + var results []Result + for _, s := range selected { + if *verbose { + fmt.Printf("\n%s (%d tests)\n", s.Name, len(s.Tests)) + } + for i, test := range s.Tests { + if test.Options == nil { + test.Options = map[string]any{"temperature": 0.1} + } + if test.Check == nil { + test.Check = HasResponse() + } + + if *verbose { + fmt.Printf(" [%d/%d] %s... ", i+1, len(s.Tests), test.Name) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*timeout)*time.Second) + result := Run(ctx, client, *model, test) + cancel() + + results = append(results, result) + + if *verbose { + if result.Error != nil { + fmt.Printf("ERROR: %v\n", result.Error) + } else if result.Passed { + fmt.Printf("PASS (%.2fs)", result.Duration.Seconds()) + if len(result.Tools) > 0 || result.Thinking { + fmt.Printf(" [") + if len(result.Tools) > 0 { + fmt.Printf("tools: %s", strings.Join(result.Tools, ",")) + } + if result.Thinking { + if len(result.Tools) > 0 { + fmt.Printf(", ") + } + fmt.Printf("thinking") + } + fmt.Printf("]") + } + fmt.Println() + + // Print tool calls with details + if len(result.ToolCalls) > 0 { + fmt.Printf(" Tool Calls:\n") + for _, tc := range result.ToolCalls { + argsJSON, _ := json.Marshal(tc.Function.Arguments) + fmt.Printf(" - %s: %s\n", tc.Function.Name, string(argsJSON)) + } + } + + // Print response if there is one + if result.Response != "" { + fmt.Printf(" Response: %s\n", result.Response) + } + } else { + fmt.Printf("FAIL (%.2fs)\n", result.Duration.Seconds()) + + // Print tool calls with details even on failure + if len(result.ToolCalls) > 0 { + fmt.Printf(" Tool Calls:\n") + for _, tc := range result.ToolCalls { + argsJSON, _ := json.Marshal(tc.Function.Arguments) + fmt.Printf(" - %s: %s\n", tc.Function.Name, string(argsJSON)) + } + } + + // Print response even on failure + if result.Response != "" { + fmt.Printf(" Response: %s\n", result.Response) + } + } + } + } + } + + printSummary(results) + + if *export != "" { + if err := writeJSON(*export, results); err != nil { + fmt.Fprintf(os.Stderr, "warning: export failed: %v\n", err) + } else if *verbose { + fmt.Printf("\nResults: %s\n", *export) + } + } + + if anyFailed(results) { + os.Exit(1) + } +} + +func printSummary(results []Result) { + var passed, failed, errors int + for _, r := range results { + if r.Error != nil { + errors++ + } else if r.Passed { + passed++ + } else { + failed++ + } + } + + total := len(results) + rate := 0.0 + if total > 0 { + rate = float64(passed) / float64(total) * 100 + } + + fmt.Printf("\n%d/%d passed (%.1f%%)", passed, total, rate) + if errors > 0 { + fmt.Printf(", %d errors", errors) + } + fmt.Println() +} + +func anyFailed(results []Result) bool { + for _, r := range results { + if !r.Passed || r.Error != nil { + return true + } + } + return false +} + +func writeJSON(path string, results []Result) error { + f, err := os.Create(path) + if err != nil { + return err + } + defer f.Close() + + enc := json.NewEncoder(f) + enc.SetIndent("", " ") + return enc.Encode(results) +} diff --git a/cmd/eval/suites.go b/cmd/eval/suites.go new file mode 100644 index 000000000..d92d9a4ac --- /dev/null +++ b/cmd/eval/suites.go @@ -0,0 +1,178 @@ +package main + +import "github.com/ollama/ollama/api" + +var suites = []Suite{ + { + Name: "basic-qa", + Tests: []Test{ + { + Name: "simple-math", + Prompt: "What is 2+2? Reply with just the number.", + Check: Contains("4"), + }, + { + Name: "capital-city", + Prompt: "What is the capital of France? Reply with just the city name.", + Check: Contains("Paris"), + }, + { + Name: "greeting", + Prompt: "Say hello", + Check: HasResponse(), + }, + }, + }, + { + Name: "reasoning", + Tests: []Test{ + { + Name: "logic-puzzle", + Prompt: "If all roses are flowers and some flowers fade quickly, can we conclude that some roses fade quickly? Answer yes or no.", + Check: Contains("no"), + }, + { + Name: "counting", + Prompt: "How many letters are in the word 'HELLO'?", + Check: Contains("5"), + }, + }, + }, + { + Name: "instruction-following", + Tests: []Test{ + { + Name: "json-output", + Prompt: "Reply with a JSON object containing a 'status' field set to 'ok'.", + Check: All(Contains("status"), Contains("ok")), + }, + { + Name: "system-prompt", + Prompt: "What is your name?", + System: "You are a helpful assistant named TestBot. When asked your name, always respond with 'TestBot'.", + Check: Contains("TestBot"), + }, + }, + }, + { + Name: "tool-calling-basic", + Tests: []Test{ + { + Name: "single-tool", + Prompt: "What's the weather like in San Francisco?", + Tools: []api.Tool{weatherTool}, + Check: CallsTool("get_weather"), + }, + { + Name: "tool-selection", + Prompt: "What time is it in Tokyo?", + Tools: []api.Tool{weatherTool, timeTool}, + Check: CallsTool("get_time"), + }, + { + Name: "no-tool-needed", + Prompt: "What is 2+2?", + Tools: []api.Tool{weatherTool, timeTool}, + Check: NoTools(), + }, + }, + }, + { + Name: "tool-calling-advanced", + Tests: []Test{ + { + Name: "parallel-calls", + Prompt: "Get the weather in both New York and Los Angeles.", + Tools: []api.Tool{weatherTool}, + Check: All(CallsTool("get_weather"), MinTools(2)), + }, + { + Name: "multi-param", + Prompt: "Search for Italian restaurants with prices between $20 and $40.", + Tools: []api.Tool{restaurantTool}, + Check: CallsTool("search_restaurants"), + }, + }, + }, + { + Name: "tool-calling-thinking", + Tests: []Test{ + { + Name: "thinking-before-tool", + Prompt: "I need to know the weather in Paris before I decide what to pack.", + Tools: []api.Tool{weatherTool}, + Think: true, + Check: CallsTool("get_weather"), + }, + { + Name: "thinking-multi-tool", + Prompt: "I'm planning a trip to London. I need to know what time it is there and what the weather is like.", + Tools: []api.Tool{weatherTool, timeTool}, + Think: true, + Check: MinTools(1), + }, + }, + }, +} + +var weatherTool = api.Tool{ + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]api.ToolProperty{ + "location": { + Type: api.PropertyType{"string"}, + Description: "The city and state", + }, + }, + }, + }, +} + +var timeTool = api.Tool{ + Type: "function", + Function: api.ToolFunction{ + Name: "get_time", + Description: "Get the current time in a timezone", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"timezone"}, + Properties: map[string]api.ToolProperty{ + "timezone": { + Type: api.PropertyType{"string"}, + Description: "The timezone name", + }, + }, + }, + }, +} + +var restaurantTool = api.Tool{ + Type: "function", + Function: api.ToolFunction{ + Name: "search_restaurants", + Description: "Search for restaurants", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"cuisine"}, + Properties: map[string]api.ToolProperty{ + "cuisine": { + Type: api.PropertyType{"string"}, + Description: "Type of cuisine", + }, + "min_price": { + Type: api.PropertyType{"number"}, + Description: "Minimum price", + }, + "max_price": { + Type: api.PropertyType{"number"}, + Description: "Maximum price", + }, + }, + }, + }, +}