Compare commits

..

6 Commits

Author SHA1 Message Date
Bruce MacDonald
375a662775 cmd: automatically open the browser to register
If a key is not found to belong to any account on ollama.com automatically
open the browser to allow the user to connect their key to an account. This
also outputs a code derived from the key that a user can look at to verify
that they are connecting the device they expect.
2024-12-02 16:41:02 -08:00
Bruce MacDonald
ae9165d661 remove images_test.go (uses filesystem key) 2024-11-27 15:52:30 -08:00
Bruce MacDonald
a262b86a5e fix lint checks 2024-11-27 15:45:45 -08:00
Bruce MacDonald
4d5d3c3276 Update error.go 2024-11-27 15:24:08 -08:00
Bruce MacDonald
ea90ee7415 Update cmd.go 2024-11-27 15:22:27 -08:00
Bruce MacDonald
40134c6587 server: show user feedback when key is anonymous
When an ollama key is not registered with any account on ollama.com this is
not obvious. In the current CLI an error message that the user is not
authorized is displayed. This change brings back previous behavior to show
the user their key and where they should add it. It protects against adding
unexpected keys by checking that the key is available locally.

A follow-up change should add structured errors from the API. This change
just relies on a known error message.
2024-11-27 15:01:12 -08:00
20 changed files with 224 additions and 490 deletions

View File

@@ -346,9 +346,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page) - [Web management](https://github.com/lemonit-eric-mao/ollama-web-management) (Web management page)
- [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.) - [Promptery](https://github.com/promptery/promptery) (desktop client for Ollama.)
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama) - [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
- [SpaceLlama](https://github.com/tcsenpai/spacellama) (Firefox and Chrome extension to quickly summarize web pages with ollama in a sidebar)
- [YouLama](https://github.com/tcsenpai/youlama) (Webapp to quickly summarize any YouTube video, supporting Invidious as well)
- [DualMind](https://github.com/tcsenpai/dualmind) (Experimental app allowing two models to talk to each other in the terminal or in a web interface)
- [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol) - [ollamarama-matrix](https://github.com/h1ddenpr0cess20/ollamarama-matrix) (Ollama chatbot for the Matrix chat protocol)
- [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app) - [ollama-chat-app](https://github.com/anan1213095357/ollama-chat-app) (Flutter-based chat app)
- [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings) - [Perfect Memory AI](https://www.perfectmemory.ai/) (Productivity AI assists personalized by what you have seen on your screen, heard and said in the meetings)
@@ -359,7 +356,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Nosia](https://github.com/nosia-ai/nosia) (Easy to install and use RAG platform based on Ollama) - [Nosia](https://github.com/nosia-ai/nosia) (Easy to install and use RAG platform based on Ollama)
- [Witsy](https://github.com/nbonamy/witsy) (An AI Desktop application avaiable for Mac/Windows/Linux) - [Witsy](https://github.com/nbonamy/witsy) (An AI Desktop application avaiable for Mac/Windows/Linux)
- [Abbey](https://github.com/US-Artificial-Intelligence/abbey) (A configurable AI interface server with notebooks, document storage, and YouTube support) - [Abbey](https://github.com/US-Artificial-Intelligence/abbey) (A configurable AI interface server with notebooks, document storage, and YouTube support)
- [Minima](https://github.com/dmayboroda/minima) (RAG with on-premises or fully local workflow)
### Cloud ### Cloud

View File

@@ -146,7 +146,6 @@ type ToolCall struct {
} }
type ToolCallFunction struct { type ToolCallFunction struct {
Index int `json:"index,omitempty"`
Name string `json:"name"` Name string `json:"name"`
Arguments ToolCallFunctionArguments `json:"arguments"` Arguments ToolCallFunctionArguments `json:"arguments"`
} }

View File

@@ -8,6 +8,7 @@ import (
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/base64"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
@@ -16,9 +17,11 @@ import (
"math" "math"
"net" "net"
"net/http" "net/http"
"net/url"
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"regexp"
"runtime" "runtime"
"strconv" "strconv"
"strings" "strings"
@@ -29,16 +32,19 @@ import (
"github.com/containerd/console" "github.com/containerd/console"
"github.com/mattn/go-runewidth" "github.com/mattn/go-runewidth"
"github.com/olekukonko/tablewriter" "github.com/olekukonko/tablewriter"
"github.com/pkg/browser"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/term" "golang.org/x/term"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server" "github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
@@ -513,6 +519,76 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generate(cmd, opts) return generate(cmd, opts)
} }
func generateFingerprint(key string) string {
hash := sha256.Sum256([]byte(key))
fingerprint := base64.RawURLEncoding.EncodeToString(hash[:6])
var formatted strings.Builder
for i, char := range fingerprint {
if i > 0 && i%2 == 0 {
formatted.WriteRune('-')
}
formatted.WriteRune(char)
}
return formatted.String()
}
// tryConnect handles key validation when a connection fails due to an unknown key.
// It attempts to open the browser for interactive sessions to let users connect their key,
// falling back to command-line instructions for non-interactive sessions.
// Returns nil if browser flow succeeds, or an error with connection instructions otherwise.
func tryConnect(unknownKeyErr error) error {
// find SSH public key in the error message
// TODO (brucemacd): the API should return structured errors so that this message parsing isn't needed
sshKeyPattern := `ssh-\w+ [^\s"]+`
re := regexp.MustCompile(sshKeyPattern)
matches := re.FindStringSubmatch(unknownKeyErr.Error())
if len(matches) > 0 {
serverPubKey := matches[0]
localPubKey, err := auth.GetPublicKey()
if err != nil {
return unknownKeyErr
}
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
// try the ollama service public key
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
if err != nil {
return unknownKeyErr
}
localPubKey = strings.TrimSpace(string(svcPubKey))
}
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
if serverPubKey != localPubKey {
return unknownKeyErr
}
if term.IsTerminal(int(os.Stdout.Fd())) {
// URL encode the key and device name for the browser URL
encodedKey := base64.RawURLEncoding.EncodeToString([]byte(localPubKey))
d, _ := os.Hostname()
encodedDevice := url.QueryEscape(d)
browserURL := fmt.Sprintf("https://ollama.com/connect?host=%s&key=%s", encodedDevice, encodedKey)
if err := browser.OpenURL(browserURL); err == nil {
fmt.Printf("\nOpening browser to add your key...\n")
fmt.Printf("\nCheck that this code matches what is shown in your browser:\n")
fmt.Printf("\n %s\n", generateFingerprint(localPubKey))
return nil
}
}
// only return error for non-interactive terminals or if browser opening failed
return fmt.Errorf("%s\nAdd your key at:\nhttps://ollama.com/settings/keys", unknownKeyErr.Error())
}
return unknownKeyErr
}
func PushHandler(cmd *cobra.Command, args []string) error { func PushHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
@@ -561,13 +637,22 @@ func PushHandler(cmd *cobra.Command, args []string) error {
request := api.PushRequest{Name: args[0], Insecure: insecure} request := api.PushRequest{Name: args[0], Insecure: insecure}
n := model.ParseName(args[0]) n := model.ParseName(args[0])
isOllamaHost := strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com")
if err := client.Push(cmd.Context(), &request, fn); err != nil { if err := client.Push(cmd.Context(), &request, fn); err != nil {
if spinner != nil { if spinner != nil {
spinner.Stop() spinner.Stop()
} }
if p != nil {
p.Stop()
}
if strings.Contains(err.Error(), "access denied") { if strings.Contains(err.Error(), "access denied") {
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own") return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
} }
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
// the user has not added their ollama key to ollama.com
// return an error with a more user-friendly message
return tryConnect(err)
}
return err return err
} }
@@ -575,7 +660,7 @@ func PushHandler(cmd *cobra.Command, args []string) error {
spinner.Stop() spinner.Stop()
destination := n.String() destination := n.String()
if strings.HasSuffix(n.Host, ".ollama.ai") || strings.HasSuffix(n.Host, ".ollama.com") { if isOllamaHost {
destination = "https://ollama.com/" + strings.TrimSuffix(n.DisplayShortest(), ":latest") destination = "https://ollama.com/" + strings.TrimSuffix(n.DisplayShortest(), ":latest")
} }
fmt.Printf("\nYou can find your model at:\n\n") fmt.Printf("\nYou can find your model at:\n\n")

View File

@@ -8,6 +8,7 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
"path/filepath"
"strings" "strings"
"testing" "testing"
@@ -15,6 +16,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/errtypes"
) )
func TestShowInfo(t *testing.T) { func TestShowInfo(t *testing.T) {
@@ -179,14 +181,18 @@ Weigh anchor!
t.Run("license", func(t *testing.T) { t.Run("license", func(t *testing.T) {
var b bytes.Buffer var b bytes.Buffer
license := "MIT License\nCopyright (c) Ollama\n" license, err := os.ReadFile(filepath.Join("..", "LICENSE"))
if err != nil {
t.Fatal(err)
}
if err := showInfo(&api.ShowResponse{ if err := showInfo(&api.ShowResponse{
Details: api.ModelDetails{ Details: api.ModelDetails{
Family: "test", Family: "test",
ParameterSize: "7B", ParameterSize: "7B",
QuantizationLevel: "FP16", QuantizationLevel: "FP16",
}, },
License: license, License: string(license),
}, &b); err != nil { }, &b); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -368,15 +374,13 @@ func TestGetModelfileName(t *testing.T) {
func TestPushHandler(t *testing.T) { func TestPushHandler(t *testing.T) {
tests := []struct { tests := []struct {
name string
modelName string modelName string
serverResponse map[string]func(w http.ResponseWriter, r *http.Request) serverResponse map[string]func(w http.ResponseWriter, r *http.Request)
expectedError string expectedError string
expectedOutput string expectedOutput string
}{ }{
{ {
name: "successful push", modelName: "successful-push",
modelName: "test-model",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){ serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/push": func(w http.ResponseWriter, r *http.Request) { "/api/push": func(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost { if r.Method != http.MethodPost {
@@ -389,8 +393,8 @@ func TestPushHandler(t *testing.T) {
return return
} }
if req.Name != "test-model" { if req.Name != "successful-push" {
t.Errorf("expected model name 'test-model', got %s", req.Name) t.Errorf("expected model name 'successful-push', got %s", req.Name)
} }
// Simulate progress updates // Simulate progress updates
@@ -409,11 +413,10 @@ func TestPushHandler(t *testing.T) {
} }
}, },
}, },
expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/test-model\n", expectedOutput: "\nYou can find your model at:\n\n\thttps://ollama.com/successful-push\n",
}, },
{ {
name: "unauthorized push", modelName: "unauthorized-push",
modelName: "unauthorized-model",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){ serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/push": func(w http.ResponseWriter, r *http.Request) { "/api/push": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@@ -428,10 +431,29 @@ func TestPushHandler(t *testing.T) {
}, },
expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own", expectedError: "you are not authorized to push to this namespace, create the model under a namespace you own",
}, },
{
modelName: "unknown-key-err",
serverResponse: map[string]func(w http.ResponseWriter, r *http.Request){
"/api/push": func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
uerr := errtypes.UnknownOllamaKey{
Key: "aaa",
}
err := json.NewEncoder(w).Encode(map[string]string{
"error": uerr.Error(),
})
if err != nil {
t.Fatal(err)
}
},
},
expectedError: "unauthorized: unknown ollama key \"aaa\"",
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.modelName, func(t *testing.T) {
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if handler, ok := tt.serverResponse[r.URL.Path]; ok { if handler, ok := tt.serverResponse[r.URL.Path]; ok {
handler(w, r) handler(w, r)

View File

@@ -49,10 +49,10 @@ Advanced parameters (optional):
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `system`: system message to (overrides what is defined in the `Modelfile`) - `system`: system message to (overrides what is defined in the `Modelfile`)
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`) - `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects - `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API - `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
- `context` (deprecated): the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
#### JSON mode #### JSON mode

View File

@@ -63,7 +63,7 @@ SYSTEM You are Mario from super mario bros, acting as an assistant.
To use this: To use this:
1. Save it as a file (e.g. `Modelfile`) 1. Save it as a file (e.g. `Modelfile`)
2. `ollama create choose-a-model-name -f <location of the file e.g. ./Modelfile>` 2. `ollama create choose-a-model-name -f <location of the file e.g. ./Modelfile>'`
3. `ollama run choose-a-model-name` 3. `ollama run choose-a-model-name`
4. Start using the model! 4. Start using the model!

1
go.mod
View File

@@ -22,6 +22,7 @@ require (
github.com/mattn/go-runewidth v0.0.14 github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0 github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
golang.org/x/image v0.22.0 golang.org/x/image v0.22.0
) )

2
go.sum
View File

@@ -159,6 +159,8 @@ github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2
github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI= github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI=
github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4= github.com/pierrec/lz4/v4 v4.1.8 h1:ieHkV+i2BRzngO4Wd/3HGowuZStgq6QkPsD1eolNAO4=
github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pierrec/lz4/v4 v4.1.8/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c h1:+mdjkGKdHQG3305AYmdv1U2eRNDiU2ErMBj1gwrq8eQ=
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c/go.mod h1:7rwL4CYBLnjLxUqIJNnCWiEdr3bn6IUYi15bNlnbCCU=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=

View File

@@ -93,7 +93,7 @@ make -j
## Vendoring ## Vendoring
Ollama currently vendors [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [ggml](https://github.com/ggerganov/ggml) through a vendoring model. While we generally strive to contribute changes back upstream to avoid drift, we cary a small set of patches which are applied to the tracking commit. A set of make targets are available to aid developers in updating to a newer tracking commit, or to work on changes. Ollama currently vendors [llama.cpp](https://github.com/ggerganov/llama.cpp/) and [ggml](https://github.com/ggerganov/ggml) through a vendoring model. While we generally strive to contribute changes back upstream to avoid drift, we cary a small set of patches which are applied to the tracking commit. A set of make targets are available to aid developers in updating to a newer tracking commit, or to work on changes.
If you update the vendoring code, start by running the following command to establish the tracking llama.cpp repo in the `./vendor/` directory. If you update the vendoring code, start by running the following command to establish the tracking llama.cpp repo in the `./vendor/` directory.
@@ -105,35 +105,35 @@ make apply-patches
**Pin to new base commit** **Pin to new base commit**
To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring` To update to a newer base commit, select the upstream git tag or commit and update `llama/vendoring.env`
#### Applying patches #### Applying patches
When updating to a newer base commit, the existing patches may not apply cleanly and require manual merge resolution. When updating to a newer base commit, the existing patches may not apply cleanly and require manual merge resolution.
Start by applying the patches. If any of the patches have conflicts, the `git am` will stop at the first failure. Start by applying the patches. If any of the patches have conflicts, the `git am` will stop at the first failure.
``` ```
make apply-patches make apply-patches
``` ```
If you see an error message about a conflict, go into the `./vendor/` directory, and perform merge resolution using your preferred tool to the patch commit which failed. Save the file(s) and continue the patch series with `git am --continue` . If any additional patches fail, follow the same pattern until the full patch series is applied. Once finished, run a final `create-patches` and `sync` target to ensure everything is updated. If you see an error message about a conflict, go into the `./vendor/` directory, and perform merge resolution using your preferred tool to the patch commit which failed. Save the file(s) and continue the patch series with `git am --continue` . If any additional patches fail, follow the same pattern until the full patch series is applied. Once finished, run a final `create-patches` and `sync` target to ensure everything is updated.
``` ```
make create-patches sync make create-patches sync
``` ```
Build and test Ollama, and make any necessary changes to the Go code based on the new base commit. Submit your PR to the Ollama repo. Build and test Ollama, and make any necessary changes to the Go code based on the new base commit. Submit your PR to the Ollama repo.
### Generating Patches ### Generating Patches
When working on new fixes or features that impact vendored code, use the following model. First get a clean tracking repo with all current patches applied: When working on new fixes or features that impact vendored code, use the following model. First get a clean tracking repo with all current patches applied:
``` ```
make apply-patches make apply-patches
``` ```
Now edit the upstream native code in the `./vendor/` directory. You do not need to commit every change in order to build, a dirty working tree in the tracking repo is OK while developing. Simply save in your editor, and run the following to refresh the vendored code with your changes, build the backend(s) and build ollama: Now edit the upstream native code in the `./vendor/` directory. You do not need to commit every change in order to build, a dirty working tree in the tracking repo is OK while developing. Simply save in your editor, and run the following to refresh the vendored code with your changes, build the backend(s) and build ollama:
``` ```
make sync make sync
@@ -142,9 +142,9 @@ go build .
``` ```
> [!IMPORTANT] > [!IMPORTANT]
> Do **NOT** run `apply-patches` while you're iterating as that will reset the tracking repo. It will detect a dirty tree and abort, but if your tree is clean and you accidentally ran this target, use `git reflog` to recover your commit(s). > Do **NOT** run `apply-patches` while you're iterating as that will reset the tracking repo. It will detect a dirty tree and abort, but if your tree is clean and you accidentally ran this target, use `git reflog` to recover your commit(s).
Iterate until you're ready to submit PRs. Once your code is ready, commit a change in the `./vendor/` directory, then generate the patches for ollama with Iterate until you're ready to submit PRs. Once your code is ready, commit a change in the `./vendor/` directory, then generate the patches for ollama with
``` ```
make create-patches make create-patches
@@ -157,4 +157,4 @@ In your `./vendor/` directory, create a branch, and cherry-pick the new commit t
Commit the changes in the ollama repo and submit a PR to Ollama, which will include the vendored code update with your change, along with the patches. Commit the changes in the ollama repo and submit a PR to Ollama, which will include the vendored code update with your change, along with the patches.
After your PR upstream is merged, follow the **Updating Base Commit** instructions above, however first remove your patch before running `apply-patches` since the new base commit contains your change already. After your PR upstream is merged, follow the **Updating Base Commit** instructions above, however first remove your patch before running `apply-patches` since the new base commit contains your change already.

View File

@@ -833,21 +833,10 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
} }
} }
type multiLPath []string
func (m *multiLPath) Set(value string) error {
*m = append(*m, value)
return nil
}
func (m *multiLPath) String() string {
return strings.Join(*m, ", ")
}
func (s *Server) loadModel( func (s *Server) loadModel(
params llama.ModelParams, params llama.ModelParams,
mpath string, mpath string,
lpath multiLPath, lpath string,
ppath string, ppath string,
kvSize int, kvSize int,
flashAttention bool, flashAttention bool,
@@ -868,12 +857,10 @@ func (s *Server) loadModel(
panic(err) panic(err)
} }
if lpath.String() != "" { if lpath != "" {
for _, path := range lpath { err := s.model.ApplyLoraFromFile(s.lc, lpath, 1.0, threads)
err := s.model.ApplyLoraFromFile(s.lc, path, 1.0, threads) if err != nil {
if err != nil { panic(err)
panic(err)
}
} }
} }
@@ -903,6 +890,7 @@ func main() {
mainGpu := flag.Int("main-gpu", 0, "Main GPU") mainGpu := flag.Int("main-gpu", 0, "Main GPU")
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention") flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size") kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
lpath := flag.String("lora", "", "Path to lora layer file")
port := flag.Int("port", 8080, "Port to expose the server on") port := flag.Int("port", 8080, "Port to expose the server on")
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation") threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
verbose := flag.Bool("verbose", false, "verbose output (default: disabled)") verbose := flag.Bool("verbose", false, "verbose output (default: disabled)")
@@ -912,9 +900,6 @@ func main() {
multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users") multiUserCache := flag.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
requirements := flag.Bool("requirements", false, "print json requirement information") requirements := flag.Bool("requirements", false, "print json requirement information")
var lpaths multiLPath
flag.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)")
flag.Parse() flag.Parse()
if *requirements { if *requirements {
printRequirements(os.Stdout) printRequirements(os.Stdout)
@@ -961,7 +946,7 @@ func main() {
params := llama.ModelParams{ params := llama.ModelParams{
NumGpuLayers: *nGpuLayers, NumGpuLayers: *nGpuLayers,
MainGpu: *mainGpu, MainGpu: *mainGpu,
UseMmap: !*noMmap && lpaths.String() == "", UseMmap: !*noMmap && *lpath == "",
UseMlock: *mlock, UseMlock: *mlock,
TensorSplit: tensorSplitFloats, TensorSplit: tensorSplitFloats,
Progress: func(progress float32) { Progress: func(progress float32) {
@@ -970,7 +955,7 @@ func main() {
} }
server.ready.Add(1) server.ready.Add(1)
go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache) go server.loadModel(params, *mpath, *lpath, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache)
server.cond = sync.NewCond(&server.mu) server.cond = sync.NewCond(&server.mu)

View File

@@ -144,6 +144,10 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
// Loop through potential servers // Loop through potential servers
finalErr := errors.New("no suitable llama servers found") finalErr := errors.New("no suitable llama servers found")
if len(adapters) > 1 {
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
}
rDir, err := runners.Refresh(build.EmbedFS) rDir, err := runners.Refresh(build.EmbedFS)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -197,9 +201,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, ggml *GGML, adapter
} }
if len(adapters) > 0 { if len(adapters) > 0 {
for _, adapter := range adapters { // TODO: applying multiple adapters is not supported by the llama.cpp server yet
params = append(params, "--lora", adapter) params = append(params, "--lora", adapters[0])
}
} }
if len(projectors) > 0 { if len(projectors) > 0 {

View File

@@ -140,7 +140,6 @@ type CompletionChunk struct {
type ToolCall struct { type ToolCall struct {
ID string `json:"id"` ID string `json:"id"`
Index int `json:"index"`
Type string `json:"type"` Type string `json:"type"`
Function struct { Function struct {
Name string `json:"name"` Name string `json:"name"`
@@ -201,13 +200,12 @@ func toolCallId() string {
return "call_" + strings.ToLower(string(b)) return "call_" + strings.ToLower(string(b))
} }
func toToolCalls(tc []api.ToolCall) []ToolCall { func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := make([]ToolCall, len(tc)) toolCalls := make([]ToolCall, len(r.Message.ToolCalls))
for i, tc := range tc { for i, tc := range r.Message.ToolCalls {
toolCalls[i].ID = toolCallId() toolCalls[i].ID = toolCallId()
toolCalls[i].Type = "function" toolCalls[i].Type = "function"
toolCalls[i].Function.Name = tc.Function.Name toolCalls[i].Function.Name = tc.Function.Name
toolCalls[i].Index = tc.Function.Index
args, err := json.Marshal(tc.Function.Arguments) args, err := json.Marshal(tc.Function.Arguments)
if err != nil { if err != nil {
@@ -217,11 +215,7 @@ func toToolCalls(tc []api.ToolCall) []ToolCall {
toolCalls[i].Function.Arguments = string(args) toolCalls[i].Function.Arguments = string(args)
} }
return toolCalls
}
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := toToolCalls(r.Message.ToolCalls)
return ChatCompletion{ return ChatCompletion{
Id: id, Id: id,
Object: "chat.completion", Object: "chat.completion",
@@ -250,7 +244,6 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
} }
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
toolCalls := toToolCalls(r.Message.ToolCalls)
return ChatCompletionChunk{ return ChatCompletionChunk{
Id: id, Id: id,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
@@ -259,7 +252,7 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
SystemFingerprint: "fp_ollama", SystemFingerprint: "fp_ollama",
Choices: []ChunkChoice{{ Choices: []ChunkChoice{{
Index: 0, Index: 0,
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls}, Delta: Message{Role: "assistant", Content: r.Message.Content},
FinishReason: func(reason string) *string { FinishReason: func(reason string) *string {
if len(reason) > 0 { if len(reason) > 0 {
return &reason return &reason

View File

@@ -195,86 +195,7 @@ func TestChatMiddleware(t *testing.T) {
Stream: &False, Stream: &False,
}, },
}, },
{
name: "chat handler with streaming tools",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris?"}
],
"stream": true,
"tools": [{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"required": ["location"],
"properties": {
"location": {
"type": "string",
"description": "The city and state"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
}
}
}
}]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris?",
},
},
Tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}{
"location": {
Type: "string",
Description: "The city and state",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
},
},
},
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &True,
},
},
{ {
name: "chat handler error forwarding", name: "chat handler error forwarding",
body: `{ body: `{

View File

@@ -23,13 +23,16 @@ import (
"strings" "strings"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/llama" "github.com/ollama/ollama/llama"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/registry"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
@@ -802,12 +805,6 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
if mp.ProtocolScheme == "http" && !regOpts.Insecure { if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return errors.New("insecure protocol http") return errors.New("insecure protocol http")
} }
if mp.Namespace != strings.ToLower(mp.Namespace) {
return fmt.Errorf("namespace must be lowercase, but is %s", mp.Namespace)
}
if mp.Repository != strings.ToLower(mp.Repository) {
return fmt.Errorf("model name must be lowercase, but is %s", mp.Repository)
}
manifest, _, err := GetManifest(mp) manifest, _, err := GetManifest(mp)
if err != nil { if err != nil {
@@ -986,8 +983,6 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
} }
var errUnauthorized = errors.New("unauthorized: access denied")
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) { func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
for range 2 { for range 2 {
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts) resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
@@ -1025,13 +1020,33 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
if err != nil { if err != nil {
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err) return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
} }
var re registry.Errs
if err := json.Unmarshal(responseBody, &re); err == nil && len(re.Errors) > 0 {
if re.HasCode(registry.ErrCodeAnonymous) {
// if the error is due to anonymous access return a custom error
// this error is used by the CLI to direct a user to add their key to an account
pubKey, nestedErr := auth.GetPublicKey()
if nestedErr != nil {
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
return nil, re
}
return nil, errtypes.UnknownOllamaKey{
Key: pubKey,
}
}
return nil, re
}
// Fallback to returning the raw response if parsing fails
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody) return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
default: default:
return resp, nil return resp, nil
} }
} }
return nil, errUnauthorized // should never be reached
return nil, fmt.Errorf("failed to make upload request")
} }
// testMakeRequestDialContext specifies the dial function for the http client in // testMakeRequestDialContext specifies the dial function for the http client in

View File

@@ -1,50 +0,0 @@
package server
import (
"context"
"strings"
"testing"
"github.com/ollama/ollama/api"
)
func TestPushModel(t *testing.T) {
noOpProgress := func(resp api.ProgressResponse) {}
tests := []struct {
modelStr string
regOpts *registryOptions
wantErr string
}{
{
modelStr: "http://example.com/namespace/repo:tag",
regOpts: &registryOptions{Insecure: false},
wantErr: "insecure protocol http",
},
{
modelStr: "docker://Example/repo:tag",
regOpts: &registryOptions{},
wantErr: "namespace must be lowercase, but is Example",
},
{
modelStr: "docker://example/Repo:tag",
regOpts: &registryOptions{},
wantErr: "model name must be lowercase, but is Repo",
},
}
for _, tt := range tests {
t.Run(tt.modelStr, func(t *testing.T) {
err := PushModel(context.Background(), tt.modelStr, tt.regOpts, noOpProgress)
if tt.wantErr != "" {
if err == nil {
t.Errorf("PushModel() error = %v, wantErr %v", err, tt.wantErr)
} else if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("PushModel() error = %v, wantErr %v", err, tt.wantErr)
}
return
}
})
}
}

View File

@@ -39,7 +39,6 @@ func TestExecuteWithTools(t *testing.T) {
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] {"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true}, The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: {"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true}, [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},

View File

@@ -251,7 +251,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var b bytes.Buffer var b bytes.Buffer
if req.Context != nil { if req.Context != nil {
slog.Warn("the context field is deprecated and will be removed in a future version of Ollama")
s, err := r.Detokenize(c.Request.Context(), req.Context) s, err := r.Detokenize(c.Request.Context(), req.Context)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -1459,7 +1458,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
if err != nil { if err != nil {
slog.Error("chat prompt error", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
@@ -1469,8 +1467,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
ch := make(chan any) ch := make(chan any)
go func() { go func() {
defer close(ch) defer close(ch)
var sb strings.Builder
var toolCallIndex int = 0
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
@@ -1496,37 +1492,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.LoadDuration = checkpointLoaded.Sub(checkpointStart) res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
} }
// TODO: tool call checking and filtering should be moved outside of this callback once streaming ch <- res
// however this was a simple change for now without reworking streaming logic of this (and other)
// handlers
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
ch <- res
return
}
// Streaming tool calls:
// If tools are recognized, use a flag to track the sending of a tool downstream
// This ensures that content is cleared from the message on the last chunk sent
sb.WriteString(r.Content)
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
res.Message.ToolCalls = toolCalls
for i := range toolCalls {
toolCalls[i].Function.Index = toolCallIndex
toolCallIndex++
}
res.Message.Content = ""
sb.Reset()
ch <- res
return
}
if r.Done {
// Send any remaining content if no tool calls were detected
if toolCallIndex == 0 {
res.Message.Content = sb.String()
}
ch <- res
}
}); err != nil { }); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }

View File

@@ -8,7 +8,6 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@@ -26,14 +25,10 @@ type mockRunner struct {
// CompletionRequest is only valid until the next call to Completion // CompletionRequest is only valid until the next call to Completion
llm.CompletionRequest llm.CompletionRequest
llm.CompletionResponse llm.CompletionResponse
CompletionFn func(context.Context, llm.CompletionRequest, func(llm.CompletionResponse)) error
} }
func (m *mockRunner) Completion(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error { func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
m.CompletionRequest = r m.CompletionRequest = r
if m.CompletionFn != nil {
return m.CompletionFn(ctx, r, fn)
}
fn(m.CompletionResponse) fn(m.CompletionResponse)
return nil return nil
} }
@@ -93,14 +88,9 @@ func TestGenerateChat(t *testing.T) {
Model: "test", Model: "test",
Modelfile: fmt.Sprintf(`FROM %s Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """ TEMPLATE """
{{- if .Tools }} {{- if .System }}System: {{ .System }} {{ end }}
{{ .Tools }} {{- if .Prompt }}User: {{ .Prompt }} {{ end }}
{{ end }} {{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
{{- range .Messages }}
{{- .Role }}: {{ .Content }}
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}
{{ end }}"""
`, createBinFile(t, llm.KV{ `, createBinFile(t, llm.KV{
"general.architecture": "llama", "general.architecture": "llama",
"llama.block_count": uint32(1), "llama.block_count": uint32(1),
@@ -273,7 +263,7 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "user: Hello!\n"); diff != "" { if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
@@ -302,7 +292,7 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\n"); diff != "" { if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
@@ -324,7 +314,7 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You can perform magic tricks.\nuser: Hello!\n"); diff != "" { if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
@@ -347,242 +337,12 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("expected status 200, got %d", w.Code) t.Errorf("expected status 200, got %d", w.Code)
} }
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "system: You are a helpful assistant.\nuser: Hello!\nassistant: I can help you with that.\nsystem: You can perform magic tricks.\nuser: Help me write tests.\n"); diff != "" { if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("mismatch (-got +want):\n%s", diff)
} }
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!") checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
}) })
t.Run("messages with tools (non-streaming)", func(t *testing.T) {
if w.Code != http.StatusOK {
t.Fatalf("failed to create test-system model: %d", w.Code)
}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}{
"location": {
Type: "string",
Description: "The city and state",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
},
},
},
},
},
}
mock.CompletionResponse = llm.CompletionResponse{
Content: `{"name":"get_weather","arguments":{"location":"Seattle, WA","unit":"celsius"}}`,
Done: true,
DoneReason: "done",
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
}
streamRequest := true
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "What's the weather in Seattle?"},
},
Tools: tools,
Stream: &streamRequest,
})
if w.Code != http.StatusOK {
var errResp struct {
Error string `json:"error"`
}
if err := json.NewDecoder(w.Body).Decode(&errResp); err != nil {
t.Logf("Failed to decode error response: %v", err)
} else {
t.Logf("Error response: %s", errResp.Error)
}
}
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var resp api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.Message.ToolCalls == nil {
t.Error("expected tool calls, got nil")
}
expectedToolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Seattle, WA",
"unit": "celsius",
},
},
}
if diff := cmp.Diff(resp.Message.ToolCalls[0], expectedToolCall); diff != "" {
t.Errorf("tool call mismatch (-got +want):\n%s", diff)
}
})
t.Run("messages with tools (streaming)", func(t *testing.T) {
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
}{
"location": {
Type: "string",
Description: "The city and state",
},
"unit": {
Type: "string",
Enum: []string{"celsius", "fahrenheit"},
},
},
},
},
},
}
// Simulate streaming response with multiple chunks
var wg sync.WaitGroup
wg.Add(1)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
defer wg.Done()
// Send chunks with small delays to simulate streaming
responses := []llm.CompletionResponse{
{
Content: `{"name":"get_`,
Done: false,
PromptEvalCount: 1,
PromptEvalDuration: 1,
},
{
Content: `weather","arguments":{"location":"Seattle`,
Done: false,
PromptEvalCount: 2,
PromptEvalDuration: 1,
},
{
Content: `, WA","unit":"celsius"}}`,
Done: true,
DoneReason: "tool_call",
PromptEvalCount: 3,
PromptEvalDuration: 1,
},
}
for _, resp := range responses {
select {
case <-ctx.Done():
return ctx.Err()
default:
fn(resp)
time.Sleep(10 * time.Millisecond) // Small delay between chunks
}
}
return nil
}
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "What's the weather in Seattle?"},
},
Tools: tools,
Stream: &stream,
})
wg.Wait()
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
// Read and validate the streamed responses
decoder := json.NewDecoder(w.Body)
var finalToolCall api.ToolCall
for {
var resp api.ChatResponse
if err := decoder.Decode(&resp); err == io.EOF {
break
} else if err != nil {
t.Fatal(err)
}
if resp.Done {
if len(resp.Message.ToolCalls) != 1 {
t.Errorf("expected 1 tool call in final response, got %d", len(resp.Message.ToolCalls))
}
finalToolCall = resp.Message.ToolCalls[0]
}
}
expectedToolCall := api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Seattle, WA",
"unit": "celsius",
},
},
}
if diff := cmp.Diff(finalToolCall, expectedToolCall); diff != "" {
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
}
})
} }
func TestGenerate(t *testing.T) { func TestGenerate(t *testing.T) {

View File

@@ -16,6 +16,6 @@ type UnknownOllamaKey struct {
Key string Key string
} }
func (e *UnknownOllamaKey) Error() string { func (e UnknownOllamaKey) Error() string {
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key)) return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
} }

37
types/registry/error.go Normal file
View File

@@ -0,0 +1,37 @@
package registry
import (
"fmt"
"slices"
"strings"
)
const ErrCodeAnonymous = "ANONYMOUS_ACCESS_DENIED"
type Err struct {
Code string `json:"code"`
Message string `json:"message"`
}
// Errs represents the structure of error responses from the registry
// TODO (brucemacd): this struct should be imported from some shared package that is used between the registry and ollama
type Errs struct {
Errors []Err `json:"errors"`
}
func (e Errs) Error() string {
if len(e.Errors) == 0 {
return "unknown registry error"
}
var msgs []string
for _, err := range e.Errors {
msgs = append(msgs, fmt.Sprintf("%s: %s", err.Code, err.Message))
}
return strings.Join(msgs, "; ")
}
func (e Errs) HasCode(code string) bool {
return slices.ContainsFunc(e.Errors, func(err Err) bool {
return err.Code == code
})
}