Compare commits
14 Commits
v0.11.9-rc
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f5c9eb5aa2 | ||
|
|
20b53eaa72 | ||
|
|
6745182885 | ||
|
|
f810ec741c | ||
|
|
e119783e66 | ||
|
|
1a558f98e2 | ||
|
|
7b91c9ce51 | ||
|
|
950d33aa30 | ||
|
|
9714e38dd0 | ||
|
|
4378ae4ffa | ||
|
|
5994e8e8fd | ||
|
|
b3e6120736 | ||
|
|
fb92b61754 | ||
|
|
8149a3c86e |
@@ -413,6 +413,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
|
||||
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
|
||||
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
|
||||
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
|
||||
|
||||
### Cloud
|
||||
|
||||
|
||||
@@ -92,6 +92,9 @@ If none of those resolve the problem, gather additional information and file an
|
||||
- Set `CUDA_ERROR_LEVEL=50` and try again to get more diagnostic logs
|
||||
- Check dmesg for any errors `sudo dmesg | grep -i nvrm` and `sudo dmesg | grep -i nvidia`
|
||||
|
||||
You may get more details for initialization failures by enabling debug prints in the uvm driver. You should only use this temporarily while troubleshooting
|
||||
- `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm uvm_debug_prints=1`
|
||||
|
||||
|
||||
## AMD GPU Discovery
|
||||
|
||||
|
||||
135
fs/ggml/ggml.go
135
fs/ggml/ggml.go
@@ -57,10 +57,28 @@ func (kv KV) EmbeddingLength() uint64 {
|
||||
return uint64(kv.Uint("embedding_length"))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCount() []uint64 {
|
||||
headCountDefault := uint32(1)
|
||||
headCount := kv.UintOrArrayValueAsArray("attention.head_count", headCountDefault)
|
||||
if len(headCount) == 1 {
|
||||
headCountDefault = headCount[0]
|
||||
}
|
||||
nLayers := int(kv.BlockCount())
|
||||
if len(headCount) > nLayers {
|
||||
slog.Warn("got more elements of attention.head_count than layers", "len(headCount)", len(headCount), "layers", nLayers)
|
||||
}
|
||||
out := make([]uint64, nLayers)
|
||||
for i := range nLayers {
|
||||
if i >= len(headCount) {
|
||||
out[i] = uint64(headCountDefault)
|
||||
} else {
|
||||
out[i] = uint64(headCount[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountMax() uint64 {
|
||||
// TODO(drifkin): using the max value can cause an overestimation. In the
|
||||
// future if array values become more popular, we can adapt the more invasive
|
||||
// <https://github.com/ollama/ollama/pull/10225>
|
||||
return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
|
||||
}
|
||||
|
||||
@@ -68,6 +86,27 @@ func (kv KV) HeadCountMin() uint64 {
|
||||
return uint64(kv.UintOrMinArrayValue("attention.head_count", 1))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKV() []uint64 {
|
||||
headCountKVDefault := uint32(1)
|
||||
headCountKV := kv.UintOrArrayValueAsArray("attention.head_count_kv", headCountKVDefault)
|
||||
if len(headCountKV) == 1 {
|
||||
headCountKVDefault = headCountKV[0]
|
||||
}
|
||||
nLayers := int(kv.BlockCount())
|
||||
if len(headCountKV) > nLayers {
|
||||
slog.Warn("got more elements of attention.head_count than layers", "len(headCountKV)", len(headCountKV), "layers", nLayers)
|
||||
}
|
||||
out := make([]uint64, nLayers)
|
||||
for i := range nLayers {
|
||||
if i >= len(headCountKV) {
|
||||
out[i] = uint64(headCountKVDefault)
|
||||
} else {
|
||||
out[i] = uint64(headCountKV[i])
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKVMax() uint64 {
|
||||
return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1))
|
||||
}
|
||||
@@ -100,6 +139,26 @@ func (kv KV) ChatTemplate() string {
|
||||
return kv.String("tokenizer.chat_template")
|
||||
}
|
||||
|
||||
// ssm architecture parameters
|
||||
|
||||
func (kv KV) SSMConvKernel() uint64 {
|
||||
return uint64(kv.Uint("ssm.conv_kernel"))
|
||||
}
|
||||
|
||||
func (kv KV) SSMInnerSize() uint64 {
|
||||
return uint64(kv.Uint("ssm.inner_size"))
|
||||
}
|
||||
|
||||
func (kv KV) SSMStateSize() uint64 {
|
||||
return uint64(kv.Uint("ssm.state_size"))
|
||||
}
|
||||
|
||||
func (kv KV) SSMGroupCount() uint64 {
|
||||
return uint64(kv.Uint("ssm.group_count"))
|
||||
}
|
||||
|
||||
// general types
|
||||
|
||||
func (kv KV) String(key string, defaultValue ...string) string {
|
||||
val, _ := keyValue(kv, key, append(defaultValue, "")...)
|
||||
return val
|
||||
@@ -131,22 +190,27 @@ func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 {
|
||||
}
|
||||
|
||||
func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
|
||||
arrVal := kv.UintOrArrayValueAsArray(key, defaultValue)
|
||||
return slices.Min(arrVal), slices.Max(arrVal)
|
||||
}
|
||||
|
||||
func (kv KV) UintOrArrayValueAsArray(key string, defaultValue uint32) []uint32 {
|
||||
if u32, ok := keyValue(kv, key, uint32(0)); ok {
|
||||
return u32, u32
|
||||
return []uint32{u32}
|
||||
} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
|
||||
min := slices.Min(u32s.values)
|
||||
max := slices.Max(u32s.values)
|
||||
return min, max
|
||||
return u32s.values
|
||||
} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
|
||||
min := slices.Min(i32s.values)
|
||||
max := slices.Max(i32s.values)
|
||||
if min < 0 || max < 0 {
|
||||
slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max)
|
||||
dst := make([]uint32, len(i32s.values))
|
||||
for i, v := range i32s.values {
|
||||
if v < 0 {
|
||||
slog.Warn("array values are unexpectedly negative", "key", key, "i", i, "v", v)
|
||||
}
|
||||
dst[i] = uint32(v)
|
||||
}
|
||||
return uint32(min), uint32(max)
|
||||
return dst
|
||||
}
|
||||
|
||||
return defaultValue, defaultValue
|
||||
return []uint32{defaultValue}
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
@@ -486,7 +550,9 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
heads := f.KV().HeadCountMax()
|
||||
headsArr := f.KV().HeadCount()
|
||||
headsKV := f.KV().HeadCountKVMax()
|
||||
headsKVArr := f.KV().HeadCountKV()
|
||||
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
|
||||
|
||||
embeddingHeads := f.KV().EmbeddingHeadCountMax()
|
||||
@@ -496,12 +562,51 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
layers := f.Tensors().GroupLayers()
|
||||
|
||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||
|
||||
// Default for models unless special-cased below. These defaults mirror the
|
||||
// cache usage in llama.cpp under the assumption that models without special
|
||||
// cases below will use the llamarunner and caching will be handled by the
|
||||
// llama.cpp layer.
|
||||
//
|
||||
// This also assumes that a layer without heads or headsKV set is recurrent
|
||||
// which is usually the case. Some models (eg nemotronh) use "blocks" in
|
||||
// place of layers where some are MLP blocks that don't have any cache.
|
||||
// Models like this will need a special case below to be accurately
|
||||
// estimated.
|
||||
var kvTotal uint64
|
||||
kv = make([]uint64, f.KV().BlockCount())
|
||||
kvSizeAttn := uint64(0)
|
||||
kvSizeRecurrent := uint64(0)
|
||||
for i := range kv {
|
||||
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
headsL := headsArr[i]
|
||||
headsKVL := headsKVArr[i]
|
||||
if headsL > 0 && headsKVL > 0 {
|
||||
// full attention layer
|
||||
// NOTE: Assumes uniform values for all attn layers
|
||||
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKVL) * bytesPerElement)
|
||||
kvSizeAttn += kv[i]
|
||||
} else {
|
||||
// recurrent layer
|
||||
ssmDConv := f.KV().SSMConvKernel()
|
||||
ssmDState := f.KV().SSMStateSize()
|
||||
ssmDInner := f.KV().SSMInnerSize()
|
||||
ssmNGroups := f.KV().SSMGroupCount()
|
||||
nEmbdR := uint64(0)
|
||||
if ssmDConv > 0 {
|
||||
nEmbdR = (ssmDConv - 1) * (ssmDInner + 2*ssmNGroups*ssmDState)
|
||||
}
|
||||
nEmbdS := ssmDState * ssmDInner
|
||||
|
||||
// recurrent always uses F32 in llama.cpp backend
|
||||
// https://github.com/ggml-org/llama.cpp/blob/master/src/llama-model.cpp#L18644
|
||||
bytesPerElementRecurrent := kvCacheBytesPerElement("f32")
|
||||
|
||||
kv[i] = (nEmbdR + nEmbdS) * uint64(bytesPerElementRecurrent)
|
||||
kvSizeRecurrent += kv[i]
|
||||
}
|
||||
kvTotal += kv[i]
|
||||
}
|
||||
slog.Debug("default cache size estimate", "attention MiB", float32(kvSizeAttn)/(1024.*1024.), "attention bytes", kvSizeAttn, "recurrent MiB", float32(kvSizeRecurrent)/(1024.*1024.), "recurrent bytes", kvSizeRecurrent)
|
||||
|
||||
switch f.KV().Architecture() {
|
||||
case "llama", "llama4":
|
||||
@@ -794,6 +899,8 @@ func kvCacheBytesPerElement(cacheType string) float64 {
|
||||
return 1 // 1/2 of fp16
|
||||
case "q4_0":
|
||||
return 0.5 // 1/4 of fp16
|
||||
case "f32":
|
||||
return 4 // f32 (default for recurrent)
|
||||
default:
|
||||
return 2 // f16 (default)
|
||||
}
|
||||
|
||||
@@ -1,18 +1,31 @@
|
||||
package harmony
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
type harmonyParserState int
|
||||
|
||||
func ShouldUseHarmony(modelFamily string, template *template.Template) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
// search for harmony tags that are nearly always used
|
||||
if template.Contains("<|start|>") && template.Contains("<|end|>") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
harmonyParserState_LookingForMessageStart harmonyParserState = iota
|
||||
harmonyParserState_ParsingHeader
|
||||
@@ -76,18 +89,28 @@ func (s *HarmonyParser) AddImplicitStart() {
|
||||
s.acc.WriteString("<|start|>assistant")
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
|
||||
if lastMessage != nil && lastMessage.Role == "assistant" {
|
||||
// handle prefilling conditions
|
||||
if lastMessage.Content != "" {
|
||||
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
|
||||
return
|
||||
} else if lastMessage.Thinking != "" {
|
||||
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
|
||||
return
|
||||
}
|
||||
func Prefill(lastMessage api.Message) string {
|
||||
if lastMessage.Role != "assistant" {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.TrimSpace(lastMessage.Content) != "":
|
||||
return "<|start|>assistant<|channel|>final<|message|>"
|
||||
case strings.TrimSpace(lastMessage.Thinking) != "":
|
||||
return "<|start|>assistant<|channel|>analysis<|message|>"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// AddImplicitStartOrPrefill adds an implicit start tag or prefill string if provided
|
||||
func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillString string) {
|
||||
if strings.TrimSpace(prefillString) != "" {
|
||||
s.acc.WriteString(prefillString)
|
||||
} else {
|
||||
s.AddImplicitStart()
|
||||
}
|
||||
s.AddImplicitStart()
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) AddContent(content string) []HarmonyEvent {
|
||||
@@ -292,7 +315,7 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case HarmonyEventHeaderComplete:
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "harmony event header complete", "header", event.Header)
|
||||
logutil.Trace("harmony event header complete", "header", event.Header)
|
||||
switch event.Header.Channel {
|
||||
case "analysis":
|
||||
if event.Header.Recipient != "" {
|
||||
@@ -315,7 +338,7 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
|
||||
h.state = harmonyMessageState_Normal
|
||||
}
|
||||
case HarmonyEventContentEmitted:
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "harmony event content", "content", event.Content, "state", h.state)
|
||||
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
|
||||
if h.state == harmonyMessageState_Normal {
|
||||
contentSb.WriteString(event.Content)
|
||||
} else if h.state == harmonyMessageState_Thinking {
|
||||
|
||||
@@ -3,6 +3,7 @@ package harmony
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -535,3 +536,202 @@ func TestFunctionConvertAndAdd(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
|
||||
t.Run("thinking_then_content_streams", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.CreateToolParser()
|
||||
type step struct {
|
||||
in string
|
||||
wantContent string
|
||||
wantThinking string
|
||||
}
|
||||
steps := []step{
|
||||
{in: "<|channel|>analysis<|message|>Thinking...", wantThinking: "Thinking..."},
|
||||
{in: "<|end|>", wantThinking: ""},
|
||||
{in: "<|start|>assistant<|message|>Answer", wantContent: "Answer"},
|
||||
{in: "<|end|>", wantContent: ""},
|
||||
}
|
||||
for i, s := range steps {
|
||||
content, thinking, tool := handler.AddContent(s.in, tp)
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
if content != s.wantContent || thinking != s.wantThinking {
|
||||
t.Fatalf("step %d: got (content=%q thinking=%q), want (content=%q thinking=%q)", i, content, thinking, s.wantContent, s.wantThinking)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("content_streams_as_it_arrives", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.CreateToolParser()
|
||||
inputs := []string{
|
||||
"<|start|>assistant<|message|>Hello",
|
||||
", world",
|
||||
"!<|end|>",
|
||||
}
|
||||
var got []string
|
||||
for _, in := range inputs {
|
||||
content, thinking, tool := handler.AddContent(in, tp)
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("unexpected thinking %q", thinking)
|
||||
}
|
||||
if content != "" {
|
||||
got = append(got, content)
|
||||
}
|
||||
}
|
||||
want := []string{"Hello", ", world", "!"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking_streams_separately_from_content", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.CreateToolParser()
|
||||
inputs := []string{
|
||||
"<|channel|>analysis<|message|>Thinking...",
|
||||
"<|end|>",
|
||||
"<|start|>assistant<|message|>Answer",
|
||||
"<|end|>",
|
||||
}
|
||||
var got []string
|
||||
for _, in := range inputs {
|
||||
content, thinking, tool := handler.AddContent(in, tp)
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
if thinking != "" {
|
||||
got = append(got, thinking)
|
||||
}
|
||||
if content != "" {
|
||||
got = append(got, content)
|
||||
}
|
||||
}
|
||||
want := []string{"Thinking...", "Answer"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("partial_tags_buffer_until_complete", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.CreateToolParser()
|
||||
inputs := []string{
|
||||
"<|chan",
|
||||
"nel|>analysis<|mess",
|
||||
"age|>Deep ",
|
||||
"thought",
|
||||
"<|end|>",
|
||||
"<|start|>assistant<|message|>Done",
|
||||
"<|end|>",
|
||||
}
|
||||
var thinkingPieces []string
|
||||
var contentPieces []string
|
||||
for _, in := range inputs {
|
||||
content, thinking, tool := handler.AddContent(in, tp)
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
if thinking != "" {
|
||||
thinkingPieces = append(thinkingPieces, thinking)
|
||||
}
|
||||
if content != "" {
|
||||
contentPieces = append(contentPieces, content)
|
||||
}
|
||||
}
|
||||
if want := []string{"Deep ", "thought"}; !reflect.DeepEqual(thinkingPieces, want) {
|
||||
t.Fatalf("thinking pieces mismatch: got %v want %v", thinkingPieces, want)
|
||||
}
|
||||
if want := []string{"Done"}; !reflect.DeepEqual(contentPieces, want) {
|
||||
t.Fatalf("content pieces mismatch: got %v want %v", contentPieces, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("simple_assistant_after_analysis", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.CreateToolParser()
|
||||
inputs := []string{
|
||||
"<|channel|>analysis<|message|>Think",
|
||||
"<|end|>",
|
||||
"<|start|>assistant<|message|>Answer",
|
||||
"<|end|>",
|
||||
}
|
||||
var contentSb, thinkingSb strings.Builder
|
||||
for _, in := range inputs {
|
||||
content, thinking, tool := handler.AddContent(in, tp)
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
contentSb.WriteString(content)
|
||||
thinkingSb.WriteString(thinking)
|
||||
}
|
||||
if contentSb.String() != "Answer" {
|
||||
t.Fatalf("content mismatch: got %q want %q", contentSb.String(), "Answer")
|
||||
}
|
||||
if thinkingSb.String() != "Think" {
|
||||
t.Fatalf("thinking mismatch: got %q want %q", thinkingSb.String(), "Think")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.CreateToolParser()
|
||||
inputs := []string{
|
||||
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>",
|
||||
}
|
||||
for _, in := range inputs {
|
||||
content, thinking, tool := handler.AddContent(in, tp)
|
||||
if content != "" || thinking != "" {
|
||||
continue
|
||||
}
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
}
|
||||
name, args := tp.Drain()
|
||||
if name == nil || *name != "functions.calculate" {
|
||||
t.Fatalf("unexpected tool name: %v", name)
|
||||
}
|
||||
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
|
||||
t.Fatalf("unexpected tool args: got %s want %s", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tool_call_across_chunks", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.CreateToolParser()
|
||||
inputs := []string{
|
||||
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+",
|
||||
"2\"}",
|
||||
"<|end|>",
|
||||
}
|
||||
for _, in := range inputs {
|
||||
content, thinking, tool := handler.AddContent(in, tp)
|
||||
if content != "" || thinking != "" {
|
||||
continue
|
||||
}
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
}
|
||||
name, args := tp.Drain()
|
||||
if name == nil || *name != "functions.calculate" {
|
||||
t.Fatalf("unexpected tool name: %v", name)
|
||||
}
|
||||
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
|
||||
t.Fatalf("unexpected tool args: got %s want %s", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -410,3 +410,99 @@ func TestAPIEmbeddings(t *testing.T) {
|
||||
t.Errorf("zero length embedding response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIToolCalling(t *testing.T) {
|
||||
initialTimeout := 60 * time.Second
|
||||
streamTimeout := 30 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
modelName := "qwen3:0.6b"
|
||||
if err := PullIfMissing(ctx, client, modelName); err != nil {
|
||||
t.Fatalf("pull failed %s", err)
|
||||
}
|
||||
|
||||
tools := []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req := api.ChatRequest{
|
||||
Model: modelName,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Call get_weather with location set to San Francisco.",
|
||||
},
|
||||
},
|
||||
Tools: tools,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
},
|
||||
}
|
||||
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var gotToolCall bool
|
||||
var lastToolCall api.ToolCall
|
||||
|
||||
fn := func(response api.ChatResponse) error {
|
||||
if len(response.Message.ToolCalls) > 0 {
|
||||
gotToolCall = true
|
||||
lastToolCall = response.Message.ToolCalls[len(response.Message.ToolCalls)-1]
|
||||
}
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return fmt.Errorf("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
req.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String())
|
||||
case <-done:
|
||||
if genErr != nil {
|
||||
t.Fatalf("chat failed: %v", genErr)
|
||||
}
|
||||
|
||||
if !gotToolCall {
|
||||
t.Fatalf("expected at least one tool call, got none")
|
||||
}
|
||||
|
||||
if lastToolCall.Function.Name != "get_weather" {
|
||||
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
|
||||
}
|
||||
|
||||
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
|
||||
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for tool-calling chat")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,6 +121,7 @@ func TestMultiModelStress(t *testing.T) {
|
||||
// The intent is to go 1 over what can fit so we force the scheduler to thrash
|
||||
targetLoadCount := 0
|
||||
slog.Info("Loading models to find how many can fit in VRAM before overflowing")
|
||||
chooseModels:
|
||||
for i, model := range chosenModels {
|
||||
req := &api.GenerateRequest{Model: model}
|
||||
slog.Info("loading", "model", model)
|
||||
@@ -142,6 +143,13 @@ func TestMultiModelStress(t *testing.T) {
|
||||
slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount])
|
||||
break
|
||||
}
|
||||
// Effectively limit model count to 2 on CPU only systems to avoid thrashing and timeouts
|
||||
for _, m := range models.Models {
|
||||
if m.SizeVRAM == 0 {
|
||||
slog.Info("model running on CPU", "name", m.Name, "target", targetLoadCount, "chosen", chosenModels[:targetLoadCount])
|
||||
break chooseModels
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if targetLoadCount == len(chosenModels) {
|
||||
|
||||
@@ -36,7 +36,7 @@ func TestLongInputContext(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("PullIfMissing failed: %v", err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
|
||||
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
func TestContextExhaustion(t *testing.T) {
|
||||
|
||||
@@ -38,8 +38,9 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "all-minilm",
|
||||
Prompt: "why is the sky blue?",
|
||||
Model: "all-minilm",
|
||||
Prompt: "why is the sky blue?",
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
}
|
||||
|
||||
res, err := embeddingTestHelper(ctx, client, t, req)
|
||||
|
||||
@@ -502,6 +502,22 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
var response string
|
||||
verify := func() {
|
||||
// Verify the response contains the expected data
|
||||
response = buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
@@ -517,21 +533,14 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||
if genErr != nil {
|
||||
t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt)
|
||||
}
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||
}
|
||||
verify()
|
||||
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
|
||||
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
|
||||
// if they are still generating valid responses
|
||||
slog.Warn("outer test context done while waiting for generate")
|
||||
verify()
|
||||
}
|
||||
return context
|
||||
}
|
||||
@@ -599,6 +608,22 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
var response string
|
||||
verify := func() {
|
||||
// Verify the response contains the expected data
|
||||
response = buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
@@ -614,23 +639,14 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR
|
||||
if genErr != nil {
|
||||
t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages)
|
||||
}
|
||||
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
||||
}
|
||||
|
||||
verify()
|
||||
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
|
||||
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
|
||||
// if they are still generating valid responses
|
||||
slog.Warn("outer test context done while waiting for chat")
|
||||
verify()
|
||||
}
|
||||
return &api.Message{Role: role, Content: buf.String()}
|
||||
}
|
||||
|
||||
@@ -173,6 +173,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
opts.NumCtx = int(trainCtx)
|
||||
}
|
||||
|
||||
opts.NumBatch = min(opts.NumBatch, opts.NumCtx)
|
||||
|
||||
loadRequest := LoadRequest{LoraPath: adapters, KvSize: opts.NumCtx * numParallel, BatchSize: opts.NumBatch, Parallel: numParallel, MultiUserCache: envconfig.MultiUserCache()}
|
||||
|
||||
defaultThreads := discover.GetSystemInfo().GetOptimalThreadCount()
|
||||
@@ -678,8 +680,12 @@ func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requ
|
||||
|
||||
if !(len(gpus) == 1 && gpus[0].Library == "cpu") {
|
||||
for _, gpu := range gpus {
|
||||
available := gpu.FreeMemory - envconfig.GpuOverhead() - gpu.MinimumMemory
|
||||
if gpu.FreeMemory < envconfig.GpuOverhead()+gpu.MinimumMemory {
|
||||
available = 0
|
||||
}
|
||||
slog.Info("gpu memory", "id", gpu.ID,
|
||||
"available", format.HumanBytes2(gpu.FreeMemory-envconfig.GpuOverhead()-gpu.MinimumMemory),
|
||||
"available", format.HumanBytes2(available),
|
||||
"free", format.HumanBytes2(gpu.FreeMemory),
|
||||
"minimum", format.HumanBytes2(gpu.MinimumMemory),
|
||||
"overhead", format.HumanBytes2(envconfig.GpuOverhead()))
|
||||
@@ -861,7 +867,7 @@ func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs d
|
||||
}
|
||||
layers[i] += memory.CPU.Weights[i].Size
|
||||
layers[i] += memory.CPU.Cache[i].Size
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "layer to assign", "layer", i, "size", format.HumanBytes2(layers[i]))
|
||||
logutil.Trace("layer to assign", "layer", i, "size", format.HumanBytes2(layers[i]))
|
||||
}
|
||||
|
||||
gpuLayers := ml.GPULayersList{}
|
||||
@@ -1343,7 +1349,9 @@ type CompletionRequest struct {
|
||||
Images []ImageData
|
||||
Options *api.Options
|
||||
|
||||
Grammar string // set before sending the request to the subprocess
|
||||
Grammar string // set before sending the request to the subprocess
|
||||
UseHarmony bool
|
||||
PrefillString string
|
||||
}
|
||||
|
||||
// DoneReason represents the reason why a completion response is done
|
||||
@@ -1356,6 +1364,8 @@ const (
|
||||
DoneReasonLength
|
||||
// DoneReasonConnectionClosed indicates the completion stopped due to the connection being closed
|
||||
DoneReasonConnectionClosed
|
||||
// DoneReasonTokenRepeatLimit indicates the completion stopped due to a token repeat limit
|
||||
DoneReasonTokenRepeatLimit
|
||||
)
|
||||
|
||||
func (d DoneReason) String() string {
|
||||
@@ -1364,19 +1374,23 @@ func (d DoneReason) String() string {
|
||||
return "length"
|
||||
case DoneReasonStop:
|
||||
return "stop"
|
||||
case DoneReasonTokenRepeatLimit:
|
||||
return "token_repeat_limit"
|
||||
default:
|
||||
return "" // closed
|
||||
}
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string `json:"content"`
|
||||
DoneReason DoneReason `json:"done_reason"`
|
||||
Done bool `json:"done"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
EvalDuration time.Duration `json:"eval_duration"`
|
||||
Content string `json:"content"`
|
||||
Thinking string `json:"thinking"`
|
||||
ToolCalls []api.ToolCall `json:"tool_calls"`
|
||||
DoneReason DoneReason `json:"done_reason"`
|
||||
Done bool `json:"done"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
EvalDuration time.Duration `json:"eval_duration"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
@@ -1494,7 +1508,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
||||
}
|
||||
switch {
|
||||
case strings.TrimSpace(c.Content) == lastToken:
|
||||
// TODO(parthsareen): token repeat limit is now handled in the runner, this currently support legacy model and can be removed in the future
|
||||
case strings.TrimSpace(c.Content) == lastToken && c.Content != "":
|
||||
tokenRepeat++
|
||||
default:
|
||||
lastToken = strings.TrimSpace(c.Content)
|
||||
@@ -1507,16 +1522,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
if c.Content != "" {
|
||||
fn(CompletionResponse{
|
||||
Content: c.Content,
|
||||
})
|
||||
}
|
||||
|
||||
if c.Done {
|
||||
fn(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 {
|
||||
fn(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package logutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
@@ -27,3 +28,11 @@ func NewLogger(w io.Writer, level slog.Level) *slog.Logger {
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
func Trace(msg string, args ...any) {
|
||||
slog.Log(context.TODO(), LevelTrace, msg, args...)
|
||||
}
|
||||
|
||||
func TraceContext(ctx context.Context, msg string, args ...any) {
|
||||
slog.Log(ctx, LevelTrace, msg, args...)
|
||||
}
|
||||
|
||||
@@ -271,7 +271,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
tt := C.ggml_new_tensor(ctxs[bt], kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
|
||||
C.ggml_set_name(tt, cname)
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
|
||||
logutil.Trace("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
|
||||
|
||||
size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt))
|
||||
if layer == -1 {
|
||||
@@ -378,7 +378,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
|
||||
for bs := range maps.Values(bbs) {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)),
|
||||
logutil.Trace("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)),
|
||||
"size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
|
||||
}
|
||||
|
||||
@@ -811,7 +811,7 @@ func (c *Context) Reserve() {
|
||||
}
|
||||
}
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])),
|
||||
logutil.Trace("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])),
|
||||
"buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), "size", format.HumanBytes2(uint64(bufferStatus.size)))
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ package model
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"fmt"
|
||||
"iter"
|
||||
"log/slog"
|
||||
@@ -202,12 +201,11 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
}
|
||||
}
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids)
|
||||
|
||||
if addSpecial && len(ids) > 0 {
|
||||
ids = bpe.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
@@ -243,6 +241,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
|
||||
logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
@@ -105,6 +104,10 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
}
|
||||
|
||||
arch := b.Config().Architecture()
|
||||
if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 {
|
||||
arch = arch + "_embed"
|
||||
}
|
||||
|
||||
f, ok := models[arch]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
||||
@@ -198,7 +201,7 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
names := fn(tagsCopy)
|
||||
for _, name := range names {
|
||||
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "found tensor", "", tensor)
|
||||
logutil.Trace("found tensor", "", tensor)
|
||||
vv.Set(reflect.ValueOf(tensor))
|
||||
break
|
||||
}
|
||||
|
||||
73
model/models/gemma3/embed.go
Normal file
73
model/models/gemma3/embed.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package gemma3
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
|
||||
*TextModel
|
||||
PoolingType uint32
|
||||
|
||||
Dense [2]*nn.Linear `gguf:"dense"`
|
||||
}
|
||||
|
||||
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
batch.Outputs = batch.Positions // return all positions
|
||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
|
||||
switch m.PoolingType {
|
||||
case 0: // None
|
||||
case 1: // Mean
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
default:
|
||||
return nil, errors.New("unsupported pooling type")
|
||||
}
|
||||
|
||||
for _, dense := range m.Dense {
|
||||
hiddenStates = dense.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
return hiddenStates, nil
|
||||
}
|
||||
|
||||
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||
m := &embedModel{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{
|
||||
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
|
||||
},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
),
|
||||
TextModel: newTextModel(c),
|
||||
PoolingType: c.Uint("pooling_type", 0),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
@@ -141,12 +141,11 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("gemma3", New)
|
||||
model.Register("gemma3_embed", newEmbedModel)
|
||||
}
|
||||
|
||||
@@ -159,8 +159,11 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||
|
||||
// set image embeddings
|
||||
@@ -198,5 +201,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState)
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
@@ -12,4 +12,5 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/qwen2"
|
||||
_ "github.com/ollama/ollama/model/models/qwen25vl"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3vl"
|
||||
)
|
||||
|
||||
@@ -44,8 +44,8 @@ func New(c fs.Config) (model.Model, error) {
|
||||
},
|
||||
),
|
||||
TextModel: NewTextModel(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: NewVisionModel(c),
|
||||
ImageProcessor: NewImageProcessor(c),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
||||
@@ -65,8 +65,8 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *
|
||||
}
|
||||
|
||||
// Calculate tensor dimensions
|
||||
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
|
||||
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
|
||||
patchDim := m.ImageProcessor.NumChannels * m.ImageProcessor.TemporalPatchSize *
|
||||
m.ImageProcessor.PatchSize * m.ImageProcessor.PatchSize
|
||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||
|
||||
pixelValues := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
|
||||
|
||||
@@ -345,8 +345,8 @@ func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor
|
||||
return positionalEmbedding
|
||||
}
|
||||
|
||||
// newVisionModel creates a new instance of the Qwen vision model
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
// NewVisionModel creates a new instance of the Qwen vision model
|
||||
func NewVisionModel(c fs.Config) *VisionModel {
|
||||
patchSize := int(c.Uint("vision.patch_size", 14))
|
||||
hiddenSize := int(c.Uint("vision.embedding_length", 1280))
|
||||
numHeads := int(c.Uint("vision.attention.head_count", 16))
|
||||
|
||||
@@ -11,40 +11,40 @@ import (
|
||||
|
||||
// ImageProcessor contains configuration for the Qwen 2.5 VL image processing
|
||||
type ImageProcessor struct {
|
||||
numChannels int
|
||||
patchSize int
|
||||
temporalPatchSize int
|
||||
mergeSize int
|
||||
minPixels int
|
||||
maxPixels int
|
||||
factor int
|
||||
rescaleFactor float32
|
||||
imageMean []float32
|
||||
imageStd []float32
|
||||
NumChannels int
|
||||
PatchSize int
|
||||
TemporalPatchSize int
|
||||
MergeSize int
|
||||
MinPixels int
|
||||
MaxPixels int
|
||||
Factor int
|
||||
RescaleFactor float32
|
||||
ImageMean []float32
|
||||
ImageStd []float32
|
||||
}
|
||||
|
||||
// newImageProcessor creates a new image processor with default values
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
func NewImageProcessor(c fs.Config) ImageProcessor {
|
||||
patchSize := int(c.Uint("vision.patch_size", 14))
|
||||
mergeSize := int(c.Uint("vision.spatial_merge_size", 2))
|
||||
|
||||
return ImageProcessor{
|
||||
numChannels: int(c.Uint("vision.num_channels", 3)), // not set
|
||||
patchSize: patchSize,
|
||||
temporalPatchSize: 2,
|
||||
mergeSize: mergeSize,
|
||||
minPixels: 56 * 56,
|
||||
maxPixels: int(c.Uint("vision.max_pixels", 28*28*1280)), // 1MP limit
|
||||
factor: patchSize * mergeSize,
|
||||
rescaleFactor: 1.0 / 255.0,
|
||||
imageMean: imageproc.ClipDefaultMean[:],
|
||||
imageStd: imageproc.ClipDefaultSTD[:],
|
||||
NumChannels: int(c.Uint("vision.num_channels", 3)), // not set
|
||||
PatchSize: patchSize,
|
||||
TemporalPatchSize: 2,
|
||||
MergeSize: mergeSize,
|
||||
MinPixels: 56 * 56,
|
||||
MaxPixels: int(c.Uint("vision.max_pixels", 28*28*1280)), // 1MP limit
|
||||
Factor: patchSize * mergeSize,
|
||||
RescaleFactor: 1.0 / 255.0,
|
||||
ImageMean: imageproc.ClipDefaultMean[:],
|
||||
ImageStd: imageproc.ClipDefaultSTD[:],
|
||||
}
|
||||
}
|
||||
|
||||
// SmartResize implements the smart resize algorithm
|
||||
func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
|
||||
factor := p.factor
|
||||
factor := p.Factor
|
||||
|
||||
if height < factor || width < factor {
|
||||
panic(fmt.Sprintf("height:%d or width:%d must be larger than factor:%d", height, width, factor))
|
||||
@@ -57,13 +57,13 @@ func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
|
||||
hBar := round(float64(height)/float64(factor)) * factor
|
||||
wBar := round(float64(width)/float64(factor)) * factor
|
||||
|
||||
if hBar*wBar > p.maxPixels {
|
||||
beta := math.Sqrt(float64(height*width) / float64(p.maxPixels))
|
||||
if hBar*wBar > p.MaxPixels {
|
||||
beta := math.Sqrt(float64(height*width) / float64(p.MaxPixels))
|
||||
|
||||
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
|
||||
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
|
||||
} else if hBar*wBar < p.minPixels {
|
||||
beta := math.Sqrt(float64(p.minPixels) / float64(height*width))
|
||||
} else if hBar*wBar < p.MinPixels {
|
||||
beta := math.Sqrt(float64(p.MinPixels) / float64(height*width))
|
||||
|
||||
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
|
||||
@@ -90,16 +90,16 @@ func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error)
|
||||
|
||||
normalizedPixels := imageproc.Normalize(
|
||||
resizedImg,
|
||||
[3]float32{p.imageMean[0], p.imageMean[1], p.imageMean[2]},
|
||||
[3]float32{p.imageStd[0], p.imageStd[1], p.imageStd[2]},
|
||||
[3]float32{p.ImageMean[0], p.ImageMean[1], p.ImageMean[2]},
|
||||
[3]float32{p.ImageStd[0], p.ImageStd[1], p.ImageStd[2]},
|
||||
true, // rescale
|
||||
true, // channelFirst
|
||||
)
|
||||
|
||||
// Calculate grid dimensions
|
||||
grid := &Grid{
|
||||
Height: resizedHeight / p.patchSize,
|
||||
Width: resizedWidth / p.patchSize,
|
||||
Height: resizedHeight / p.PatchSize,
|
||||
Width: resizedWidth / p.PatchSize,
|
||||
Temporal: 1, // For single images, temporal dimension is 1
|
||||
}
|
||||
|
||||
@@ -113,10 +113,10 @@ func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error)
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) {
|
||||
channels := p.numChannels
|
||||
patchSize := p.patchSize
|
||||
mergeSize := p.mergeSize
|
||||
temporalPatchSize := p.temporalPatchSize
|
||||
channels := p.NumChannels
|
||||
patchSize := p.PatchSize
|
||||
mergeSize := p.MergeSize
|
||||
temporalPatchSize := p.TemporalPatchSize
|
||||
|
||||
// Calculate output dimensions
|
||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||
|
||||
153
model/models/qwen3vl/model.go
Normal file
153
model/models/qwen3vl/model.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package qwen3vl
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"image"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/model/models/qwen25vl"
|
||||
"github.com/ollama/ollama/model/models/qwen3"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
TextModel *qwen3.Model
|
||||
*qwen25vl.VisionModel
|
||||
|
||||
qwen25vl.ImageProcessor
|
||||
}
|
||||
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
textModel, err := qwen3.New(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
),
|
||||
TextModel: textModel.(*qwen3.Model),
|
||||
VisionModel: qwen25vl.NewVisionModel(c),
|
||||
ImageProcessor: qwen25vl.NewImageProcessor(c),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *qwen25vl.Grid, error) {
|
||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
f32s, grid, err := m.ImageProcessor.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Calculate tensor dimensions
|
||||
patchDim := m.ImageProcessor.NumChannels * m.ImageProcessor.TemporalPatchSize *
|
||||
m.ImageProcessor.PatchSize * m.ImageProcessor.PatchSize
|
||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||
|
||||
pixelValues := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
|
||||
|
||||
return pixelValues, grid, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
pixels, grid, err := m.PixelValues(ctx, multimodalData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
|
||||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
||||
}
|
||||
|
||||
// PostTokenize arranges Qwen-3-VL's inputs for the forward pass
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
var (
|
||||
imageToken int32 = 151655
|
||||
visionStartToken int32 = 151652
|
||||
visionEndToken int32 = 151653
|
||||
)
|
||||
|
||||
nImg := 0
|
||||
for _, inp := range inputs {
|
||||
if inp.Multimodal == nil {
|
||||
// If not a multimodal input, add it to the result unchanged
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
// Adding the 'Picture' prefix is a hack, at the time of writing there is no way to prefix
|
||||
// the image tokens with a prompt, so we add a prefix here
|
||||
nImg++
|
||||
pre, err := m.Encode(fmt.Sprintf(" Picture %d: ", nImg), true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
|
||||
}
|
||||
for i := range pre {
|
||||
result = append(result, &input.Input{Token: pre[i]})
|
||||
}
|
||||
|
||||
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
|
||||
|
||||
// First add the vision start token
|
||||
result = append(result, &input.Input{Token: visionStartToken})
|
||||
|
||||
// Add the image token with the multimodal tensor data at the first position
|
||||
result = append(result, &input.Input{
|
||||
Token: imageToken,
|
||||
Multimodal: inp.Multimodal,
|
||||
MultimodalHash: inp.MultimodalHash,
|
||||
SameBatch: patchesPerChunk,
|
||||
})
|
||||
|
||||
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
||||
|
||||
result = append(result, &input.Input{Token: visionEndToken})
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
return m.TextModel.Forward(ctx, batch)
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("qwen3vl", New)
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package model
|
||||
|
||||
import (
|
||||
"container/heap"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
@@ -25,7 +24,7 @@ func (spm SentencePieceModel) Vocabulary() *Vocabulary {
|
||||
}
|
||||
|
||||
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||
|
||||
counter := map[int]int{}
|
||||
var maxTokenLen int
|
||||
@@ -39,7 +38,7 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
||||
}
|
||||
}
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
||||
logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
||||
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
||||
"max token len", maxTokenLen)
|
||||
|
||||
@@ -182,12 +181,11 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids)
|
||||
|
||||
if addSpecial && len(ids) > 0 {
|
||||
ids = spm.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
@@ -246,6 +244,6 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
||||
}
|
||||
}
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String())
|
||||
logutil.Trace("decoded", "ids", ids, "string", sb.String())
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
|
||||
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
|
||||
}
|
||||
|
||||
slog.Debug("adding bos token to prompt", "id", v.BOS)
|
||||
slog.Debug("adding bos token to prompt", "id", v.BOS[0])
|
||||
ids = append([]int32{v.BOS[0]}, ids...)
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
|
||||
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
|
||||
}
|
||||
|
||||
slog.Debug("adding eos token to prompt", "id", v.EOS)
|
||||
slog.Debug("adding eos token to prompt", "id", v.EOS[0])
|
||||
ids = append(ids, v.EOS[0])
|
||||
}
|
||||
|
||||
|
||||
@@ -246,7 +246,7 @@ func filesForModel(path string) ([]string, error) {
|
||||
for _, match := range matches {
|
||||
if ct, err := detectContentType(match); err != nil {
|
||||
return nil, err
|
||||
} else if ct != contentType {
|
||||
} else if len(contentType) > 0 && ct != contentType {
|
||||
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
|
||||
}
|
||||
}
|
||||
@@ -255,7 +255,8 @@ func filesForModel(path string) ([]string, error) {
|
||||
}
|
||||
|
||||
var files []string
|
||||
if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 {
|
||||
// some safetensors files do not properly match "application/octet-stream", so skip checking their contentType
|
||||
if st, _ := glob(filepath.Join(path, "*.safetensors"), ""); len(st) > 0 {
|
||||
// safetensors files might be unresolved git lfs references; skip if they are
|
||||
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
|
||||
files = append(files, st...)
|
||||
|
||||
@@ -34,8 +34,8 @@ type InputCache struct {
|
||||
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
|
||||
numCtx := kvSize / int32(numSlots)
|
||||
|
||||
if numCtx < 1 {
|
||||
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
|
||||
if int(numCtx) < batchSize {
|
||||
return nil, fmt.Errorf("kv size must be at least as large as batch size * parallel (kv: %v batch: %v parallel: %v)", kvSize, batchSize, numSlots)
|
||||
}
|
||||
|
||||
slots := make([]InputCacheSlot, numSlots)
|
||||
@@ -70,11 +70,9 @@ func kvCacheTypeFromStr(s string) ml.DType {
|
||||
}
|
||||
|
||||
func (c *InputCache) Close() {
|
||||
if c == nil {
|
||||
return
|
||||
if c != nil && c.cache != nil {
|
||||
c.cache.Close()
|
||||
}
|
||||
|
||||
c.cache.Close()
|
||||
}
|
||||
|
||||
// Locking: Operations on InputCacheSlot (including finding one
|
||||
@@ -95,7 +93,7 @@ type InputCacheSlot struct {
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*input.Input, error) {
|
||||
func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*InputCacheSlot, []*input.Input, error) {
|
||||
var slot *InputCacheSlot
|
||||
var numPast int32
|
||||
var err error
|
||||
@@ -113,6 +111,10 @@ func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*i
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if !cachePrompt {
|
||||
numPast = 0
|
||||
}
|
||||
|
||||
slot.InUse = true
|
||||
slot.lastUsed = time.Now()
|
||||
|
||||
|
||||
@@ -393,7 +393,7 @@ func TestLoadCacheSlot(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
|
||||
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt, true)
|
||||
|
||||
// Check error state
|
||||
if (err != nil) != tt.wantErr {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"image"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -29,6 +30,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
@@ -405,6 +407,8 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
||||
func (s *Server) run(ctx context.Context) {
|
||||
s.ready.Wait()
|
||||
|
||||
supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32
|
||||
|
||||
var activeBatch batchState
|
||||
for {
|
||||
select {
|
||||
@@ -418,7 +422,12 @@ func (s *Server) run(ctx context.Context) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
go s.computeBatch(activeBatch)
|
||||
|
||||
if supportsAsync {
|
||||
go s.computeBatch(activeBatch)
|
||||
} else {
|
||||
s.computeBatch(activeBatch)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -429,12 +438,12 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
// before setting up the next batch so the seqs inputs are ready to receive their
|
||||
// token values and we get the correct input pointers for the batchInputs
|
||||
if pendingBatch.ctx != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
|
||||
logutil.Trace("forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
|
||||
<-pendingBatch.computeStartedCh
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
|
||||
logutil.Trace("forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
|
||||
nextBatch.inputsReadyCh = pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
|
||||
} else {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no pending batch detected", "batchID", s.batchID)
|
||||
logutil.Trace("forwardBatch no pending batch detected", "batchID", s.batchID)
|
||||
// No pendingBatch, so the inputs will be ready in the seqs immediately
|
||||
nextBatch.inputsReadyCh = make(chan struct{}, 1)
|
||||
nextBatch.inputsReadyCh <- struct{}{}
|
||||
@@ -546,7 +555,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
if i+1 == len(seq.inputs) {
|
||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||
}
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
||||
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||
}
|
||||
|
||||
@@ -560,7 +569,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
}
|
||||
|
||||
if len(batchInputs) == 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no batchInputs, going idle", "batchID", s.batchID)
|
||||
logutil.Trace("forwardBatch no batchInputs, going idle", "batchID", s.batchID)
|
||||
nextBatch.ctx.Close()
|
||||
nextBatch.ctx = nil
|
||||
return
|
||||
@@ -589,14 +598,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
defer activeBatch.ctx.Close()
|
||||
|
||||
// Wait until inputs are ready
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
|
||||
logutil.Trace("computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
|
||||
<-activeBatch.inputsReadyCh
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: inputs are ready", "batchID", activeBatch.id)
|
||||
logutil.Trace("computeBatch: inputs are ready", "batchID", activeBatch.id)
|
||||
|
||||
// Once we complete, signal the next batch of inputs are ready
|
||||
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
|
||||
defer func() {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: outputs are ready", "batchID", activeBatch.id)
|
||||
logutil.Trace("computeBatch: outputs are ready", "batchID", activeBatch.id)
|
||||
activeBatch.outputsReadyCh <- struct{}{}
|
||||
}()
|
||||
|
||||
@@ -626,7 +635,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
// Detect if the sequence we're processing has already been completed and replaced
|
||||
// with a new sequence
|
||||
if seq != activeBatch.seqs[i] {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
|
||||
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -666,18 +675,19 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
activeBatch.batch.Inputs.SetValueFromIntSlice(batchInputs)
|
||||
activeBatch.ctx.ComputeWithNotify(
|
||||
func() {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
|
||||
logutil.Trace("computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
|
||||
activeBatch.computeStartedCh <- struct{}{}
|
||||
},
|
||||
activeBatch.modelOutput)
|
||||
logits := activeBatch.modelOutput.Floats()
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: logits ready", "batchID", activeBatch.id)
|
||||
outputs := activeBatch.modelOutput.Floats()
|
||||
|
||||
logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: decoding", "batchID", activeBatch.id)
|
||||
logutil.Trace("computeBatch: decoding", "batchID", activeBatch.id)
|
||||
for i, seq := range s.seqs {
|
||||
if seq == nil || nextBatchTokens[i] == nil {
|
||||
continue
|
||||
@@ -689,16 +699,15 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
|
||||
// if done processing the prompt, generate an embedding and return
|
||||
if seq.embeddingOnly {
|
||||
// TODO(jessegross): Embedding support
|
||||
slog.Warn("generation of embedding outputs not yet supported", "id", activeBatch.id, "seqIdx", i)
|
||||
seq.embedding <- outputs
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
// sample a token
|
||||
vocabSize := len(logits) / len(activeBatch.batch.Outputs)
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(logits), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||
token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||
vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
|
||||
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||
if err != nil {
|
||||
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||
return
|
||||
@@ -711,7 +720,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
// TODO (jmorganca): we should send this back
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
|
||||
logutil.Trace("computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
@@ -773,6 +782,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
||||
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
||||
if req.UseHarmony {
|
||||
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
||||
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(req.PrefillString)
|
||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||
}
|
||||
|
||||
if req.Options == nil {
|
||||
opts := api.DefaultOptions()
|
||||
req.Options = &opts
|
||||
@@ -834,7 +851,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
s.seqsSem.Release(1)
|
||||
@@ -855,6 +872,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
var lastToken string
|
||||
tokenRepeat := 0
|
||||
const tokenRepeatLimit = 30
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -863,8 +883,27 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
case content, ok := <-seq.responses:
|
||||
if ok {
|
||||
if strings.TrimSpace(content) == lastToken {
|
||||
tokenRepeat++
|
||||
}
|
||||
if tokenRepeat == tokenRepeatLimit {
|
||||
http.Error(w, "token repeat limit reached", http.StatusInternalServerError)
|
||||
seq.doneReason = llm.DoneReasonTokenRepeatLimit
|
||||
close(seq.quit)
|
||||
return
|
||||
}
|
||||
lastToken = strings.TrimSpace(content)
|
||||
|
||||
var thinking string
|
||||
if harmonyMessageHandler != nil {
|
||||
var toolContent string
|
||||
content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser)
|
||||
harmonyToolParser.Add(toolContent)
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Content: content,
|
||||
Content: content,
|
||||
Thinking: thinking,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
close(seq.quit)
|
||||
@@ -873,7 +912,29 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
flusher.Flush()
|
||||
} else {
|
||||
var toolCalls []api.ToolCall
|
||||
if harmonyMessageHandler != nil {
|
||||
// these tools still need to be transformed to the original function name
|
||||
toolName, toolContent := harmonyToolParser.Drain()
|
||||
if toolName != nil {
|
||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to unmarshal tool call function arguments: %v", err), http.StatusInternalServerError)
|
||||
close(seq.quit)
|
||||
return
|
||||
}
|
||||
toolCalls = append(toolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: *toolName,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
ToolCalls: toolCalls,
|
||||
Done: true,
|
||||
DoneReason: seq.doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
@@ -890,6 +951,67 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 {
|
||||
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
var req llm.EmbeddingRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting embedding request due to client closing the connection")
|
||||
} else {
|
||||
http.Error(w, fmt.Sprintf("failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, fmt.Sprintf("failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.seqs[i] = seq
|
||||
s.cond.Signal()
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
||||
Embedding: <-seq.embedding,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
|
||||
@@ -1206,10 +1328,7 @@ func Execute(args []string) error {
|
||||
mux := http.NewServeMux()
|
||||
// TODO: support embeddings
|
||||
mux.HandleFunc("POST /load", server.load)
|
||||
mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /embedding", server.embeddings)
|
||||
mux.HandleFunc("POST /completion", server.completion)
|
||||
mux.HandleFunc("GET /health", server.health)
|
||||
|
||||
|
||||
131
server/routes.go
131
server/routes.go
@@ -46,18 +46,6 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
func shouldUseHarmony(model *Model) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
// search for harmony tags that are nearly always used
|
||||
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func experimentEnabled(name string) bool {
|
||||
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
|
||||
}
|
||||
@@ -207,13 +195,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
useHarmony := shouldUseHarmony(m) && !req.Raw
|
||||
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
||||
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
||||
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw
|
||||
var functionNameMap *harmony.FunctionNameMap
|
||||
|
||||
if useHarmony {
|
||||
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
||||
harmonyMessageHandler.HarmonyParser.AddImplicitStart()
|
||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||
functionNameMap = harmony.NewFunctionNameMap()
|
||||
}
|
||||
|
||||
// Validate Think value: string values currently only allowed for gptoss models
|
||||
@@ -357,16 +343,19 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
var sb strings.Builder
|
||||
defer close(ch)
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
UseHarmony: useHarmony,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
res := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: cr.Content,
|
||||
Done: cr.Done,
|
||||
Thinking: cr.Thinking,
|
||||
ToolCalls: cr.ToolCalls,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: cr.PromptEvalCount,
|
||||
PromptEvalDuration: cr.PromptEvalDuration,
|
||||
@@ -375,12 +364,22 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
},
|
||||
}
|
||||
|
||||
if res.Done {
|
||||
res.DoneReason = cr.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
if useHarmony {
|
||||
content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser)
|
||||
res.Response = content
|
||||
res.Thinking = thinking
|
||||
harmonyToolParser.Add(toolContent)
|
||||
} else if thinkingState != nil {
|
||||
for i, tool := range res.ToolCalls {
|
||||
res.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
|
||||
}
|
||||
if res.Response != "" || res.Thinking != "" || len(res.ToolCalls) > 0 || res.Done {
|
||||
ch <- res
|
||||
}
|
||||
return
|
||||
}
|
||||
if thinkingState != nil {
|
||||
thinking, content := thinkingState.AddContent(cr.Content)
|
||||
res.Thinking = thinking
|
||||
res.Response = content
|
||||
@@ -391,30 +390,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if cr.Done {
|
||||
if useHarmony {
|
||||
toolName, toolContent := harmonyToolParser.Drain()
|
||||
if toolName != nil {
|
||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
||||
ch <- gin.H{"error": errStr}
|
||||
return
|
||||
}
|
||||
|
||||
res.ToolCalls = append(res.ToolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: *toolName,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
res.DoneReason = cr.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
|
||||
if !req.Raw {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
|
||||
if err != nil {
|
||||
@@ -1616,27 +1591,21 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
msgs = filterThinkTags(msgs, m)
|
||||
|
||||
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
||||
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
||||
|
||||
useHarmony := shouldUseHarmony(m)
|
||||
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template)
|
||||
|
||||
processedTools := req.Tools
|
||||
var functionNameMap *harmony.FunctionNameMap
|
||||
var prefillString string
|
||||
// TODO(parthsareen): this can be abstracted to not be model specific and potentially moved to the runner
|
||||
if useHarmony {
|
||||
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
||||
var lastMessage *api.Message
|
||||
if len(msgs) > 0 {
|
||||
lastMessage = &msgs[len(msgs)-1]
|
||||
}
|
||||
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||
|
||||
prefillString = harmony.Prefill(msgs[len(msgs)-1])
|
||||
functionNameMap = harmony.NewFunctionNameMap()
|
||||
// make a copy of tools to pass to the chat prompt. Function names may be
|
||||
// renamed to be valid Harmony function names.
|
||||
processedTools = make([]api.Tool, len(req.Tools))
|
||||
copy(processedTools, req.Tools)
|
||||
for i, tool := range processedTools {
|
||||
processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
|
||||
processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1689,15 +1658,17 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
defer close(ch)
|
||||
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
UseHarmony: useHarmony,
|
||||
PrefillString: prefillString,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
res := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||
Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls},
|
||||
Done: r.Done,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
@@ -1713,31 +1684,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
if useHarmony {
|
||||
content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
|
||||
res.Message.Content = content
|
||||
res.Message.Thinking = thinking
|
||||
harmonyToolParser.Add(toolContent)
|
||||
|
||||
if r.Done {
|
||||
toolName, toolContent := harmonyToolParser.Drain()
|
||||
if toolName != nil {
|
||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
||||
ch <- gin.H{"error": errStr}
|
||||
return
|
||||
}
|
||||
res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}}
|
||||
}
|
||||
for i, tool := range res.Message.ToolCalls {
|
||||
res.Message.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
|
||||
}
|
||||
|
||||
// only send messages with meaningful content (empty messages confuse clients)
|
||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
|
||||
ch <- res
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -118,7 +117,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
name: "content streams as it arrives",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false},
|
||||
input: llm.CompletionResponse{Content: "Hello", Done: false},
|
||||
wantContent: "Hello",
|
||||
},
|
||||
{
|
||||
@@ -126,7 +125,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
wantContent: ", world",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
input: llm.CompletionResponse{Content: "!", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "!",
|
||||
},
|
||||
},
|
||||
@@ -135,20 +134,15 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
name: "thinking streams separately from content",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false},
|
||||
input: llm.CompletionResponse{Thinking: "Thinking...", Done: false},
|
||||
wantThinking: "Thinking...",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|end|>", Done: false},
|
||||
// No output expected - just closes the analysis message and resets state to normal
|
||||
input: llm.CompletionResponse{Content: "Answer", Done: false},
|
||||
wantContent: "Answer",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false},
|
||||
wantContent: "Answer", // After message end, state is reset to normal
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
// No output expected - just closes the assistant message
|
||||
input: llm.CompletionResponse{Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -156,24 +150,16 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
name: "partial tags buffer until complete",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|chan", Done: false},
|
||||
// No output - partial tag
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false},
|
||||
// No output - still building tags
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "age|>Deep ", Done: false},
|
||||
input: llm.CompletionResponse{Thinking: "Deep ", Done: false},
|
||||
wantThinking: "Deep ",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "thought<|end|>", Done: false},
|
||||
input: llm.CompletionResponse{Thinking: "thought", Done: false},
|
||||
wantThinking: "thought",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "Done", // After message end, state is reset to normal
|
||||
input: llm.CompletionResponse{Content: "Done", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "Done",
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -181,7 +167,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
name: "simple assistant after analysis",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
input: llm.CompletionResponse{Thinking: "Think", Content: "Answer", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "Answer",
|
||||
wantThinking: "Think",
|
||||
},
|
||||
@@ -191,7 +177,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
name: "tool call parsed and returned correctly",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
input: llm.CompletionResponse{Content: "The weather is sunny", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"location": "San Francisco"}}}}, Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "The weather is sunny",
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
@@ -210,15 +196,10 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
name: "tool call with streaming JSON across chunks",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false},
|
||||
// No output yet - incomplete JSON
|
||||
input: llm.CompletionResponse{Done: false},
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false},
|
||||
// Still no output - incomplete JSON
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "2\"}", Done: true},
|
||||
input: llm.CompletionResponse{ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}}, Done: true},
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
@@ -400,9 +381,9 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mockResponses := []llm.CompletionResponse{
|
||||
{Content: "<|message|>First ", Done: false},
|
||||
{Content: "First ", Done: false},
|
||||
{Content: "chunk ", Done: false},
|
||||
{Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
{Content: "here", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
}
|
||||
|
||||
mock := mockRunner{
|
||||
@@ -507,189 +488,3 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
||||
t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatHarmonyParserStreaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
type expectedChunk struct {
|
||||
afterResponse int // Which mock response this chunk should appear after
|
||||
content string // Expected content in this chunk
|
||||
thinking string // Expected thinking in this chunk
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
mockResponses []llm.CompletionResponse
|
||||
expectedChunks []expectedChunk
|
||||
wantContent string
|
||||
wantThinking string
|
||||
}{
|
||||
{
|
||||
name: "simple message without thinking",
|
||||
mockResponses: []llm.CompletionResponse{
|
||||
{Content: "<|start|>assistant<|message|>Hello, ", Done: false},
|
||||
{Content: "how can I help?", Done: false},
|
||||
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
expectedChunks: []expectedChunk{
|
||||
{afterResponse: 1, content: "Hello, "},
|
||||
{afterResponse: 2, content: "how can I help?"},
|
||||
},
|
||||
wantContent: "Hello, how can I help?",
|
||||
},
|
||||
{
|
||||
name: "message with analysis channel for thinking",
|
||||
mockResponses: []llm.CompletionResponse{
|
||||
{Content: "<|channel|>analysis<|message|>", Done: false},
|
||||
{Content: "Let me think ", Done: false},
|
||||
{Content: "about this problem...", Done: false},
|
||||
{Content: "<|end|>", Done: false},
|
||||
{Content: "<|start|>assistant<|message|>", Done: false},
|
||||
{Content: "The answer ", Done: false},
|
||||
{Content: "is 42", Done: false},
|
||||
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
expectedChunks: []expectedChunk{
|
||||
{afterResponse: 2, thinking: "Let me think "},
|
||||
{afterResponse: 3, thinking: "about this problem..."},
|
||||
{afterResponse: 6, content: "The answer "},
|
||||
{afterResponse: 7, content: "is 42"},
|
||||
},
|
||||
wantContent: "The answer is 42",
|
||||
wantThinking: "Let me think about this problem...",
|
||||
},
|
||||
{
|
||||
name: "streaming with partial tags across boundaries",
|
||||
mockResponses: []llm.CompletionResponse{
|
||||
{Content: "<|chan", Done: false},
|
||||
{Content: "nel|>analy", Done: false},
|
||||
{Content: "sis<|mess", Done: false},
|
||||
{Content: "age|>Think", Done: false},
|
||||
{Content: "ing deeply...<|end|>", Done: false},
|
||||
{Content: "<|start|>assi", Done: false},
|
||||
{Content: "stant<|message|>Result ", Done: false},
|
||||
{Content: "computed<|e", Done: false},
|
||||
{Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
expectedChunks: []expectedChunk{
|
||||
{afterResponse: 4, thinking: "Think"},
|
||||
{afterResponse: 5, thinking: "ing deeply..."},
|
||||
{afterResponse: 7, content: "Result "},
|
||||
{afterResponse: 8, content: "computed"},
|
||||
},
|
||||
wantContent: "Result computed",
|
||||
wantThinking: "Thinking deeply...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Channel to synchronize mock responses with chunk verification
|
||||
responsesSent := make(chan int, len(tc.mockResponses))
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
// Send mock responses one at a time, notifying when each is sent
|
||||
for i, resp := range tc.mockResponses {
|
||||
fn(resp)
|
||||
responsesSent <- i + 1
|
||||
}
|
||||
close(responsesSent)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a minimal model
|
||||
_, digest := createHarmonyTestModel(t)
|
||||
|
||||
// Create model with passthrough template
|
||||
stream := false
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "harmony-test",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("failed to create model: %d", w.Code)
|
||||
}
|
||||
|
||||
// Test chat endpoint with streaming
|
||||
streamTrue := true
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "harmony-test",
|
||||
Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
Stream: &streamTrue,
|
||||
Tools: getTestTools(),
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Parse streaming response
|
||||
var chunks []api.ChatResponse
|
||||
var content, thinking strings.Builder
|
||||
|
||||
decoder := json.NewDecoder(w.Body)
|
||||
for decoder.More() {
|
||||
var chunk api.ChatResponse
|
||||
if err := decoder.Decode(&chunk); err != nil {
|
||||
t.Fatalf("failed to decode chunk: %v", err)
|
||||
}
|
||||
chunks = append(chunks, chunk)
|
||||
|
||||
// Accumulate content and thinking from each chunk
|
||||
content.WriteString(chunk.Message.Content)
|
||||
thinking.WriteString(chunk.Message.Thinking)
|
||||
|
||||
// Debug output
|
||||
t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done)
|
||||
}
|
||||
|
||||
// Verify we got streaming chunks
|
||||
if len(chunks) == 0 {
|
||||
t.Fatal("expected streaming chunks, got none")
|
||||
}
|
||||
|
||||
gotContent := content.String()
|
||||
gotThinking := thinking.String()
|
||||
|
||||
if gotContent != tc.wantContent {
|
||||
t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent)
|
||||
}
|
||||
if gotThinking != tc.wantThinking {
|
||||
t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking)
|
||||
}
|
||||
|
||||
// Verify last chunk has done=true
|
||||
lastChunk := chunks[len(chunks)-1]
|
||||
if !lastChunk.Done {
|
||||
t.Error("expected last chunk to have done=true")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user