Merge remote-tracking branch 'upstream/main' into vulkanV3
This commit is contained in:
commit
0d4f3341c3
22
api/types.go
22
api/types.go
|
|
@ -313,10 +313,11 @@ func (t *ToolFunction) String() string {
|
|||
// ChatResponse is the response returned by [Client.Chat]. Its fields are
|
||||
// similar to [GenerateResponse].
|
||||
type ChatResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Message Message `json:"message"`
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Message Message `json:"message"`
|
||||
DoneReason string `json:"done_reason,omitempty"`
|
||||
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
|
||||
|
||||
Done bool `json:"done"`
|
||||
|
||||
|
|
@ -329,13 +330,6 @@ type DebugInfo struct {
|
|||
ImageCount int `json:"image_count,omitempty"`
|
||||
}
|
||||
|
||||
// DebugTemplateResponse is returned when _debug_render_only is set to true
|
||||
type DebugTemplateResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
DebugInfo DebugInfo `json:"_debug_info"`
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||
|
|
@ -443,6 +437,8 @@ type CreateRequest struct {
|
|||
System string `json:"system,omitempty"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
Renderer string `json:"renderer,omitempty"`
|
||||
Parser string `json:"parser,omitempty"`
|
||||
|
||||
// Deprecated: set the model name with Model instead
|
||||
Name string `json:"name"`
|
||||
|
|
@ -480,6 +476,8 @@ type ShowResponse struct {
|
|||
Parameters string `json:"parameters,omitempty"`
|
||||
Template string `json:"template,omitempty"`
|
||||
System string `json:"system,omitempty"`
|
||||
Renderer string `json:"renderer,omitempty"`
|
||||
Parser string `json:"parser,omitempty"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
Messages []Message `json:"messages,omitempty"`
|
||||
ModelInfo map[string]any `json:"model_info,omitempty"`
|
||||
|
|
@ -592,6 +590,8 @@ type GenerateResponse struct {
|
|||
Metrics
|
||||
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
|
||||
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
|
||||
}
|
||||
|
||||
// ModelDetails provides details about a model.
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ type bertModel struct {
|
|||
LayerNormEPS float32 `json:"layer_norm_eps"`
|
||||
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
|
||||
NormEpsilon float32 `json:"norm_epsilon"`
|
||||
normalizeEmbeddings bool
|
||||
|
||||
PoolingType uint32
|
||||
}
|
||||
|
|
@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
|
|||
|
||||
var pooling string
|
||||
for _, m := range modules {
|
||||
if m.Type == "sentence_transformers.models.Pooling" {
|
||||
switch m.Type {
|
||||
case "sentence_transformers.models.Pooling":
|
||||
pooling = m.Path
|
||||
break
|
||||
case "sentence_transformers.models.Normalize":
|
||||
p.normalizeEmbeddings = true
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
|
|||
kv["general.architecture"] = "bert"
|
||||
kv["bert.attention.causal"] = false
|
||||
kv["bert.pooling_type"] = p.PoolingType
|
||||
kv["bert.normalize_embeddings"] = p.normalizeEmbeddings
|
||||
|
||||
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)
|
||||
|
||||
|
|
|
|||
|
|
@ -45,10 +45,18 @@ func cudaVariant(gpuInfo CudaGPUInfo) string {
|
|||
}
|
||||
}
|
||||
|
||||
// Check GPU compute capability FIRST
|
||||
isOldGPU := gpuInfo.computeMajor < 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor < 5)
|
||||
if isOldGPU {
|
||||
// GPU is Pascal or older (CC <= 7.4) - use CUDA v12 (supports CC 6.1)
|
||||
return "v12"
|
||||
}
|
||||
|
||||
// GPU is Turing or newer (CC >= 7.5) - can use newer CUDA
|
||||
if gpuInfo.DriverMajor < 13 {
|
||||
// The detected driver is older than 580 (Aug 2025)
|
||||
// Warn if their CC is compatible with v13 and they should upgrade their driver to get better performance
|
||||
if gpuInfo.computeMajor > 7 || (gpuInfo.computeMajor == 7 && gpuInfo.computeMinor >= 5) {
|
||||
if !isOldGPU {
|
||||
slog.Warn("old CUDA driver detected - please upgrade to a newer driver for best performance", "version", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor))
|
||||
}
|
||||
return "v12"
|
||||
|
|
|
|||
|
|
@ -11,6 +11,10 @@ Then build and run Ollama from the root directory of the repository:
|
|||
go run . serve
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Ollama includes native code compiled with CGO. From time to time these data structures can change and CGO can get out of sync resulting in unexpected crashes. You can force a full build of the native code by running `go clean -cache` first.
|
||||
|
||||
|
||||
## macOS (Apple Silicon)
|
||||
|
||||
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.
|
||||
|
|
|
|||
|
|
@ -3,29 +3,15 @@ package harmony
|
|||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
type harmonyParserState int
|
||||
|
||||
func ShouldUseHarmony(modelFamily string, template *template.Template) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
// search for harmony tags that are nearly always used
|
||||
if template.Contains("<|start|>") && template.Contains("<|end|>") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
const (
|
||||
harmonyParserState_LookingForMessageStart harmonyParserState = iota
|
||||
harmonyParserState_ParsingHeader
|
||||
|
|
@ -89,28 +75,18 @@ func (s *HarmonyParser) AddImplicitStart() {
|
|||
s.acc.WriteString("<|start|>assistant")
|
||||
}
|
||||
|
||||
func Prefill(lastMessage api.Message) string {
|
||||
if lastMessage.Role != "assistant" {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.TrimSpace(lastMessage.Content) != "":
|
||||
return "<|start|>assistant<|channel|>final<|message|>"
|
||||
case strings.TrimSpace(lastMessage.Thinking) != "":
|
||||
return "<|start|>assistant<|channel|>analysis<|message|>"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// AddImplicitStartOrPrefill adds an implicit start tag or prefill string if provided
|
||||
func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillString string) {
|
||||
if strings.TrimSpace(prefillString) != "" {
|
||||
s.acc.WriteString(prefillString)
|
||||
} else {
|
||||
s.AddImplicitStart()
|
||||
func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
|
||||
if lastMessage != nil && lastMessage.Role == "assistant" {
|
||||
// handle prefilling conditions
|
||||
if lastMessage.Content != "" {
|
||||
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
|
||||
return
|
||||
} else if lastMessage.Thinking != "" {
|
||||
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
|
||||
return
|
||||
}
|
||||
}
|
||||
s.AddImplicitStart()
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) AddContent(content string) []HarmonyEvent {
|
||||
|
|
@ -289,7 +265,6 @@ type HarmonyMessageHandler struct {
|
|||
state harmonyMessageState
|
||||
HarmonyParser *HarmonyParser
|
||||
FunctionNameMap *FunctionNameMap
|
||||
ToolParser *HarmonyToolCallAccumulator
|
||||
}
|
||||
|
||||
// NewHarmonyMessageHandler creates a new message handler
|
||||
|
|
@ -302,16 +277,12 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler {
|
|||
HeaderEndTag: "<|message|>",
|
||||
},
|
||||
FunctionNameMap: NewFunctionNameMap(),
|
||||
ToolParser: &HarmonyToolCallAccumulator{
|
||||
state: harmonyToolCallState_Normal,
|
||||
currentToolName: nil,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// AddContent processes the content and returns the content, thinking, and tool content.
|
||||
// content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser
|
||||
func (h *HarmonyMessageHandler) AddContent(content string) (string, string, string) {
|
||||
func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) {
|
||||
contentSb := strings.Builder{}
|
||||
thinkingSb := strings.Builder{}
|
||||
toolContentSb := strings.Builder{}
|
||||
|
|
@ -328,14 +299,14 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
|
|||
// event.Header.Recipient is the tool name, something like
|
||||
// "browser.search" for a built-in, or "functions.calc" for a
|
||||
// custom one
|
||||
h.ToolParser.SetToolName(event.Header.Recipient)
|
||||
toolParser.SetToolName(event.Header.Recipient)
|
||||
} else {
|
||||
h.state = harmonyMessageState_Thinking
|
||||
}
|
||||
case "commentary":
|
||||
if event.Header.Recipient != "" {
|
||||
h.state = harmonyMessageState_ToolCalling
|
||||
h.ToolParser.SetToolName(event.Header.Recipient)
|
||||
toolParser.SetToolName(event.Header.Recipient)
|
||||
} else {
|
||||
h.state = harmonyMessageState_Normal
|
||||
}
|
||||
|
|
@ -358,6 +329,13 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
|
|||
return contentSb.String(), thinkingSb.String(), toolContentSb.String()
|
||||
}
|
||||
|
||||
func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator {
|
||||
return &HarmonyToolCallAccumulator{
|
||||
state: harmonyToolCallState_Normal,
|
||||
currentToolName: nil,
|
||||
}
|
||||
}
|
||||
|
||||
type harmonyToolCallState int
|
||||
|
||||
const (
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ package harmony
|
|||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
|
|
@ -536,202 +535,3 @@ func TestFunctionConvertAndAdd(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
|
||||
t.Run("thinking_then_content_streams", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.ToolParser
|
||||
type step struct {
|
||||
in string
|
||||
wantContent string
|
||||
wantThinking string
|
||||
}
|
||||
steps := []step{
|
||||
{in: "<|channel|>analysis<|message|>Thinking...", wantThinking: "Thinking..."},
|
||||
{in: "<|end|>", wantThinking: ""},
|
||||
{in: "<|start|>assistant<|message|>Answer", wantContent: "Answer"},
|
||||
{in: "<|end|>", wantContent: ""},
|
||||
}
|
||||
for i, s := range steps {
|
||||
content, thinking, tool := handler.AddContent(s.in)
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
if content != s.wantContent || thinking != s.wantThinking {
|
||||
t.Fatalf("step %d: got (content=%q thinking=%q), want (content=%q thinking=%q)", i, content, thinking, s.wantContent, s.wantThinking)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("content_streams_as_it_arrives", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.ToolParser
|
||||
inputs := []string{
|
||||
"<|start|>assistant<|message|>Hello",
|
||||
", world",
|
||||
"!<|end|>",
|
||||
}
|
||||
var got []string
|
||||
for _, in := range inputs {
|
||||
content, thinking, tool := handler.AddContent(in)
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
if thinking != "" {
|
||||
t.Fatalf("unexpected thinking %q", thinking)
|
||||
}
|
||||
if content != "" {
|
||||
got = append(got, content)
|
||||
}
|
||||
}
|
||||
want := []string{"Hello", ", world", "!"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("thinking_streams_separately_from_content", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.ToolParser
|
||||
inputs := []string{
|
||||
"<|channel|>analysis<|message|>Thinking...",
|
||||
"<|end|>",
|
||||
"<|start|>assistant<|message|>Answer",
|
||||
"<|end|>",
|
||||
}
|
||||
var got []string
|
||||
for _, in := range inputs {
|
||||
content, thinking, tool := handler.AddContent(in)
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
if thinking != "" {
|
||||
got = append(got, thinking)
|
||||
}
|
||||
if content != "" {
|
||||
got = append(got, content)
|
||||
}
|
||||
}
|
||||
want := []string{"Thinking...", "Answer"}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("partial_tags_buffer_until_complete", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.ToolParser
|
||||
inputs := []string{
|
||||
"<|chan",
|
||||
"nel|>analysis<|mess",
|
||||
"age|>Deep ",
|
||||
"thought",
|
||||
"<|end|>",
|
||||
"<|start|>assistant<|message|>Done",
|
||||
"<|end|>",
|
||||
}
|
||||
var thinkingPieces []string
|
||||
var contentPieces []string
|
||||
for _, in := range inputs {
|
||||
content, thinking, tool := handler.AddContent(in)
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
if thinking != "" {
|
||||
thinkingPieces = append(thinkingPieces, thinking)
|
||||
}
|
||||
if content != "" {
|
||||
contentPieces = append(contentPieces, content)
|
||||
}
|
||||
}
|
||||
if want := []string{"Deep ", "thought"}; !reflect.DeepEqual(thinkingPieces, want) {
|
||||
t.Fatalf("thinking pieces mismatch: got %v want %v", thinkingPieces, want)
|
||||
}
|
||||
if want := []string{"Done"}; !reflect.DeepEqual(contentPieces, want) {
|
||||
t.Fatalf("content pieces mismatch: got %v want %v", contentPieces, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("simple_assistant_after_analysis", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.ToolParser
|
||||
inputs := []string{
|
||||
"<|channel|>analysis<|message|>Think",
|
||||
"<|end|>",
|
||||
"<|start|>assistant<|message|>Answer",
|
||||
"<|end|>",
|
||||
}
|
||||
var contentSb, thinkingSb strings.Builder
|
||||
for _, in := range inputs {
|
||||
content, thinking, tool := handler.AddContent(in)
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
contentSb.WriteString(content)
|
||||
thinkingSb.WriteString(thinking)
|
||||
}
|
||||
if contentSb.String() != "Answer" {
|
||||
t.Fatalf("content mismatch: got %q want %q", contentSb.String(), "Answer")
|
||||
}
|
||||
if thinkingSb.String() != "Think" {
|
||||
t.Fatalf("thinking mismatch: got %q want %q", thinkingSb.String(), "Think")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.ToolParser
|
||||
inputs := []string{
|
||||
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>",
|
||||
}
|
||||
for _, in := range inputs {
|
||||
content, thinking, tool := handler.AddContent(in)
|
||||
if content != "" || thinking != "" {
|
||||
continue
|
||||
}
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
}
|
||||
name, args := tp.Drain()
|
||||
if name == nil || *name != "functions.calculate" {
|
||||
t.Fatalf("unexpected tool name: %v", name)
|
||||
}
|
||||
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
|
||||
t.Fatalf("unexpected tool args: got %s want %s", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("tool_call_across_chunks", func(t *testing.T) {
|
||||
handler := NewHarmonyMessageHandler()
|
||||
handler.HarmonyParser.AddImplicitStart()
|
||||
tp := handler.ToolParser
|
||||
inputs := []string{
|
||||
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+",
|
||||
"2\"}",
|
||||
"<|end|>",
|
||||
}
|
||||
for _, in := range inputs {
|
||||
content, thinking, tool := handler.AddContent(in)
|
||||
if content != "" || thinking != "" {
|
||||
continue
|
||||
}
|
||||
if tool != "" {
|
||||
tp.Add(tool)
|
||||
}
|
||||
}
|
||||
name, args := tp.Drain()
|
||||
if name == nil || *name != "functions.calculate" {
|
||||
t.Fatalf("unexpected tool name: %v", name)
|
||||
}
|
||||
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
|
||||
t.Fatalf("unexpected tool args: got %s want %s", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) {
|
|||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: "Write me a story with a ton of emojis?",
|
||||
Prompt: "Write me a story in english with a lot of emojis",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
|
|
|
|||
|
|
@ -561,7 +561,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of the US thanksgiving holiday? Be brief but factual in your reply",
|
||||
Prompt: "how do rainbows form? Be brief but factual in your reply",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
}, {
|
||||
|
|
@ -579,9 +579,9 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||
[][]string{
|
||||
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
|
||||
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
|
||||
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states", "cultural", "hardship", "autumn", "festival"},
|
||||
{"water", "droplet", "refracted", "reflect", "color", "spectrum"},
|
||||
{"fourth", "july", "declaration", "independence"},
|
||||
{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||
{"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor"},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -515,33 +515,34 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32,
|
|||
}
|
||||
nChunks := C.mtmd_input_chunks_size(ic)
|
||||
numEmbed := llamaContext.Model().NEmbd()
|
||||
lastChunkSize := 0
|
||||
embed := make([][]float32, 0)
|
||||
for i := range int(nChunks) {
|
||||
chunk := C.mtmd_input_chunks_get(ic, C.size_t(i))
|
||||
numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk))
|
||||
lastChunkSize = numTokens
|
||||
slog.Debug("chunk tokens", "index", i, "numTokens", numTokens)
|
||||
|
||||
// Encode the chunk
|
||||
if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) {
|
||||
return nil, errors.New("unable to encode mtmd image chunk")
|
||||
}
|
||||
}
|
||||
|
||||
// Get the embeddings
|
||||
embed := make([][]float32, lastChunkSize)
|
||||
embd := C.mtmd_get_output_embd(c.c)
|
||||
if nil == embd {
|
||||
return nil, errors.New("failed to get image embedding")
|
||||
}
|
||||
// Get the embeddings for this chunk
|
||||
chunkEmbed := make([][]float32, numTokens)
|
||||
chunkEmbd := C.mtmd_get_output_embd(c.c)
|
||||
if nil == chunkEmbd {
|
||||
continue
|
||||
}
|
||||
|
||||
// Extend the embedding array for each token
|
||||
s := unsafe.Slice((*float32)(embd), numEmbed*lastChunkSize)
|
||||
rows := make([]float32, len(s))
|
||||
copy(rows, s)
|
||||
for i := range lastChunkSize {
|
||||
embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
|
||||
// Extend the embedding array for each token
|
||||
s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed)
|
||||
rows := make([]float32, len(s))
|
||||
copy(rows, s)
|
||||
for i := range numTokens {
|
||||
chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed]
|
||||
}
|
||||
embed = append(embed, chunkEmbed...)
|
||||
}
|
||||
|
||||
slog.Debug("image embeddings", "totalEmbeddings", len(embed))
|
||||
return embed, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -35,7 +35,6 @@ import (
|
|||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/parser"
|
||||
)
|
||||
|
||||
type filteredEnv []string
|
||||
|
|
@ -1350,9 +1349,7 @@ type CompletionRequest struct {
|
|||
Images []ImageData
|
||||
Options *api.Options
|
||||
|
||||
Grammar string // set before sending the request to the subprocess
|
||||
ParserType parser.TokenParserType
|
||||
PrefillString string
|
||||
Grammar string // set before sending the request to the subprocess
|
||||
}
|
||||
|
||||
// DoneReason represents the reason why a completion response is done
|
||||
|
|
@ -1379,15 +1376,13 @@ func (d DoneReason) String() string {
|
|||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string `json:"content"`
|
||||
Thinking string `json:"thinking"`
|
||||
ToolCalls []api.ToolCall `json:"tool_calls"`
|
||||
DoneReason DoneReason `json:"done_reason"`
|
||||
Done bool `json:"done"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
EvalDuration time.Duration `json:"eval_duration"`
|
||||
Content string `json:"content"`
|
||||
DoneReason DoneReason `json:"done_reason"`
|
||||
Done bool `json:"done"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
|
||||
EvalCount int `json:"eval_count"`
|
||||
EvalDuration time.Duration `json:"eval_duration"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
|
|
@ -1505,8 +1500,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
|
||||
}
|
||||
switch {
|
||||
// TODO(parthsareen): token repeat limit is now handled in the runner, this currently support legacy model and can be removed in the future
|
||||
case strings.TrimSpace(c.Content) == lastToken && c.Content != "":
|
||||
case strings.TrimSpace(c.Content) == lastToken:
|
||||
tokenRepeat++
|
||||
default:
|
||||
lastToken = strings.TrimSpace(c.Content)
|
||||
|
|
@ -1519,14 +1513,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||
return ctx.Err()
|
||||
}
|
||||
|
||||
if c.Content != "" {
|
||||
fn(CompletionResponse{
|
||||
Content: c.Content,
|
||||
})
|
||||
}
|
||||
|
||||
if c.Done {
|
||||
fn(c)
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 {
|
||||
fn(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -416,6 +416,7 @@ type Tensor interface {
|
|||
AddID(ctx Context, t2, ids Tensor) Tensor
|
||||
|
||||
Softmax(ctx Context) Tensor
|
||||
L2Norm(ctx Context, eps float32) Tensor
|
||||
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
|
||||
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
|
||||
Scale(ctx Context, s float64) Tensor
|
||||
|
|
@ -429,12 +430,13 @@ type Tensor interface {
|
|||
Sin(ctx Context) Tensor
|
||||
Cos(ctx Context) Tensor
|
||||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context) Tensor
|
||||
QuickGELU(ctx Context) Tensor
|
||||
SILU(ctx Context) Tensor
|
||||
RELU(ctx Context) Tensor
|
||||
GELU(ctx Context, up ...Tensor) Tensor
|
||||
SILU(ctx Context, up ...Tensor) Tensor
|
||||
RELU(ctx Context, up ...Tensor) Tensor
|
||||
Sigmoid(ctx Context) Tensor
|
||||
SwiGLU(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||
|
||||
// AlphaLimitSILU is a variant of SILU that clamps the input to the range [-limit, limit]
|
||||
SILUAlphaLimit(ctx Context, up Tensor, alpha, limit float32) Tensor
|
||||
|
||||
Reshape(ctx Context, shape ...int) Tensor
|
||||
View(ctx Context, offset int, shape ...int) Tensor
|
||||
|
|
|
|||
|
|
@ -1205,6 +1205,13 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
|
|||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
|
||||
if w != nil {
|
||||
|
|
@ -1424,35 +1431,46 @@ func (t *Tensor) IM2Col(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
|
|||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
||||
func (t *Tensor) GELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
if len(t2) > 0 {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_geglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
|
||||
}
|
||||
}
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) QuickGELU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_gelu_quick_inplace(ctx.(*Context).ctx, t.t),
|
||||
func (t *Tensor) SILU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
if len(t2) > 0 {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_swiglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) RELU(ctx ml.Context) ml.Tensor {
|
||||
func (t *Tensor) RELU(ctx ml.Context, t2 ...ml.Tensor) ml.Tensor {
|
||||
if len(t2) > 0 {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reglu_split(ctx.(*Context).ctx, t.t, t2[0].(*Tensor).t),
|
||||
}
|
||||
}
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_relu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SwiGLU(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
|
||||
func (t *Tensor) SILUAlphaLimit(ctx ml.Context, up ml.Tensor, alpha, limit float32) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_swiglu_oai(ctx.(*Context).ctx, t.t, up.(*Tensor).t, C.float(alpha), C.float(limit)),
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache
|
|||
}
|
||||
|
||||
func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
ctx.Forward(query)
|
||||
if key != nil && value != nil {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
|
|
@ -39,6 +40,7 @@ func AttentionWithSinks(ctx ml.Context, query, key, value, sinks ml.Tensor, scal
|
|||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||
}
|
||||
|
||||
ctx.Forward(key, value)
|
||||
if cache != nil {
|
||||
cache.Put(ctx, key, value)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
package pooling
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type Type uint32
|
||||
|
||||
const (
|
||||
TypeNone Type = iota
|
||||
TypeMean
|
||||
TypeCLS
|
||||
TypeLast
|
||||
)
|
||||
|
||||
func (t Type) String() string {
|
||||
switch t {
|
||||
case TypeMean:
|
||||
return "Mean"
|
||||
case TypeCLS:
|
||||
return "CLS"
|
||||
case TypeLast:
|
||||
return "Last"
|
||||
default:
|
||||
return "Unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (t Type) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||
switch t {
|
||||
case TypeMean:
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
||||
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
case TypeCLS:
|
||||
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
|
||||
case TypeLast:
|
||||
hiddenStates = hiddenStates.View(ctx, (hiddenStates.Dim(1)-1)*hiddenStates.Stride(1), hiddenStates.Dim(0))
|
||||
return hiddenStates
|
||||
default:
|
||||
panic("unknown pooling type")
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
package pooling_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/discover"
|
||||
fsggml "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/backend/ggml"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
)
|
||||
|
||||
func setup(tb testing.TB, n int) ml.Backend {
|
||||
tb.Helper()
|
||||
|
||||
f, err := os.CreateTemp(tb.TempDir(), "*.bin")
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := fsggml.WriteGGUF(f, fsggml.KV{
|
||||
"general.architecture": "test",
|
||||
"test.block_count": uint32(1),
|
||||
}, []*fsggml.Tensor{
|
||||
{Name: "blk.0.weight", Shape: []uint64{1}, WriterTo: bytes.NewBuffer(make([]byte, 4))},
|
||||
}); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
var gpuLayers ml.GPULayersList
|
||||
if gpus := discover.GetGPUInfo(); len(gpus) > 0 {
|
||||
gpuLayers = append(gpuLayers, ml.GPULayers{
|
||||
ID: gpus[0].ID,
|
||||
Layers: slices.Collect(func(yield func(int) bool) {
|
||||
for i := range n {
|
||||
if !yield(i) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}),
|
||||
})
|
||||
}
|
||||
b, err := ggml.New(f.Name(), ml.BackendParams{AllocMemory: true, GPULayers: gpuLayers})
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
func TestForward(t *testing.T) {
|
||||
cases := map[pooling.Type][]float32{
|
||||
pooling.TypeMean: {4, 5, 6, 7, 8, 9, 10, 11},
|
||||
pooling.TypeCLS: {0, 1, 2, 3, 4, 5, 6, 7},
|
||||
pooling.TypeLast: {8, 9, 10, 11, 12, 13, 14, 15},
|
||||
}
|
||||
for typ, want := range cases {
|
||||
t.Run(typ.String(), func(t *testing.T) {
|
||||
b := setup(t, 99)
|
||||
defer b.Close()
|
||||
|
||||
ctx := b.NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
tt := ctx.Input().Arange(0, 16, 1, ml.DTypeF32).Reshape(ctx, 8, 2)
|
||||
tt = typ.Forward(ctx, tt)
|
||||
|
||||
ctx.Forward(tt).Compute(tt)
|
||||
if diff := cmp.Diff(want, tt.Floats()); diff != "" {
|
||||
t.Error(diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -54,10 +54,9 @@ type Batch struct {
|
|||
// Inputs is the input tokens, including placeholders for multimodal inputs.
|
||||
Inputs ml.Tensor
|
||||
|
||||
// Multimodal is a set of multimodal embeddings previously created by
|
||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||
// models or for batches without multimodal elements.
|
||||
Multimodal []MultimodalIndex
|
||||
// Outputs are the set of indicies into Inputs for which output data should
|
||||
// be returned.
|
||||
Outputs ml.Tensor
|
||||
|
||||
// Positions is the position for each Input, relative to its sequence. Equal
|
||||
// in length to Inputs.
|
||||
|
|
@ -66,7 +65,8 @@ type Batch struct {
|
|||
// Sequences is the sequence for each Input. Equal in length to Inputs.
|
||||
Sequences []int
|
||||
|
||||
// Outputs are the set of indicies into Inputs for which output data should
|
||||
// be returned.
|
||||
Outputs []int32
|
||||
// Multimodal is a set of multimodal embeddings previously created by
|
||||
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
|
||||
// models or for batches without multimodal elements.
|
||||
Multimodal []MultimodalIndex
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import (
|
|||
"fmt"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"math"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
|
@ -21,10 +20,15 @@ import (
|
|||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
var ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
||||
var (
|
||||
ErrNoVisionModel = errors.New("this model is missing data required for image input")
|
||||
ErrUnsupportedModel = errors.New("model not supported")
|
||||
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
|
||||
)
|
||||
|
||||
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||
type Model interface {
|
||||
|
|
@ -104,7 +108,7 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
|||
}
|
||||
|
||||
arch := b.Config().Architecture()
|
||||
if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 {
|
||||
if pooling.Type(b.Config().Uint("pooling_type")) != pooling.TypeNone {
|
||||
arch = arch + "_embed"
|
||||
}
|
||||
|
||||
|
|
@ -242,7 +246,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
|
|||
vv = vv.Elem()
|
||||
}
|
||||
|
||||
vv = vv.Elem()
|
||||
vv = reflect.Indirect(vv)
|
||||
if v.IsNil() {
|
||||
vv = reflect.New(v.Type().Elem()).Elem()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,181 @@
|
|||
package bert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
TypeEmbedding *nn.Embedding `gguf:"token_types"`
|
||||
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
|
||||
TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"`
|
||||
|
||||
Layers []EncoderLayer `gguf:"blk"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
// Forward implements model.Model.
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize))
|
||||
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))))
|
||||
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
|
||||
}
|
||||
|
||||
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||
if m.normalize {
|
||||
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||
}
|
||||
|
||||
return hiddenStates, nil
|
||||
}
|
||||
|
||||
type EncoderLayer struct {
|
||||
*Attention
|
||||
AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"`
|
||||
|
||||
*MLP
|
||||
MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"`
|
||||
}
|
||||
|
||||
func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
// Attention
|
||||
residual := hiddenStates
|
||||
hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
// MLP
|
||||
residual = hiddenStates
|
||||
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
|
||||
|
||||
return hiddenStates
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"`
|
||||
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"`
|
||||
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
|
||||
query := a.Query.Forward(ctx, hiddenStates)
|
||||
if a.QueryNorm != nil {
|
||||
query = a.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
}
|
||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||
|
||||
key := a.Key.Forward(ctx, hiddenStates)
|
||||
if a.KeyNorm != nil {
|
||||
key = a.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
}
|
||||
key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
|
||||
|
||||
value := a.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
return a.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx))
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
hiddenSize,
|
||||
numHeads,
|
||||
numKVHeads,
|
||||
keyLength,
|
||||
valueLength int
|
||||
poolingType pooling.Type
|
||||
eps float32
|
||||
normalize bool
|
||||
}
|
||||
|
||||
func (o Options) headDim() int {
|
||||
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
var processor model.TextProcessor
|
||||
switch c.String("tokenizer.ggml.model", "bert") {
|
||||
case "bert":
|
||||
processor = model.NewWordPiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.cls_token_id"),
|
||||
c.Uint("tokenizer.ggml.bos_token_id"),
|
||||
)),
|
||||
},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
|
||||
EOS: []int32{
|
||||
int32(cmp.Or(
|
||||
c.Uint("tokenizer.ggml.separator_token_id"),
|
||||
//nolint:misspell
|
||||
// NOTE: "seperator_token_id" is a typo in model metadata but we need to
|
||||
// support it for compatibility.
|
||||
c.Uint("tokenizer.ggml.seperator_token_id"),
|
||||
c.Uint("tokenizer.ggml.eos_token_id"),
|
||||
)),
|
||||
},
|
||||
},
|
||||
)
|
||||
default:
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
return &Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]EncoderLayer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_epsilon"),
|
||||
poolingType: pooling.Type(c.Uint("pooling_type")),
|
||||
normalize: c.Bool("normalize_embeddings", true),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("bert", New)
|
||||
model.Register("bert_embed", New)
|
||||
}
|
||||
|
|
@ -24,7 +24,7 @@ type Options struct {
|
|||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
model.SentencePiece
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
|
|
@ -40,7 +40,7 @@ const (
|
|||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
|
|
@ -138,7 +138,7 @@ type MLP struct {
|
|||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
|
|
@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
|||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
|
||||
|
|
@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
lastLayerOutputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
|
||||
|
|
|
|||
|
|
@ -1,49 +1,38 @@
|
|||
package gemma3
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
model.SentencePiece
|
||||
|
||||
*TextModel
|
||||
PoolingType uint32
|
||||
poolingType pooling.Type
|
||||
|
||||
Dense [2]*nn.Linear `gguf:"dense"`
|
||||
}
|
||||
|
||||
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
batch.Outputs = batch.Positions // return all positions
|
||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
|
||||
switch m.PoolingType {
|
||||
case 0: // None
|
||||
case 1: // Mean
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
default:
|
||||
return nil, errors.New("unsupported pooling type")
|
||||
}
|
||||
|
||||
hiddenStates = m.poolingType.Forward(ctx, hiddenStates)
|
||||
for _, dense := range m.Dense {
|
||||
hiddenStates = dense.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
|
||||
return hiddenStates, nil
|
||||
}
|
||||
|
||||
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||
m := &embedModel{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
|
|
@ -61,7 +50,7 @@ func newEmbedModel(c fs.Config) (model.Model, error) {
|
|||
},
|
||||
),
|
||||
TextModel: newTextModel(c),
|
||||
PoolingType: c.Uint("pooling_type", 0),
|
||||
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import (
|
|||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
model.SentencePiece
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
|
|
@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
|||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ type TextMLP struct {
|
|||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
|
|
@ -161,7 +161,6 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
|||
|
||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||
|
|
@ -194,7 +193,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
|||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
lastLayerOutputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import (
|
|||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
model.SentencePiece
|
||||
|
||||
*TextModel
|
||||
}
|
||||
|
|
@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||
func New(c fs.Config) (model.Model, error) {
|
||||
m := Model{
|
||||
TextModel: newTextModel(c),
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
|||
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
|
||||
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
|
||||
hiddenStates = hiddenStates.Rows(ctx, batch.Outputs)
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
|
|
@ -170,8 +170,7 @@ func (d TextLayer) Forward(ctx ml.Context, hiddenStates, perLayerInput, position
|
|||
}
|
||||
|
||||
active = d.PerLayerInputGate.Forward(ctx, active)
|
||||
active = active.GELU(ctx)
|
||||
active = active.Mul(ctx, perLayerInput)
|
||||
active = active.GELU(ctx, perLayerInput)
|
||||
|
||||
active = d.PerLayerProjection.Forward(ctx, active)
|
||||
active = d.PostPerLayerNorm.Forward(ctx, active, opts.eps)
|
||||
|
|
@ -292,7 +291,7 @@ func (mlp TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, activationSpa
|
|||
hiddenStates = hiddenStates.Sub(ctx, cutoff).RELU(ctx)
|
||||
}
|
||||
|
||||
hiddenStates = hiddenStates.GELU(ctx).Mul(ctx, upStates)
|
||||
hiddenStates = hiddenStates.GELU(ctx, upStates)
|
||||
hiddenStates = mlp.Down.Forward(ctx, hiddenStates)
|
||||
return hiddenStates
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
|
|||
}
|
||||
|
||||
var outputs ml.Tensor
|
||||
if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 {
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if i == len(m.TransformerBlocks)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)
|
||||
|
|
@ -210,7 +210,7 @@ func (mlp *MLPBlock) Forward(ctx ml.Context, hiddenStates, one ml.Tensor, opts *
|
|||
up = mlp.Up.Forward(ctx, hiddenStates, selectedExperts)
|
||||
}
|
||||
|
||||
hiddenStates = gate.SwiGLU(ctx, up, 1.702, 7)
|
||||
hiddenStates = gate.SILUAlphaLimit(ctx, up, 1.702, 7)
|
||||
|
||||
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ type MLP struct {
|
|||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
|
|
@ -160,7 +160,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
|
||||
|
|
|
|||
|
|
@ -176,9 +176,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
|||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
|
|||
|
|
@ -58,14 +58,14 @@ type TextMLP struct {
|
|||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type TextExperts struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
}
|
||||
|
||||
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
|
|
@ -76,9 +76,9 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
|
|||
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
|
||||
hiddenStates = hiddenStates.Mul(ctx, scores)
|
||||
|
||||
upStates := e.Up.Weight.MulmatID(ctx, hiddenStates, experts)
|
||||
gateStates := e.Gate.Weight.MulmatID(ctx, hiddenStates, experts)
|
||||
downStates := e.Down.Weight.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
|
||||
upStates := e.Up.Forward(ctx, hiddenStates, experts)
|
||||
gateStates := e.Gate.Forward(ctx, hiddenStates, experts)
|
||||
downStates := e.Down.Forward(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
|
||||
|
||||
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
|
||||
for i := 1; i < opts.numExpertsUsed; i++ {
|
||||
|
|
@ -96,7 +96,7 @@ type TextSharedExpert struct {
|
|||
}
|
||||
|
||||
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
|||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ type MLP struct {
|
|||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ type VisionMLP struct {
|
|||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||
}
|
||||
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
// TODO: attention mask, cross attention mask
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
|
|||
|
|
@ -58,7 +58,7 @@ type TextMLP struct {
|
|||
}
|
||||
|
||||
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextModelOptions) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package models
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/model/models/bert"
|
||||
_ "github.com/ollama/ollama/model/models/gemma2"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3"
|
||||
_ "github.com/ollama/ollama/model/models/gemma3n"
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ type MLP struct {
|
|||
}
|
||||
|
||||
func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
|
|
@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
|
||||
|
|
|
|||
|
|
@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
|||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache)
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ type MLP struct {
|
|||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
// Apply SwiGLU activation gating
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
// Project back to hidden dimension
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -100,8 +100,7 @@ type VisionMLP struct {
|
|||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
// Using activation as specified in config (likely GELU or SiLU/Swish)
|
||||
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
|
||||
upOutput := mlp.Up.Forward(ctx, hiddenStates)
|
||||
hiddenStates = gateOutput.SILU(ctx).Mul(ctx, upOutput)
|
||||
hiddenStates = gateOutput.SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,10 +30,10 @@ func (o Options) headDim() int {
|
|||
}
|
||||
|
||||
type Attention struct {
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
|
@ -65,10 +65,10 @@ type MLP interface {
|
|||
}
|
||||
|
||||
type sparse struct {
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate_exps"`
|
||||
Up *nn.Linear `gguf:"ffn_up_exps"`
|
||||
Down *nn.Linear `gguf:"ffn_down_exps"`
|
||||
Router *nn.Linear `gguf:"ffn_gate_inp"`
|
||||
Gate *nn.LinearBatch `gguf:"ffn_gate_exps"`
|
||||
Up *nn.LinearBatch `gguf:"ffn_up_exps"`
|
||||
Down *nn.LinearBatch `gguf:"ffn_down_exps"`
|
||||
}
|
||||
|
||||
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
|
||||
|
|
@ -87,13 +87,9 @@ func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options
|
|||
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
|
||||
|
||||
upStates := mlp.Up.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates, selectedExperts))
|
||||
|
||||
hiddenStates = mlp.Gate.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
hiddenStates = hiddenStates.SILU(ctx)
|
||||
hiddenStates = hiddenStates.Mul(ctx, upStates)
|
||||
|
||||
experts := mlp.Down.Weight.MulmatID(ctx, hiddenStates, selectedExperts)
|
||||
experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts)
|
||||
experts = experts.Mul(ctx, routingWeights)
|
||||
|
||||
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
|
||||
|
|
@ -111,7 +107,8 @@ type dense struct {
|
|||
}
|
||||
|
||||
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).
|
||||
SILU(ctx, mlp.Up.Forward(ctx, hiddenStates))
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
|
|
@ -165,7 +162,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,37 @@
|
|||
package parsers
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type Parser interface {
|
||||
Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error)
|
||||
HasToolSupport() bool
|
||||
HasThinkingSupport() bool
|
||||
}
|
||||
|
||||
func ParserForName(name string) Parser {
|
||||
switch name {
|
||||
case "qwen3-coder":
|
||||
parser := &Qwen3CoderParser{}
|
||||
return parser
|
||||
case "passthrough":
|
||||
return &PassthroughParser{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
type PassthroughParser struct{}
|
||||
|
||||
func (p *PassthroughParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
return s, "", nil, nil
|
||||
}
|
||||
|
||||
func (p *PassthroughParser) HasToolSupport() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *PassthroughParser) HasThinkingSupport() bool {
|
||||
return false
|
||||
}
|
||||
|
|
@ -0,0 +1,410 @@
|
|||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type qwenParserState int
|
||||
|
||||
const (
|
||||
toolOpenTag = "<tool_call>"
|
||||
toolCloseTag = "</tool_call>"
|
||||
)
|
||||
|
||||
const (
|
||||
qwenParserState_LookingForToolStart qwenParserState = iota
|
||||
qwenParserState_CollectingToolContent
|
||||
)
|
||||
|
||||
type Qwen3CoderParser struct {
|
||||
state qwenParserState
|
||||
acc strings.Builder
|
||||
}
|
||||
|
||||
func (p *Qwen3CoderParser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Qwen3CoderParser) HasThinkingSupport() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Qwen3CoderParser) Add(s string, tools []api.Tool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.acc.WriteString(s)
|
||||
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var sb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case qwenEventRawToolCall:
|
||||
toolCall, err := parseToolCall(event, tools)
|
||||
if err != nil {
|
||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case qwenEventContent:
|
||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||
// events, we naively append them together here. See the note below about
|
||||
// `qwenEvent`s for more details
|
||||
sb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), "", toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *Qwen3CoderParser) parseEvents() []qwenEvent {
|
||||
var all []qwenEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []qwenEvent
|
||||
events, keepLooping = eat(p)
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "acc", p.acc.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
// we use some internal event types in order to communicate between `Add` and
|
||||
// `eat`. We do this to support interleaving content and parallel tool calls in
|
||||
// the parser, even though qwen3-coder isn't supposed to do this. Our API
|
||||
// doesn't currently support models outputting multiple messages in a turn, so
|
||||
// we wouldn't be able to represent it yet, but there's no reason to prevent the
|
||||
// parser from supporting it, especially for future models if they end up using
|
||||
// a similar format.
|
||||
type qwenEvent interface {
|
||||
isQwenEvent()
|
||||
}
|
||||
|
||||
type qwenEventRawToolCall struct {
|
||||
raw string
|
||||
}
|
||||
|
||||
type qwenEventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (qwenEventContent) isQwenEvent() {}
|
||||
func (qwenEventRawToolCall) isQwenEvent() {}
|
||||
|
||||
// eat consumes the parser's buffer, and returns a list of any unambiguous
|
||||
// events from the current parser state. If the parser transitions to another
|
||||
// state, it may have additional events to emit on the next call, which is what
|
||||
// the second return value indicates
|
||||
func eat(p *Qwen3CoderParser) ([]qwenEvent, bool) {
|
||||
var events []qwenEvent
|
||||
|
||||
switch p.state {
|
||||
case qwenParserState_LookingForToolStart:
|
||||
if strings.Contains(p.acc.String(), toolOpenTag) {
|
||||
// we found a full tool open tag, so we can emit the content before the
|
||||
// tag, being sure to trim any trailing whitespace
|
||||
split := strings.SplitN(p.acc.String(), toolOpenTag, 2)
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwenEventContent{content: before})
|
||||
}
|
||||
after := split[1]
|
||||
p.acc.Reset()
|
||||
p.acc.WriteString(after)
|
||||
p.state = qwenParserState_CollectingToolContent
|
||||
return events, true
|
||||
} else if overlap := overlap(p.acc.String(), toolOpenTag); overlap > 0 {
|
||||
// we found a partial tool open tag, so we can emit the unambiguous part,
|
||||
// which is the (trailing-whitespace trimmed) content before the partial
|
||||
// tool open tag
|
||||
beforePartialTag := p.acc.String()[:len(p.acc.String())-overlap]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
unambiguous := p.acc.String()[:ambiguousStart]
|
||||
ambiguous := p.acc.String()[ambiguousStart:]
|
||||
p.acc.Reset()
|
||||
p.acc.WriteString(ambiguous)
|
||||
events = append(events, qwenEventContent{content: unambiguous})
|
||||
return events, false
|
||||
} else {
|
||||
// we found content that is entirely not a tool call. We should withhold
|
||||
// any trailing whitespace in case this is the end of the content
|
||||
whitespaceLen := trailingWhitespaceLen(p.acc.String())
|
||||
ambiguousStart := len(p.acc.String()) - whitespaceLen
|
||||
unambiguous := p.acc.String()[:ambiguousStart]
|
||||
ambiguous := p.acc.String()[ambiguousStart:]
|
||||
p.acc.Reset()
|
||||
p.acc.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, qwenEventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
case qwenParserState_CollectingToolContent:
|
||||
if strings.Contains(p.acc.String(), toolCloseTag) {
|
||||
split := strings.SplitN(p.acc.String(), toolCloseTag, 2)
|
||||
before := split[0]
|
||||
if len(before) == 0 {
|
||||
slog.Warn("qwen tool call closing tag found but no content before it")
|
||||
}
|
||||
// remove any whitespace between the tool call and any content after it
|
||||
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
p.acc.Reset()
|
||||
p.acc.WriteString(after)
|
||||
events = append(events, qwenEventRawToolCall{raw: before})
|
||||
p.state = qwenParserState_LookingForToolStart
|
||||
return events, true
|
||||
} else {
|
||||
// note that we don't need to check the overlap here because we only plan
|
||||
// on parsing the tool call once we see the full closing tag. We don't
|
||||
// stream back the unparsed tool content, so there's no need to be eager
|
||||
// here
|
||||
return events, false
|
||||
}
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(drifkin): move this to a shared location
|
||||
// longest overlap between suffix of s and prefix of delim
|
||||
func overlap(s, delim string) int {
|
||||
max := min(len(delim), len(s))
|
||||
for i := max; i > 0; i-- {
|
||||
if strings.HasSuffix(s, delim[:i]) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func trailingWhitespaceLen(s string) int {
|
||||
for i := len(s) - 1; i >= 0; i-- {
|
||||
if !unicode.IsSpace(rune(s[i])) {
|
||||
return len(s) - i - 1
|
||||
}
|
||||
}
|
||||
return len(s)
|
||||
}
|
||||
|
||||
type XMLFunctionCall struct {
|
||||
XMLName xml.Name `xml:"function"`
|
||||
Name string `xml:"name,attr"`
|
||||
Parameters []XMLParameter `xml:"parameter"`
|
||||
}
|
||||
|
||||
type XMLParameter struct {
|
||||
Name string `xml:"name,attr"`
|
||||
Value string `xml:",chardata"`
|
||||
}
|
||||
|
||||
// parseToolCall parses a raw tool call string into an api.ToolCall.
|
||||
// The raw string follows an xml-like format, here's an example:
|
||||
//
|
||||
// <function=get_current_temperature>
|
||||
// <parameter=location>
|
||||
// San Francisco
|
||||
// </parameter>
|
||||
// <parameter=unit>
|
||||
// celsius
|
||||
// </parameter>
|
||||
// </function>
|
||||
func parseToolCall(raw qwenEventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
toolCall := api.ToolCall{}
|
||||
|
||||
xmlString := transformToXML(raw.raw)
|
||||
|
||||
var functionCall XMLFunctionCall
|
||||
err := xml.Unmarshal([]byte(xmlString), &functionCall)
|
||||
if err != nil {
|
||||
return api.ToolCall{}, err
|
||||
}
|
||||
|
||||
toolCall.Function = api.ToolCallFunction{
|
||||
Name: functionCall.Name,
|
||||
}
|
||||
|
||||
// Find the matching tool to get parameter types
|
||||
var matchedTool *api.Tool
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == functionCall.Name {
|
||||
matchedTool = &tools[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
|
||||
for _, parameter := range functionCall.Parameters {
|
||||
// Look up the parameter type if we found the tool
|
||||
var paramType api.PropertyType
|
||||
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
||||
if prop, ok := matchedTool.Function.Parameters.Properties[parameter.Name]; ok {
|
||||
paramType = prop.Type
|
||||
}
|
||||
}
|
||||
|
||||
toolCall.Function.Arguments[parameter.Name] = parseValue(parameter.Value, paramType)
|
||||
}
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
|
||||
// parseValue converts a raw string value to the appropriate type based on the parameter type specification.
|
||||
//
|
||||
// For union types (multiple types in PropertyType, which we support but doesn't
|
||||
// seem as though the reference parser does type coercion with those types in
|
||||
// mind) we use a type precedence approach:
|
||||
// 1. null - checked first regardless of declared types (matches reference implementation)
|
||||
// 2. boolean - only "true"/"false" are valid booleans
|
||||
// 3. integer - must parse as a whole number
|
||||
// 4. number - must parse as numeric (returns int if no decimal part)
|
||||
// 5. array - must parse as valid JSON array
|
||||
// 6. object - must parse as valid JSON object
|
||||
// 7. string - always succeeds (least specific type)
|
||||
//
|
||||
// This precedence ensures we return the most specific type that successfully parses,
|
||||
// following the principle of least surprise. For example, with PropertyType{"string", "number"},
|
||||
// "123" becomes 123 (number), while "hello" becomes "hello" (string).
|
||||
func parseValue(raw string, paramType api.PropertyType) any {
|
||||
// first remove a single leading newlines, and a single trailing newline (if
|
||||
// they exist). This follows the reference implementation
|
||||
raw = strings.TrimPrefix(raw, "\n")
|
||||
raw = strings.TrimSuffix(raw, "\n")
|
||||
|
||||
// Check for null first (case-insensitive) - this takes precedence over any type
|
||||
if strings.ToLower(raw) == "null" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If no type is specified, default to string
|
||||
if len(paramType) == 0 {
|
||||
return raw
|
||||
}
|
||||
|
||||
// Check if any of the specified types match, using type precedence
|
||||
// Order: boolean -> integer -> number -> array -> object -> string
|
||||
typeSet := make(map[string]bool)
|
||||
for _, t := range paramType {
|
||||
typeSet[t] = true
|
||||
}
|
||||
|
||||
// Try boolean first (most restrictive)
|
||||
if typeSet["boolean"] {
|
||||
lower := strings.ToLower(raw)
|
||||
switch lower {
|
||||
case "true":
|
||||
return true
|
||||
case "false":
|
||||
return false
|
||||
}
|
||||
// If not a valid boolean but boolean is the only type, return false (matching reference)
|
||||
if len(paramType) == 1 {
|
||||
return false
|
||||
}
|
||||
// Otherwise try other types
|
||||
}
|
||||
|
||||
// Try integer
|
||||
if typeSet["integer"] {
|
||||
if i, err := strconv.ParseInt(raw, 10, 64); err == nil {
|
||||
// Return as int if it fits in int32, otherwise int64
|
||||
if i >= math.MinInt32 && i <= math.MaxInt32 {
|
||||
return int(i)
|
||||
}
|
||||
return i
|
||||
}
|
||||
// If integer is the only type and parsing failed, fall back to string
|
||||
if len(paramType) == 1 {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
|
||||
// Try number (float)
|
||||
if typeSet["number"] {
|
||||
if f, err := strconv.ParseFloat(raw, 64); err == nil {
|
||||
// If the number has no decimal part, return as int (matching reference)
|
||||
if f == math.Trunc(f) {
|
||||
i := int64(f)
|
||||
if i >= math.MinInt32 && i <= math.MaxInt32 {
|
||||
return int(i)
|
||||
}
|
||||
return i
|
||||
}
|
||||
return f
|
||||
}
|
||||
// If number is the only type and parsing failed, fall back to string
|
||||
if len(paramType) == 1 {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
|
||||
// Try array
|
||||
if typeSet["array"] {
|
||||
var arr []interface{}
|
||||
if err := json.Unmarshal([]byte(raw), &arr); err == nil {
|
||||
return arr
|
||||
}
|
||||
// If array is the only type and parsing failed, fall back to string
|
||||
if len(paramType) == 1 {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
|
||||
// Try object
|
||||
if typeSet["object"] {
|
||||
var obj map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(raw), &obj); err == nil {
|
||||
return obj
|
||||
}
|
||||
// If object is the only type and parsing failed, fall back to string
|
||||
if len(paramType) == 1 {
|
||||
return raw
|
||||
}
|
||||
}
|
||||
|
||||
// String always succeeds (or if "string" is in the type set)
|
||||
if typeSet["string"] {
|
||||
return raw
|
||||
}
|
||||
|
||||
// If we get here, none of the types matched and string wasn't an option
|
||||
// We return string as a fallback. The reference implementation will attempt
|
||||
// to parse the value as a python literal, but we purposefully don't support
|
||||
// that
|
||||
return raw
|
||||
}
|
||||
|
||||
var qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`)
|
||||
|
||||
// transformToXML transforms a raw qwen tool call with xml-like tags into valid
|
||||
// xml so that it can be parsed by any xml parser
|
||||
func transformToXML(raw string) string {
|
||||
// take the form `<tag=abc>` and transform it to `<tag name="abc">`, taking
|
||||
// care to properly escape the string that becomes the attribute value
|
||||
return qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string {
|
||||
groups := qwenTagRegex.FindStringSubmatch(match)
|
||||
tag := groups[1]
|
||||
var escapedValue strings.Builder
|
||||
xml.EscapeText(&escapedValue, []byte(groups[2]))
|
||||
return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
|
||||
})
|
||||
}
|
||||
|
|
@ -0,0 +1,830 @@
|
|||
package parsers
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// tool creates a test tool with the given name and properties
|
||||
func tool(name string, props map[string]api.ToolProperty) api.Tool {
|
||||
t := api.Tool{Type: "function", Function: api.ToolFunction{Name: name}}
|
||||
t.Function.Parameters.Type = "object"
|
||||
t.Function.Parameters.Properties = props
|
||||
return t
|
||||
}
|
||||
|
||||
func TestQwenParserStreaming(t *testing.T) {
|
||||
type step struct {
|
||||
input string
|
||||
wantEvents []qwenEvent
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
steps []step
|
||||
only bool
|
||||
}{
|
||||
{
|
||||
desc: "simple message streamed word by word",
|
||||
steps: []step{
|
||||
{
|
||||
input: "hi",
|
||||
wantEvents: []qwenEvent{qwenEventContent{content: "hi"}},
|
||||
},
|
||||
{
|
||||
input: " there",
|
||||
wantEvents: []qwenEvent{qwenEventContent{content: " there"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "content before tool call",
|
||||
steps: []step{
|
||||
{
|
||||
input: "hi there<tool_call>",
|
||||
wantEvents: []qwenEvent{qwenEventContent{content: "hi there"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "multiple tool calls in one message",
|
||||
steps: []step{
|
||||
{
|
||||
input: "before1<tool_call>in tool call</tool_call>after1<tool_call>in tool call 2</tool_call>after2",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventContent{content: "before1"},
|
||||
qwenEventRawToolCall{raw: "in tool call"},
|
||||
qwenEventContent{content: "after1"},
|
||||
qwenEventRawToolCall{raw: "in tool call 2"},
|
||||
qwenEventContent{content: "after2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "tool calls with split tags",
|
||||
steps: []step{
|
||||
{
|
||||
input: "before<tool",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventContent{content: "before"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "_call>in tool call</tool",
|
||||
wantEvents: []qwenEvent{},
|
||||
},
|
||||
{
|
||||
input: "_call>af",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventRawToolCall{raw: "in tool call"},
|
||||
qwenEventContent{content: "af"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "ter",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventContent{content: "ter"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace between content and tool call",
|
||||
steps: []step{
|
||||
{
|
||||
input: "abc\n<tool_call>def</tool_call>",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventContent{content: "abc"},
|
||||
qwenEventRawToolCall{raw: "def"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "trailing whitespace between tool call and content",
|
||||
steps: []step{
|
||||
{
|
||||
input: "<tool_call>abc</tool_call>\ndef",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventRawToolCall{raw: "abc"},
|
||||
qwenEventContent{content: "def"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "empty content before tool call",
|
||||
steps: []step{
|
||||
{
|
||||
input: "\n<tool_call>abc</tool_call>",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventRawToolCall{raw: "abc"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "partial tool open tag fakeout",
|
||||
steps: []step{
|
||||
{
|
||||
input: "abc\n<tool_call",
|
||||
wantEvents: []qwenEvent{
|
||||
// \n should not be emitted yet because `<tool_call` might be a tool
|
||||
// open tag, in which case the whitespace should be trimmed
|
||||
qwenEventContent{content: "abc"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: " fakeout",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventContent{content: "\n<tool_call fakeout"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "token-by-token whitespace handling",
|
||||
steps: []step{
|
||||
{
|
||||
input: "a",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventContent{content: "a"},
|
||||
},
|
||||
},
|
||||
{
|
||||
input: "\n",
|
||||
wantEvents: []qwenEvent{},
|
||||
},
|
||||
{
|
||||
input: "b",
|
||||
wantEvents: []qwenEvent{
|
||||
qwenEventContent{content: "\nb"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
anyOnlies := false
|
||||
for _, tc := range cases {
|
||||
if tc.only {
|
||||
anyOnlies = true
|
||||
}
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
if anyOnlies && !tc.only {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := Qwen3CoderParser{}
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.acc.WriteString(step.input)
|
||||
gotEvents := parser.parseEvents()
|
||||
|
||||
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
|
||||
// avoid deep equal on empty vs. nil slices
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
|
||||
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenToolParser(t *testing.T) {
|
||||
type step struct {
|
||||
name string
|
||||
rawToolCall string
|
||||
tools []api.Tool
|
||||
wantToolCall api.ToolCall
|
||||
}
|
||||
|
||||
steps := []step{
|
||||
{
|
||||
name: "simple tool call",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `<function=get_current_temperature>
|
||||
<parameter=location>
|
||||
San Francisco
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
celsius
|
||||
</parameter>
|
||||
</function>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_temperature",
|
||||
Arguments: map[string]any{
|
||||
"location": "San Francisco",
|
||||
"unit": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "names with spaces",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `<function=get current temperature>
|
||||
<parameter=location with spaces>
|
||||
San Francisco
|
||||
</parameter>
|
||||
<parameter=unit with spaces>
|
||||
celsius
|
||||
</parameter>
|
||||
</function>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get current temperature",
|
||||
Arguments: map[string]any{
|
||||
"location with spaces": "San Francisco",
|
||||
"unit with spaces": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
// this mirrors the reference implementation's behavior, but unclear if it
|
||||
// ever happens. If so, then we should probably remove them instead, this
|
||||
// test is to just document the current behavior and test that we don't get
|
||||
// xml errors
|
||||
{
|
||||
name: "names with quotes",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `<function="get current temperature">
|
||||
<parameter="location with spaces">
|
||||
San Francisco
|
||||
</parameter>
|
||||
<parameter="unit with spaces">
|
||||
"celsius"
|
||||
</parameter>
|
||||
</function>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "\"get current temperature\"",
|
||||
Arguments: map[string]any{
|
||||
"\"location with spaces\"": "San Francisco",
|
||||
"\"unit with spaces\"": "\"celsius\"",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with typed parameters",
|
||||
tools: []api.Tool{
|
||||
tool("calculate", map[string]api.ToolProperty{
|
||||
"x": {Type: api.PropertyType{"number"}},
|
||||
"y": {Type: api.PropertyType{"integer"}},
|
||||
"enabled": {Type: api.PropertyType{"boolean"}},
|
||||
"items": {Type: api.PropertyType{"array"}},
|
||||
}),
|
||||
},
|
||||
rawToolCall: `<function=calculate>
|
||||
<parameter=x>
|
||||
3.14
|
||||
</parameter>
|
||||
<parameter=y>
|
||||
42
|
||||
</parameter>
|
||||
<parameter=enabled>
|
||||
true
|
||||
</parameter>
|
||||
<parameter=items>
|
||||
["a", "b", "c"]
|
||||
</parameter>
|
||||
</function>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "calculate",
|
||||
Arguments: map[string]any{
|
||||
"x": 3.14,
|
||||
"y": 42,
|
||||
"enabled": true,
|
||||
"items": []any{"a", "b", "c"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, step := range steps {
|
||||
gotToolCall, err := parseToolCall(qwenEventRawToolCall{raw: step.rawToolCall}, step.tools)
|
||||
if err != nil {
|
||||
t.Errorf("step %d (%s): %v", i, step.name, err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotToolCall, step.wantToolCall) {
|
||||
t.Errorf("step %d (%s): got tool call %#v, want %#v", i, step.name, gotToolCall, step.wantToolCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenToolCallValueParsing(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
raw string
|
||||
paramType api.PropertyType
|
||||
want any
|
||||
}{
|
||||
{
|
||||
desc: "default string value (no type specified)",
|
||||
paramType: api.PropertyType{},
|
||||
raw: "some-string",
|
||||
want: "some-string",
|
||||
},
|
||||
{
|
||||
desc: "trim a single leading and trailing newline",
|
||||
paramType: api.PropertyType{},
|
||||
raw: "\nsome-string\n",
|
||||
want: "some-string",
|
||||
},
|
||||
{
|
||||
desc: "trim at most one leading and trailing newline",
|
||||
paramType: api.PropertyType{},
|
||||
raw: "\n\nsome-string\n\n",
|
||||
want: "\nsome-string\n",
|
||||
},
|
||||
{
|
||||
desc: "newline really has to be the first character to be trimmed",
|
||||
paramType: api.PropertyType{},
|
||||
raw: " \nsome-string\n ",
|
||||
want: " \nsome-string\n ",
|
||||
},
|
||||
{
|
||||
desc: "numeric type",
|
||||
paramType: api.PropertyType{"number"},
|
||||
raw: "123",
|
||||
want: 123,
|
||||
},
|
||||
// Integer parsing tests
|
||||
{
|
||||
desc: "integer type",
|
||||
paramType: api.PropertyType{"integer"},
|
||||
raw: "42",
|
||||
want: 42,
|
||||
},
|
||||
{
|
||||
desc: "negative integer",
|
||||
paramType: api.PropertyType{"integer"},
|
||||
raw: "-100",
|
||||
want: -100,
|
||||
},
|
||||
{
|
||||
desc: "zero integer",
|
||||
paramType: api.PropertyType{"integer"},
|
||||
raw: "0",
|
||||
want: 0,
|
||||
},
|
||||
{
|
||||
desc: "integer with leading zeros",
|
||||
paramType: api.PropertyType{"integer"},
|
||||
raw: "007",
|
||||
want: 7,
|
||||
},
|
||||
{
|
||||
desc: "large integer",
|
||||
paramType: api.PropertyType{"integer"},
|
||||
raw: "2147483648", // Just beyond int32 max
|
||||
want: int64(2147483648),
|
||||
},
|
||||
// Float/number parsing tests
|
||||
{
|
||||
desc: "float type",
|
||||
paramType: api.PropertyType{"number"},
|
||||
raw: "3.14",
|
||||
want: 3.14,
|
||||
},
|
||||
{
|
||||
desc: "negative float",
|
||||
paramType: api.PropertyType{"number"},
|
||||
raw: "-273.15",
|
||||
want: -273.15,
|
||||
},
|
||||
{
|
||||
desc: "float without decimal part",
|
||||
paramType: api.PropertyType{"number"},
|
||||
raw: "100.0",
|
||||
want: 100,
|
||||
},
|
||||
{
|
||||
desc: "scientific notation positive",
|
||||
paramType: api.PropertyType{"number"},
|
||||
raw: "1.23e5",
|
||||
want: 123000, // Will be int since it has no decimal part
|
||||
},
|
||||
{
|
||||
desc: "scientific notation negative",
|
||||
paramType: api.PropertyType{"number"},
|
||||
raw: "1.5e-3",
|
||||
want: 0.0015,
|
||||
},
|
||||
{
|
||||
desc: "very small float",
|
||||
paramType: api.PropertyType{"number"},
|
||||
raw: "0.00000001",
|
||||
want: 0.00000001,
|
||||
},
|
||||
// String parsing tests
|
||||
{
|
||||
desc: "explicit string type",
|
||||
paramType: api.PropertyType{"string"},
|
||||
raw: "hello world",
|
||||
want: "hello world",
|
||||
},
|
||||
{
|
||||
desc: "string with special characters",
|
||||
paramType: api.PropertyType{"string"},
|
||||
raw: "/usr/local/bin/test-file_v2.0.sh",
|
||||
want: "/usr/local/bin/test-file_v2.0.sh",
|
||||
},
|
||||
{
|
||||
desc: "string with quotes",
|
||||
paramType: api.PropertyType{"string"},
|
||||
raw: `He said "hello" to me`,
|
||||
want: `He said "hello" to me`,
|
||||
},
|
||||
{
|
||||
desc: "multiline string",
|
||||
paramType: api.PropertyType{"string"},
|
||||
raw: "line one\nline two\nline three",
|
||||
want: "line one\nline two\nline three",
|
||||
},
|
||||
{
|
||||
desc: "empty string",
|
||||
paramType: api.PropertyType{"string"},
|
||||
raw: "",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
desc: "string that looks like a number",
|
||||
paramType: api.PropertyType{"string"},
|
||||
raw: "12345",
|
||||
want: "12345",
|
||||
},
|
||||
// Boolean parsing tests
|
||||
{
|
||||
desc: "boolean true",
|
||||
paramType: api.PropertyType{"boolean"},
|
||||
raw: "true",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
desc: "boolean false",
|
||||
paramType: api.PropertyType{"boolean"},
|
||||
raw: "false",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
desc: "boolean case insensitive true",
|
||||
paramType: api.PropertyType{"boolean"},
|
||||
raw: "True",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
desc: "boolean case insensitive false",
|
||||
paramType: api.PropertyType{"boolean"},
|
||||
raw: "FALSE",
|
||||
want: false,
|
||||
},
|
||||
// Null parsing tests
|
||||
{
|
||||
desc: "null value lowercase",
|
||||
paramType: api.PropertyType{"string"},
|
||||
raw: "null",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
desc: "null value case insensitive",
|
||||
paramType: api.PropertyType{"integer"},
|
||||
raw: "NULL",
|
||||
want: nil,
|
||||
},
|
||||
// Array parsing tests
|
||||
{
|
||||
desc: "array of strings",
|
||||
paramType: api.PropertyType{"array"},
|
||||
raw: `["foo", "bar", "baz"]`,
|
||||
want: []any{"foo", "bar", "baz"},
|
||||
},
|
||||
{
|
||||
desc: "array of numbers",
|
||||
paramType: api.PropertyType{"array"},
|
||||
raw: `[1, 2.5, 3]`,
|
||||
want: []any{float64(1), 2.5, float64(3)},
|
||||
},
|
||||
{
|
||||
desc: "array of mixed types",
|
||||
paramType: api.PropertyType{"array"},
|
||||
raw: `["string", 123, true, null]`,
|
||||
want: []any{"string", float64(123), true, nil},
|
||||
},
|
||||
{
|
||||
desc: "empty array",
|
||||
paramType: api.PropertyType{"array"},
|
||||
raw: `[]`,
|
||||
want: []any{},
|
||||
},
|
||||
// Object parsing tests
|
||||
{
|
||||
desc: "simple object",
|
||||
paramType: api.PropertyType{"object"},
|
||||
raw: `{"key": "value", "number": 42}`,
|
||||
want: map[string]any{"key": "value", "number": float64(42)},
|
||||
},
|
||||
{
|
||||
desc: "nested object",
|
||||
paramType: api.PropertyType{"object"},
|
||||
raw: `{"outer": {"inner": "value"}}`,
|
||||
want: map[string]any{"outer": map[string]any{"inner": "value"}},
|
||||
},
|
||||
{
|
||||
desc: "empty object",
|
||||
paramType: api.PropertyType{"object"},
|
||||
raw: `{}`,
|
||||
want: map[string]any{},
|
||||
},
|
||||
// Error cases and fallback behavior
|
||||
{
|
||||
desc: "invalid integer falls back to string",
|
||||
paramType: api.PropertyType{"integer"},
|
||||
raw: "not-a-number",
|
||||
want: "not-a-number",
|
||||
},
|
||||
{
|
||||
desc: "invalid float falls back to string",
|
||||
paramType: api.PropertyType{"number"},
|
||||
raw: "3.14.159",
|
||||
want: "3.14.159",
|
||||
},
|
||||
{
|
||||
desc: "invalid boolean falls back to false",
|
||||
paramType: api.PropertyType{"boolean"},
|
||||
raw: "yes",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
desc: "invalid JSON array falls back to string",
|
||||
paramType: api.PropertyType{"array"},
|
||||
raw: "[1, 2, unclosed",
|
||||
want: "[1, 2, unclosed",
|
||||
},
|
||||
{
|
||||
desc: "invalid JSON object falls back to string",
|
||||
paramType: api.PropertyType{"object"},
|
||||
raw: `{"key": unclosed`,
|
||||
want: `{"key": unclosed`,
|
||||
},
|
||||
// Edge cases
|
||||
{
|
||||
desc: "integer overflow should use int64",
|
||||
paramType: api.PropertyType{"integer"},
|
||||
raw: "2147483648", // Beyond int32 max
|
||||
want: int64(2147483648),
|
||||
},
|
||||
{
|
||||
desc: "float with many decimal places",
|
||||
paramType: api.PropertyType{"number"},
|
||||
raw: "3.141592653589793",
|
||||
want: 3.141592653589793,
|
||||
},
|
||||
{
|
||||
desc: "string with JSON-like content",
|
||||
paramType: api.PropertyType{"string"},
|
||||
raw: `{"this": "is", "just": "a string"}`,
|
||||
want: `{"this": "is", "just": "a string"}`,
|
||||
},
|
||||
{
|
||||
desc: "whitespace-only string",
|
||||
paramType: api.PropertyType{"string"},
|
||||
raw: " ",
|
||||
want: " ",
|
||||
},
|
||||
// Unknown parameter (no type specified in tools)
|
||||
{
|
||||
desc: "parameter not in tool definition defaults to string",
|
||||
paramType: api.PropertyType{},
|
||||
raw: "some value",
|
||||
want: "some value",
|
||||
},
|
||||
// Union type tests
|
||||
{
|
||||
desc: "string or number union - valid number",
|
||||
paramType: api.PropertyType{"string", "number"},
|
||||
raw: "42.5",
|
||||
want: 42.5,
|
||||
},
|
||||
{
|
||||
desc: "string or number union - non-numeric string",
|
||||
paramType: api.PropertyType{"string", "number"},
|
||||
raw: "hello",
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
desc: "number or string union - valid number (order shouldn't matter)",
|
||||
paramType: api.PropertyType{"number", "string"},
|
||||
raw: "42.5",
|
||||
want: 42.5,
|
||||
},
|
||||
{
|
||||
desc: "integer or null union - valid integer",
|
||||
paramType: api.PropertyType{"integer", "null"},
|
||||
raw: "123",
|
||||
want: 123,
|
||||
},
|
||||
{
|
||||
desc: "integer or null union - null value",
|
||||
paramType: api.PropertyType{"integer", "null"},
|
||||
raw: "null",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
desc: "null or integer union - null value (order shouldn't matter)",
|
||||
paramType: api.PropertyType{"null", "integer"},
|
||||
raw: "null",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
desc: "boolean or string union - valid boolean",
|
||||
paramType: api.PropertyType{"boolean", "string"},
|
||||
raw: "true",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
desc: "boolean or string union - non-boolean becomes string",
|
||||
paramType: api.PropertyType{"boolean", "string"},
|
||||
raw: "yes",
|
||||
want: "yes",
|
||||
},
|
||||
{
|
||||
desc: "string or boolean union - valid boolean (precedence test)",
|
||||
paramType: api.PropertyType{"string", "boolean"},
|
||||
raw: "false",
|
||||
want: false, // Should be boolean, not string "false"
|
||||
},
|
||||
{
|
||||
desc: "integer or number union - integer value",
|
||||
paramType: api.PropertyType{"integer", "number"},
|
||||
raw: "42",
|
||||
want: 42,
|
||||
},
|
||||
{
|
||||
desc: "integer or number union - float value",
|
||||
paramType: api.PropertyType{"integer", "number"},
|
||||
raw: "42.5",
|
||||
want: 42.5,
|
||||
},
|
||||
{
|
||||
desc: "number or integer union - integer value (precedence test)",
|
||||
paramType: api.PropertyType{"number", "integer"},
|
||||
raw: "42",
|
||||
want: 42, // Should try integer first due to precedence
|
||||
},
|
||||
{
|
||||
desc: "array or object union - valid array",
|
||||
paramType: api.PropertyType{"array", "object"},
|
||||
raw: `[1, 2, 3]`,
|
||||
want: []any{float64(1), float64(2), float64(3)},
|
||||
},
|
||||
{
|
||||
desc: "array or object union - valid object",
|
||||
paramType: api.PropertyType{"array", "object"},
|
||||
raw: `{"key": "value"}`,
|
||||
want: map[string]any{"key": "value"},
|
||||
},
|
||||
{
|
||||
desc: "object or array union - valid array (precedence test)",
|
||||
paramType: api.PropertyType{"object", "array"},
|
||||
raw: `[1, 2, 3]`,
|
||||
want: []any{float64(1), float64(2), float64(3)},
|
||||
},
|
||||
{
|
||||
desc: "complex multi-type union - null",
|
||||
paramType: api.PropertyType{"string", "number", "boolean", "null"},
|
||||
raw: "null",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
desc: "complex multi-type union - boolean",
|
||||
paramType: api.PropertyType{"string", "number", "boolean", "null"},
|
||||
raw: "true",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
desc: "complex multi-type union - number",
|
||||
paramType: api.PropertyType{"string", "number", "boolean", "null"},
|
||||
raw: "3.14",
|
||||
want: 3.14,
|
||||
},
|
||||
{
|
||||
desc: "complex multi-type union - string",
|
||||
paramType: api.PropertyType{"string", "number", "boolean", "null"},
|
||||
raw: "hello",
|
||||
want: "hello",
|
||||
},
|
||||
{
|
||||
desc: "integer string union - integer string becomes integer",
|
||||
paramType: api.PropertyType{"integer", "string"},
|
||||
raw: "123",
|
||||
want: 123,
|
||||
},
|
||||
{
|
||||
desc: "string integer union - integer string becomes integer (precedence)",
|
||||
paramType: api.PropertyType{"string", "integer"},
|
||||
raw: "123",
|
||||
want: 123, // Integer has higher precedence than string
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
got := parseValue(tc.raw, tc.paramType)
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Errorf("got %v (type %T), want %v (type %T)", got, got, tc.want, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwenXMLTransform(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
raw string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
desc: "simple example",
|
||||
raw: `<function=get_current_temperature>
|
||||
<parameter=location>
|
||||
San Francisco
|
||||
</parameter>
|
||||
<parameter=unit>
|
||||
celsius
|
||||
</parameter>
|
||||
</function>`,
|
||||
want: `<function name="get_current_temperature">
|
||||
<parameter name="location">
|
||||
San Francisco
|
||||
</parameter>
|
||||
<parameter name="unit">
|
||||
celsius
|
||||
</parameter>
|
||||
</function>`,
|
||||
},
|
||||
// even though quotes aren't expected in these tags, we have these tests to
|
||||
// make sure they're escaped so they don't blow up the xml parser in case
|
||||
// they happen
|
||||
{
|
||||
desc: "names with quotes",
|
||||
raw: `<function="get current temperature">
|
||||
<parameter="location with spaces">
|
||||
San Francisco
|
||||
</parameter>
|
||||
<parameter="unit with spaces">
|
||||
celsius
|
||||
</parameter>
|
||||
</function>`,
|
||||
want: `<function name=""get current temperature"">
|
||||
<parameter name=""location with spaces"">
|
||||
San Francisco
|
||||
</parameter>
|
||||
<parameter name=""unit with spaces"">
|
||||
celsius
|
||||
</parameter>
|
||||
</function>`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
got := transformToXML(tc.raw)
|
||||
if got != tc.want {
|
||||
t.Errorf("got %q, want %q", got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTrailingWhitespaceLen(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
s string
|
||||
want int
|
||||
}{
|
||||
{desc: "no whitespace", s: "abc", want: 0},
|
||||
{desc: "trailing whitespace", s: "abc ", want: 1},
|
||||
{desc: "trailing whitespace with newlines", s: "abc \n", want: 2},
|
||||
{desc: "only whitespace", s: " \n ", want: 4},
|
||||
{desc: "leading whitespace doesn't count", s: " \n abc", want: 0},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
got := trailingWhitespaceLen(tc.s)
|
||||
if got != tc.want {
|
||||
t.Errorf("got %d, want %d", got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,217 @@
|
|||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
var (
|
||||
imStartTag = "<|im_start|>"
|
||||
imEndTag = "<|im_end|>"
|
||||
)
|
||||
|
||||
// renderAdditionalKeys renders all JSON fields except the ones in handledKeys
|
||||
// This follows the same approach from the reference implementation, which gives
|
||||
// a particular key ordering
|
||||
func renderAdditionalKeys(obj any, handledKeys map[string]bool) string {
|
||||
data, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
for key, value := range m {
|
||||
if handledKeys[key] {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if value is a map or array (needs JSON serialization)
|
||||
switch v := value.(type) {
|
||||
case map[string]any, []any:
|
||||
jsonBytes, _ := json.Marshal(v)
|
||||
// TODO(drifkin): it would be nice to format the JSON here similarly to
|
||||
// python's default json.dumps behavior (spaces after commas and colons).
|
||||
// This would let us be byte-for-byte compatible with the reference
|
||||
// implementation for most common inputs
|
||||
jsonStr := string(jsonBytes)
|
||||
sb.WriteString("\n<" + key + ">" + jsonStr + "</" + key + ">")
|
||||
case nil:
|
||||
continue
|
||||
default:
|
||||
// Simple types, convert to string
|
||||
sb.WriteString("\n<" + key + ">" + fmt.Sprintf("%v", value) + "</" + key + ">")
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func Qwen3CoderRenderer(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
// filter out system messages and choose the first (if any) to win
|
||||
var systemMessage string
|
||||
var filteredMessages []api.Message
|
||||
for _, message := range messages {
|
||||
if message.Role != "system" {
|
||||
filteredMessages = append(filteredMessages, message)
|
||||
continue
|
||||
}
|
||||
|
||||
if systemMessage == "" {
|
||||
systemMessage = message.Content
|
||||
}
|
||||
}
|
||||
|
||||
if systemMessage != "" || len(tools) > 0 {
|
||||
sb.WriteString(imStartTag + "system\n")
|
||||
|
||||
// if we have tools but no system message, match the reference implementation by providing a default system message
|
||||
if systemMessage == "" {
|
||||
systemMessage = "You are Qwen, a helpful AI assistant that can interact with a computer to solve tasks."
|
||||
}
|
||||
|
||||
sb.WriteString(systemMessage)
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("\n\n# Tools\n\nYou have access to the following functions:\n\n")
|
||||
sb.WriteString("<tools>")
|
||||
for _, tool := range tools {
|
||||
sb.WriteString("\n")
|
||||
sb.WriteString("<function>\n")
|
||||
sb.WriteString("<name>" + tool.Function.Name + "</name>")
|
||||
if tool.Function.Description != "" {
|
||||
sb.WriteString("\n<description>" + tool.Function.Description + "</description>")
|
||||
}
|
||||
sb.WriteString("\n<parameters>")
|
||||
|
||||
for name, prop := range tool.Function.Parameters.Properties {
|
||||
sb.WriteString("\n<parameter>")
|
||||
sb.WriteString("\n<name>" + name + "</name>")
|
||||
|
||||
if len(prop.Type) > 0 {
|
||||
// TODO(!!!)(drifkin): we should match the reference implementation for
|
||||
// more complex types here instead of using this format
|
||||
sb.WriteString("\n<type>" + prop.ToTypeScriptType() + "</type>")
|
||||
}
|
||||
|
||||
if prop.Description != "" {
|
||||
sb.WriteString("\n<description>" + prop.Description + "</description>")
|
||||
}
|
||||
|
||||
// Render any additional keys not already handled
|
||||
handledKeys := map[string]bool{
|
||||
"type": true,
|
||||
"description": true,
|
||||
}
|
||||
sb.WriteString(renderAdditionalKeys(prop, handledKeys))
|
||||
|
||||
sb.WriteString("\n</parameter>")
|
||||
}
|
||||
|
||||
// Render extra keys for parameters (everything except 'type' and 'properties')
|
||||
paramHandledKeys := map[string]bool{
|
||||
"type": true,
|
||||
"properties": true,
|
||||
}
|
||||
sb.WriteString(renderAdditionalKeys(tool.Function.Parameters, paramHandledKeys))
|
||||
|
||||
sb.WriteString("\n</parameters>")
|
||||
sb.WriteString("\n</function>")
|
||||
}
|
||||
sb.WriteString("\n</tools>")
|
||||
sb.WriteString("\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>")
|
||||
}
|
||||
|
||||
sb.WriteString(imEndTag + "\n")
|
||||
}
|
||||
|
||||
for i, message := range filteredMessages {
|
||||
lastMessage := i == len(filteredMessages)-1
|
||||
prefill := lastMessage && message.Role == "assistant"
|
||||
switch message.Role {
|
||||
case "assistant":
|
||||
if len(message.ToolCalls) > 0 {
|
||||
sb.WriteString(imStartTag + "assistant\n")
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content + "\n")
|
||||
}
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("\n<tool_call>\n<function=" + toolCall.Function.Name + ">")
|
||||
for name, value := range toolCall.Function.Arguments {
|
||||
valueStr := formatToolCallArgument(value)
|
||||
sb.WriteString("\n<parameter=" + name + ">\n" + valueStr + "\n</parameter>")
|
||||
}
|
||||
sb.WriteString("\n</function>\n</tool_call>")
|
||||
}
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
} else {
|
||||
sb.WriteString(imStartTag + "assistant\n")
|
||||
sb.WriteString(message.Content)
|
||||
if !prefill {
|
||||
sb.WriteString(imEndTag + "\n")
|
||||
}
|
||||
}
|
||||
case "tool":
|
||||
// consecutive tool responses should share a single `<im_start>user`, but
|
||||
// have their own <tool_response> tags
|
||||
|
||||
// only start a new user block if this is the first tool response
|
||||
if i == 0 || filteredMessages[i-1].Role != "tool" {
|
||||
sb.WriteString(imStartTag + "user\n")
|
||||
}
|
||||
|
||||
sb.WriteString("<tool_response>\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("\n</tool_response>\n")
|
||||
|
||||
// close the user block only if this is the last tool response
|
||||
if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" {
|
||||
sb.WriteString(imEndTag + "\n")
|
||||
}
|
||||
default:
|
||||
sb.WriteString(imStartTag + message.Role + "\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString(imEndTag + "\n")
|
||||
}
|
||||
|
||||
if lastMessage && !prefill {
|
||||
sb.WriteString(imStartTag + "assistant\n")
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
func formatToolCallArgument(value any) string {
|
||||
if value == nil {
|
||||
return "null"
|
||||
}
|
||||
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []byte:
|
||||
return string(v)
|
||||
}
|
||||
|
||||
if reflect.TypeOf(value) != nil {
|
||||
kind := reflect.TypeOf(value).Kind()
|
||||
if kind == reflect.Map || kind == reflect.Slice || kind == reflect.Array {
|
||||
if marshalled, err := json.Marshal(value); err == nil {
|
||||
return string(marshalled)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v", value)
|
||||
}
|
||||
|
|
@ -0,0 +1,338 @@
|
|||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestQwen3CoderRenderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msgs []api.Message
|
||||
tools []api.Tool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
expected: `<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
Hello, how are you?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "with tools and response",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "What is the weather like in San Francisco?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "I'll check the weather in San Francisco for you.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{
|
||||
"unit": "fahrenheit",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: "{\"location\": \"San Francisco, CA\", \"temperature\": 68, \"condition\": \"partly cloudy\", \"humidity\": 65, \"wind_speed\": 12}", ToolName: "get_weather"},
|
||||
{Role: "user", Content: "That sounds nice! What about New York?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Required: []string{"unit"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"unit": {Type: api.PropertyType{"string"}, Enum: []any{"celsius", "fahrenheit"}, Description: "The unit of temperature"},
|
||||
// TODO(drifkin): add multiple params back once we have predictable
|
||||
// order via some sort of ordered map type (see
|
||||
// <https://github.com/ollama/ollama/issues/12244>)
|
||||
/*
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city and state, e.g. San Francisco, CA"},
|
||||
*/
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
expected: `<|im_start|>system
|
||||
You are a helpful assistant with access to tools.
|
||||
|
||||
# Tools
|
||||
|
||||
You have access to the following functions:
|
||||
|
||||
<tools>
|
||||
<function>
|
||||
<name>get_weather</name>
|
||||
<description>Get the current weather in a given location</description>
|
||||
<parameters>
|
||||
<parameter>
|
||||
<name>unit</name>
|
||||
<type>string</type>
|
||||
<description>The unit of temperature</description>
|
||||
<enum>["celsius","fahrenheit"]</enum>
|
||||
</parameter>
|
||||
<required>["unit"]</required>
|
||||
</parameters>
|
||||
</function>
|
||||
</tools>
|
||||
|
||||
If you choose to call a function ONLY reply in the following format with NO suffix:
|
||||
|
||||
<tool_call>
|
||||
<function=example_function_name>
|
||||
<parameter=example_parameter_1>
|
||||
value_1
|
||||
</parameter>
|
||||
<parameter=example_parameter_2>
|
||||
This is the value for the second parameter
|
||||
that can span
|
||||
multiple lines
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
|
||||
<IMPORTANT>
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
|
||||
- Required parameters MUST be specified
|
||||
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
|
||||
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
|
||||
</IMPORTANT><|im_end|>
|
||||
<|im_start|>user
|
||||
What is the weather like in San Francisco?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
I'll check the weather in San Francisco for you.
|
||||
|
||||
<tool_call>
|
||||
<function=get_weather>
|
||||
<parameter=unit>
|
||||
fahrenheit
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
{"location": "San Francisco, CA", "temperature": 68, "condition": "partly cloudy", "humidity": 65, "wind_speed": 12}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
<|im_start|>user
|
||||
That sounds nice! What about New York?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "parallel tool calls",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "call double(1) and triple(2)"},
|
||||
{Role: "assistant", Content: "I'll call double(1) and triple(2) for you.", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{Name: "double", Arguments: map[string]any{"number": "1"}}},
|
||||
{Function: api.ToolCallFunction{Name: "triple", Arguments: map[string]any{"number": "2"}}},
|
||||
}},
|
||||
{Role: "tool", Content: "{\"number\": 2}", ToolName: "double"},
|
||||
{Role: "tool", Content: "{\"number\": 6}", ToolName: "triple"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{Function: api.ToolFunction{Name: "double", Description: "Double a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
|
||||
"number": {Type: api.PropertyType{"string"}, Description: "The number to double"},
|
||||
}}}},
|
||||
{Function: api.ToolFunction{Name: "triple", Description: "Triple a number", Parameters: api.ToolFunctionParameters{Properties: map[string]api.ToolProperty{
|
||||
"number": {Type: api.PropertyType{"string"}, Description: "The number to triple"},
|
||||
}}}},
|
||||
},
|
||||
expected: `<|im_start|>system
|
||||
You are a helpful assistant with access to tools.
|
||||
|
||||
# Tools
|
||||
|
||||
You have access to the following functions:
|
||||
|
||||
<tools>
|
||||
<function>
|
||||
<name>double</name>
|
||||
<description>Double a number</description>
|
||||
<parameters>
|
||||
<parameter>
|
||||
<name>number</name>
|
||||
<type>string</type>
|
||||
<description>The number to double</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
</function>
|
||||
<function>
|
||||
<name>triple</name>
|
||||
<description>Triple a number</description>
|
||||
<parameters>
|
||||
<parameter>
|
||||
<name>number</name>
|
||||
<type>string</type>
|
||||
<description>The number to triple</description>
|
||||
</parameter>
|
||||
</parameters>
|
||||
</function>
|
||||
</tools>
|
||||
|
||||
If you choose to call a function ONLY reply in the following format with NO suffix:
|
||||
|
||||
<tool_call>
|
||||
<function=example_function_name>
|
||||
<parameter=example_parameter_1>
|
||||
value_1
|
||||
</parameter>
|
||||
<parameter=example_parameter_2>
|
||||
This is the value for the second parameter
|
||||
that can span
|
||||
multiple lines
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
|
||||
<IMPORTANT>
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags
|
||||
- Required parameters MUST be specified
|
||||
- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after
|
||||
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
|
||||
</IMPORTANT><|im_end|>
|
||||
<|im_start|>user
|
||||
call double(1) and triple(2)<|im_end|>
|
||||
<|im_start|>assistant
|
||||
I'll call double(1) and triple(2) for you.
|
||||
|
||||
<tool_call>
|
||||
<function=double>
|
||||
<parameter=number>
|
||||
1
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
<function=triple>
|
||||
<parameter=number>
|
||||
2
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
{"number": 2}
|
||||
</tool_response>
|
||||
<tool_response>
|
||||
{"number": 6}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "prefill",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Tell me something interesting."},
|
||||
{Role: "assistant", Content: "I'll tell you something interesting about cats"},
|
||||
},
|
||||
expected: `<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
Tell me something interesting.<|im_end|>
|
||||
<|im_start|>assistant
|
||||
I'll tell you something interesting about cats`,
|
||||
},
|
||||
{
|
||||
name: "complex tool call arguments should remain json encoded",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "call tool"},
|
||||
{Role: "assistant", ToolCalls: []api.ToolCall{
|
||||
{Function: api.ToolCallFunction{
|
||||
Name: "echo",
|
||||
Arguments: map[string]any{
|
||||
"payload": map[string]any{"foo": "bar"},
|
||||
},
|
||||
}},
|
||||
}},
|
||||
{Role: "tool", Content: "{\"payload\": {\"foo\": \"bar\"}}", ToolName: "echo"},
|
||||
},
|
||||
expected: `<|im_start|>user
|
||||
call tool<|im_end|>
|
||||
<|im_start|>assistant
|
||||
|
||||
<tool_call>
|
||||
<function=echo>
|
||||
<parameter=payload>
|
||||
{"foo":"bar"}
|
||||
</parameter>
|
||||
</function>
|
||||
</tool_call><|im_end|>
|
||||
<|im_start|>user
|
||||
<tool_response>
|
||||
{"payload": {"foo": "bar"}}
|
||||
</tool_response>
|
||||
<|im_end|>
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := Qwen3CoderRenderer(tt.msgs, tt.tools, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatToolCallArgument(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
arg any
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
arg: "foo",
|
||||
// notice no quotes around the string
|
||||
expected: "foo",
|
||||
},
|
||||
{
|
||||
name: "map",
|
||||
arg: map[string]any{"foo": "bar"},
|
||||
expected: "{\"foo\":\"bar\"}",
|
||||
},
|
||||
{
|
||||
name: "number",
|
||||
arg: 1,
|
||||
expected: "1",
|
||||
},
|
||||
{
|
||||
name: "boolean",
|
||||
arg: true,
|
||||
expected: "true",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := formatToolCallArgument(tt.arg)
|
||||
if got != tt.expected {
|
||||
t.Errorf("formatToolCallArgument(%v) = %v, want %v", tt.arg, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
package renderers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type rendererFunc func([]api.Message, []api.Tool, *api.ThinkValue) (string, error)
|
||||
|
||||
func RenderWithRenderer(name string, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||
renderer := rendererForName(name)
|
||||
if renderer == nil {
|
||||
return "", fmt.Errorf("unknown renderer %q", name)
|
||||
}
|
||||
return renderer(msgs, tools, think)
|
||||
}
|
||||
|
||||
func rendererForName(name string) rendererFunc {
|
||||
switch name {
|
||||
case "qwen3-coder":
|
||||
return Qwen3CoderRenderer
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
@ -12,18 +12,18 @@ import (
|
|||
|
||||
const spmWhitespaceSep = "▁"
|
||||
|
||||
type SentencePieceModel struct {
|
||||
type SentencePiece struct {
|
||||
maxTokenLen int
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*SentencePieceModel)(nil)
|
||||
var _ TextProcessor = (*SentencePiece)(nil)
|
||||
|
||||
func (spm SentencePieceModel) Vocabulary() *Vocabulary {
|
||||
func (spm SentencePiece) Vocabulary() *Vocabulary {
|
||||
return spm.vocab
|
||||
}
|
||||
|
||||
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
||||
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
|
||||
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||
|
||||
counter := map[int]int{}
|
||||
|
|
@ -42,17 +42,17 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
|||
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
||||
"max token len", maxTokenLen)
|
||||
|
||||
return SentencePieceModel{
|
||||
return SentencePiece{
|
||||
maxTokenLen: maxTokenLen,
|
||||
vocab: vocab,
|
||||
}
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
||||
func (spm SentencePiece) Is(id int32, special Special) bool {
|
||||
return spm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
fragments := []fragment{{value: s}}
|
||||
for _, special := range spm.vocab.SpecialVocabulary() {
|
||||
id := spm.vocab.Encode(special)
|
||||
|
|
@ -218,7 +218,7 @@ func (q *queue) Pop() interface{} {
|
|||
return item
|
||||
}
|
||||
|
||||
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
||||
func (spm SentencePiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for _, id := range ids {
|
||||
data := spm.vocab.Decode(id)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import (
|
|||
"github.com/ollama/ollama/convert/sentencepiece"
|
||||
)
|
||||
|
||||
func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
||||
func loadSentencePieceVocab(t *testing.T) SentencePiece {
|
||||
t.Helper()
|
||||
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
|
||||
|
|
@ -45,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
|||
}
|
||||
}
|
||||
|
||||
return NewSentencePieceModel(&v)
|
||||
return NewSentencePiece(&v)
|
||||
}
|
||||
|
||||
func TestSentencePieceEncode(t *testing.T) {
|
||||
|
|
@ -115,7 +115,7 @@ func TestSentencePieceEncode(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
|
||||
func TestSentencePieceDecodeByteTokens(t *testing.T) {
|
||||
vocab := &Vocabulary{
|
||||
Values: []string{
|
||||
"normal",
|
||||
|
|
@ -134,7 +134,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
|
|||
Scores: []float32{0, 0, 0, 0, 0},
|
||||
}
|
||||
|
||||
spm := NewSentencePieceModel(vocab)
|
||||
spm := NewSentencePiece(vocab)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
|
|||
|
|
@ -0,0 +1,167 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type WordPiece struct {
|
||||
vocab *Vocabulary
|
||||
}
|
||||
|
||||
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
|
||||
// this differs from original word piece which uses "##" to indicate subwords.
|
||||
const ggmlPrefix = "▁"
|
||||
|
||||
var wordPieceReplacer = strings.NewReplacer(
|
||||
" .", ".",
|
||||
" ?", "?",
|
||||
" !", "!",
|
||||
" ,", ",",
|
||||
" ' ", "'",
|
||||
" n't", "n't",
|
||||
" 'm", "'m",
|
||||
" do not", " don't",
|
||||
" 's", "'s",
|
||||
" 've", "'ve",
|
||||
" 're", "'re",
|
||||
)
|
||||
|
||||
// Decode implements TextProcessor.
|
||||
func (wpm WordPiece) Decode(ids []int32) (string, error) {
|
||||
var sb strings.Builder
|
||||
for i, id := range ids {
|
||||
if id < 0 || int(id) >= len(wpm.vocab.Values) {
|
||||
return "", fmt.Errorf("invalid token id: %d", id)
|
||||
}
|
||||
|
||||
var separator string
|
||||
piece := wpm.vocab.Values[id]
|
||||
if i > 0 &&
|
||||
(strings.HasPrefix(piece, ggmlPrefix) ||
|
||||
(strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
|
||||
separator = " "
|
||||
}
|
||||
|
||||
sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
// words splits a string into words, treating CJK characters as separate words.
|
||||
// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
|
||||
func (wpm WordPiece) words(s string) iter.Seq[string] {
|
||||
return func(yield func(string) bool) {
|
||||
runes := make([]rune, 0, len(s)*3)
|
||||
for _, r := range s {
|
||||
switch {
|
||||
case r >= 0x4E00 && r <= 0x9FFF,
|
||||
r >= 0x3400 && r <= 0x4DBF,
|
||||
r >= 0x20000 && r <= 0x2A6DF,
|
||||
r >= 0x2A700 && r <= 0x2B73F,
|
||||
r >= 0x2B740 && r <= 0x2B81F,
|
||||
r >= 0x2B820 && r <= 0x2CEAF,
|
||||
r >= 0xF900 && r <= 0xFAFF,
|
||||
r >= 0x2F800 && r <= 0x2FA1F:
|
||||
runes = append(runes, ' ', r, ' ')
|
||||
default:
|
||||
runes = append(runes, r)
|
||||
}
|
||||
}
|
||||
|
||||
for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
|
||||
// split on but keep punctuation
|
||||
var start int
|
||||
for start < len(w) {
|
||||
end := strings.IndexFunc(w[start:], unicode.IsPunct)
|
||||
if end < 0 {
|
||||
end = len(w) - start
|
||||
} else if end == 0 {
|
||||
end = 1
|
||||
}
|
||||
|
||||
if !yield(w[start : start+end]) {
|
||||
return
|
||||
}
|
||||
|
||||
start += end
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Encode implements TextProcessor.
|
||||
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
var ids []int32
|
||||
|
||||
// TODO: use [UNK] from config
|
||||
unk := wpm.vocab.Encode("[UNK]")
|
||||
for word := range wpm.words(s) {
|
||||
var start int
|
||||
var pieces []int32
|
||||
for start < len(word) {
|
||||
end := len(word)
|
||||
|
||||
var piece int32
|
||||
for start < end {
|
||||
subword := word[start:end]
|
||||
if start == 0 {
|
||||
subword = ggmlPrefix + subword
|
||||
}
|
||||
|
||||
// TODO: some models might not want [ToLower]
|
||||
piece = wpm.vocab.Encode(strings.ToLower(subword))
|
||||
if piece >= 0 {
|
||||
break
|
||||
}
|
||||
|
||||
end--
|
||||
}
|
||||
|
||||
if piece < 0 {
|
||||
// Unknown token
|
||||
pieces = pieces[:0]
|
||||
break
|
||||
}
|
||||
|
||||
pieces = append(pieces, piece)
|
||||
start = end
|
||||
}
|
||||
|
||||
if len(pieces) > 0 {
|
||||
ids = append(ids, pieces...)
|
||||
} else {
|
||||
ids = append(ids, unk)
|
||||
}
|
||||
}
|
||||
|
||||
if addSpecial && len(ids) > 0 {
|
||||
ids = wpm.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
// Is implements TextProcessor.
|
||||
func (wpm WordPiece) Is(id int32, special Special) bool {
|
||||
return wpm.vocab.Is(id, special)
|
||||
}
|
||||
|
||||
// Vocabulary implements TextProcessor.
|
||||
func (wpm WordPiece) Vocabulary() *Vocabulary {
|
||||
return wpm.vocab
|
||||
}
|
||||
|
||||
var _ TextProcessor = (*WordPiece)(nil)
|
||||
|
||||
func NewWordPiece(vocab *Vocabulary) WordPiece {
|
||||
return WordPiece{
|
||||
vocab: vocab,
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestWordPiece(t *testing.T) {
|
||||
wpm := NewWordPiece(
|
||||
&Vocabulary{
|
||||
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
|
||||
AddBOS: true,
|
||||
AddEOS: true,
|
||||
BOS: []int32{1},
|
||||
EOS: []int32{2},
|
||||
})
|
||||
|
||||
ids, err := wpm.Encode("Hello world!", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
|
||||
t.Errorf("unexpected ids (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
words, err := wpm.Decode(ids)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWordPieceWords(t *testing.T) {
|
||||
var wpm WordPiece
|
||||
|
||||
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
|
||||
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
|
||||
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
|
||||
t.Errorf("unexpected words (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
|
@ -105,16 +105,18 @@ type ChatCompletionRequest struct {
|
|||
Tools []api.Tool `json:"tools"`
|
||||
Reasoning *Reasoning `json:"reasoning,omitempty"`
|
||||
ReasoningEffort *string `json:"reasoning_effort,omitempty"`
|
||||
DebugRenderOnly bool `json:"_debug_render_only"`
|
||||
}
|
||||
|
||||
type ChatCompletion struct {
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Usage Usage `json:"usage,omitempty"`
|
||||
Id string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
SystemFingerprint string `json:"system_fingerprint"`
|
||||
Choices []Choice `json:"choices"`
|
||||
Usage Usage `json:"usage,omitempty"`
|
||||
DebugInfo *api.DebugInfo `json:"_debug_info,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionChunk struct {
|
||||
|
|
@ -141,6 +143,7 @@ type CompletionRequest struct {
|
|||
Temperature *float32 `json:"temperature"`
|
||||
TopP float32 `json:"top_p"`
|
||||
Suffix string `json:"suffix"`
|
||||
DebugRenderOnly bool `json:"_debug_render_only"`
|
||||
}
|
||||
|
||||
type Completion struct {
|
||||
|
|
@ -273,8 +276,8 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
|||
}
|
||||
return nil
|
||||
}(r.DoneReason),
|
||||
}},
|
||||
Usage: toUsage(r),
|
||||
}}, Usage: toUsage(r),
|
||||
DebugInfo: r.DebugInfo,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -568,13 +571,14 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||
}
|
||||
|
||||
return &api.ChatRequest{
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Format: format,
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Tools: r.Tools,
|
||||
Think: think,
|
||||
Model: r.Model,
|
||||
Messages: messages,
|
||||
Format: format,
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Tools: r.Tools,
|
||||
Think: think,
|
||||
DebugRenderOnly: r.DebugRenderOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
@ -648,11 +652,12 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
|||
}
|
||||
|
||||
return api.GenerateRequest{
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Suffix: r.Suffix,
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Suffix: r.Suffix,
|
||||
DebugRenderOnly: r.DebugRenderOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -100,6 +100,10 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
|
|||
req.System = c.Args
|
||||
case "license":
|
||||
licenses = append(licenses, c.Args)
|
||||
case "renderer":
|
||||
req.Renderer = c.Args
|
||||
case "parser":
|
||||
req.Parser = c.Args
|
||||
case "message":
|
||||
role, msg, _ := strings.Cut(c.Args, ": ")
|
||||
messages = append(messages, api.Message{Role: role, Content: msg})
|
||||
|
|
@ -320,7 +324,7 @@ func (c Command) String() string {
|
|||
switch c.Name {
|
||||
case "model":
|
||||
fmt.Fprintf(&sb, "FROM %s", c.Args)
|
||||
case "license", "template", "system", "adapter":
|
||||
case "license", "template", "system", "adapter", "renderer", "parser":
|
||||
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
|
||||
case "message":
|
||||
role, message, _ := strings.Cut(c.Args, ": ")
|
||||
|
|
@ -346,7 +350,7 @@ const (
|
|||
var (
|
||||
errMissingFrom = errors.New("no FROM line")
|
||||
errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
|
||||
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
|
||||
errInvalidCommand = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"renderer\", \"parser\", \"parameter\", or \"message\"")
|
||||
)
|
||||
|
||||
type ParserError struct {
|
||||
|
|
@ -606,7 +610,7 @@ func isValidMessageRole(role string) bool {
|
|||
|
||||
func isValidCommand(cmd string) bool {
|
||||
switch strings.ToLower(cmd) {
|
||||
case "from", "license", "template", "system", "adapter", "parameter", "message":
|
||||
case "from", "license", "template", "system", "adapter", "renderer", "parser", "parameter", "message":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
|
|
|
|||
|
|
@ -198,6 +198,34 @@ BADCOMMAND param1 value1
|
|||
}
|
||||
}
|
||||
|
||||
func TestParseFileRenderer(t *testing.T) {
|
||||
input := `
|
||||
FROM foo
|
||||
RENDERER renderer1
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
|
||||
modelfile, err := ParseFile(reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "renderer", Args: "renderer1"}}, modelfile.Commands)
|
||||
}
|
||||
|
||||
func TestParseFileParser(t *testing.T) {
|
||||
input := `
|
||||
FROM foo
|
||||
PARSER parser1
|
||||
`
|
||||
|
||||
reader := strings.NewReader(input)
|
||||
|
||||
modelfile, err := ParseFile(reader)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []Command{{Name: "model", Args: "foo"}, {Name: "parser", Args: "parser1"}}, modelfile.Commands)
|
||||
}
|
||||
|
||||
func TestParseFileMessages(t *testing.T) {
|
||||
cases := []struct {
|
||||
input string
|
||||
|
|
|
|||
|
|
@ -1,126 +0,0 @@
|
|||
package parser
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
)
|
||||
|
||||
type TokenParserType int
|
||||
|
||||
const (
|
||||
TokenParserTypeDefault TokenParserType = iota
|
||||
TokenParserTypeHarmony
|
||||
)
|
||||
|
||||
type TokenParser struct {
|
||||
messageHandler MessageHandler
|
||||
parserEngine ParserInternals
|
||||
toolParser ToolParser
|
||||
lastToken string
|
||||
tokenRepeat int
|
||||
repeatLimit int
|
||||
}
|
||||
|
||||
const defaultTokenRepeatLimit = 30
|
||||
|
||||
type MessageHandler interface {
|
||||
AddContent(token string) (content, thinking string, toolContent string)
|
||||
}
|
||||
|
||||
type ParserInternals interface {
|
||||
AddImplicitStartOrPrefill(prefillString string)
|
||||
}
|
||||
|
||||
type ToolParser interface {
|
||||
Add(token string)
|
||||
Drain() (toolName *string, toolContent string)
|
||||
}
|
||||
|
||||
// Default implementation for the TokenParser interface as a no-op passthrough
|
||||
type defaultMessageHandler struct{}
|
||||
|
||||
func (defaultMessageHandler) AddContent(token string) (string, string, string) {
|
||||
return token, "", ""
|
||||
}
|
||||
|
||||
type defaultEngine struct{}
|
||||
|
||||
func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {}
|
||||
|
||||
type defaultToolParser struct{}
|
||||
|
||||
func (defaultToolParser) Add(token string) {}
|
||||
|
||||
func (defaultToolParser) Drain() (*string, string) { return nil, "" }
|
||||
|
||||
func NewTokenParser(parserType TokenParserType, prefillString string) TokenParser {
|
||||
switch parserType {
|
||||
case TokenParserTypeHarmony:
|
||||
harmonyMessageHandler := harmony.NewHarmonyMessageHandler()
|
||||
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(prefillString)
|
||||
return TokenParser{
|
||||
messageHandler: harmonyMessageHandler,
|
||||
parserEngine: harmonyMessageHandler.HarmonyParser,
|
||||
toolParser: harmonyMessageHandler.ToolParser,
|
||||
repeatLimit: defaultTokenRepeatLimit,
|
||||
}
|
||||
|
||||
default:
|
||||
return TokenParser{
|
||||
messageHandler: defaultMessageHandler{},
|
||||
parserEngine: defaultEngine{},
|
||||
toolParser: defaultToolParser{},
|
||||
repeatLimit: 30,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *TokenParser) AddContent(token string) (string, string, error) {
|
||||
if p.repeatLimitReached(token) {
|
||||
return "", "", errors.New("token repeat limit reached")
|
||||
}
|
||||
content, thinking, toolContent := p.messageHandler.AddContent(token)
|
||||
p.toolParser.Add(toolContent)
|
||||
return content, thinking, nil
|
||||
}
|
||||
|
||||
// repeatLimitReached updates repeat counters and returns true if the repeat limit is reached.
|
||||
func (p *TokenParser) repeatLimitReached(token string) bool {
|
||||
if p == nil {
|
||||
return false
|
||||
}
|
||||
trimmed := strings.TrimSpace(token)
|
||||
if trimmed == p.lastToken {
|
||||
p.tokenRepeat++
|
||||
} else {
|
||||
p.tokenRepeat = 0
|
||||
}
|
||||
p.lastToken = trimmed
|
||||
|
||||
return p.tokenRepeat >= p.repeatLimit
|
||||
}
|
||||
|
||||
// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level
|
||||
func (p *TokenParser) Drain() []api.ToolCall {
|
||||
toolName, toolContent := p.toolParser.Drain()
|
||||
if toolName != nil {
|
||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||
return nil
|
||||
}
|
||||
return []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: *toolName,
|
||||
Arguments: args,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
@ -11,7 +11,6 @@ import (
|
|||
"image"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
|
|
@ -32,9 +31,9 @@ import (
|
|||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/runner/common"
|
||||
"github.com/ollama/ollama/sample"
|
||||
|
||||
|
|
@ -406,7 +405,7 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
|||
func (s *Server) run(ctx context.Context) {
|
||||
s.ready.Wait()
|
||||
|
||||
supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32
|
||||
supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone
|
||||
|
||||
var activeBatch batchState
|
||||
for {
|
||||
|
|
@ -468,6 +467,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
|||
|
||||
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
||||
var batchInputs []*input.Input
|
||||
var batchOutputs []int32
|
||||
var batch input.Batch
|
||||
|
||||
resumeSeq := -1
|
||||
|
|
@ -550,9 +550,9 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
|||
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
||||
batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
||||
|
||||
seq.iBatch = len(batch.Outputs)
|
||||
if i+1 == len(seq.inputs) {
|
||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||
seq.iBatch = len(batchOutputs)
|
||||
if i+1 == len(seq.inputs) || seq.embeddingOnly {
|
||||
batchOutputs = append(batchOutputs, int32(len(batchInputs)-1))
|
||||
}
|
||||
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||
|
|
@ -577,6 +577,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
|||
|
||||
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
|
||||
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
|
||||
batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs))
|
||||
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to build graph: %w", err)
|
||||
|
|
@ -704,8 +705,8 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
|||
}
|
||||
|
||||
// sample a token
|
||||
vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
|
||||
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||
vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0)
|
||||
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||
if err != nil {
|
||||
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||
|
|
@ -781,8 +782,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString)
|
||||
|
||||
if req.Options == nil {
|
||||
opts := api.DefaultOptions()
|
||||
req.Options = &opts
|
||||
|
|
@ -873,18 +872,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
case content, ok := <-seq.responses:
|
||||
if ok {
|
||||
var thinking string
|
||||
var err error
|
||||
content, thinking, err = tokenParser.AddContent(content)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
close(seq.quit)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Content: content,
|
||||
Thinking: thinking,
|
||||
Content: content,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
close(seq.quit)
|
||||
|
|
@ -893,9 +882,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
flusher.Flush()
|
||||
} else {
|
||||
toolCalls := tokenParser.Drain()
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
ToolCalls: toolCalls,
|
||||
Done: true,
|
||||
DoneReason: seq.doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
|
|
@ -913,7 +900,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 {
|
||||
if pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone {
|
||||
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
|
@ -1061,12 +1048,8 @@ func (s *Server) reserveWorstCaseGraph() error {
|
|||
batch.Positions[i] = int32(i)
|
||||
}
|
||||
|
||||
batch.Outputs = make([]int32, s.parallel)
|
||||
for i := range batch.Outputs {
|
||||
batch.Outputs[i] = int32(i)
|
||||
}
|
||||
|
||||
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
|
||||
batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel)
|
||||
|
||||
cache := s.model.Config().Cache
|
||||
if cache != nil {
|
||||
|
|
|
|||
|
|
@ -323,6 +323,8 @@ func createModel(r api.CreateRequest, name model.Name, baseLayers []*layerGGML,
|
|||
RootFS: RootFS{
|
||||
Type: "layers",
|
||||
},
|
||||
Renderer: r.Renderer,
|
||||
Parser: r.Parser,
|
||||
}
|
||||
|
||||
var layers []Layer
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/fs/gguf"
|
||||
"github.com/ollama/ollama/model/parsers"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/thinking"
|
||||
|
|
@ -94,8 +95,9 @@ func (m *Model) Capabilities() []model.Capability {
|
|||
return capabilities
|
||||
}
|
||||
|
||||
builtinParser := parsers.ParserForName(m.Config.Parser)
|
||||
// Check for tools capability
|
||||
if slices.Contains(m.Template.Vars(), "tools") {
|
||||
if slices.Contains(m.Template.Vars(), "tools") || (builtinParser != nil && builtinParser.HasToolSupport()) {
|
||||
capabilities = append(capabilities, model.CapabilityTools)
|
||||
}
|
||||
|
||||
|
|
@ -112,7 +114,8 @@ func (m *Model) Capabilities() []model.Capability {
|
|||
// Check for thinking capability
|
||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||
hasTags := openingTag != "" && closingTag != ""
|
||||
if hasTags || slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily) {
|
||||
isGptoss := slices.Contains([]string{"gptoss", "gpt-oss"}, m.Config.ModelFamily)
|
||||
if hasTags || isGptoss || (builtinParser != nil && builtinParser.HasThinkingSupport()) {
|
||||
capabilities = append(capabilities, model.CapabilityThinking)
|
||||
}
|
||||
|
||||
|
|
@ -198,6 +201,20 @@ func (m *Model) String() string {
|
|||
})
|
||||
}
|
||||
|
||||
if m.Config.Renderer != "" {
|
||||
modelfile.Commands = append(modelfile.Commands, parser.Command{
|
||||
Name: "renderer",
|
||||
Args: m.Config.Renderer,
|
||||
})
|
||||
}
|
||||
|
||||
if m.Config.Parser != "" {
|
||||
modelfile.Commands = append(modelfile.Commands, parser.Command{
|
||||
Name: "parser",
|
||||
Args: m.Config.Parser,
|
||||
})
|
||||
}
|
||||
|
||||
for k, v := range m.Options {
|
||||
switch v := v.(type) {
|
||||
case []any:
|
||||
|
|
@ -238,6 +255,8 @@ type ConfigV2 struct {
|
|||
ModelFamilies []string `json:"model_families"`
|
||||
ModelType string `json:"model_type"`
|
||||
FileType string `json:"file_type"`
|
||||
Renderer string `json:"renderer,omitempty"`
|
||||
Parser string `json:"parser,omitempty"`
|
||||
|
||||
// required by spec
|
||||
Architecture string `json:"architecture"`
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import (
|
|||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/model/renderers"
|
||||
"github.com/ollama/ollama/template"
|
||||
)
|
||||
|
||||
|
|
@ -41,18 +42,12 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||
}
|
||||
}
|
||||
|
||||
thinkVal := false
|
||||
thinkLevel := ""
|
||||
if think != nil {
|
||||
thinkVal = think.Bool()
|
||||
thinkLevel = think.String()
|
||||
}
|
||||
var b bytes.Buffer
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
|
||||
p, err := renderPrompt(m, append(system, msgs[i:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
s, err := tokenize(ctx, b.String())
|
||||
s, err := tokenize(ctx, p)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
|
@ -101,6 +96,23 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||
}
|
||||
|
||||
// truncate any messages that do not fit into the context window
|
||||
p, err := renderPrompt(m, append(system, msgs[currMsgIdx:]...), tools, think)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return p, images, nil
|
||||
}
|
||||
|
||||
func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
||||
if m.Config.Renderer != "" {
|
||||
rendered, err := renderers.RenderWithRenderer(m.Config.Renderer, msgs, tools, think)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return rendered, nil
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
thinkVal := false
|
||||
thinkLevel := ""
|
||||
|
|
@ -108,9 +120,8 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
|||
thinkVal = think.Bool()
|
||||
thinkLevel = think.String()
|
||||
}
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[currMsgIdx:]...), Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
|
||||
return "", nil, err
|
||||
if err := m.Template.Execute(&b, template.Values{Messages: msgs, Tools: tools, Think: thinkVal, ThinkLevel: thinkLevel, IsThinkSet: think != nil}); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return b.String(), images, nil
|
||||
return b.String(), nil
|
||||
}
|
||||
|
|
|
|||
188
server/routes.go
188
server/routes.go
|
|
@ -35,8 +35,8 @@ import (
|
|||
"github.com/ollama/ollama/harmony"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/model/parsers"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/registry"
|
||||
"github.com/ollama/ollama/template"
|
||||
|
|
@ -47,6 +47,18 @@ import (
|
|||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
func shouldUseHarmony(model *Model) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
// search for harmony tags that are nearly always used
|
||||
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func experimentEnabled(name string) bool {
|
||||
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
|
||||
}
|
||||
|
|
@ -196,17 +208,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw
|
||||
var parserType parser.TokenParserType
|
||||
useHarmony := shouldUseHarmony(m) && !req.Raw
|
||||
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
||||
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
||||
if useHarmony {
|
||||
parserType = parser.TokenParserTypeHarmony
|
||||
} else {
|
||||
parserType = parser.TokenParserTypeDefault
|
||||
}
|
||||
var functionNameMap *harmony.FunctionNameMap
|
||||
|
||||
if useHarmony {
|
||||
functionNameMap = harmony.NewFunctionNameMap()
|
||||
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
||||
harmonyMessageHandler.HarmonyParser.AddImplicitStart()
|
||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||
}
|
||||
|
||||
// Validate Think value: string values currently only allowed for gptoss models
|
||||
|
|
@ -322,10 +330,10 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||
|
||||
// If debug mode is enabled, return the rendered template instead of calling the model
|
||||
if req.DebugRenderOnly {
|
||||
c.JSON(http.StatusOK, api.DebugTemplateResponse{
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
DebugInfo: api.DebugInfo{
|
||||
DebugInfo: &api.DebugInfo{
|
||||
RenderedTemplate: prompt,
|
||||
ImageCount: len(images),
|
||||
},
|
||||
|
|
@ -350,19 +358,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||
var sb strings.Builder
|
||||
defer close(ch)
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
ParserType: parserType,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
res := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: cr.Content,
|
||||
Done: cr.Done,
|
||||
Thinking: cr.Thinking,
|
||||
ToolCalls: cr.ToolCalls,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: cr.PromptEvalCount,
|
||||
PromptEvalDuration: cr.PromptEvalDuration,
|
||||
|
|
@ -371,22 +376,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||
},
|
||||
}
|
||||
|
||||
if res.Done {
|
||||
res.DoneReason = cr.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
if useHarmony {
|
||||
for i, tool := range res.ToolCalls {
|
||||
res.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
|
||||
}
|
||||
if res.Response != "" || res.Thinking != "" || len(res.ToolCalls) > 0 || res.Done {
|
||||
ch <- res
|
||||
}
|
||||
return
|
||||
}
|
||||
if thinkingState != nil {
|
||||
content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser)
|
||||
res.Response = content
|
||||
res.Thinking = thinking
|
||||
harmonyToolParser.Add(toolContent)
|
||||
} else if thinkingState != nil {
|
||||
thinking, content := thinkingState.AddContent(cr.Content)
|
||||
res.Thinking = thinking
|
||||
res.Response = content
|
||||
|
|
@ -397,6 +392,30 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||
}
|
||||
|
||||
if cr.Done {
|
||||
if useHarmony {
|
||||
toolName, toolContent := harmonyToolParser.Drain()
|
||||
if toolName != nil {
|
||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
||||
ch <- gin.H{"error": errStr}
|
||||
return
|
||||
}
|
||||
|
||||
res.ToolCalls = append(res.ToolCalls, api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: *toolName,
|
||||
Arguments: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
res.DoneReason = cr.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
|
||||
if !req.Raw {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
|
||||
if err != nil {
|
||||
|
|
@ -470,7 +489,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||
}
|
||||
|
||||
truncate := true
|
||||
|
||||
if req.Truncate != nil && !*req.Truncate {
|
||||
truncate = false
|
||||
}
|
||||
|
|
@ -537,7 +555,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
|
||||
ctxLen--
|
||||
}
|
||||
|
||||
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
|
||||
ctxLen--
|
||||
}
|
||||
|
||||
tokens = tokens[:ctxLen]
|
||||
|
||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
|
|
@ -1599,27 +1626,32 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
}
|
||||
msgs = filterThinkTags(msgs, m)
|
||||
|
||||
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template)
|
||||
var parserType parser.TokenParserType
|
||||
if useHarmony {
|
||||
parserType = parser.TokenParserTypeHarmony
|
||||
} else {
|
||||
parserType = parser.TokenParserTypeDefault
|
||||
var builtinParser parsers.Parser
|
||||
if m.Config.Parser != "" {
|
||||
builtinParser = parsers.ParserForName(m.Config.Parser)
|
||||
}
|
||||
|
||||
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
||||
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
||||
|
||||
useHarmony := shouldUseHarmony(m) || m.Config.Parser == "harmony"
|
||||
|
||||
processedTools := req.Tools
|
||||
var functionNameMap *harmony.FunctionNameMap
|
||||
var prefillString string
|
||||
// TODO(parthsareen): this can be abstracted to not be model specific and potentially moved to the runner
|
||||
if useHarmony {
|
||||
prefillString = harmony.Prefill(msgs[len(msgs)-1])
|
||||
functionNameMap = harmony.NewFunctionNameMap()
|
||||
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
||||
var lastMessage *api.Message
|
||||
if len(msgs) > 0 {
|
||||
lastMessage = &msgs[len(msgs)-1]
|
||||
}
|
||||
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||
|
||||
// make a copy of tools to pass to the chat prompt. Function names may be
|
||||
// renamed to be valid Harmony function names.
|
||||
processedTools = make([]api.Tool, len(req.Tools))
|
||||
copy(processedTools, req.Tools)
|
||||
for i, tool := range processedTools {
|
||||
processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name)
|
||||
processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1632,10 +1664,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
|
||||
// If debug mode is enabled, return the rendered template instead of calling the model
|
||||
if req.DebugRenderOnly {
|
||||
c.JSON(http.StatusOK, api.DebugTemplateResponse{
|
||||
c.JSON(http.StatusOK, api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
DebugInfo: api.DebugInfo{
|
||||
DebugInfo: &api.DebugInfo{
|
||||
RenderedTemplate: prompt,
|
||||
ImageCount: len(images),
|
||||
},
|
||||
|
|
@ -1672,17 +1704,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
defer close(ch)
|
||||
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
ParserType: parserType,
|
||||
PrefillString: prefillString,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
res := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls},
|
||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||
Done: r.Done,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
|
|
@ -1697,14 +1727,54 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
// TODO(drifkin): fold this as much as possibleinto the generic m.Config.Parser logic
|
||||
if useHarmony {
|
||||
for i, tool := range res.Message.ToolCalls {
|
||||
res.Message.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
|
||||
content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
|
||||
res.Message.Content = content
|
||||
res.Message.Thinking = thinking
|
||||
harmonyToolParser.Add(toolContent)
|
||||
|
||||
if r.Done {
|
||||
toolName, toolContent := harmonyToolParser.Drain()
|
||||
if toolName != nil {
|
||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
||||
ch <- gin.H{"error": errStr}
|
||||
return
|
||||
}
|
||||
res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}}
|
||||
}
|
||||
}
|
||||
|
||||
// only send messages with meaningful content (empty messages confuse clients)
|
||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
|
||||
ch <- res
|
||||
}
|
||||
|
||||
return
|
||||
} else if builtinParser != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
|
||||
|
||||
content, thinking, toolCalls, err := builtinParser.Add(r.Content, req.Tools)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
res.Message.Content = content
|
||||
res.Message.Thinking = thinking
|
||||
res.Message.ToolCalls = toolCalls
|
||||
|
||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
|
||||
ch <- res
|
||||
} else {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -180,7 +180,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
|
|||
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response api.DebugTemplateResponse
|
||||
var response api.GenerateResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
|
@ -385,7 +385,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
|
|||
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response api.DebugTemplateResponse
|
||||
var response api.ChatResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -117,7 +118,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||
name: "content streams as it arrives",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "Hello", Done: false},
|
||||
input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false},
|
||||
wantContent: "Hello",
|
||||
},
|
||||
{
|
||||
|
|
@ -125,7 +126,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||
wantContent: ", world",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "!", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "!",
|
||||
},
|
||||
},
|
||||
|
|
@ -134,15 +135,20 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||
name: "thinking streams separately from content",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Thinking: "Thinking...", Done: false},
|
||||
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false},
|
||||
wantThinking: "Thinking...",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "Answer", Done: false},
|
||||
wantContent: "Answer",
|
||||
input: llm.CompletionResponse{Content: "<|end|>", Done: false},
|
||||
// No output expected - just closes the analysis message and resets state to normal
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Done: true, DoneReason: llm.DoneReasonStop},
|
||||
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false},
|
||||
wantContent: "Answer", // After message end, state is reset to normal
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
// No output expected - just closes the assistant message
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -150,16 +156,24 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||
name: "partial tags buffer until complete",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Thinking: "Deep ", Done: false},
|
||||
input: llm.CompletionResponse{Content: "<|chan", Done: false},
|
||||
// No output - partial tag
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false},
|
||||
// No output - still building tags
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "age|>Deep ", Done: false},
|
||||
wantThinking: "Deep ",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Thinking: "thought", Done: false},
|
||||
input: llm.CompletionResponse{Content: "thought<|end|>", Done: false},
|
||||
wantThinking: "thought",
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "Done", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "Done",
|
||||
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "Done", // After message end, state is reset to normal
|
||||
},
|
||||
},
|
||||
},
|
||||
|
|
@ -167,7 +181,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||
name: "simple assistant after analysis",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Thinking: "Think", Content: "Answer", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "Answer",
|
||||
wantThinking: "Think",
|
||||
},
|
||||
|
|
@ -177,7 +191,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||
name: "tool call parsed and returned correctly",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "The weather is sunny", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"location": "San Francisco"}}}}, Done: true, DoneReason: llm.DoneReasonStop},
|
||||
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
wantContent: "The weather is sunny",
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
|
|
@ -196,10 +210,15 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||
name: "tool call with streaming JSON across chunks",
|
||||
steps: []step{
|
||||
{
|
||||
input: llm.CompletionResponse{Done: false},
|
||||
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false},
|
||||
// No output yet - incomplete JSON
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}}, Done: true},
|
||||
input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false},
|
||||
// Still no output - incomplete JSON
|
||||
},
|
||||
{
|
||||
input: llm.CompletionResponse{Content: "2\"}", Done: true},
|
||||
wantToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
|
|
@ -381,9 +400,9 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
|||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mockResponses := []llm.CompletionResponse{
|
||||
{Content: "First ", Done: false},
|
||||
{Content: "<|message|>First ", Done: false},
|
||||
{Content: "chunk ", Done: false},
|
||||
{Content: "here", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
{Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
}
|
||||
|
||||
mock := mockRunner{
|
||||
|
|
@ -488,3 +507,189 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
|||
t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks)
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatHarmonyParserStreaming(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
type expectedChunk struct {
|
||||
afterResponse int // Which mock response this chunk should appear after
|
||||
content string // Expected content in this chunk
|
||||
thinking string // Expected thinking in this chunk
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
mockResponses []llm.CompletionResponse
|
||||
expectedChunks []expectedChunk
|
||||
wantContent string
|
||||
wantThinking string
|
||||
}{
|
||||
{
|
||||
name: "simple message without thinking",
|
||||
mockResponses: []llm.CompletionResponse{
|
||||
{Content: "<|start|>assistant<|message|>Hello, ", Done: false},
|
||||
{Content: "how can I help?", Done: false},
|
||||
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
expectedChunks: []expectedChunk{
|
||||
{afterResponse: 1, content: "Hello, "},
|
||||
{afterResponse: 2, content: "how can I help?"},
|
||||
},
|
||||
wantContent: "Hello, how can I help?",
|
||||
},
|
||||
{
|
||||
name: "message with analysis channel for thinking",
|
||||
mockResponses: []llm.CompletionResponse{
|
||||
{Content: "<|channel|>analysis<|message|>", Done: false},
|
||||
{Content: "Let me think ", Done: false},
|
||||
{Content: "about this problem...", Done: false},
|
||||
{Content: "<|end|>", Done: false},
|
||||
{Content: "<|start|>assistant<|message|>", Done: false},
|
||||
{Content: "The answer ", Done: false},
|
||||
{Content: "is 42", Done: false},
|
||||
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
expectedChunks: []expectedChunk{
|
||||
{afterResponse: 2, thinking: "Let me think "},
|
||||
{afterResponse: 3, thinking: "about this problem..."},
|
||||
{afterResponse: 6, content: "The answer "},
|
||||
{afterResponse: 7, content: "is 42"},
|
||||
},
|
||||
wantContent: "The answer is 42",
|
||||
wantThinking: "Let me think about this problem...",
|
||||
},
|
||||
{
|
||||
name: "streaming with partial tags across boundaries",
|
||||
mockResponses: []llm.CompletionResponse{
|
||||
{Content: "<|chan", Done: false},
|
||||
{Content: "nel|>analy", Done: false},
|
||||
{Content: "sis<|mess", Done: false},
|
||||
{Content: "age|>Think", Done: false},
|
||||
{Content: "ing deeply...<|end|>", Done: false},
|
||||
{Content: "<|start|>assi", Done: false},
|
||||
{Content: "stant<|message|>Result ", Done: false},
|
||||
{Content: "computed<|e", Done: false},
|
||||
{Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop},
|
||||
},
|
||||
expectedChunks: []expectedChunk{
|
||||
{afterResponse: 4, thinking: "Think"},
|
||||
{afterResponse: 5, thinking: "ing deeply..."},
|
||||
{afterResponse: 7, content: "Result "},
|
||||
{afterResponse: 8, content: "computed"},
|
||||
},
|
||||
wantContent: "Result computed",
|
||||
wantThinking: "Thinking deeply...",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Channel to synchronize mock responses with chunk verification
|
||||
responsesSent := make(chan int, len(tc.mockResponses))
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
// Send mock responses one at a time, notifying when each is sent
|
||||
for i, resp := range tc.mockResponses {
|
||||
fn(resp)
|
||||
responsesSent <- i + 1
|
||||
}
|
||||
close(responsesSent)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a minimal model
|
||||
_, digest := createHarmonyTestModel(t)
|
||||
|
||||
// Create model with passthrough template
|
||||
stream := false
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "harmony-test",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("failed to create model: %d", w.Code)
|
||||
}
|
||||
|
||||
// Test chat endpoint with streaming
|
||||
streamTrue := true
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "harmony-test",
|
||||
Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
Stream: &streamTrue,
|
||||
Tools: getTestTools(),
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
// Parse streaming response
|
||||
var chunks []api.ChatResponse
|
||||
var content, thinking strings.Builder
|
||||
|
||||
decoder := json.NewDecoder(w.Body)
|
||||
for decoder.More() {
|
||||
var chunk api.ChatResponse
|
||||
if err := decoder.Decode(&chunk); err != nil {
|
||||
t.Fatalf("failed to decode chunk: %v", err)
|
||||
}
|
||||
chunks = append(chunks, chunk)
|
||||
|
||||
// Accumulate content and thinking from each chunk
|
||||
content.WriteString(chunk.Message.Content)
|
||||
thinking.WriteString(chunk.Message.Thinking)
|
||||
|
||||
// Debug output
|
||||
t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done)
|
||||
}
|
||||
|
||||
// Verify we got streaming chunks
|
||||
if len(chunks) == 0 {
|
||||
t.Fatal("expected streaming chunks, got none")
|
||||
}
|
||||
|
||||
gotContent := content.String()
|
||||
gotThinking := thinking.String()
|
||||
|
||||
if gotContent != tc.wantContent {
|
||||
t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent)
|
||||
}
|
||||
if gotThinking != tc.wantThinking {
|
||||
t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking)
|
||||
}
|
||||
|
||||
// Verify last chunk has done=true
|
||||
lastChunk := chunks[len(chunks)-1]
|
||||
if !lastChunk.Done {
|
||||
t.Error("expected last chunk to have done=true")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue