Compare commits

..

9 Commits

Author SHA1 Message Date
ParthSareen
1fe7e07f63 sampler/runner: enable gpt-oss structured outputs 2025-09-03 15:04:20 -07:00
ParthSareen
40d3436cd1 cleanup passing in harmony flag and add generate support 2025-09-02 15:02:44 -07:00
ParthSareen
5bc783b58e harmony: move tests from routes to parser 2025-09-02 15:02:44 -07:00
ParthSareen
87714c1c39 harmony: add harmony parsing to runner 2025-09-02 15:02:44 -07:00
ParthSareen
f7ca3b7f7e routes: ChatHandler to get parsed harmony from runner 2025-09-02 15:02:44 -07:00
ParthSareen
72189c6d6e harmony: simplify prefill, add marshalling for functions, and update harmony check 2025-09-02 15:02:43 -07:00
ParthSareen
1d09e01431 server: update completion request signature and update token repeat 2025-09-02 13:56:18 -07:00
ParthSareen
eb7660d724 server: add thinking and tool calls to CompletionResponse 2025-09-02 13:56:18 -07:00
ParthSareen
4a5bdd5f12 harmony: move harmony parsing into a package 2025-09-02 13:56:18 -07:00
25 changed files with 268 additions and 740 deletions

View File

@@ -413,7 +413,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
### Cloud

View File

@@ -92,9 +92,6 @@ If none of those resolve the problem, gather additional information and file an
- Set `CUDA_ERROR_LEVEL=50` and try again to get more diagnostic logs
- Check dmesg for any errors `sudo dmesg | grep -i nvrm` and `sudo dmesg | grep -i nvidia`
You may get more details for initialization failures by enabling debug prints in the uvm driver. You should only use this temporarily while troubleshooting
- `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm uvm_debug_prints=1`
## AMD GPU Discovery

View File

@@ -57,28 +57,10 @@ func (kv KV) EmbeddingLength() uint64 {
return uint64(kv.Uint("embedding_length"))
}
func (kv KV) HeadCount() []uint64 {
headCountDefault := uint32(1)
headCount := kv.UintOrArrayValueAsArray("attention.head_count", headCountDefault)
if len(headCount) == 1 {
headCountDefault = headCount[0]
}
nLayers := int(kv.BlockCount())
if len(headCount) > nLayers {
slog.Warn("got more elements of attention.head_count than layers", "len(headCount)", len(headCount), "layers", nLayers)
}
out := make([]uint64, nLayers)
for i := range nLayers {
if i >= len(headCount) {
out[i] = uint64(headCountDefault)
} else {
out[i] = uint64(headCount[i])
}
}
return out
}
func (kv KV) HeadCountMax() uint64 {
// TODO(drifkin): using the max value can cause an overestimation. In the
// future if array values become more popular, we can adapt the more invasive
// <https://github.com/ollama/ollama/pull/10225>
return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1))
}
@@ -86,27 +68,6 @@ func (kv KV) HeadCountMin() uint64 {
return uint64(kv.UintOrMinArrayValue("attention.head_count", 1))
}
func (kv KV) HeadCountKV() []uint64 {
headCountKVDefault := uint32(1)
headCountKV := kv.UintOrArrayValueAsArray("attention.head_count_kv", headCountKVDefault)
if len(headCountKV) == 1 {
headCountKVDefault = headCountKV[0]
}
nLayers := int(kv.BlockCount())
if len(headCountKV) > nLayers {
slog.Warn("got more elements of attention.head_count than layers", "len(headCountKV)", len(headCountKV), "layers", nLayers)
}
out := make([]uint64, nLayers)
for i := range nLayers {
if i >= len(headCountKV) {
out[i] = uint64(headCountKVDefault)
} else {
out[i] = uint64(headCountKV[i])
}
}
return out
}
func (kv KV) HeadCountKVMax() uint64 {
return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1))
}
@@ -139,26 +100,6 @@ func (kv KV) ChatTemplate() string {
return kv.String("tokenizer.chat_template")
}
// ssm architecture parameters
func (kv KV) SSMConvKernel() uint64 {
return uint64(kv.Uint("ssm.conv_kernel"))
}
func (kv KV) SSMInnerSize() uint64 {
return uint64(kv.Uint("ssm.inner_size"))
}
func (kv KV) SSMStateSize() uint64 {
return uint64(kv.Uint("ssm.state_size"))
}
func (kv KV) SSMGroupCount() uint64 {
return uint64(kv.Uint("ssm.group_count"))
}
// general types
func (kv KV) String(key string, defaultValue ...string) string {
val, _ := keyValue(kv, key, append(defaultValue, "")...)
return val
@@ -190,27 +131,22 @@ func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 {
}
func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) {
arrVal := kv.UintOrArrayValueAsArray(key, defaultValue)
return slices.Min(arrVal), slices.Max(arrVal)
}
func (kv KV) UintOrArrayValueAsArray(key string, defaultValue uint32) []uint32 {
if u32, ok := keyValue(kv, key, uint32(0)); ok {
return []uint32{u32}
return u32, u32
} else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok {
return u32s.values
min := slices.Min(u32s.values)
max := slices.Max(u32s.values)
return min, max
} else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok {
dst := make([]uint32, len(i32s.values))
for i, v := range i32s.values {
if v < 0 {
slog.Warn("array values are unexpectedly negative", "key", key, "i", i, "v", v)
}
dst[i] = uint32(v)
min := slices.Min(i32s.values)
max := slices.Max(i32s.values)
if min < 0 || max < 0 {
slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max)
}
return dst
return uint32(min), uint32(max)
}
return []uint32{defaultValue}
return defaultValue, defaultValue
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
@@ -550,9 +486,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCountMax()
headsArr := f.KV().HeadCount()
headsKV := f.KV().HeadCountKVMax()
headsKVArr := f.KV().HeadCountKV()
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size)
embeddingHeads := f.KV().EmbeddingHeadCountMax()
@@ -562,51 +496,12 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
layers := f.Tensors().GroupLayers()
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
// Default for models unless special-cased below. These defaults mirror the
// cache usage in llama.cpp under the assumption that models without special
// cases below will use the llamarunner and caching will be handled by the
// llama.cpp layer.
//
// This also assumes that a layer without heads or headsKV set is recurrent
// which is usually the case. Some models (eg nemotronh) use "blocks" in
// place of layers where some are MLP blocks that don't have any cache.
// Models like this will need a special case below to be accurately
// estimated.
var kvTotal uint64
kv = make([]uint64, f.KV().BlockCount())
kvSizeAttn := uint64(0)
kvSizeRecurrent := uint64(0)
for i := range kv {
headsL := headsArr[i]
headsKVL := headsKVArr[i]
if headsL > 0 && headsKVL > 0 {
// full attention layer
// NOTE: Assumes uniform values for all attn layers
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKVL) * bytesPerElement)
kvSizeAttn += kv[i]
} else {
// recurrent layer
ssmDConv := f.KV().SSMConvKernel()
ssmDState := f.KV().SSMStateSize()
ssmDInner := f.KV().SSMInnerSize()
ssmNGroups := f.KV().SSMGroupCount()
nEmbdR := uint64(0)
if ssmDConv > 0 {
nEmbdR = (ssmDConv - 1) * (ssmDInner + 2*ssmNGroups*ssmDState)
}
nEmbdS := ssmDState * ssmDInner
// recurrent always uses F32 in llama.cpp backend
// https://github.com/ggml-org/llama.cpp/blob/master/src/llama-model.cpp#L18644
bytesPerElementRecurrent := kvCacheBytesPerElement("f32")
kv[i] = (nEmbdR + nEmbdS) * uint64(bytesPerElementRecurrent)
kvSizeRecurrent += kv[i]
}
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
kvTotal += kv[i]
}
slog.Debug("default cache size estimate", "attention MiB", float32(kvSizeAttn)/(1024.*1024.), "attention bytes", kvSizeAttn, "recurrent MiB", float32(kvSizeRecurrent)/(1024.*1024.), "recurrent bytes", kvSizeRecurrent)
switch f.KV().Architecture() {
case "llama", "llama4":
@@ -899,8 +794,6 @@ func kvCacheBytesPerElement(cacheType string) float64 {
return 1 // 1/2 of fp16
case "q4_0":
return 0.5 // 1/4 of fp16
case "f32":
return 4 // f32 (default for recurrent)
default:
return 2 // f16 (default)
}

View File

@@ -1,13 +1,14 @@
package harmony
import (
"encoding/json"
"fmt"
"log/slog"
"maps"
"slices"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/template"
)
@@ -47,13 +48,13 @@ func (s harmonyParserState) String() string {
}
type HarmonyParser struct {
state harmonyParserState
MessageStartTag string
MessageEndTag string
HeaderEndTag string
constraintsAllowed bool
acc strings.Builder
lifetimeAcc strings.Builder
state harmonyParserState
MessageStartTag string
MessageEndTag string
HeaderEndTag string
ConstrainAllowed bool
acc strings.Builder
lifetimeAcc strings.Builder
}
type HarmonyEvent interface {
@@ -90,32 +91,19 @@ func (s *HarmonyParser) AddImplicitStart() {
s.acc.WriteString("<|start|>assistant")
}
func (s *HarmonyParser) ConstraintsAllowed() bool {
return s.constraintsAllowed
}
func Prefill(lastMessage api.Message) string {
if lastMessage.Role != "assistant" {
return ""
// AddImplicitStartOrPrefill adds content or thinking to the accumulator else adds start tag
func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillContentOrThinking *bool) {
if prefillContentOrThinking != nil {
if *prefillContentOrThinking {
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
return
} else {
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
return
}
}
switch {
case strings.TrimSpace(lastMessage.Content) != "":
return "<|start|>assistant<|channel|>final<|message|>"
case strings.TrimSpace(lastMessage.Thinking) != "":
return "<|start|>assistant<|channel|>analysis<|message|>"
default:
return ""
}
}
// AddImplicitStartOrPrefill adds an implicit start tag or prefill string if provided
func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillString string) {
if strings.TrimSpace(prefillString) != "" {
s.acc.WriteString(prefillString)
} else {
s.AddImplicitStart()
}
s.AddImplicitStart()
}
func (s *HarmonyParser) AddContent(content string) []HarmonyEvent {
@@ -294,7 +282,6 @@ type HarmonyMessageHandler struct {
state harmonyMessageState
HarmonyParser *HarmonyParser
FunctionNameMap *FunctionNameMap
ToolParser *HarmonyToolCallAccumulator
}
// NewHarmonyMessageHandler creates a new message handler
@@ -307,16 +294,12 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler {
HeaderEndTag: "<|message|>",
},
FunctionNameMap: NewFunctionNameMap(),
ToolParser: &HarmonyToolCallAccumulator{
state: harmonyToolCallState_Normal,
currentToolName: nil,
},
}
}
// AddContent processes the content and returns the content, thinking, and tool content.
// content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser
func (h *HarmonyMessageHandler) AddContent(content string) (string, string, string) {
func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) {
contentSb := strings.Builder{}
thinkingSb := strings.Builder{}
toolContentSb := strings.Builder{}
@@ -333,20 +316,20 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
// event.Header.Recipient is the tool name, something like
// "browser.search" for a built-in, or "functions.calc" for a
// custom one
h.ToolParser.SetToolName(event.Header.Recipient)
toolParser.SetToolName(event.Header.Recipient)
} else {
h.state = harmonyMessageState_Thinking
}
case "commentary":
if event.Header.Recipient != "" {
h.state = harmonyMessageState_ToolCalling
h.ToolParser.SetToolName(event.Header.Recipient)
toolParser.SetToolName(event.Header.Recipient)
} else {
h.state = harmonyMessageState_Normal
}
case "final":
h.state = harmonyMessageState_Normal
h.HarmonyParser.constraintsAllowed = true
h.HarmonyParser.ConstrainAllowed = true
}
case HarmonyEventContentEmitted:
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
@@ -364,6 +347,13 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
return contentSb.String(), thinkingSb.String(), toolContentSb.String()
}
func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator {
return &HarmonyToolCallAccumulator{
state: harmonyToolCallState_Normal,
currentToolName: nil,
}
}
type harmonyToolCallState int
const (
@@ -405,6 +395,38 @@ type FunctionNameMap struct {
harmonyToUser map[string]string
}
func (m FunctionNameMap) MarshalJSON() ([]byte, error) {
// necessary to avoid exposing map internals
type alias struct {
UserToHarmony map[string]string `json:"userToHarmony"`
HarmonyToUser map[string]string `json:"harmonyToUser"`
}
return json.Marshal(alias{
UserToHarmony: m.userToHarmony,
HarmonyToUser: m.harmonyToUser,
})
}
func (m *FunctionNameMap) UnmarshalJSON(b []byte) error {
type alias struct {
UserToHarmony map[string]string `json:"userToHarmony"`
HarmonyToUser map[string]string `json:"harmonyToUser"`
}
var a alias
if err := json.Unmarshal(b, &a); err != nil {
return err
}
if m.userToHarmony == nil {
m.userToHarmony = make(map[string]string)
}
if m.harmonyToUser == nil {
m.harmonyToUser = make(map[string]string)
}
maps.Copy(m.userToHarmony, a.UserToHarmony)
maps.Copy(m.harmonyToUser, a.HarmonyToUser)
return nil
}
func NewFunctionNameMap() *FunctionNameMap {
return &FunctionNameMap{
userToHarmony: make(map[string]string),

View File

@@ -1,6 +1,7 @@
package harmony
import (
"encoding/json"
"fmt"
"reflect"
"strings"
@@ -541,7 +542,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
t.Run("thinking_then_content_streams", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
tp := handler.CreateToolParser()
type step struct {
in string
wantContent string
@@ -554,7 +555,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
{in: "<|end|>", wantContent: ""},
}
for i, s := range steps {
content, thinking, tool := handler.AddContent(s.in)
content, thinking, tool := handler.AddContent(s.in, tp)
if tool != "" {
tp.Add(tool)
}
@@ -567,7 +568,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
t.Run("content_streams_as_it_arrives", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
tp := handler.CreateToolParser()
inputs := []string{
"<|start|>assistant<|message|>Hello",
", world",
@@ -575,7 +576,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
}
var got []string
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
content, thinking, tool := handler.AddContent(in, tp)
if tool != "" {
tp.Add(tool)
}
@@ -595,7 +596,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
t.Run("thinking_streams_separately_from_content", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
tp := handler.CreateToolParser()
inputs := []string{
"<|channel|>analysis<|message|>Thinking...",
"<|end|>",
@@ -604,7 +605,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
}
var got []string
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
content, thinking, tool := handler.AddContent(in, tp)
if tool != "" {
tp.Add(tool)
}
@@ -624,7 +625,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
t.Run("partial_tags_buffer_until_complete", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
tp := handler.CreateToolParser()
inputs := []string{
"<|chan",
"nel|>analysis<|mess",
@@ -637,7 +638,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
var thinkingPieces []string
var contentPieces []string
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
content, thinking, tool := handler.AddContent(in, tp)
if tool != "" {
tp.Add(tool)
}
@@ -659,7 +660,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
t.Run("simple_assistant_after_analysis", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
tp := handler.CreateToolParser()
inputs := []string{
"<|channel|>analysis<|message|>Think",
"<|end|>",
@@ -668,7 +669,7 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
}
var contentSb, thinkingSb strings.Builder
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
content, thinking, tool := handler.AddContent(in, tp)
if tool != "" {
tp.Add(tool)
}
@@ -686,12 +687,12 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
tp := handler.CreateToolParser()
inputs := []string{
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>",
}
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
content, thinking, tool := handler.AddContent(in, tp)
if content != "" || thinking != "" {
continue
}
@@ -711,14 +712,14 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
t.Run("tool_call_across_chunks", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
tp := handler.CreateToolParser()
inputs := []string{
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+",
"2\"}",
"<|end|>",
}
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
content, thinking, tool := handler.AddContent(in, tp)
if content != "" || thinking != "" {
continue
}
@@ -735,3 +736,25 @@ func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
}
})
}
func TestFunctionNameMapJSONRoundTrip(t *testing.T) {
m := NewFunctionNameMap()
gotConverted := m.ConvertAndAdd("get weather")
if gotConverted == "" {
t.Fatal("conversion returned empty")
}
b, err := json.Marshal(m)
if err != nil {
t.Fatalf("marshal: %v", err)
}
var m2 FunctionNameMap
if err := json.Unmarshal(b, &m2); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if m2.userToHarmony["get weather"] != gotConverted {
t.Fatalf("userToHarmony lost: got %q want %q", m2.userToHarmony["get weather"], gotConverted)
}
if m2.harmonyToUser[gotConverted] != "get weather" {
t.Fatalf("harmonyToUser lost: got %q want %q", m2.harmonyToUser[gotConverted], "get weather")
}
}

View File

@@ -410,99 +410,3 @@ func TestAPIEmbeddings(t *testing.T) {
t.Errorf("zero length embedding response")
}
}
func TestAPIToolCalling(t *testing.T) {
initialTimeout := 60 * time.Second
streamTimeout := 30 * time.Second
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
modelName := "qwen3:0.6b"
if err := PullIfMissing(ctx, client, modelName); err != nil {
t.Fatalf("pull failed %s", err)
}
tools := []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA",
},
},
},
},
},
}
req := api.ChatRequest{
Model: modelName,
Messages: []api.Message{
{
Role: "user",
Content: "Call get_weather with location set to San Francisco.",
},
},
Tools: tools,
Options: map[string]any{
"temperature": 0,
},
}
stallTimer := time.NewTimer(initialTimeout)
var gotToolCall bool
var lastToolCall api.ToolCall
fn := func(response api.ChatResponse) error {
if len(response.Message.ToolCalls) > 0 {
gotToolCall = true
lastToolCall = response.Message.ToolCalls[len(response.Message.ToolCalls)-1]
}
if !stallTimer.Reset(streamTimeout) {
return fmt.Errorf("stall was detected while streaming response, aborting")
}
return nil
}
stream := true
req.Stream = &stream
done := make(chan int)
var genErr error
go func() {
genErr = client.Chat(ctx, &req, fn)
done <- 0
}()
select {
case <-stallTimer.C:
t.Errorf("tool-calling chat never started. Timed out after: %s", initialTimeout.String())
case <-done:
if genErr != nil {
t.Fatalf("chat failed: %v", genErr)
}
if !gotToolCall {
t.Fatalf("expected at least one tool call, got none")
}
if lastToolCall.Function.Name != "get_weather" {
t.Errorf("unexpected tool called: got %q want %q", lastToolCall.Function.Name, "get_weather")
}
if _, ok := lastToolCall.Function.Arguments["location"]; !ok {
t.Errorf("expected tool arguments to include 'location', got: %s", lastToolCall.Function.Arguments.String())
}
case <-ctx.Done():
t.Error("outer test context done while waiting for tool-calling chat")
}
}

View File

@@ -121,7 +121,6 @@ func TestMultiModelStress(t *testing.T) {
// The intent is to go 1 over what can fit so we force the scheduler to thrash
targetLoadCount := 0
slog.Info("Loading models to find how many can fit in VRAM before overflowing")
chooseModels:
for i, model := range chosenModels {
req := &api.GenerateRequest{Model: model}
slog.Info("loading", "model", model)
@@ -143,13 +142,6 @@ chooseModels:
slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount])
break
}
// Effectively limit model count to 2 on CPU only systems to avoid thrashing and timeouts
for _, m := range models.Models {
if m.SizeVRAM == 0 {
slog.Info("model running on CPU", "name", m.Name, "target", targetLoadCount, "chosen", chosenModels[:targetLoadCount])
break chooseModels
}
}
}
}
if targetLoadCount == len(chosenModels) {

View File

@@ -36,7 +36,7 @@ func TestLongInputContext(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err)
}
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
}
func TestContextExhaustion(t *testing.T) {

View File

@@ -38,9 +38,8 @@ func TestAllMiniLMEmbeddings(t *testing.T) {
defer cleanup()
req := api.EmbeddingRequest{
Model: "all-minilm",
Prompt: "why is the sky blue?",
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Model: "all-minilm",
Prompt: "why is the sky blue?",
}
res, err := embeddingTestHelper(ctx, client, t, req)

View File

@@ -502,22 +502,6 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
done <- 0
}()
var response string
verify := func() {
// Verify the response contains the expected data
response = buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
}
}
select {
case <-stallTimer.C:
if buf.Len() == 0 {
@@ -533,14 +517,21 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
if genErr != nil {
t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt)
}
verify()
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
}
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
case <-ctx.Done():
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
// if they are still generating valid responses
slog.Warn("outer test context done while waiting for generate")
verify()
t.Error("outer test context done while waiting for generate")
}
return context
}
@@ -608,22 +599,6 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR
done <- 0
}()
var response string
verify := func() {
// Verify the response contains the expected data
response = buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
}
}
select {
case <-stallTimer.C:
if buf.Len() == 0 {
@@ -639,14 +614,23 @@ func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatR
if genErr != nil {
t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages)
}
verify()
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
}
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
case <-ctx.Done():
// On slow systems, we might timeout before some models finish rambling, so check what we have so far to see
// if it's considered a pass - the stallTimer will detect hangs, but we want to consider slow systems a pass
// if they are still generating valid responses
slog.Warn("outer test context done while waiting for chat")
verify()
t.Error("outer test context done while waiting for generate")
}
return &api.Message{Role: role, Content: buf.String()}
}

View File

@@ -35,7 +35,6 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/parser"
)
type filteredEnv []string
@@ -174,8 +173,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
opts.NumCtx = int(trainCtx)
}
opts.NumBatch = min(opts.NumBatch, opts.NumCtx)
loadRequest := LoadRequest{LoraPath: adapters, KvSize: opts.NumCtx * numParallel, BatchSize: opts.NumBatch, Parallel: numParallel, MultiUserCache: envconfig.MultiUserCache()}
defaultThreads := discover.GetSystemInfo().GetOptimalThreadCount()
@@ -1350,9 +1347,9 @@ type CompletionRequest struct {
Images []ImageData
Options *api.Options
Grammar string // set before sending the request to the subprocess
ParserType parser.TokenParserType
PrefillString string
Grammar string // set before sending the request to the subprocess
UseHarmony bool
PrefillContent *bool
}
// DoneReason represents the reason why a completion response is done
@@ -1505,8 +1502,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
}
switch {
// TODO(parthsareen): token repeat limit is now handled in the runner, this currently support legacy model and can be removed in the future
case strings.TrimSpace(c.Content) == lastToken && c.Content != "":
case lastToken != "" && (strings.TrimSpace(c.Content) == lastToken || strings.TrimSpace(c.Thinking) == lastToken):
tokenRepeat++
default:
lastToken = strings.TrimSpace(c.Content)

View File

@@ -201,11 +201,12 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
}
}
logutil.Trace("encoded", "string", s, "ids", ids)
if addSpecial && len(ids) > 0 {
ids = bpe.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}

View File

@@ -5,7 +5,6 @@ import (
"fmt"
_ "image/jpeg"
_ "image/png"
"math"
"os"
"reflect"
"strconv"
@@ -104,10 +103,6 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
}
arch := b.Config().Architecture()
if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 {
arch = arch + "_embed"
}
f, ok := models[arch]
if !ok {
return nil, fmt.Errorf("unsupported model architecture %q", arch)

View File

@@ -1,73 +0,0 @@
package gemma3
import (
"errors"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type embedModel struct {
model.Base
model.SentencePieceModel
*TextModel
PoolingType uint32
Dense [2]*nn.Linear `gguf:"dense"`
}
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
batch.Outputs = batch.Positions // return all positions
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
switch m.PoolingType {
case 0: // None
case 1: // Mean
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
default:
return nil, errors.New("unsupported pooling type")
}
for _, dense := range m.Dense {
hiddenStates = dense.Forward(ctx, hiddenStates)
}
return hiddenStates, nil
}
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
SentencePieceModel: model.NewSentencePieceModel(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
TextModel: newTextModel(c),
PoolingType: c.Uint("pooling_type", 0),
}
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return m, nil
}

View File

@@ -141,11 +141,12 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
return m.Output.Forward(ctx, hiddenStates), nil
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
}
func init() {
model.Register("gemma3", New)
model.Register("gemma3_embed", newEmbedModel)
}

View File

@@ -159,11 +159,8 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
// set image embeddings
@@ -201,5 +198,5 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
return m.Output.Forward(ctx, hiddenState)
}

View File

@@ -181,11 +181,12 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
}
}
logutil.Trace("encoded", "string", s, "ids", ids)
if addSpecial && len(ids) > 0 {
ids = spm.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}

View File

@@ -49,7 +49,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
}
slog.Debug("adding bos token to prompt", "id", v.BOS[0])
slog.Debug("adding bos token to prompt", "id", v.BOS)
ids = append([]int32{v.BOS[0]}, ids...)
}
@@ -58,7 +58,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
}
slog.Debug("adding eos token to prompt", "id", v.EOS[0])
slog.Debug("adding eos token to prompt", "id", v.EOS)
ids = append(ids, v.EOS[0])
}

View File

@@ -246,7 +246,7 @@ func filesForModel(path string) ([]string, error) {
for _, match := range matches {
if ct, err := detectContentType(match); err != nil {
return nil, err
} else if len(contentType) > 0 && ct != contentType {
} else if ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, match)
}
}
@@ -255,8 +255,7 @@ func filesForModel(path string) ([]string, error) {
}
var files []string
// some safetensors files do not properly match "application/octet-stream", so skip checking their contentType
if st, _ := glob(filepath.Join(path, "*.safetensors"), ""); len(st) > 0 {
if st, _ := glob(filepath.Join(path, "*.safetensors"), "application/octet-stream"); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...)

View File

@@ -1,135 +0,0 @@
package parser
import (
"encoding/json"
"errors"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/harmony"
)
type TokenParserType int
const (
TokenParserTypeDefault TokenParserType = iota
TokenParserTypeHarmony
)
type TokenParser struct {
messageHandler MessageHandler
parserEngine ParserInternals
toolParser ToolParser
lastToken string
tokenRepeat int
repeatLimit int
}
const defaultTokenRepeatLimit = 30
type MessageHandler interface {
AddContent(token string) (content, thinking string, toolContent string)
}
type ParserInternals interface {
AddImplicitStartOrPrefill(prefillString string)
ConstraintsAllowed() bool
}
type ToolParser interface {
Add(token string)
Drain() (toolName *string, toolContent string)
}
// Default implementation for the TokenParser interface as a no-op passthrough
type defaultMessageHandler struct{}
func (defaultMessageHandler) AddContent(token string) (string, string, string) {
return token, "", ""
}
type defaultEngine struct{}
func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {}
func (defaultEngine) ConstraintsAllowed() bool {
return true
}
type defaultToolParser struct{}
func (defaultToolParser) Add(token string) {}
func (defaultToolParser) Drain() (*string, string) { return nil, "" }
func NewTokenParser(parserType TokenParserType, prefillString string) TokenParser {
switch parserType {
case TokenParserTypeHarmony:
harmonyMessageHandler := harmony.NewHarmonyMessageHandler()
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(prefillString)
return TokenParser{
messageHandler: harmonyMessageHandler,
parserEngine: harmonyMessageHandler.HarmonyParser,
toolParser: harmonyMessageHandler.ToolParser,
repeatLimit: defaultTokenRepeatLimit,
}
default:
return TokenParser{
messageHandler: defaultMessageHandler{},
parserEngine: defaultEngine{},
toolParser: defaultToolParser{},
repeatLimit: 30,
}
}
}
func (p *TokenParser) AddContent(token string) (string, string, error) {
if p.repeatLimitReached(token) {
return "", "", errors.New("token repeat limit reached")
}
content, thinking, toolContent := p.messageHandler.AddContent(token)
p.toolParser.Add(toolContent)
return content, thinking, nil
}
// repeatLimitReached updates repeat counters and returns true if the repeat limit is reached.
func (p *TokenParser) repeatLimitReached(token string) bool {
if p == nil {
return false
}
trimmed := strings.TrimSpace(token)
if trimmed == p.lastToken {
p.tokenRepeat++
} else {
p.tokenRepeat = 0
}
p.lastToken = trimmed
return p.tokenRepeat >= p.repeatLimit
}
func (p *TokenParser) ConstraintsAllowed() bool {
return p.parserEngine.ConstraintsAllowed()
}
// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level
func (p *TokenParser) Drain() []api.ToolCall {
toolName, toolContent := p.toolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
return nil
}
return []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: *toolName,
Arguments: args,
},
},
}
}
return nil
}

View File

@@ -34,8 +34,8 @@ type InputCache struct {
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
numCtx := kvSize / int32(numSlots)
if int(numCtx) < batchSize {
return nil, fmt.Errorf("kv size must be at least as large as batch size * parallel (kv: %v batch: %v parallel: %v)", kvSize, batchSize, numSlots)
if numCtx < 1 {
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
}
slots := make([]InputCacheSlot, numSlots)
@@ -70,9 +70,11 @@ func kvCacheTypeFromStr(s string) ml.DType {
}
func (c *InputCache) Close() {
if c != nil && c.cache != nil {
c.cache.Close()
if c == nil {
return
}
c.cache.Close()
}
// Locking: Operations on InputCacheSlot (including finding one
@@ -93,7 +95,7 @@ type InputCacheSlot struct {
lastUsed time.Time
}
func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*InputCacheSlot, []*input.Input, error) {
func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*input.Input, error) {
var slot *InputCacheSlot
var numPast int32
var err error
@@ -111,10 +113,6 @@ func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*In
return nil, nil, err
}
if !cachePrompt {
numPast = 0
}
slot.InUse = true
slot.lastUsed = time.Now()

View File

@@ -393,7 +393,7 @@ func TestLoadCacheSlot(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt, true)
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
// Check error state
if (err != nil) != tt.wantErr {

View File

@@ -11,7 +11,6 @@ import (
"image"
"log"
"log/slog"
"math"
"net"
"net/http"
"os"
@@ -30,12 +29,12 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/harmony"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
@@ -62,11 +61,6 @@ type Sequence struct {
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
// startGate
startGate *sync.Mutex
grammarReady bool
// input cache being used by this sequence
cache *InputCacheSlot
@@ -169,7 +163,6 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// TODO(jessegross): Ingest cached history for grammar
startGate := &sync.Mutex{}
return &Sequence{
ctxs: ctxs,
mmStore: mmStore,
@@ -185,8 +178,6 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
startGate: startGate,
grammarReady: false,
}, nil
}
@@ -415,8 +406,6 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
func (s *Server) run(ctx context.Context) {
s.ready.Wait()
supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32
var activeBatch batchState
for {
select {
@@ -430,12 +419,7 @@ func (s *Server) run(ctx context.Context) {
if err != nil {
panic(err)
}
if supportsAsync {
go s.computeBatch(activeBatch)
} else {
s.computeBatch(activeBatch)
}
go s.computeBatch(activeBatch)
}
}
}
@@ -446,12 +430,12 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
// before setting up the next batch so the seqs inputs are ready to receive their
// token values and we get the correct input pointers for the batchInputs
if pendingBatch.ctx != nil {
logutil.Trace("forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
<-pendingBatch.computeStartedCh
logutil.Trace("forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
nextBatch.inputsReadyCh = pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
} else {
logutil.Trace("forwardBatch no pending batch detected", "batchID", s.batchID)
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no pending batch detected", "batchID", s.batchID)
// No pendingBatch, so the inputs will be ready in the seqs immediately
nextBatch.inputsReadyCh = make(chan struct{}, 1)
nextBatch.inputsReadyCh <- struct{}{}
@@ -563,7 +547,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
if i+1 == len(seq.inputs) {
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
}
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
seq.pendingInputs = append(seq.pendingInputs, inp)
}
@@ -577,7 +561,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
}
if len(batchInputs) == 0 {
logutil.Trace("forwardBatch no batchInputs, going idle", "batchID", s.batchID)
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no batchInputs, going idle", "batchID", s.batchID)
nextBatch.ctx.Close()
nextBatch.ctx = nil
return
@@ -606,14 +590,14 @@ func (s *Server) computeBatch(activeBatch batchState) {
defer activeBatch.ctx.Close()
// Wait until inputs are ready
logutil.Trace("computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
<-activeBatch.inputsReadyCh
logutil.Trace("computeBatch: inputs are ready", "batchID", activeBatch.id)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: inputs are ready", "batchID", activeBatch.id)
// Once we complete, signal the next batch of inputs are ready
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
defer func() {
logutil.Trace("computeBatch: outputs are ready", "batchID", activeBatch.id)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: outputs are ready", "batchID", activeBatch.id)
activeBatch.outputsReadyCh <- struct{}{}
}()
@@ -643,7 +627,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
// Detect if the sequence we're processing has already been completed and replaced
// with a new sequence
if seq != activeBatch.seqs[i] {
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
continue
}
@@ -683,19 +667,18 @@ func (s *Server) computeBatch(activeBatch batchState) {
activeBatch.batch.Inputs.SetValueFromIntSlice(batchInputs)
activeBatch.ctx.ComputeWithNotify(
func() {
logutil.Trace("computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
activeBatch.computeStartedCh <- struct{}{}
},
activeBatch.modelOutput)
logits := activeBatch.modelOutput.Floats()
outputs := activeBatch.modelOutput.Floats()
logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: logits ready", "batchID", activeBatch.id)
s.mu.Lock()
defer s.mu.Unlock()
logutil.Trace("computeBatch: decoding", "batchID", activeBatch.id)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: decoding", "batchID", activeBatch.id)
for i, seq := range s.seqs {
if seq == nil || nextBatchTokens[i] == nil {
continue
@@ -707,26 +690,20 @@ func (s *Server) computeBatch(activeBatch batchState) {
// if done processing the prompt, generate an embedding and return
if seq.embeddingOnly {
seq.embedding <- outputs
// TODO(jessegross): Embedding support
slog.Warn("generation of embedding outputs not yet supported", "id", activeBatch.id, "seqIdx", i)
s.removeSequence(i, llm.DoneReasonStop)
continue
}
// sample a token
vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
if !seq.grammarReady {
seq.startGate.Lock()
}
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
vocabSize := len(logits) / len(activeBatch.batch.Outputs)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(logits), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil {
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
return
}
if !seq.grammarReady {
seq.startGate.Unlock()
}
nextBatchTokens[i].Token = token
@@ -735,7 +712,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
logutil.Trace("computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
s.removeSequence(i, llm.DoneReasonStop)
continue
}
@@ -797,6 +774,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
var harmonyMessageHandler *harmony.HarmonyMessageHandler
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
if req.UseHarmony {
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(req.PrefillContent)
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
}
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
@@ -844,12 +829,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString)
// this accounts for the default case and also the case where there is a prefill which moves the state of the parser to allow for constraints
if tokenParser.ConstraintsAllowed() {
seq.grammarReady = true
}
// Ensure there is a place to put the sequence, released when removed from s.seqs
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
if errors.Is(err, context.Canceled) {
@@ -864,7 +843,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
if err != nil {
s.mu.Unlock()
s.seqsSem.Release(1)
@@ -886,6 +865,12 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
// TODO(parthsareen): generalize grammar enablement on the fly for all thinking models
if harmonyMessageHandler == nil {
seq.sampler.SetGrammar(grammar)
}
grammarSet := false
for {
select {
case <-r.Context().Done():
@@ -893,23 +878,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
if !seq.grammarReady {
seq.startGate.Lock()
}
var thinking string
var err error
content, thinking, err = tokenParser.AddContent(content)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
close(seq.quit)
return
}
// only apply the grammar once
if tokenParser.ConstraintsAllowed() && !seq.grammarReady {
seq.sampler.SetGrammar(grammar, &s.mu)
seq.grammarReady = true
seq.startGate.Unlock()
if harmonyMessageHandler != nil {
var toolContent string
content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser)
harmonyToolParser.Add(toolContent)
if harmonyMessageHandler.HarmonyParser.ConstrainAllowed && !grammarSet {
seq.sampler.SetGrammar(grammar)
grammarSet = true
}
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
@@ -923,7 +900,27 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush()
} else {
toolCalls := tokenParser.Drain()
var toolCalls []api.ToolCall
if harmonyMessageHandler != nil {
// these tools still need to be transformed to the original function name
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
http.Error(w, fmt.Sprintf("failed to unmarshal tool call function arguments: %v", err), http.StatusInternalServerError)
close(seq.quit)
return
}
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: *toolName,
Arguments: args,
},
})
}
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
ToolCalls: toolCalls,
Done: true,
@@ -938,74 +935,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
if !seq.grammarReady {
seq.startGate.Unlock()
}
}
}
}
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 {
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
return
}
var req llm.EmbeddingRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
if err != nil {
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting embedding request due to client closing the connection")
} else {
http.Error(w, fmt.Sprintf("failed to acquire semaphore: %v", err), http.StatusInternalServerError)
}
return
}
s.mu.Lock()
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
if err != nil {
s.mu.Unlock()
s.seqsSem.Release(1)
http.Error(w, fmt.Sprintf("failed to load cache: %v", err), http.StatusInternalServerError)
return
}
s.seqs[i] = seq
s.cond.Signal()
found = true
break
}
}
s.mu.Unlock()
if !found {
s.seqsSem.Release(1)
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
Embedding: <-seq.embedding,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
}
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
@@ -1322,7 +1255,10 @@ func Execute(args []string) error {
mux := http.NewServeMux()
// TODO: support embeddings
mux.HandleFunc("POST /load", server.load)
mux.HandleFunc("POST /embedding", server.embeddings)
mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
})
mux.HandleFunc("POST /completion", server.completion)
mux.HandleFunc("GET /health", server.health)

View File

@@ -5,7 +5,6 @@ import (
"math"
"math/rand/v2"
"slices"
"sync"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model"
@@ -26,9 +25,7 @@ type Sampler struct {
grammar *GrammarSampler
}
func (s *Sampler) SetGrammar(grammar *GrammarSampler, mutex *sync.Mutex) {
mutex.Lock()
defer mutex.Unlock()
func (s *Sampler) SetGrammar(grammar *GrammarSampler) {
s.grammar = grammar
}

View File

@@ -36,7 +36,6 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template"
@@ -197,12 +196,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw
var parserType parser.TokenParserType
if useHarmony {
parserType = parser.TokenParserTypeHarmony
} else {
parserType = parser.TokenParserTypeDefault
}
var functionNameMap *harmony.FunctionNameMap
if useHarmony {
@@ -354,7 +347,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
Images: images,
Format: req.Format,
Options: opts,
ParserType: parserType,
UseHarmony: useHarmony,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
Model: req.Model,
@@ -1599,20 +1592,29 @@ func (s *Server) ChatHandler(c *gin.Context) {
msgs = filterThinkTags(msgs, m)
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template)
var parserType parser.TokenParserType
if useHarmony {
parserType = parser.TokenParserTypeHarmony
} else {
parserType = parser.TokenParserTypeDefault
}
processedTools := req.Tools
var functionNameMap *harmony.FunctionNameMap
var prefillString string
// TODO(parthsareen): this can be abstracted to not be model specific and potentially moved to the runner
var prefillContentOrThinking *bool
if useHarmony {
prefillString = harmony.Prefill(msgs[len(msgs)-1])
functionNameMap = harmony.NewFunctionNameMap()
var lastMessage *api.Message
if len(msgs) > 0 {
lastMessage = &msgs[len(msgs)-1]
}
// prefill content or thinking flag if the last message is an assistant message
if lastMessage != nil && lastMessage.Role == "assistant" {
if lastMessage.Content != "" {
trueVal := true
// true sets content to be prefilled
prefillContentOrThinking = &trueVal
} else if lastMessage.Thinking != "" {
// false sets thinking to be prefilled
falseVal := false
prefillContentOrThinking = &falseVal
}
}
// make a copy of tools to pass to the chat prompt. Function names may be
// renamed to be valid Harmony function names.
processedTools = make([]api.Tool, len(req.Tools))
@@ -1671,12 +1673,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
ParserType: parserType,
PrefillString: prefillString,
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
UseHarmony: useHarmony,
PrefillContent: prefillContentOrThinking,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
@@ -1697,10 +1699,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
if useHarmony {
// only send messages with meaningful content (empty messages confuse clients)
for i, tool := range res.Message.ToolCalls {
res.Message.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
}
// only send messages with meaningful content (empty messages confuse clients)
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
ch <- res
}