Compare commits
59 Commits
v0.1.22
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a859f037da | ||
|
|
27aa2d4a19 | ||
|
|
b9f91a0b36 | ||
|
|
b538dc3858 | ||
|
|
f0e9496c85 | ||
|
|
09a6f76f4c | ||
|
|
e135167484 | ||
|
|
38296ab352 | ||
|
|
f43dea68d1 | ||
|
|
e1f50377f4 | ||
|
|
7913104527 | ||
|
|
bfbf2f7cf7 | ||
|
|
fe3cbd014f | ||
|
|
3d6f48507a | ||
|
|
f3761405c8 | ||
|
|
e49dc9f3d8 | ||
|
|
d125510b4b | ||
|
|
1ca386aa9e | ||
|
|
fb56988014 | ||
|
|
d046bee790 | ||
|
|
f11bf0740b | ||
|
|
8450bf66e6 | ||
|
|
b4e11be8ef | ||
|
|
a896079705 | ||
|
|
583950c828 | ||
|
|
8ac08a0eec | ||
|
|
60f47be64c | ||
|
|
6e56077ada | ||
|
|
98ae9467bb | ||
|
|
b7a24af083 | ||
|
|
c8b1f2369e | ||
|
|
72b12c3be7 | ||
|
|
0632dff3f8 | ||
|
|
509e2dec8a | ||
|
|
78a48de804 | ||
|
|
e7dbb00331 | ||
|
|
c3f9538636 | ||
|
|
2e06ed01d5 | ||
|
|
4072b5879b | ||
|
|
15562e887d | ||
|
|
f2245c7c77 | ||
|
|
e4b9b72f2a | ||
|
|
311f8e0c3f | ||
|
|
f07f8b7a9e | ||
|
|
4c4c730a0a | ||
|
|
e02ecfb6c8 | ||
|
|
c8059b4dcf | ||
|
|
59d87127f5 | ||
|
|
b5cf31b460 | ||
|
|
cc4915e262 | ||
|
|
667a2ba18a | ||
|
|
e054ebe059 | ||
|
|
9d3dcfd0ec | ||
|
|
6e0ea5ecc8 | ||
|
|
6eb3cddcb6 | ||
|
|
a4564232a4 | ||
|
|
a447a083f2 | ||
|
|
681a914990 | ||
|
|
27331ae3a8 |
13
README.md
13
README.md
@@ -200,18 +200,21 @@ brew install cmake go
|
||||
```
|
||||
|
||||
Then generate dependencies:
|
||||
|
||||
```
|
||||
go generate ./...
|
||||
```
|
||||
|
||||
Then build the binary:
|
||||
|
||||
```
|
||||
go build .
|
||||
```
|
||||
|
||||
More detailed instructions can be found in the [developer guide](https://github.com/jmorganca/ollama/blob/main/docs/development.md)
|
||||
|
||||
|
||||
### Running local builds
|
||||
|
||||
Next, start the server:
|
||||
|
||||
```
|
||||
@@ -253,6 +256,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
## Community Integrations
|
||||
|
||||
### Web & Desktop
|
||||
|
||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
|
||||
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
|
||||
@@ -265,7 +269,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Amica](https://github.com/semperai/amica)
|
||||
- [chatd](https://github.com/BruceMacD/chatd)
|
||||
- [Ollama-SwiftUI](https://github.com/kghandour/Ollama-SwiftUI)
|
||||
|
||||
- [MindMac](https://mindmac.app)
|
||||
|
||||
### Terminal
|
||||
|
||||
@@ -278,6 +282,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [gptel Emacs client](https://github.com/karthink/gptel)
|
||||
- [Oatmeal](https://github.com/dustinblackman/oatmeal)
|
||||
- [cmdh](https://github.com/pgibler/cmdh)
|
||||
- [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/).
|
||||
|
||||
### Database
|
||||
|
||||
@@ -304,7 +309,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [LangChainDart](https://github.com/davidmigloz/langchain_dart)
|
||||
- [Semantic Kernel - Python](https://github.com/microsoft/semantic-kernel/tree/main/python/semantic_kernel/connectors/ai/ollama)
|
||||
- [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md)
|
||||
|
||||
- [Ollama for R - rollama](https://github.com/JBGruber/rollama)
|
||||
|
||||
### Mobile
|
||||
|
||||
@@ -326,3 +331,5 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
|
||||
- [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot)
|
||||
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
|
||||
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
|
||||
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace)
|
||||
|
||||
127
api/types.go
127
api/types.go
@@ -34,24 +34,26 @@ func (e StatusError) Error() string {
|
||||
type ImageData []byte
|
||||
|
||||
type GenerateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
System string `json:"system"`
|
||||
Template string `json:"template"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Format string `json:"format"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
System string `json:"system"`
|
||||
Template string `json:"template"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Format string `json:"format"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Format string `json:"format"`
|
||||
Model string `json:"model"`
|
||||
Messages []Message `json:"messages"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Format string `json:"format"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
@@ -126,8 +128,9 @@ type Runner struct {
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
@@ -276,85 +279,20 @@ func (m *Metrics) Summary() {
|
||||
var ErrInvalidOpts = fmt.Errorf("invalid options")
|
||||
|
||||
func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// build map of json struct tags to their types
|
||||
jsonOpts := make(map[string]reflect.StructField)
|
||||
for _, field := range reflect.VisibleFields(typeOpts) {
|
||||
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
|
||||
if jsonTag != "" {
|
||||
jsonOpts[jsonTag] = field
|
||||
err = json.Unmarshal(data, opts)
|
||||
if err != nil {
|
||||
// Custom error handling
|
||||
if jsonErr, ok := err.(*json.UnmarshalTypeError); ok {
|
||||
return fmt.Errorf("invalid type for option '%v': expected %v, got %v", jsonErr.Field, jsonErr.Type, jsonErr.Value)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
invalidOpts := []string{}
|
||||
for key, val := range m {
|
||||
if opt, ok := jsonOpts[key]; ok {
|
||||
field := valueOpts.FieldByName(opt.Name)
|
||||
if field.IsValid() && field.CanSet() {
|
||||
if val == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch field.Kind() {
|
||||
case reflect.Int:
|
||||
switch t := val.(type) {
|
||||
case int64:
|
||||
field.SetInt(t)
|
||||
case float64:
|
||||
// when JSON unmarshals numbers, it uses float64, not int
|
||||
field.SetInt(int64(t))
|
||||
default:
|
||||
return fmt.Errorf("option %q must be of type integer", key)
|
||||
}
|
||||
case reflect.Bool:
|
||||
val, ok := val.(bool)
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of type boolean", key)
|
||||
}
|
||||
field.SetBool(val)
|
||||
case reflect.Float32:
|
||||
// JSON unmarshals to float64
|
||||
val, ok := val.(float64)
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of type float32", key)
|
||||
}
|
||||
field.SetFloat(val)
|
||||
case reflect.String:
|
||||
val, ok := val.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of type string", key)
|
||||
}
|
||||
field.SetString(val)
|
||||
case reflect.Slice:
|
||||
// JSON unmarshals to []interface{}, not []string
|
||||
val, ok := val.([]interface{})
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of type array", key)
|
||||
}
|
||||
// convert []interface{} to []string
|
||||
slice := make([]string, len(val))
|
||||
for i, item := range val {
|
||||
str, ok := item.(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("option %q must be of an array of strings", key)
|
||||
}
|
||||
slice[i] = str
|
||||
}
|
||||
field.Set(reflect.ValueOf(slice))
|
||||
default:
|
||||
return fmt.Errorf("unknown type loading config params: %v", field.Kind())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
invalidOpts = append(invalidOpts, key)
|
||||
}
|
||||
}
|
||||
|
||||
if len(invalidOpts) > 0 {
|
||||
return fmt.Errorf("%w: %v", ErrInvalidOpts, strings.Join(invalidOpts, ", "))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -413,14 +351,19 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||
case float64:
|
||||
if t < 0 {
|
||||
t = math.MaxFloat64
|
||||
d.Duration = time.Duration(t)
|
||||
} else {
|
||||
d.Duration = time.Duration(t * float64(time.Second))
|
||||
}
|
||||
|
||||
d.Duration = time.Duration(t)
|
||||
case string:
|
||||
d.Duration, err = time.ParseDuration(t)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.Duration < 0 {
|
||||
mf := math.MaxFloat64
|
||||
d.Duration = time.Duration(mf)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
107
cmd/cmd.go
107
cmd/cmd.go
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
@@ -146,19 +147,68 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
|
||||
// check if the model exists on the server
|
||||
_, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
||||
show, err := client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
||||
var statusError api.StatusError
|
||||
switch {
|
||||
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
|
||||
if err := PullHandler(cmd, []string{name}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
show, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case err != nil:
|
||||
return err
|
||||
}
|
||||
|
||||
return RunGenerate(cmd, args)
|
||||
interactive := true
|
||||
|
||||
opts := runOptions{
|
||||
Model: args[0],
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]interface{}{},
|
||||
MultiModal: slices.Contains(show.Details.Families, "clip"),
|
||||
ParentModel: show.Details.ParentModel,
|
||||
}
|
||||
|
||||
format, err := cmd.Flags().GetString("format")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.Format = format
|
||||
|
||||
prompts := args[1:]
|
||||
// prepend stdin to the prompt if provided
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
||||
in, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prompts = append([]string{string(in)}, prompts...)
|
||||
opts.WordWrap = false
|
||||
interactive = false
|
||||
}
|
||||
opts.Prompt = strings.Join(prompts, " ")
|
||||
if len(prompts) > 0 {
|
||||
interactive = false
|
||||
}
|
||||
|
||||
nowrap, err := cmd.Flags().GetBool("nowordwrap")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.WordWrap = !nowrap
|
||||
|
||||
if !interactive {
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
|
||||
func PushHandler(cmd *cobra.Command, args []string) error {
|
||||
@@ -410,51 +460,6 @@ func PullHandler(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func RunGenerate(cmd *cobra.Command, args []string) error {
|
||||
interactive := true
|
||||
|
||||
opts := runOptions{
|
||||
Model: args[0],
|
||||
WordWrap: os.Getenv("TERM") == "xterm-256color",
|
||||
Options: map[string]interface{}{},
|
||||
}
|
||||
|
||||
format, err := cmd.Flags().GetString("format")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.Format = format
|
||||
|
||||
prompts := args[1:]
|
||||
// prepend stdin to the prompt if provided
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) {
|
||||
in, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
prompts = append([]string{string(in)}, prompts...)
|
||||
opts.WordWrap = false
|
||||
interactive = false
|
||||
}
|
||||
opts.Prompt = strings.Join(prompts, " ")
|
||||
if len(prompts) > 0 {
|
||||
interactive = false
|
||||
}
|
||||
|
||||
nowrap, err := cmd.Flags().GetBool("nowordwrap")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
opts.WordWrap = !nowrap
|
||||
|
||||
if !interactive {
|
||||
return generate(cmd, opts)
|
||||
}
|
||||
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
|
||||
type generateContextKey string
|
||||
|
||||
type runOptions struct {
|
||||
@@ -630,10 +635,18 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if opts.MultiModal {
|
||||
opts.Prompt, opts.Images, err = extractFileData(opts.Prompt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
request := api.GenerateRequest{
|
||||
Model: opts.Model,
|
||||
Prompt: opts.Prompt,
|
||||
Context: generateContext,
|
||||
Images: opts.Images,
|
||||
Format: opts.Format,
|
||||
System: opts.System,
|
||||
Template: opts.Template,
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -98,6 +99,11 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
|
||||
|
||||
if opts.MultiModal {
|
||||
fmt.Fprintf(os.Stderr, "Use %s to include .jpg or .png images.\n", filepath.FromSlash("/path/to/file"))
|
||||
}
|
||||
|
||||
fmt.Fprintln(os.Stderr, "")
|
||||
}
|
||||
|
||||
@@ -207,6 +213,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
switch multiline {
|
||||
case MultilineSystem:
|
||||
opts.System = sb.String()
|
||||
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
case MultilineTemplate:
|
||||
@@ -226,7 +233,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
fmt.Fprintln(&sb)
|
||||
multiline = MultilinePrompt
|
||||
scanner.Prompt.UseAlt = true
|
||||
break
|
||||
}
|
||||
case scanner.Pasting:
|
||||
fmt.Fprintln(&sb, line)
|
||||
@@ -349,10 +355,13 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
|
||||
if args[1] == "system" {
|
||||
opts.System = sb.String()
|
||||
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
|
||||
fmt.Println("Set system message.")
|
||||
sb.Reset()
|
||||
} else if args[1] == "template" {
|
||||
opts.Template = sb.String()
|
||||
fmt.Println("Set prompt template.")
|
||||
sb.Reset()
|
||||
}
|
||||
|
||||
sb.Reset()
|
||||
@@ -487,29 +496,18 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newMessage.Content = msg
|
||||
|
||||
// reset the context if we find another image
|
||||
// clear all previous images for better responses
|
||||
if len(images) > 0 {
|
||||
newMessage.Images = append(newMessage.Images, images...)
|
||||
// reset the context for the new image
|
||||
opts.Messages = []api.Message{}
|
||||
} else {
|
||||
if len(opts.Messages) > 1 {
|
||||
newMessage.Images = append(newMessage.Images, opts.Messages[len(opts.Messages)-2].Images...)
|
||||
for i := range opts.Messages {
|
||||
opts.Messages[i].Images = nil
|
||||
}
|
||||
}
|
||||
if len(newMessage.Images) == 0 {
|
||||
fmt.Println("This model requires you to add a jpeg, png, or svg image.")
|
||||
fmt.Println()
|
||||
sb.Reset()
|
||||
continue
|
||||
}
|
||||
|
||||
newMessage.Content = msg
|
||||
newMessage.Images = images
|
||||
}
|
||||
|
||||
if opts.System != "" {
|
||||
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
|
||||
}
|
||||
opts.Messages = append(opts.Messages, newMessage)
|
||||
|
||||
assistant, err := chat(cmd, opts)
|
||||
@@ -603,10 +601,10 @@ func extractFileData(input string) (string, []api.ImageData, error) {
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
}
|
||||
fmt.Printf("Couldn't process image: %q\n", err)
|
||||
fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err)
|
||||
return "", imgs, err
|
||||
}
|
||||
fmt.Printf("Added image '%s'\n", nfp)
|
||||
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
|
||||
input = strings.ReplaceAll(input, fp, "")
|
||||
imgs = append(imgs, data)
|
||||
}
|
||||
|
||||
@@ -542,7 +542,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||
"role": "user",
|
||||
"content": "what is in this image?",
|
||||
"images": ["iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC"]
|
||||
},
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
@@ -50,7 +50,8 @@ development and runtime packages.
|
||||
Typically the build scripts will auto-detect CUDA, however, if your Linux distro
|
||||
or installation approach uses unusual paths, you can specify the location by
|
||||
specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
|
||||
libraries, and `CUDACXX` to the location of the nvcc compiler.
|
||||
libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize
|
||||
set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
|
||||
|
||||
Then generate dependencies:
|
||||
|
||||
|
||||
112
docs/import.md
112
docs/import.md
@@ -15,7 +15,7 @@ FROM ./mistral-7b-v0.1.Q4_0.gguf
|
||||
(Optional) many chat models require a prompt template in order to answer correctly. A default prompt template can be specified with the `TEMPLATE` instruction in the `Modelfile`:
|
||||
|
||||
```
|
||||
FROM ./q4_0.bin
|
||||
FROM ./mistral-7b-v0.1.Q4_0.gguf
|
||||
TEMPLATE "[INST] {{ .Prompt }} [/INST]"
|
||||
```
|
||||
|
||||
@@ -37,55 +37,69 @@ ollama run example "What is your favourite condiment?"
|
||||
|
||||
## Importing (PyTorch & Safetensors)
|
||||
|
||||
### Supported models
|
||||
> Importing from PyTorch and Safetensors is a longer process than importing from GGUF. Improvements that make it easier are a work in progress.
|
||||
|
||||
Ollama supports a set of model architectures, with support for more coming soon:
|
||||
### Setup
|
||||
|
||||
- Llama & Mistral
|
||||
- Falcon & RW
|
||||
- BigCode
|
||||
First, clone the `ollama/ollama` repo:
|
||||
|
||||
To view a model's architecture, check the `config.json` file in its HuggingFace repo. You should see an entry under `architectures` (e.g. `LlamaForCausalLM`).
|
||||
```
|
||||
git clone git@github.com:ollama/ollama.git ollama
|
||||
cd ollama
|
||||
```
|
||||
|
||||
### Step 1: Clone the HuggingFace repository (optional)
|
||||
and then fetch its `llama.cpp` submodule:
|
||||
|
||||
```shell
|
||||
git submodule init
|
||||
git submodule update llm/llama.cpp
|
||||
```
|
||||
|
||||
Next, install the Python dependencies:
|
||||
|
||||
```
|
||||
python3 -m venv llm/llama.cpp/.venv
|
||||
source llm/llama.cpp/.venv/bin/activate
|
||||
pip install -r llm/llama.cpp/requirements.txt
|
||||
```
|
||||
|
||||
Then build the `quantize` tool:
|
||||
|
||||
```
|
||||
make -C llm/llama.cpp quantize
|
||||
```
|
||||
|
||||
### Clone the HuggingFace repository (optional)
|
||||
|
||||
If the model is currently hosted in a HuggingFace repository, first clone that repository to download the raw model.
|
||||
|
||||
Install [Git LFS](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage), verify it's installed, and then clone the model's repository:
|
||||
|
||||
```
|
||||
git lfs install
|
||||
git clone https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
|
||||
cd Mistral-7B-Instruct-v0.1
|
||||
git clone https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 model
|
||||
```
|
||||
|
||||
### Step 2: Convert and quantize to a `.bin` file (optional, for PyTorch and Safetensors)
|
||||
### Convert the model
|
||||
|
||||
If the model is in PyTorch or Safetensors format, a [Docker image](https://hub.docker.com/r/ollama/quantize) with the tooling required to convert and quantize models is available.
|
||||
|
||||
First, Install [Docker](https://www.docker.com/get-started/).
|
||||
|
||||
Next, to convert and quantize your model, run:
|
||||
> Note: some model architectures require using specific convert scripts. For example, Qwen models require running `convert-hf-to-gguf.py` instead of `convert.py`
|
||||
|
||||
```
|
||||
docker run --rm -v .:/model ollama/quantize -q q4_0 /model
|
||||
python llm/llama.cpp/convert.py ./model --outtype f16 --outfile converted.bin
|
||||
```
|
||||
|
||||
This will output two files into the directory:
|
||||
### Quantize the model
|
||||
|
||||
- `f16.bin`: the model converted to GGUF
|
||||
- `q4_0.bin` the model quantized to a 4-bit quantization (Ollama will use this file to create the Ollama model)
|
||||
```
|
||||
llm/llama.cpp/quantize converted.bin quantized.bin q4_0
|
||||
```
|
||||
|
||||
### Step 3: Write a `Modelfile`
|
||||
|
||||
Next, create a `Modelfile` for your model:
|
||||
|
||||
```
|
||||
FROM ./q4_0.bin
|
||||
```
|
||||
|
||||
(Optional) many chat models require a prompt template in order to answer correctly. A default prompt template can be specified with the `TEMPLATE` instruction in the `Modelfile`:
|
||||
|
||||
```
|
||||
FROM ./q4_0.bin
|
||||
FROM quantized.bin
|
||||
TEMPLATE "[INST] {{ .Prompt }} [/INST]"
|
||||
```
|
||||
|
||||
@@ -149,47 +163,3 @@ The quantization options are as follow (from highest highest to lowest levels of
|
||||
- `q6_K`
|
||||
- `q8_0`
|
||||
- `f16`
|
||||
|
||||
## Manually converting & quantizing models
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Start by cloning the `llama.cpp` repo to your machine in another directory:
|
||||
|
||||
```
|
||||
git clone https://github.com/ggerganov/llama.cpp.git
|
||||
cd llama.cpp
|
||||
```
|
||||
|
||||
Next, install the Python dependencies:
|
||||
|
||||
```
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
Finally, build the `quantize` tool:
|
||||
|
||||
```
|
||||
make quantize
|
||||
```
|
||||
|
||||
### Convert the model
|
||||
|
||||
Run the correct conversion script for your model architecture:
|
||||
|
||||
```shell
|
||||
# LlamaForCausalLM or MistralForCausalLM
|
||||
python convert.py <path to model directory>
|
||||
|
||||
# FalconForCausalLM
|
||||
python convert-falcon-hf-to-gguf.py <path to model directory>
|
||||
|
||||
# GPTBigCodeForCausalLM
|
||||
python convert-starcoder-hf-to-gguf.py <path to model directory>
|
||||
```
|
||||
|
||||
### Quantize the model
|
||||
|
||||
```
|
||||
quantize <path to model dir>/ggml-model-f32.bin <path to model dir>/q4_0.bin q4_0
|
||||
```
|
||||
|
||||
@@ -12,6 +12,13 @@ On Linux systems with systemd, the logs can be found with this command:
|
||||
journalctl -u ollama
|
||||
```
|
||||
|
||||
When you run Ollama in a container, the logs go to stdout/stderr in the container:
|
||||
|
||||
```shell
|
||||
docker logs <container-name>
|
||||
```
|
||||
(Use `docker ps` to find the container name)
|
||||
|
||||
If manually running `ollama serve` in a terminal, the logs will be on that terminal.
|
||||
|
||||
Join the [Discord](https://discord.gg/ollama) for help interpreting the logs.
|
||||
|
||||
22
gpu/gpu.go
22
gpu/gpu.go
@@ -30,8 +30,8 @@ type handles struct {
|
||||
var gpuMutex sync.Mutex
|
||||
var gpuHandles *handles = nil
|
||||
|
||||
// With our current CUDA compile flags, 5.2 and older will not work properly
|
||||
const CudaComputeMajorMin = 6
|
||||
// With our current CUDA compile flags, older than 5.0 will not work properly
|
||||
var CudaComputeMin = [2]C.int{5, 0}
|
||||
|
||||
// Possible locations for the nvidia-ml library
|
||||
var CudaLinuxGlobs = []string{
|
||||
@@ -122,28 +122,34 @@ func GetGPUInfo() GpuInfo {
|
||||
initGPUHandles()
|
||||
}
|
||||
|
||||
// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
|
||||
cpuVariant := GetCPUVariant()
|
||||
if cpuVariant == "" && runtime.GOARCH == "amd64" {
|
||||
slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
|
||||
}
|
||||
|
||||
var memInfo C.mem_info_t
|
||||
resp := GpuInfo{}
|
||||
if gpuHandles.cuda != nil {
|
||||
if gpuHandles.cuda != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
|
||||
C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
|
||||
if memInfo.err != nil {
|
||||
slog.Info(fmt.Sprintf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err)))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
} else {
|
||||
} else if memInfo.count > 0 {
|
||||
// Verify minimum compute capability
|
||||
var cc C.cuda_compute_capability_t
|
||||
C.cuda_compute_capability(*gpuHandles.cuda, &cc)
|
||||
if cc.err != nil {
|
||||
slog.Info(fmt.Sprintf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err)))
|
||||
C.free(unsafe.Pointer(cc.err))
|
||||
} else if cc.major >= CudaComputeMajorMin {
|
||||
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
|
||||
slog.Info(fmt.Sprintf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
|
||||
resp.Library = "cuda"
|
||||
} else {
|
||||
slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
|
||||
}
|
||||
}
|
||||
} else if gpuHandles.rocm != nil {
|
||||
} else if gpuHandles.rocm != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
|
||||
C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
|
||||
if memInfo.err != nil {
|
||||
slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err)))
|
||||
@@ -151,7 +157,7 @@ func GetGPUInfo() GpuInfo {
|
||||
} else if memInfo.igpu_index >= 0 && memInfo.count == 1 {
|
||||
// Only one GPU detected and it appears to be an integrated GPU - skip it
|
||||
slog.Info("ROCm unsupported integrated GPU detected")
|
||||
} else {
|
||||
} else if memInfo.count > 0 {
|
||||
if memInfo.igpu_index >= 0 {
|
||||
// We have multiple GPUs reported, and one of them is an integrated GPU
|
||||
// so we have to set the env var to bypass it
|
||||
@@ -185,7 +191,7 @@ func GetGPUInfo() GpuInfo {
|
||||
if resp.Library == "" {
|
||||
C.cpu_check_ram(&memInfo)
|
||||
resp.Library = "cpu"
|
||||
resp.Variant = GetCPUVariant()
|
||||
resp.Variant = cpuVariant
|
||||
}
|
||||
if memInfo.err != nil {
|
||||
slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
|
||||
|
||||
@@ -178,7 +178,7 @@ void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) {
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
if (h.handle == NULL) {
|
||||
resp->str = strdup("nvml handle not initialized");
|
||||
resp->str = strdup("rocm handle not initialized");
|
||||
resp->status = 1;
|
||||
return;
|
||||
}
|
||||
@@ -195,4 +195,4 @@ void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) {
|
||||
resp->str = strdup(buf);
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
#endif // __APPLE__
|
||||
|
||||
@@ -4,7 +4,7 @@ package llm
|
||||
#cgo CFLAGS: -I${SRCDIR}/ext_server -I${SRCDIR}/llama.cpp -I${SRCDIR}/llama.cpp/common -I${SRCDIR}/llama.cpp/examples/server
|
||||
#cgo CFLAGS: -DNDEBUG -DLLAMA_SERVER_LIBRARY=1 -D_XOPEN_SOURCE=600 -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
|
||||
#cgo CFLAGS: -Wmissing-noreturn -Wextra -Wcast-qual -Wno-unused-function -Wno-array-bounds
|
||||
#cgo CPPFLAGS: -Ofast -Wextra -Wno-unused-function -Wno-unused-variable -Wno-deprecated-declarations -Wno-unused-but-set-variable
|
||||
#cgo CPPFLAGS: -Ofast -Wextra -Wno-unused-function -Wno-unused-variable -Wno-deprecated-declarations
|
||||
#cgo darwin CFLAGS: -D_DARWIN_C_SOURCE
|
||||
#cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE
|
||||
#cgo darwin CPPFLAGS: -DGGML_USE_METAL -DGGML_METAL_NDEBUG
|
||||
@@ -161,13 +161,10 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts
|
||||
func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
|
||||
resp := newExtServerResp(128)
|
||||
defer freeExtServerResp(resp)
|
||||
var imageData []ImageData
|
||||
|
||||
if len(predict.Images) > 0 {
|
||||
for cnt, i := range predict.Images {
|
||||
imageData = append(imageData, ImageData{Data: i, ID: cnt})
|
||||
}
|
||||
slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images)))
|
||||
}
|
||||
slog.Info(fmt.Sprintf("loaded %d images", len(imageData)))
|
||||
|
||||
request := map[string]any{
|
||||
"prompt": predict.Prompt,
|
||||
@@ -189,7 +186,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
|
||||
"penalize_nl": predict.Options.PenalizeNewline,
|
||||
"seed": predict.Options.Seed,
|
||||
"stop": predict.Options.Stop,
|
||||
"image_data": imageData,
|
||||
"image_data": predict.Images,
|
||||
"cache_prompt": true,
|
||||
}
|
||||
|
||||
|
||||
@@ -26,13 +26,13 @@
|
||||
|
||||
// Expose the llama server as a callable extern "C" API
|
||||
llama_server_context *llama = NULL;
|
||||
std::atomic<bool> ext_server_running(false);
|
||||
std::thread ext_server_thread;
|
||||
|
||||
void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
|
||||
assert(err != NULL && sparams != NULL);
|
||||
log_set_target(stderr);
|
||||
if (!sparams->verbose_logging) {
|
||||
server_verbose = true;
|
||||
log_disable();
|
||||
}
|
||||
|
||||
@@ -122,18 +122,23 @@ void llama_server_start() {
|
||||
assert(llama != NULL);
|
||||
// TODO mutex to protect thread creation
|
||||
ext_server_thread = std::thread([&]() {
|
||||
ext_server_running = true;
|
||||
try {
|
||||
LOG_TEE("llama server main loop starting\n");
|
||||
ggml_time_init();
|
||||
while (ext_server_running.load()) {
|
||||
if (!llama->update_slots()) {
|
||||
LOG_TEE(
|
||||
"unexpected error in llama server update_slots - exiting main "
|
||||
"loop\n");
|
||||
break;
|
||||
}
|
||||
}
|
||||
llama->queue_tasks.on_new_task(std::bind(
|
||||
&llama_server_context::process_single_task, llama, std::placeholders::_1));
|
||||
llama->queue_tasks.on_finish_multitask(std::bind(
|
||||
&llama_server_context::on_finish_multitask, llama, std::placeholders::_1));
|
||||
llama->queue_tasks.on_all_tasks_finished(std::bind(
|
||||
&llama_server_context::run_on_all_tasks_finished, llama));
|
||||
llama->queue_results.on_multitask_update(std::bind(
|
||||
&llama_server_queue::update_multitask,
|
||||
&llama->queue_tasks,
|
||||
std::placeholders::_1,
|
||||
std::placeholders::_2,
|
||||
std::placeholders::_3
|
||||
));
|
||||
llama->queue_tasks.start_loop();
|
||||
} catch (std::exception &e) {
|
||||
LOG_TEE("caught exception in llama server main loop: %s\n", e.what());
|
||||
} catch (...) {
|
||||
@@ -146,13 +151,10 @@ void llama_server_start() {
|
||||
|
||||
void llama_server_stop() {
|
||||
assert(llama != NULL);
|
||||
// TODO - too verbose, remove once things are solid
|
||||
LOG_TEE("requesting llama server shutdown\n");
|
||||
ext_server_running = false;
|
||||
|
||||
// unblocks the update_slots() loop so it can clean up and exit
|
||||
llama->request_cancel(0);
|
||||
|
||||
LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n");
|
||||
// This may take a while for any pending tasks to drain
|
||||
// TODO - consider a timeout to cancel tasks if it's taking too long
|
||||
llama->queue_tasks.terminate();
|
||||
ext_server_thread.join();
|
||||
delete llama;
|
||||
llama = NULL;
|
||||
@@ -165,7 +167,9 @@ void llama_server_completion(const char *json_req, ext_server_resp_t *resp) {
|
||||
resp->msg[0] = '\0';
|
||||
try {
|
||||
json data = json::parse(json_req);
|
||||
resp->id = llama->request_completion(data, false, false, -1);
|
||||
resp->id = llama->queue_tasks.get_new_id();
|
||||
llama->queue_results.add_waiting_task_id(resp->id);
|
||||
llama->request_completion(resp->id, data, false, false, -1);
|
||||
} catch (std::exception &e) {
|
||||
snprintf(resp->msg, resp->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
@@ -183,16 +187,22 @@ void llama_server_completion_next_result(const int task_id,
|
||||
resp->json_resp = NULL;
|
||||
std::string result_json;
|
||||
try {
|
||||
task_result result = llama->next_result(task_id);
|
||||
task_result result = llama->queue_results.recv(task_id);
|
||||
result_json =
|
||||
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||
resp->id = result.id;
|
||||
resp->stop = result.stop;
|
||||
resp->error = result.error;
|
||||
if (result.error) {
|
||||
LOG_TEE("next result cancel on error\n");
|
||||
llama->request_cancel(task_id);
|
||||
LOG_TEE("next result removing waiting tak ID: %d\n", task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
} else if (result.stop) {
|
||||
LOG_TEE("next result cancel on stop\n");
|
||||
llama->request_cancel(task_id);
|
||||
LOG_TEE("next result removing waiting task ID: %d\n", task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
resp->error = true;
|
||||
@@ -223,6 +233,7 @@ void llama_server_completion_cancel(const int task_id, ext_server_resp_t *err) {
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
llama->request_cancel(task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
@@ -307,13 +318,15 @@ void llama_server_embedding(const char *json_req, char **json_resp,
|
||||
} else {
|
||||
prompt = "";
|
||||
}
|
||||
const int task_id = llama->request_completion(
|
||||
{{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
|
||||
task_result result = llama->next_result(task_id);
|
||||
const int task_id = llama->queue_tasks.get_new_id();
|
||||
llama->queue_results.add_waiting_task_id(task_id);
|
||||
llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
|
||||
task_result result = llama->queue_results.recv(task_id);
|
||||
std::string result_json = result.result_json.dump();
|
||||
const std::string::size_type size = result_json.size() + 1;
|
||||
*json_resp = new char[size];
|
||||
snprintf(*json_resp, size, "%s", result_json.c_str());
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
|
||||
@@ -39,6 +39,9 @@ init_vars() {
|
||||
*)
|
||||
;;
|
||||
esac
|
||||
if [ -z "${CMAKE_CUDA_ARCHITECTURES}" ] ; then
|
||||
CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
|
||||
fi
|
||||
}
|
||||
|
||||
git_module_setup() {
|
||||
@@ -62,15 +65,17 @@ apply_patches() {
|
||||
echo 'include (../../../ext_server/CMakeLists.txt) # ollama' >>${LLAMACPP_DIR}/examples/server/CMakeLists.txt
|
||||
fi
|
||||
|
||||
# apply temporary patches until fix is upstream
|
||||
for patch in ../patches/*.diff; do
|
||||
for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do
|
||||
(cd ${LLAMACPP_DIR}; git checkout ${file})
|
||||
if [ -n "$(ls -A ../patches/*.diff)" ]; then
|
||||
# apply temporary patches until fix is upstream
|
||||
for patch in ../patches/*.diff; do
|
||||
for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do
|
||||
(cd ${LLAMACPP_DIR}; git checkout ${file})
|
||||
done
|
||||
done
|
||||
done
|
||||
for patch in ../patches/*.diff; do
|
||||
(cd ${LLAMACPP_DIR} && git apply ${patch})
|
||||
done
|
||||
for patch in ../patches/*.diff; do
|
||||
(cd ${LLAMACPP_DIR} && git apply ${patch})
|
||||
done
|
||||
fi
|
||||
|
||||
# Avoid duplicate main symbols when we link into the cgo binary
|
||||
sed -e 's/int main(/int __main(/g' <${LLAMACPP_DIR}/examples/server/server.cpp >${LLAMACPP_DIR}/examples/server/server.cpp.tmp &&
|
||||
@@ -109,4 +114,12 @@ compress_libs() {
|
||||
# Keep the local tree clean after we're done with the build
|
||||
cleanup() {
|
||||
(cd ${LLAMACPP_DIR}/examples/server/ && git checkout CMakeLists.txt server.cpp)
|
||||
|
||||
if [ -n "$(ls -A ../patches/*.diff)" ]; then
|
||||
for patch in ../patches/*.diff; do
|
||||
for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do
|
||||
(cd ${LLAMACPP_DIR}; git checkout ${file})
|
||||
done
|
||||
done
|
||||
fi
|
||||
}
|
||||
|
||||
@@ -128,6 +128,11 @@ if [ -z "${CUDA_LIB_DIR}" ] && [ -d /opt/cuda/targets/x86_64-linux/lib ]; then
|
||||
CUDA_LIB_DIR=/opt/cuda/targets/x86_64-linux/lib
|
||||
fi
|
||||
|
||||
# Allow override in case libcudart is in the wrong place
|
||||
if [ -z "${CUDART_LIB_DIR}" ]; then
|
||||
CUDART_LIB_DIR="${CUDA_LIB_DIR}"
|
||||
fi
|
||||
|
||||
if [ -d "${CUDA_LIB_DIR}" ]; then
|
||||
echo "CUDA libraries detected - building dynamic CUDA library"
|
||||
init_vars
|
||||
@@ -135,7 +140,7 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
|
||||
if [ -n "${CUDA_MAJOR}" ]; then
|
||||
CUDA_VARIANT=_v${CUDA_MAJOR}
|
||||
fi
|
||||
CMAKE_DEFS="-DLLAMA_CUBLAS=on ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}"
|
||||
CMAKE_DEFS="-DLLAMA_CUBLAS=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cuda${CUDA_VARIANT}"
|
||||
EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
|
||||
build
|
||||
@@ -151,6 +156,8 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
|
||||
cp "${CUDA_LIB_DIR}/${DEP}" "${BUILD_DIR}/lib/"
|
||||
elif [ -e "${CUDA_LIB_DIR}/${lib}.${CUDA_MAJOR}" ]; then
|
||||
cp "${CUDA_LIB_DIR}/${lib}.${CUDA_MAJOR}" "${BUILD_DIR}/lib/"
|
||||
elif [ -e "${CUDART_LIB_DIR}/${lib}" ]; then
|
||||
cp -d ${CUDART_LIB_DIR}/${lib}* "${BUILD_DIR}/lib/"
|
||||
else
|
||||
cp -d "${CUDA_LIB_DIR}/${lib}*" "${BUILD_DIR}/lib/"
|
||||
fi
|
||||
|
||||
@@ -25,6 +25,11 @@ function init_vars {
|
||||
}
|
||||
$script:GZIP=(get-command -ea 'silentlycontinue' gzip).path
|
||||
$script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path
|
||||
if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) {
|
||||
$script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
|
||||
} else {
|
||||
$script:CMAKE_CUDA_ARCHITECTURES=$env:CMAKE_CUDA_ARCHITECTURES
|
||||
}
|
||||
}
|
||||
|
||||
function git_module_setup {
|
||||
@@ -151,7 +156,7 @@ if ($null -ne $script:CUDA_LIB_DIR) {
|
||||
}
|
||||
init_vars
|
||||
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
|
||||
$script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on")
|
||||
$script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
|
||||
build
|
||||
install
|
||||
cp "${script:CUDA_LIB_DIR}/cudart64_*.dll" "${script:buildDir}/lib"
|
||||
|
||||
Submodule llm/llama.cpp updated: cd4fddb29f...d2f650cb5b
@@ -62,7 +62,7 @@ const maxRetries = 3
|
||||
type PredictOpts struct {
|
||||
Prompt string
|
||||
Format string
|
||||
Images []api.ImageData
|
||||
Images []ImageData
|
||||
Options api.Options
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
|
||||
index 0462fbd2..4fa7b57f 100644
|
||||
index a48582ad..9fffffd8 100644
|
||||
--- a/examples/server/server.cpp
|
||||
+++ b/examples/server/server.cpp
|
||||
@@ -1857,12 +1857,6 @@ struct llama_server_context
|
||||
@@ -1564,12 +1564,6 @@ struct llama_server_context
|
||||
LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
|
||||
}
|
||||
|
||||
@@ -15,8 +15,8 @@ index 0462fbd2..4fa7b57f 100644
|
||||
if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
|
||||
{
|
||||
// we have to evaluate at least 1 token to generate logits.
|
||||
@@ -1870,6 +1864,12 @@ struct llama_server_context
|
||||
slot.n_past--;
|
||||
@@ -1581,6 +1575,12 @@ struct llama_server_context
|
||||
}
|
||||
}
|
||||
|
||||
+ LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
|
||||
|
||||
90
llm/patches/02-shutdown.diff
Normal file
90
llm/patches/02-shutdown.diff
Normal file
@@ -0,0 +1,90 @@
|
||||
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
|
||||
index 11dd82c3..311495a8 100644
|
||||
--- a/examples/server/server.cpp
|
||||
+++ b/examples/server/server.cpp
|
||||
@@ -28,6 +28,7 @@
|
||||
#include <chrono>
|
||||
#include <condition_variable>
|
||||
#include <atomic>
|
||||
+#include <signal.h>
|
||||
|
||||
using json = nlohmann::json;
|
||||
|
||||
@@ -2394,6 +2395,9 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con
|
||||
}
|
||||
}
|
||||
|
||||
+std::function<void(int)> shutdown_handler;
|
||||
+inline void signal_handler(int signal) { shutdown_handler(signal); }
|
||||
+
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
#if SERVER_VERBOSE != 1
|
||||
@@ -3014,8 +3018,14 @@ int main(int argc, char **argv)
|
||||
std::placeholders::_2,
|
||||
std::placeholders::_3
|
||||
));
|
||||
- llama.queue_tasks.start_loop();
|
||||
|
||||
+ shutdown_handler = [&](int) {
|
||||
+ llama.queue_tasks.terminate();
|
||||
+ };
|
||||
+ signal(SIGTERM, signal_handler);
|
||||
+ signal(SIGINT, signal_handler);
|
||||
+ llama.queue_tasks.start_loop();
|
||||
+ svr.stop();
|
||||
t.join();
|
||||
|
||||
llama_backend_free();
|
||||
diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp
|
||||
index 70cce072..2acb1eab 100644
|
||||
--- a/examples/server/utils.hpp
|
||||
+++ b/examples/server/utils.hpp
|
||||
@@ -6,6 +6,7 @@
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <unordered_map>
|
||||
+#include <atomic>
|
||||
|
||||
#include "json.hpp"
|
||||
|
||||
@@ -190,6 +191,7 @@ inline std::string format_chatml(std::vector<json> messages)
|
||||
struct llama_server_queue {
|
||||
int id = 0;
|
||||
std::mutex mutex_tasks;
|
||||
+ std::atomic<bool> running;
|
||||
// queues
|
||||
std::vector<task_server> queue_tasks;
|
||||
std::vector<task_server> queue_tasks_deferred;
|
||||
@@ -248,9 +250,15 @@ struct llama_server_queue {
|
||||
queue_tasks_deferred.clear();
|
||||
}
|
||||
|
||||
- // Start the main loop. This call is blocking
|
||||
- [[noreturn]]
|
||||
+ // end the start_loop routine
|
||||
+ void terminate() {
|
||||
+ running = false;
|
||||
+ condition_tasks.notify_all();
|
||||
+ }
|
||||
+
|
||||
+ // Start the main loop.
|
||||
void start_loop() {
|
||||
+ running = true;
|
||||
while (true) {
|
||||
// new task arrived
|
||||
LOG_VERBOSE("have new task", {});
|
||||
@@ -294,8 +302,12 @@ struct llama_server_queue {
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex_tasks);
|
||||
if (queue_tasks.empty()) {
|
||||
+ if (!running.load()) {
|
||||
+ LOG_VERBOSE("ending start_loop", {});
|
||||
+ return;
|
||||
+ }
|
||||
condition_tasks.wait(lock, [&]{
|
||||
- return !queue_tasks.empty();
|
||||
+ return (!queue_tasks.empty() || !running.load());
|
||||
});
|
||||
}
|
||||
}
|
||||
@@ -147,12 +147,7 @@ func (s SignatureData) Bytes() []byte {
|
||||
|
||||
// SignData takes a SignatureData object and signs it with a raw private key
|
||||
func (s SignatureData) Sign(rawKey []byte) (string, error) {
|
||||
privateKey, err := ssh.ParseRawPrivateKey(rawKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
signer, err := ssh.NewSignerFromKey(privateKey)
|
||||
signer, err := ssh.ParsePrivateKey(rawKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -25,6 +25,11 @@ import (
|
||||
"github.com/jmorganca/ollama/format"
|
||||
)
|
||||
|
||||
const maxRetries = 6
|
||||
|
||||
var errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
var errPartStalled = errors.New("part stalled")
|
||||
|
||||
var blobDownloadManager sync.Map
|
||||
|
||||
type blobDownload struct {
|
||||
@@ -44,10 +49,11 @@ type blobDownload struct {
|
||||
}
|
||||
|
||||
type blobDownloadPart struct {
|
||||
N int
|
||||
Offset int64
|
||||
Size int64
|
||||
Completed int64
|
||||
N int
|
||||
Offset int64
|
||||
Size int64
|
||||
Completed int64
|
||||
lastUpdated time.Time
|
||||
|
||||
*blobDownload `json:"-"`
|
||||
}
|
||||
@@ -72,6 +78,13 @@ func (p *blobDownloadPart) StopsAt() int64 {
|
||||
return p.Offset + p.Size
|
||||
}
|
||||
|
||||
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
|
||||
n = len(b)
|
||||
p.blobDownload.Completed.Add(int64(n))
|
||||
p.lastUpdated = time.Now()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
||||
partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
|
||||
if err != nil {
|
||||
@@ -157,6 +170,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
|
||||
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
|
||||
// return immediately if the context is canceled or the device is out of space
|
||||
return err
|
||||
case errors.Is(err, errPartStalled):
|
||||
try--
|
||||
continue
|
||||
case err != nil:
|
||||
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
|
||||
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
|
||||
@@ -195,28 +211,54 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
|
||||
}
|
||||
|
||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
|
||||
headers := make(http.Header)
|
||||
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
g.Go(func() error {
|
||||
headers := make(http.Header)
|
||||
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
|
||||
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
n, err := io.Copy(w, io.TeeReader(resp.Body, b))
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
// rollback progress
|
||||
b.Completed.Add(-n)
|
||||
return err
|
||||
}
|
||||
n, err := io.Copy(w, io.TeeReader(resp.Body, part))
|
||||
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
// rollback progress
|
||||
b.Completed.Add(-n)
|
||||
return err
|
||||
}
|
||||
|
||||
part.Completed += n
|
||||
if err := b.writePart(part.Name(), part); err != nil {
|
||||
return err
|
||||
}
|
||||
part.Completed += n
|
||||
if err := b.writePart(part.Name(), part); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// return nil or context.Canceled or UnexpectedEOF (resumable)
|
||||
return err
|
||||
// return nil or context.Canceled or UnexpectedEOF (resumable)
|
||||
return err
|
||||
})
|
||||
|
||||
g.Go(func() error {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if part.Completed >= part.Size {
|
||||
return nil
|
||||
}
|
||||
|
||||
if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
|
||||
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
|
||||
// reset last updated
|
||||
part.lastUpdated = time.Time{}
|
||||
return errPartStalled
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func (b *blobDownload) newPart(offset, size int64) error {
|
||||
@@ -255,12 +297,6 @@ func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error
|
||||
return json.NewEncoder(partFile).Encode(part)
|
||||
}
|
||||
|
||||
func (b *blobDownload) Write(p []byte) (n int, err error) {
|
||||
n = len(p)
|
||||
b.Completed.Add(int64(n))
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) acquire() {
|
||||
b.references.Add(1)
|
||||
}
|
||||
@@ -279,20 +315,19 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
|
||||
Digest: b.Digest,
|
||||
Total: b.Total,
|
||||
Completed: b.Completed.Load(),
|
||||
})
|
||||
|
||||
if b.done || b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
|
||||
Digest: b.Digest,
|
||||
Total: b.Total,
|
||||
Completed: b.Completed.Load(),
|
||||
})
|
||||
|
||||
if b.done || b.err != nil {
|
||||
return b.err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,10 +338,6 @@ type downloadOpts struct {
|
||||
fn func(api.ProgressResponse)
|
||||
}
|
||||
|
||||
const maxRetries = 6
|
||||
|
||||
var errMaxRetriesExceeded = errors.New("max retries exceeded")
|
||||
|
||||
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
||||
func downloadBlob(ctx context.Context, opts downloadOpts) error {
|
||||
fp, err := GetBlobsPath(opts.digest)
|
||||
|
||||
@@ -63,6 +63,7 @@ type PromptVars struct {
|
||||
Prompt string
|
||||
Response string
|
||||
First bool
|
||||
Images []llm.ImageData
|
||||
}
|
||||
|
||||
// extractParts extracts the parts of the template before and after the {{.Response}} node.
|
||||
@@ -119,10 +120,6 @@ func Prompt(promptTemplate string, p PromptVars) (string, error) {
|
||||
|
||||
// PreResponsePrompt returns the prompt before the response tag
|
||||
func (m *Model) PreResponsePrompt(p PromptVars) (string, error) {
|
||||
if p.System == "" {
|
||||
// use the default system prompt for this model if one is not specified
|
||||
p.System = m.System
|
||||
}
|
||||
pre, _, err := extractParts(m.Template)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -150,62 +147,68 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
|
||||
return Prompt(post, p)
|
||||
}
|
||||
|
||||
func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
|
||||
type ChatHistory struct {
|
||||
Prompts []PromptVars
|
||||
LastSystem string
|
||||
}
|
||||
|
||||
// ChatPrompts returns a list of formatted chat prompts from a list of messages
|
||||
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
|
||||
// build the prompt from the list of messages
|
||||
var prompt strings.Builder
|
||||
var currentImages []api.ImageData
|
||||
lastSystem := m.System
|
||||
currentVars := PromptVars{
|
||||
First: true,
|
||||
System: m.System,
|
||||
}
|
||||
|
||||
writePrompt := func() error {
|
||||
p, err := Prompt(m.Template, currentVars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
prompt.WriteString(p)
|
||||
currentVars = PromptVars{}
|
||||
return nil
|
||||
}
|
||||
prompts := []PromptVars{}
|
||||
var images []llm.ImageData
|
||||
|
||||
for _, msg := range msgs {
|
||||
switch strings.ToLower(msg.Role) {
|
||||
case "system":
|
||||
if currentVars.System != "" {
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
// if this is the first message it overrides the system prompt in the modelfile
|
||||
if !currentVars.First && currentVars.System != "" {
|
||||
prompts = append(prompts, currentVars)
|
||||
currentVars = PromptVars{}
|
||||
}
|
||||
currentVars.System = msg.Content
|
||||
lastSystem = msg.Content
|
||||
case "user":
|
||||
if currentVars.Prompt != "" {
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
prompts = append(prompts, currentVars)
|
||||
currentVars = PromptVars{}
|
||||
}
|
||||
|
||||
currentVars.Prompt = msg.Content
|
||||
currentImages = msg.Images
|
||||
for i := range msg.Images {
|
||||
id := len(images) + i
|
||||
currentVars.Prompt += fmt.Sprintf(" [img-%d]", id)
|
||||
currentVars.Images = append(currentVars.Images, llm.ImageData{
|
||||
ID: id,
|
||||
Data: msg.Images[i],
|
||||
})
|
||||
}
|
||||
|
||||
images = append(images, currentVars.Images...)
|
||||
case "assistant":
|
||||
currentVars.Response = msg.Content
|
||||
if err := writePrompt(); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
prompts = append(prompts, currentVars)
|
||||
currentVars = PromptVars{}
|
||||
default:
|
||||
return "", nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||
return nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
|
||||
}
|
||||
}
|
||||
|
||||
// Append the last set of vars if they are non-empty
|
||||
if currentVars.Prompt != "" || currentVars.System != "" {
|
||||
p, err := m.PreResponsePrompt(currentVars)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("pre-response template: %w", err)
|
||||
}
|
||||
prompt.WriteString(p)
|
||||
prompts = append(prompts, currentVars)
|
||||
}
|
||||
|
||||
return prompt.String(), currentImages, nil
|
||||
return &ChatHistory{
|
||||
Prompts: prompts,
|
||||
LastSystem: lastSystem,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type ManifestV2 struct {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -233,17 +234,58 @@ func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func chatHistoryEqual(a, b ChatHistory) bool {
|
||||
if len(a.Prompts) != len(b.Prompts) {
|
||||
return false
|
||||
}
|
||||
for i, v := range a.Prompts {
|
||||
|
||||
if v.First != b.Prompts[i].First {
|
||||
return false
|
||||
}
|
||||
|
||||
if v.Response != b.Prompts[i].Response {
|
||||
return false
|
||||
}
|
||||
|
||||
if v.Prompt != b.Prompts[i].Prompt {
|
||||
return false
|
||||
}
|
||||
|
||||
if v.System != b.Prompts[i].System {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(v.Images) != len(b.Prompts[i].Images) {
|
||||
return false
|
||||
}
|
||||
|
||||
for j, img := range v.Images {
|
||||
if img.ID != b.Prompts[i].Images[j].ID {
|
||||
return false
|
||||
}
|
||||
|
||||
if !bytes.Equal(img.Data, b.Prompts[i].Images[j].Data) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return a.LastSystem == b.LastSystem
|
||||
}
|
||||
|
||||
func TestChat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
msgs []api.Message
|
||||
want string
|
||||
wantErr string
|
||||
name string
|
||||
model Model
|
||||
msgs []api.Message
|
||||
want ChatHistory
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "Single Message",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
name: "Single Message",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
@@ -254,34 +296,22 @@ func TestChat(t *testing.T) {
|
||||
Content: "What are the potion ingredients?",
|
||||
},
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "First Message",
|
||||
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are a Wizard.",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What are the potion ingredients?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "eye of newt",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Anything else?",
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Message History",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
name: "Message History",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
@@ -300,18 +330,85 @@ func TestChat(t *testing.T) {
|
||||
Content: "Anything else?",
|
||||
},
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "sugar",
|
||||
First: true,
|
||||
},
|
||||
{
|
||||
Prompt: "Anything else?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Assistant Only",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
name: "Assistant Only",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "everything nice",
|
||||
},
|
||||
},
|
||||
want: "[INST] [/INST]everything nice",
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Response: "everything nice",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Last system message is preserved from modelfile",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
System: "You are Mojo Jojo.",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "hi",
|
||||
},
|
||||
},
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are Mojo Jojo.",
|
||||
Prompt: "hi",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
LastSystem: "You are Mojo Jojo.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Last system message is preserved from messages",
|
||||
model: Model{
|
||||
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
System: "You are Mojo Jojo.",
|
||||
},
|
||||
msgs: []api.Message{
|
||||
{
|
||||
Role: "system",
|
||||
Content: "You are Professor Utonium.",
|
||||
},
|
||||
},
|
||||
want: ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are Professor Utonium.",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
LastSystem: "You are Professor Utonium.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Invalid Role",
|
||||
@@ -326,11 +423,8 @@ func TestChat(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
m := Model{
|
||||
Template: tt.template,
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, _, err := m.ChatPrompt(tt.msgs)
|
||||
got, err := tt.model.ChatPrompts(tt.msgs)
|
||||
if tt.wantErr != "" {
|
||||
if err == nil {
|
||||
t.Errorf("ChatPrompt() expected error, got nil")
|
||||
@@ -338,9 +432,10 @@ func TestChat(t *testing.T) {
|
||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
|
||||
if !chatHistoryEqual(*got, tt.want) {
|
||||
t.Errorf("ChatPrompt() got = %#v, want %#v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
210
server/routes.go
210
server/routes.go
@@ -178,15 +178,17 @@ func GenerateHandler(c *gin.Context) {
|
||||
|
||||
opts, err := modelOptions(model, req.Options)
|
||||
if err != nil {
|
||||
if errors.Is(err, api.ErrInvalidOpts) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
sessionDuration := defaultSessionDuration
|
||||
var sessionDuration time.Duration
|
||||
if req.KeepAlive == nil {
|
||||
sessionDuration = defaultSessionDuration
|
||||
} else {
|
||||
sessionDuration = req.KeepAlive.Duration
|
||||
}
|
||||
|
||||
if err := load(c, model, opts, sessionDuration); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -233,6 +235,15 @@ func GenerateHandler(c *gin.Context) {
|
||||
Prompt: req.Prompt,
|
||||
First: len(req.Context) == 0,
|
||||
}
|
||||
|
||||
if promptVars.System == "" {
|
||||
promptVars.System = model.System
|
||||
}
|
||||
|
||||
for i := range req.Images {
|
||||
promptVars.Prompt += fmt.Sprintf(" [img-%d]", i)
|
||||
}
|
||||
|
||||
p, err := model.PreResponsePrompt(promptVars)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -242,6 +253,8 @@ func GenerateHandler(c *gin.Context) {
|
||||
prompt = rebuild.String()
|
||||
}
|
||||
|
||||
slog.Debug("generate handler", "prompt", prompt)
|
||||
|
||||
ch := make(chan any)
|
||||
var generated strings.Builder
|
||||
go func() {
|
||||
@@ -295,11 +308,19 @@ func GenerateHandler(c *gin.Context) {
|
||||
ch <- resp
|
||||
}
|
||||
|
||||
var images []llm.ImageData
|
||||
for i := range req.Images {
|
||||
images = append(images, llm.ImageData{
|
||||
ID: i,
|
||||
Data: req.Images[i],
|
||||
})
|
||||
}
|
||||
|
||||
// Start prediction
|
||||
predictReq := llm.PredictOpts{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: req.Images,
|
||||
Images: images,
|
||||
Options: opts,
|
||||
}
|
||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||
@@ -371,14 +392,17 @@ func EmbeddingHandler(c *gin.Context) {
|
||||
|
||||
opts, err := modelOptions(model, req.Options)
|
||||
if err != nil {
|
||||
if errors.Is(err, api.ErrInvalidOpts) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
sessionDuration := defaultSessionDuration
|
||||
|
||||
var sessionDuration time.Duration
|
||||
if req.KeepAlive == nil {
|
||||
sessionDuration = defaultSessionDuration
|
||||
} else {
|
||||
sessionDuration = req.KeepAlive.Duration
|
||||
}
|
||||
|
||||
if err := load(c, model, opts, sessionDuration); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -918,13 +942,26 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
}
|
||||
|
||||
func Serve(ln net.Listener) error {
|
||||
level := slog.LevelInfo
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
var programLevel = new(slog.LevelVar)
|
||||
h := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: programLevel, AddSource: true})
|
||||
slog.SetDefault(slog.New(h))
|
||||
programLevel.Set(slog.LevelDebug)
|
||||
slog.Debug("Debug logging enabled")
|
||||
level = slog.LevelDebug
|
||||
}
|
||||
|
||||
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
AddSource: true,
|
||||
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
|
||||
if attr.Key == slog.SourceKey {
|
||||
source := attr.Value.Any().(*slog.Source)
|
||||
source.File = filepath.Base(source.File)
|
||||
}
|
||||
|
||||
return attr
|
||||
},
|
||||
})
|
||||
|
||||
slog.SetDefault(slog.New(handler))
|
||||
|
||||
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
|
||||
// clean up unused layers and manifests
|
||||
if err := PruneLayers(); err != nil {
|
||||
@@ -1067,14 +1104,17 @@ func ChatHandler(c *gin.Context) {
|
||||
|
||||
opts, err := modelOptions(model, req.Options)
|
||||
if err != nil {
|
||||
if errors.Is(err, api.ErrInvalidOpts) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
sessionDuration := defaultSessionDuration
|
||||
|
||||
var sessionDuration time.Duration
|
||||
if req.KeepAlive == nil {
|
||||
sessionDuration = defaultSessionDuration
|
||||
} else {
|
||||
sessionDuration = req.KeepAlive.Duration
|
||||
}
|
||||
|
||||
if err := load(c, model, opts, sessionDuration); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -1094,12 +1134,20 @@ func ChatHandler(c *gin.Context) {
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
prompt, images, err := model.ChatPrompt(req.Messages)
|
||||
chat, err := model.ChatPrompts(req.Messages)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("chat handler", "prompt", prompt)
|
||||
|
||||
ch := make(chan any)
|
||||
|
||||
go func() {
|
||||
@@ -1173,3 +1221,115 @@ func ChatHandler(c *gin.Context) {
|
||||
|
||||
streamResponse(c, ch)
|
||||
}
|
||||
|
||||
// promptInfo stores the variables used to template a prompt, and the token length of the resulting template for some model
|
||||
type promptInfo struct {
|
||||
vars PromptVars
|
||||
tokenLen int
|
||||
}
|
||||
|
||||
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
|
||||
// while preserving the most recent system message.
|
||||
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) {
|
||||
if len(chat.Prompts) == 0 {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
var promptsToAdd []promptInfo
|
||||
var totalTokenLength int
|
||||
var systemPromptIncluded bool
|
||||
|
||||
var images []llm.ImageData
|
||||
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
|
||||
for i := len(chat.Prompts) - 1; i >= 0; i-- {
|
||||
prompt := chat.Prompts[i]
|
||||
promptText, err := promptString(model, prompt, i == len(chat.Prompts)-1)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
encodedTokens, err := loaded.runner.Encode(ctx, promptText)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
|
||||
break // reached max context length, stop adding more prompts
|
||||
}
|
||||
|
||||
for j := range prompt.Images {
|
||||
if totalTokenLength+768 > loaded.NumCtx {
|
||||
// this decreases the token length but overestimating is fine
|
||||
prompt.Prompt = strings.ReplaceAll(prompt.Prompt, fmt.Sprintf(" [img-%d]", prompt.Images[j].ID), "")
|
||||
continue
|
||||
}
|
||||
|
||||
totalTokenLength += 768
|
||||
images = append(images, prompt.Images[j])
|
||||
}
|
||||
|
||||
totalTokenLength += len(encodedTokens)
|
||||
systemPromptIncluded = systemPromptIncluded || prompt.System != ""
|
||||
promptsToAdd = append(promptsToAdd, promptInfo{vars: prompt, tokenLen: len(encodedTokens)})
|
||||
}
|
||||
|
||||
// ensure the system prompt is included, if not already
|
||||
if chat.LastSystem != "" && !systemPromptIncluded {
|
||||
var err error
|
||||
promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
|
||||
promptsToAdd[len(promptsToAdd)-1].vars.First = true
|
||||
|
||||
// construct the final prompt string from the prompts which fit within the context window
|
||||
var result string
|
||||
for i, prompt := range promptsToAdd {
|
||||
promptText, err := promptString(model, prompt.vars, i == 0)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
result = promptText + result
|
||||
}
|
||||
|
||||
return result, images, nil
|
||||
}
|
||||
|
||||
// promptString applies the model template to the prompt
|
||||
func promptString(model *Model, vars PromptVars, isMostRecent bool) (string, error) {
|
||||
if isMostRecent {
|
||||
p, err := model.PreResponsePrompt(vars)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("pre-response template: %w", err)
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
p, err := Prompt(model.Template, vars)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// includeSystemPrompt adjusts the prompts to include the system prompt.
|
||||
func includeSystemPrompt(ctx context.Context, systemPrompt string, totalTokenLength int, promptsToAdd []promptInfo) ([]promptInfo, error) {
|
||||
systemTokens, err := loaded.runner.Encode(ctx, systemPrompt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := len(promptsToAdd) - 1; i >= 0; i-- {
|
||||
if totalTokenLength+len(systemTokens) <= loaded.NumCtx {
|
||||
promptsToAdd[i].vars.System = systemPrompt
|
||||
return promptsToAdd[:i+1], nil
|
||||
}
|
||||
totalTokenLength -= promptsToAdd[i].tokenLen
|
||||
}
|
||||
|
||||
// if got here, system did not fit anywhere, so return the most recent prompt with the system message set
|
||||
recent := promptsToAdd[len(promptsToAdd)-1]
|
||||
recent.vars.System = systemPrompt
|
||||
return []promptInfo{recent}, nil
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/llm"
|
||||
"github.com/jmorganca/ollama/parser"
|
||||
"github.com/jmorganca/ollama/version"
|
||||
)
|
||||
@@ -239,3 +240,258 @@ func Test_Routes(t *testing.T) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func Test_ChatPrompt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
chat *ChatHistory
|
||||
numCtx int
|
||||
runner MockLLM
|
||||
want string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "Single Message",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
numCtx: 1,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1}, // fit the ctxLen
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "First Message",
|
||||
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "eye of newt",
|
||||
First: true,
|
||||
},
|
||||
{
|
||||
Prompt: "Anything else?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
numCtx: 2,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1}, // fit the ctxLen
|
||||
},
|
||||
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Message History",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
System: "You are a Wizard.",
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "sugar",
|
||||
First: true,
|
||||
},
|
||||
{
|
||||
Prompt: "Anything else?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a Wizard.",
|
||||
},
|
||||
numCtx: 4,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1}, // fit the ctxLen, 1 for each message
|
||||
},
|
||||
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Assistant Only",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Response: "everything nice",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
numCtx: 1,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] [/INST]everything nice",
|
||||
},
|
||||
{
|
||||
name: "Message History Truncated, No System",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Prompt: "What are the potion ingredients?",
|
||||
Response: "sugar",
|
||||
First: true,
|
||||
},
|
||||
{
|
||||
Prompt: "Anything else?",
|
||||
Response: "spice",
|
||||
},
|
||||
{
|
||||
Prompt: "... and?",
|
||||
},
|
||||
},
|
||||
},
|
||||
numCtx: 2, // only 1 message from history and most recent message
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] Anything else? [/INST]spice[INST] ... and? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "System is Preserved when Truncated",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Prompt: "What are the magic words?",
|
||||
Response: "abracadabra",
|
||||
},
|
||||
{
|
||||
Prompt: "What is the spell for invisibility?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a wizard.",
|
||||
},
|
||||
numCtx: 2,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "System is Preserved when Length Exceeded",
|
||||
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Prompt: "What are the magic words?",
|
||||
Response: "abracadabra",
|
||||
},
|
||||
{
|
||||
Prompt: "What is the spell for invisibility?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a wizard.",
|
||||
},
|
||||
numCtx: 1,
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "First is Preserved when Truncated",
|
||||
template: "[INST] {{ if .First }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]",
|
||||
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
// first message omitted for test
|
||||
{
|
||||
Prompt: "Do you have a magic hat?",
|
||||
Response: "Of course.",
|
||||
},
|
||||
{
|
||||
Prompt: "What is the spell for invisibility?",
|
||||
},
|
||||
},
|
||||
LastSystem: "You are a wizard.",
|
||||
},
|
||||
numCtx: 3, // two most recent messages and room for system message
|
||||
runner: MockLLM{
|
||||
encoding: []int{1},
|
||||
},
|
||||
want: "[INST] You are a wizard. Do you have a magic hat? [/INST]Of course.[INST] What is the spell for invisibility? [/INST]",
|
||||
},
|
||||
{
|
||||
name: "Most recent message is returned when longer than ctxLen",
|
||||
template: "[INST] {{ .Prompt }} [/INST]",
|
||||
|
||||
chat: &ChatHistory{
|
||||
Prompts: []PromptVars{
|
||||
{
|
||||
Prompt: "What is the spell for invisibility?",
|
||||
First: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
numCtx: 1, // two most recent messages
|
||||
runner: MockLLM{
|
||||
encoding: []int{1, 2},
|
||||
},
|
||||
want: "[INST] What is the spell for invisibility? [/INST]",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testCase := range tests {
|
||||
tt := testCase
|
||||
m := &Model{
|
||||
Template: tt.template,
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
loaded.runner = &tt.runner
|
||||
loaded.Options = &api.Options{
|
||||
Runner: api.Runner{
|
||||
NumCtx: tt.numCtx,
|
||||
},
|
||||
}
|
||||
// TODO: add tests for trimming images
|
||||
got, _, err := trimmedPrompt(context.Background(), tt.chat, m)
|
||||
if tt.wantErr != "" {
|
||||
if err == nil {
|
||||
t.Errorf("ChatPrompt() expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErr) {
|
||||
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type MockLLM struct {
|
||||
encoding []int
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Predict(ctx context.Context, pred llm.PredictOpts, fn func(llm.PredictResult)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Encode(ctx context.Context, prompt string) ([]int, error) {
|
||||
return llm.encoding, nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Decode(ctx context.Context, tokens []int) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Embedding(ctx context.Context, input string) ([]float64, error) {
|
||||
return []float64{}, nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Close() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user