Compare commits

..

13 Commits

Author SHA1 Message Date
ParthSareen
2ab70e82d0 remove rebase debug 2025-05-15 14:57:42 -07:00
ParthSareen
717fa7a44a Add sentinel errors, remove redundant calls 2025-05-15 14:29:31 -07:00
ParthSareen
53f7946fb6 add tests, organize, comments 2025-05-14 15:39:02 -07:00
ParthSareen
bc83789be9 tools package and utils 2025-05-13 17:44:45 -07:00
ParthSareen
4059b8db01 renaming and splitting stuff up 2025-05-13 17:43:15 -07:00
ParthSareen
b8b9c0c7cf checkpoint 2025-05-13 17:43:15 -07:00
ParthSareen
779547fcde checkpoint - cleanup still left, functionality setup 2025-05-13 17:43:15 -07:00
ParthSareen
6cb7494061 checkpoint for new parser
TODO:
- cleanup routes interface
- internal/external states
2025-05-13 17:43:15 -07:00
ParthSareen
a44734b030 add new parser, tests, and templates 2025-05-13 17:43:15 -07:00
ParthSareen
b5a982ecb0 wip 2025-05-13 17:43:15 -07:00
ParthSareen
516a540df7 jsonv2 decoder 2025-05-13 17:43:15 -07:00
ParthSareen
7f2f996cd6 server/routes: catch when JSON tool was used 2025-05-13 17:43:15 -07:00
ParthSareen
610054a234 model: support tools streaming and improve parsing 2025-05-13 17:43:15 -07:00
96 changed files with 2375 additions and 3393 deletions

View File

@@ -51,8 +51,6 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/include
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-cpu/amx)
add_compile_definitions(NDEBUG)
set(GGML_CPU ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src)
set_property(TARGET ggml PROPERTY EXCLUDE_FROM_ALL TRUE)

View File

@@ -405,7 +405,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
- [AppFlowy](https://github.com/AppFlowy-IO/AppFlowy) (AI collaborative workspace with Ollama, cross-platform and self-hostable)
- [Lumina](https://github.com/cushydigit/lumina.git) (A lightweight, minimal React.js frontend for interacting with Ollama servers)
- [Tiny Notepad](https://pypi.org/project/tiny-notepad) (A lightweight, notepad-like interface to chat with ollama available on PyPI)
### Cloud

View File

@@ -747,38 +747,11 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
case float64:
v = fmt.Sprintf("%g", vData)
case []any:
targetWidth := 10 // Small width where we are displaying the data in a column
var itemsToShow int
totalWidth := 1 // Start with 1 for opening bracket
// Find how many we can fit
for i := range vData {
itemStr := fmt.Sprintf("%v", vData[i])
width := runewidth.StringWidth(itemStr)
// Add separator width (", ") for all items except the first
if i > 0 {
width += 2
}
// Check if adding this item would exceed our width limit
if totalWidth+width > targetWidth && i > 0 {
break
}
totalWidth += width
itemsToShow++
}
// Format the output
if itemsToShow < len(vData) {
v = fmt.Sprintf("%v", vData[:itemsToShow])
v = strings.TrimSuffix(v, "]")
v += fmt.Sprintf(" ...+%d more]", len(vData)-itemsToShow)
} else {
v = fmt.Sprintf("%v", vData)
n := 3
if len(vData) < n {
n = len(vData)
}
v = fmt.Sprintf("%v", vData[:n])
default:
v = fmt.Sprintf("%T", vData)
}
@@ -799,19 +772,10 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
head := func(s string, n int) (rows [][]string) {
scanner := bufio.NewScanner(strings.NewReader(s))
count := 0
for scanner.Scan() {
text := strings.TrimSpace(scanner.Text())
if text == "" {
continue
for scanner.Scan() && (len(rows) < n || n < 0) {
if text := scanner.Text(); text != "" {
rows = append(rows, []string{"", strings.TrimSpace(text)})
}
count++
if n < 0 || count <= n {
rows = append(rows, []string{"", text})
}
}
if n >= 0 && count > n {
rows = append(rows, []string{"", "..."})
}
return
}
@@ -1236,11 +1200,11 @@ func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
return err
}
if err := client.Heartbeat(cmd.Context()); err != nil {
if !(strings.Contains(err.Error(), " refused") || strings.Contains(err.Error(), "could not connect")) {
if !strings.Contains(err.Error(), " refused") {
return err
}
if err := startApp(cmd.Context(), client); err != nil {
return fmt.Errorf("ollama server not responding - %w", err)
return errors.New("could not connect to ollama app, is it running?")
}
}
return nil
@@ -1318,7 +1282,7 @@ func NewCLI() *cobra.Command {
}
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\"")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_K_M)")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)")
showCmd := &cobra.Command{
Use: "show MODEL",

View File

@@ -225,7 +225,6 @@ Weigh anchor!
System
You are a pirate!
Ahoy, matey!
...
`
if diff := cmp.Diff(expect, b.String()); diff != "" {

View File

@@ -4,27 +4,17 @@ import (
"context"
"errors"
"fmt"
"log/slog"
"os"
"os/exec"
"path"
"path/filepath"
"strings"
"syscall"
"unsafe"
"github.com/ollama/ollama/api"
"golang.org/x/sys/windows"
)
const (
Installer = "OllamaSetup.exe"
)
func startApp(ctx context.Context, client *api.Client) error {
if len(isProcRunning(Installer)) > 0 {
return fmt.Errorf("upgrade in progress...")
}
// log.Printf("XXX Attempting to find and start ollama app")
AppName := "ollama app.exe"
exe, err := os.Executable()
if err != nil {
@@ -66,41 +56,3 @@ func startApp(ctx context.Context, client *api.Client) error {
}
return waitForServer(ctx, client)
}
func isProcRunning(procName string) []uint32 {
pids := make([]uint32, 2048)
var ret uint32
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
slog.Debug("failed to check for running installers", "error", err)
return nil
}
pids = pids[:ret]
var matches []uint32
for _, pid := range pids {
if pid == 0 {
continue
}
hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_VM_READ, false, pid)
if err != nil {
continue
}
defer windows.CloseHandle(hProcess)
var module windows.Handle
var cbNeeded uint32
cb := (uint32)(unsafe.Sizeof(module))
if err := windows.EnumProcessModules(hProcess, &module, cb, &cbNeeded); err != nil {
continue
}
var sz uint32 = 1024 * 8
moduleName := make([]uint16, sz)
cb = uint32(len(moduleName)) * (uint32)(unsafe.Sizeof(uint16(0)))
if err := windows.GetModuleBaseName(hProcess, module, &moduleName[0], cb); err != nil && err != syscall.ERROR_INSUFFICIENT_BUFFER {
continue
}
exeFile := path.Base(strings.ToLower(syscall.UTF16ToString(moduleName)))
if strings.EqualFold(exeFile, procName) {
matches = append(matches, pid)
}
}
return matches
}

View File

@@ -53,11 +53,8 @@ func (ModelParameters) KV(t *Tokenizer) ggml.KV {
}
for _, sv := range t.SpecialVocabulary {
kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
kv[fmt.Sprintf("tokenizer.ggml.%s_token_id", sv.Key())] = uint32(sv.ID)
if len(sv.IDs) > 0 {
kv[fmt.Sprintf("tokenizer.ggml.%s_token_ids", sv.Key())] = sv.IDs
}
kv[fmt.Sprintf("tokenizer.ggml.add_%s_token", sv.Key())] = sv.AddToken
}
return kv
@@ -194,8 +191,6 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &phi3Model{}
case "Qwen2ForCausalLM":
conv = &qwen2Model{}
case "Qwen2_5_VLForConditionalGeneration":
conv = &qwen25VLModel{}
case "BertModel":
conv = &bertModel{}
case "CohereForCausalLM":

View File

@@ -139,8 +139,7 @@ func (p *llamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
}
for _, t := range ts {
if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") ||
strings.HasSuffix(t.Name(), "attn_q_proj.weight") || strings.HasSuffix(t.Name(), "attn_k_proj.weight") {
if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
if !p.skipRepack {
t.SetRepacker(p.repack)
}
@@ -182,9 +181,9 @@ func (p *llamaModel) repack(name string, data []float32, shape []uint64) ([]floa
}
var heads uint32
if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_q_proj.weight") {
if strings.HasSuffix(name, "attn_q.weight") {
heads = p.NumAttentionHeads
} else if strings.HasSuffix(name, "attn_k.weight") || strings.HasSuffix(name, "attn_k_proj.weight") {
} else if strings.HasSuffix(name, "attn_k.weight") {
heads = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
} else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name)

View File

@@ -94,9 +94,7 @@ func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
var text []Tensor
for _, t := range ts {
if !strings.HasPrefix(t.Name(), "v.") && !strings.HasPrefix(t.Name(), "mm.") {
text = append(text, t)
} else if t.Name() == "v.position_embd.gate" {
if t.Name() == "v.position_embd.gate" {
for _, name := range []string{"v.position_embd.gate", "v.tile_position_embd.gate"} {
tt := t.Clone()
tt.SetRepacker(m.repack(name))
@@ -107,21 +105,23 @@ func (m *mllamaModel) Tensors(ts []Tensor) []*ggml.Tensor {
WriterTo: tt,
})
}
} else {
if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" {
t.SetRepacker(m.repack(t.Name()))
} else if strings.HasSuffix(t.Name(), "attn_q.weight") || strings.HasSuffix(t.Name(), "attn_k.weight") {
t.SetRepacker(m.repack(t.Name()))
} else if strings.HasSuffix(t.Name(), "attn_gate") || strings.HasSuffix(t.Name(), "ffn_gate") {
t.SetRepacker(m.repack(t.Name()))
}
} else if t.Name() == "v.pre_tile_position_embd.gate" || t.Name() == "v.post_tile_position_embd.gate" {
t.SetRepacker(m.repack(t.Name()))
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
} else if strings.HasPrefix(t.Name(), "v.") || strings.HasPrefix(t.Name(), "mm.") {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
} else {
text = append(text, t)
}
}
@@ -137,35 +137,16 @@ func (m *mllamaModel) repack(name string) Repacker {
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if strings.HasSuffix(name, "attn_q.weight") || strings.HasSuffix(name, "attn_k.weight") {
heads := m.VisionModel.AttentionHeads
if err := t.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
t, err = tensor.Tanh(t)
if err != nil {
return nil, err
}
if err := t.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := t.Reshape(dims...); err != nil {
return nil, err
}
if err := t.Transpose(); err != nil {
return nil, err
}
} else {
t, err = tensor.Tanh(t)
if name == "v.position_embd.gate" {
t, err = tensor.Sub(float32(1), t)
if err != nil {
return nil, err
}
if name == "v.position_embd.gate" {
t, err = tensor.Sub(float32(1), t)
if err != nil {
return nil, err
}
}
}
t = tensor.Materialize(t)

View File

@@ -15,7 +15,6 @@ type qwen2Model struct {
Type string `json:"type"`
Factor ropeFactor `json:"factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
MropeSection []int32 `json:"mrope_section"`
} `json:"rope_scaling"`
RMSNormEPS float32 `json:"rms_norm_eps"`
}
@@ -40,8 +39,6 @@ func (q *qwen2Model) KV(t *Tokenizer) ggml.KV {
case "yarn":
kv["qwen2.rope.scaling.type"] = q.RopeScaling.Type
kv["qwen2.rope.scaling.factor"] = q.RopeScaling.Factor
case "mrope", "default":
kv["qwen2.rope.mrope_section"] = q.RopeScaling.MropeSection
default:
panic("unknown rope scaling type")
}

View File

@@ -1,102 +0,0 @@
package convert
import (
"cmp"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type qwen25VLModel struct {
qwen2Model
VisionModel struct {
Depth uint32 `json:"depth"`
HiddenSize uint32 `json:"hidden_size"`
NumHeads uint32 `json:"num_heads"`
InChannels uint32 `json:"in_chans"`
PatchSize uint32 `json:"patch_size"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
SpatialPatchSize uint32 `json:"spatial_patch_size"`
WindowSize uint32 `json:"window_size"`
RMSNormEps float32 `json:"layer_norm_epsilon"`
RopeTheta float32 `json:"rope_theta"`
FullAttentionBlocks []int32 `json:"fullatt_block_indexes"`
TemporalPatchSize uint32 `json:"temporal_patch_size"`
} `json:"vision_config"`
}
var _ ModelConverter = (*qwen25VLModel)(nil)
func (q *qwen25VLModel) KV(t *Tokenizer) ggml.KV {
kv := q.ModelParameters.KV(t)
kv["general.architecture"] = "qwen25vl"
for k, v := range q.qwen2Model.KV(t) {
if strings.HasPrefix(k, "qwen2.") {
kv[strings.Replace(k, "qwen2.", "qwen25vl.", 1)] = v
}
}
if q.VisionModel.FullAttentionBlocks == nil {
kv["qwen25vl.vision.fullatt_block_indexes"] = []int32{7, 15, 23, 31}
}
kv["qwen25vl.vision.block_count"] = cmp.Or(q.VisionModel.Depth, 32)
kv["qwen25vl.vision.embedding_length"] = q.VisionModel.HiddenSize
kv["qwen25vl.vision.attention.head_count"] = cmp.Or(q.VisionModel.NumHeads, 16)
kv["qwen25vl.vision.num_channels"] = q.VisionModel.InChannels
kv["qwen25vl.vision.patch_size"] = cmp.Or(q.VisionModel.PatchSize, 14)
kv["qwen25vl.vision.spatial_merge_size"] = cmp.Or(q.VisionModel.SpatialMergeSize, 2)
kv["qwen25vl.vision.spatial_patch_size"] = q.VisionModel.SpatialPatchSize
kv["qwen25vl.vision.window_size"] = cmp.Or(q.VisionModel.WindowSize, 112)
kv["qwen25vl.vision.attention.layer_norm_epsilon"] = cmp.Or(q.VisionModel.RMSNormEps, 1e-6)
kv["qwen25vl.vision.rope.freq_base"] = cmp.Or(q.VisionModel.RopeTheta, 1e4)
kv["qwen25vl.vision.fullatt_block_indexes"] = q.VisionModel.FullAttentionBlocks
kv["qwen25vl.vision.temporal_patch_size"] = cmp.Or(q.VisionModel.TemporalPatchSize, 2)
return kv
}
func (q *qwen25VLModel) Tensors(ts []Tensor) []*ggml.Tensor {
var out []*ggml.Tensor
for _, t := range ts {
if strings.Contains(t.Name(), "patch_embed.proj") {
for t := range splitDim(t, 2,
strings.NewReplacer("patch_embed.proj", "patch_embd_0"),
strings.NewReplacer("patch_embed.proj", "patch_embd_1"),
) {
t.Shape = slices.DeleteFunc(t.Shape, func(i uint64) bool { return i == 1 })
out = append(out, t)
}
} else if strings.Contains(t.Name(), "attn.qkv") {
out = append(out, slices.Collect(splitDim(t, 0,
strings.NewReplacer("attn.qkv", "attn_q"),
strings.NewReplacer("attn.qkv", "attn_k"),
strings.NewReplacer("attn.qkv", "attn_v"),
))...)
} else {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
}
return out
}
func (p *qwen25VLModel) Replacements() []string {
return append(
p.qwen2Model.Replacements(),
"visual", "v",
"blocks", "blk",
"attn.proj", "attn_out",
"norm1", "ln1",
"norm2", "ln2",
)
}

View File

@@ -47,7 +47,7 @@ func convertFull(t *testing.T, fsys fs.FS) (*os.File, ggml.KV, ggml.Tensors) {
}
t.Cleanup(func() { r.Close() })
m, err := ggml.Decode(r, -1)
m, _, err := ggml.Decode(r, -1)
if err != nil {
t.Fatal(err)
}
@@ -332,7 +332,7 @@ func TestConvertAdapter(t *testing.T) {
}
defer r.Close()
m, err := ggml.Decode(r, -1)
m, _, err := ggml.Decode(r, -1)
if err != nil {
t.Fatal(err)
}

View File

@@ -1,56 +0,0 @@
package convert
import (
"iter"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
)
// splitDim splits a tensor along a specified dimension into multiple tensors. The dimension
// is split evenly based on the number of replacers provided.
func splitDim(t Tensor, dim int, replacers ...*strings.Replacer) iter.Seq[*ggml.Tensor] {
return func(yield func(*ggml.Tensor) bool) {
for i, replacer := range replacers {
shape := slices.Clone(t.Shape())
shape[dim] = shape[dim] / uint64(len(replacers))
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
slice[dim] = tensor.S(i*int(shape[dim]), (i+1)*int(shape[dim]))
tt := t.Clone()
tt.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))
for i := range shape {
dims[i] = int(shape[i])
}
var t tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
t, err := t.Slice(slice...)
if err != nil {
return nil, err
}
t = tensor.Materialize(t)
// flatten tensor so it can be written as a vector
if err := t.Reshape(t.Shape().TotalSize()); err != nil {
return nil, err
}
return native.VectorF32(t.(*tensor.Dense))
})
if !yield(&ggml.Tensor{
Name: replacer.Replace(t.Name()),
Kind: t.Kind(),
Shape: shape,
WriterTo: tt,
}) {
break
}
}
}
}

View File

@@ -110,7 +110,6 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
}
if f, err := fsys.Open("tokenizer_config.json"); errors.Is(err, os.ErrNotExist) {
// noop
} else if err != nil {
return nil, err
} else {
@@ -172,34 +171,6 @@ func parseTokenizer(fsys fs.FS, specialTokenTypes []string) (*Tokenizer, error)
}
}
if f, err := fsys.Open("generation_config.json"); errors.Is(err, os.ErrNotExist) {
} else if err != nil {
return nil, err
} else {
defer f.Close()
var p map[string]json.RawMessage
if err := json.NewDecoder(f).Decode(&p); err != nil {
return nil, err
}
for _, st := range specialTokenTypes {
if bts, ok := p[fmt.Sprintf("%s_token_id", st)]; ok {
var ids []int32
if err := json.Unmarshal(bts, &ids); err != nil {
// value is not a list so the existing ID is used
continue
}
if i := slices.IndexFunc(t.SpecialVocabulary, func(sv *SpecialVocabulary) bool {
return sv.Type == st
}); i >= 0 {
t.SpecialVocabulary[i].IDs = ids
}
}
}
}
return t, nil
}
@@ -309,9 +280,6 @@ type SpecialVocabulary struct {
ID int
Content string
AddToken bool
// IDs is populated by generation_config.json
IDs []int32
}
func (sv SpecialVocabulary) Key() string {

View File

@@ -247,67 +247,6 @@ func TestParseTokenizer(t *testing.T) {
Pre: "default",
},
},
{
name: "generation config eos token ids",
fsys: createTokenizerFS(t, t.TempDir(), map[string]io.Reader{
"tokenizer.json": strings.NewReader(`{
"added_tokens": [
{
"id": 0,
"content": "<bos>",
"special": true
},
{
"id": 1,
"content": "<eos>",
"special": true
},
{
"id": 2,
"content": "<eot>",
"special": true
},
{
"id": 3,
"content": "<eom>",
"special": true
}
],
"model": {
"vocab": {
"<bos>": 0,
"<eos>": 1,
"<eot>": 2,
"<eom>": 3
}
}
}`),
"tokenizer_config.json": strings.NewReader(`{
"add_bos_token": true,
"add_eos_token": false,
"bos_token": "<bos>",
"eos_token": "<eos>"
}`),
"generation_config.json": strings.NewReader(`{
"bos_token_id": 0,
"eos_token_id": [1, 2, 3]
}`),
}),
specialTokenTypes: []string{"pad", "eos", "bos", "unk"},
want: &Tokenizer{
Vocabulary: &Vocabulary{
Model: "gpt2",
Tokens: []string{"<bos>", "<eos>", "<eot>", "<eom>"},
Scores: []float32{0, 1, 2, 3},
Types: []int32{3, 3, 3, 3},
},
SpecialVocabulary: []*SpecialVocabulary{
{Type: "eos", Content: "<eos>", ID: 1, IDs: []int32{1, 2, 3}, AddToken: false},
{Type: "bos", Content: "<bos>", ID: 0, AddToken: true},
},
Pre: "default",
},
},
}
for _, tt := range cases {

View File

@@ -15,7 +15,6 @@ import (
type GGML struct {
container
model
Length int64
}
type model interface {
@@ -127,7 +126,6 @@ func (kv KV) OllamaEngineRequired() bool {
"mistral3",
"llama4",
"mllama",
"qwen25vl",
}, kv.Architecture())
}
@@ -387,12 +385,12 @@ func DetectContentType(b []byte) string {
//
// It collects array values for arrays with a size less than or equal to
// maxArraySize. If the maxArraySize is negative, all arrays are collected.
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
rs = bufioutil.NewBufferedSeeker(rs, 32<<10)
var magic uint32
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
return nil, err
return nil, 0, err
}
var c container
@@ -402,25 +400,24 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
case FILE_MAGIC_GGUF_BE:
c = &containerGGUF{ByteOrder: binary.BigEndian, maxArraySize: maxArraySize}
default:
return nil, errors.New("invalid file magic")
return nil, 0, errors.New("invalid file magic")
}
model, err := c.Decode(rs)
if err != nil {
return nil, err
return nil, 0, err
}
offset, err := rs.Seek(0, io.SeekCurrent)
if err != nil {
return nil, err
return nil, 0, err
}
// final model type
return &GGML{
container: c,
model: model,
Length: offset,
}, nil
}, offset, nil
}
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
@@ -652,20 +649,6 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize +
numPatches*numPatches*headCount)
case "qwen25vl":
maxPixels := uint64(llm.KV().Uint("vision.max_pixels", 28*28*1280))
numPatches := maxPixels / (patchSize * patchSize)
graphSize = 4 * (maxPixels*numChannels + // Original image storage
// Normalized pixels
maxPixels*numChannels +
// Patches storage (numPatches * channels * patchSize^2)
numPatches*numChannels*patchSize*patchSize +
// Self-attention calculations
numPatches*numPatches*headCount +
// Additional buffer for processing
embeddingLength*numPatches)
case "llama4":
// vision graph is computed independently in the same schedule
// and is negligible compared to the worst case text graph

View File

@@ -35,7 +35,7 @@ func TestWriteGGUF(t *testing.T) {
}
defer r.Close()
ff, err := Decode(r, 0)
ff, _, err := Decode(r, 0)
if err != nil {
t.Fatal(err)
}

1
go.mod
View File

@@ -19,6 +19,7 @@ require (
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/dlclark/regexp2 v1.11.4
github.com/emirpasic/gods/v2 v2.0.0-alpha
github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1
github.com/google/go-cmp v0.6.0
github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0

2
go.sum
View File

@@ -69,6 +69,8 @@ github.com/go-fonts/latin-modern v0.2.0/go.mod h1:rQVLdDMK+mK1xscDwsqM5J8U2jrRa3
github.com/go-fonts/liberation v0.1.1/go.mod h1:K6qoJYypsmfVjWg8KOVDQhLc8UDgIK2HYqyqAO9z7GY=
github.com/go-fonts/stix v0.1.0/go.mod h1:w/c1f0ldAUlJmLBvlbkvVXLAD+tAMqobIIQpmnUIzUY=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1 h1:+VexzzkMLb1tnvpuQdGT/DicIRW7MN8ozsXqBMgp0Hk=
github.com/go-json-experiment/json v0.0.0-20250417205406-170dfdcf87d1/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M=
github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=

View File

@@ -19,7 +19,7 @@ func TestVisionModels(t *testing.T) {
}
testCases := []testCase{
{
model: "qwen2.5vl",
model: "llava:7b",
},
{
model: "llama3.2-vision",
@@ -60,7 +60,6 @@ func TestVisionModels(t *testing.T) {
}
func TestIntegrationSplitBatch(t *testing.T) {
skipUnderMinVRAM(t, 6)
image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err)
req := api.GenerateRequest{

View File

@@ -544,7 +544,7 @@ func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext,
cparams.penalty_last_n = C.int32_t(params.RepeatLastN)
cparams.penalty_repeat = C.float(params.PenaltyRepeat)
cparams.penalty_freq = C.float(params.PenaltyFreq)
cparams.penalty_present = C.float(params.PenaltyPresent)
cparams.penalty_present = C.float(params.PenaltyFreq)
cparams.seed = C.uint32_t(params.Seed)
grammar := C.CString(params.Grammar)
@@ -602,7 +602,7 @@ type Grammar struct {
mu sync.Mutex
}
func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogTokens []int32) *Grammar {
func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogTokens []uint32) *Grammar {
cGrammar := C.CString(grammar)
defer C.free(unsafe.Pointer(cGrammar))
@@ -622,7 +622,7 @@ func NewGrammar(grammar string, vocabIds []uint32, vocabValues []string, eogToke
cEogTokens[i] = C.uint32_t(token)
}
g := C.grammar_init(cGrammar, unsafe.SliceData(cTokens), C.size_t(len(cTokens)), unsafe.SliceData(cPieces), unsafe.SliceData(cEogTokens), C.size_t(len(cEogTokens)))
g := C.grammar_init(cGrammar, (*C.uint32_t)(unsafe.Pointer(&cTokens[0])), C.size_t(len(cTokens)), (**C.char)(unsafe.Pointer(&cPieces[0])), (*C.uint32_t)(unsafe.Pointer(&cEogTokens[0])), C.size_t(len(cEogTokens)))
if g == nil {
return nil
}

View File

@@ -1,277 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <git@mxy.ng>
Date: Thu, 1 May 2025 13:45:12 -0700
Subject: [PATCH] add argsort and cuda copy for i32
---
ggml/src/ggml-cpu/ops.cpp | 43 ++++++++++++++
ggml/src/ggml-cuda/argsort.cu | 102 +++++++++++++++++++++++++++++++++-
ggml/src/ggml-cuda/cpy.cu | 49 ++++++++++++++++
3 files changed, 192 insertions(+), 2 deletions(-)
diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp
index becdae07..7a44b6cf 100644
--- a/ggml/src/ggml-cpu/ops.cpp
+++ b/ggml/src/ggml-cpu/ops.cpp
@@ -6890,6 +6890,45 @@ static void ggml_compute_forward_argsort_f32(
}
}
+static void ggml_compute_forward_argsort_i32(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ GGML_ASSERT(nb0 == sizeof(int32_t));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int64_t nr = ggml_nrows(src0);
+
+ ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
+
+ for (int64_t i = ith; i < nr; i += nth) {
+ int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
+ const int32_t * src_data = (int32_t *)((char *) src0->data + i*nb01);
+
+ for (int64_t j = 0; j < ne0; j++) {
+ dst_data[j] = j;
+ }
+
+ // C doesn't have a functional sort, so we do a bubble sort instead
+ for (int64_t j = 0; j < ne0; j++) {
+ for (int64_t k = j + 1; k < ne0; k++) {
+ if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
+ (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
+ int32_t tmp = dst_data[j];
+ dst_data[j] = dst_data[k];
+ dst_data[k] = tmp;
+ }
+ }
+ }
+ }
+}
+
void ggml_compute_forward_argsort(
const ggml_compute_params * params,
ggml_tensor * dst) {
@@ -6901,6 +6940,10 @@ void ggml_compute_forward_argsort(
{
ggml_compute_forward_argsort_f32(params, dst);
} break;
+ case GGML_TYPE_I32:
+ {
+ ggml_compute_forward_argsort_i32(params, dst);
+ } break;
default:
{
GGML_ABORT("fatal error");
diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu
index 607ded85..53b02634 100644
--- a/ggml/src/ggml-cuda/argsort.cu
+++ b/ggml/src/ggml-cuda/argsort.cu
@@ -85,13 +85,107 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
}
}
+
+template<ggml_sort_order order>
+static __global__ void k_argsort_i32_i32(const int32_t * x, int * dst, const int ncols, const int ncols_pad) {
+ extern __shared__ int shared_mem[];
+ int * indices = shared_mem;
+
+ const int tid = threadIdx.x;
+ const int row = blockIdx.y;
+
+ // Initialize all indices, handling the case where threads < ncols_pad
+ for (int i = tid; i < ncols_pad; i += blockDim.x) {
+ indices[i] = i < ncols ? i : 0; // Use 0 for padding indices
+ }
+ __syncthreads();
+
+ // Bitonic sort
+ for (int k = 2; k <= ncols_pad; k *= 2) {
+ for (int j = k/2; j > 0; j /= 2) {
+ for (int i = tid; i < ncols_pad; i += blockDim.x) {
+ const int ij = i ^ j;
+ if (ij > i) {
+ // Only compare values within the actual data range
+ if (i < ncols && ij < ncols) {
+ if ((i & k) == 0) {
+ if (order == GGML_SORT_ORDER_ASC) {
+ if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) {
+ int tmp = indices[i];
+ indices[i] = indices[ij];
+ indices[ij] = tmp;
+ }
+ } else {
+ if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) {
+ int tmp = indices[i];
+ indices[i] = indices[ij];
+ indices[ij] = tmp;
+ }
+ }
+ } else {
+ if (order == GGML_SORT_ORDER_ASC) {
+ if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) {
+ int tmp = indices[i];
+ indices[i] = indices[ij];
+ indices[ij] = tmp;
+ }
+ } else {
+ if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) {
+ int tmp = indices[i];
+ indices[i] = indices[ij];
+ indices[ij] = tmp;
+ }
+ }
+ }
+ }
+ }
+ }
+ __syncthreads();
+ }
+ }
+
+ // Write sorted indices to output, only threads handling valid data
+ for (int i = tid; i < ncols; i += blockDim.x) {
+ dst[row * ncols + i] = indices[i];
+ }
+}
+
+static void argsort_i32_i32_cuda(const int32_t * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
+ // Bitonic sort requires ncols to be power of 2
+ const int ncols_pad = next_power_of_2(ncols);
+
+ // Ensure thread count doesn't exceed maximum (typically 1024)
+ const int max_threads = 1024; // This is the typical max for most GPUs
+ const int threads_per_block = ncols_pad > max_threads ? max_threads : ncols_pad;
+
+ const dim3 block_dims(threads_per_block, 1, 1);
+ const dim3 block_nums(1, nrows, 1);
+ const size_t shared_mem = ncols_pad * sizeof(int);
+
+ // Check if shared memory size is within limits
+ const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
+
+ // Instead of logging an error, use GGML_ASSERT with a descriptive message
+ GGML_ASSERT(shared_mem <= max_shared_mem && "argsort: required shared memory exceeds device limit");
+
+ // Launch kernels with the updated thread configuration
+ if (order == GGML_SORT_ORDER_ASC) {
+ k_argsort_i32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
+ } else if (order == GGML_SORT_ORDER_DESC) {
+ k_argsort_i32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
+ } else {
+ GGML_ABORT("fatal error");
+ }
+}
+
+
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(src0));
@@ -100,5 +194,9 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
- argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
+ if (src0->type == GGML_TYPE_I32) {
+ argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream);
+ } else {
+ argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
+ }
}
diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu
index 2d46176e..47383486 100644
--- a/ggml/src/ggml-cuda/cpy.cu
+++ b/ggml/src/ggml-cuda/cpy.cu
@@ -38,6 +38,13 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
*dsti = *xi;
}
+static __device__ void cpy_1_i32_i32(const char * cxi, char * cdsti) {
+ const int32_t * xi = (const int32_t *) cxi;
+ int32_t * dsti = (int32_t *) cdsti;
+
+ *dsti = *xi;
+}
+
template <cpy_kernel_t cpy_1>
static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -68,6 +75,44 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in
cpy_1(cx + x_offset, cdst + dst_offset);
}
+// First, add this template function after the other template functions
+template <cpy_kernel_t cpy_1>
+static __global__ void cpy_i32_i32(const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+ const int nb12, const int nb13) {
+ const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= ne) {
+ return;
+ }
+
+ const int64_t i03 = i/(ne00 * ne01 * ne02);
+ const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+ const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
+ const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+ const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+ const int64_t i13 = i/(ne10 * ne11 * ne12);
+ const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+ const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+ const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+ const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
+
+ cpy_1(cx + x_offset, cdst + dst_offset);
+}
+
+// Then modify the ggml_cpy_i32_i32_cuda function to use the new template
+static void ggml_cpy_i32_i32_cuda(
+ const char * cx, char * cdst, const int ne,
+ const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+ const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int graph_cpynode_index) {
+
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+ cpy_i32_i32<cpy_1_i32_i32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+ (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q8_0 * dsti = (block_q8_0 *) cdsti;
@@ -631,6 +676,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
+ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
+ ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -686,6 +733,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
+ } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
+ return (void*) cpy_i32_i32<cpy_1_i32_i32>;
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));

View File

@@ -1,9 +1,12 @@
package llm
import (
"cmp"
"fmt"
"log/slog"
"maps"
"os"
"slices"
"strconv"
"strings"
@@ -82,11 +85,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
var graphOffload uint64
// Projectors loaded into GPU0 only
var llamaEngineProjectorWeights uint64
// Projectors loaded with output layer
var ollamaEngineProjectorWeights uint64
var ollamaEngineProjectorGraph uint64
var projectorWeights uint64
var projectorGraph uint64
// Conditional output size on GPU 0
var memoryLayerOutput uint64
@@ -111,23 +111,21 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", availableList)
for _, projector := range projectors {
llamaEngineProjectorWeights += projectorMemoryRequirements(projector)
weight := projectorMemoryRequirements(projector)
projectorWeights += weight
// multimodal models require at least 2048 context
opts.NumCtx = max(opts.NumCtx, 2048)
}
if llamaEngineProjectorWeights == 0 {
ollamaEngineProjectorWeights, ollamaEngineProjectorGraph = f.VisionGraphSize()
opts.NumCtx = max(opts.NumCtx, 2048)
if projectorWeights == 0 && projectorGraph == 0 {
projectorWeights, projectorGraph = f.VisionGraphSize()
}
layers := f.Tensors().GroupLayers()
// add one layer worth of memory as a buffer
if blk0, ok := layers["blk.0"]; ok {
layerSize = blk0.Size()
} else {
slog.Warn("model missing blk.0 layer size")
}
// add one layer (chosing the max layer) worth of memory as a buffer
layerSize = slices.MaxFunc(slices.Collect(maps.Values(layers)), func(a, b ggml.Layer) int {
return cmp.Compare(a.Size(), b.Size())
}).Size()
var kvct string
if envconfig.FlashAttention() &&
@@ -165,7 +163,6 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
graphFullOffload = graphPartialOffload
}
// Output layer handled at the end if we have space
if layer, ok := layers["output_norm"]; ok {
memoryLayerOutput += layer.Size()
}
@@ -175,7 +172,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
memoryLayerOutput += layer.Size()
}
gpuZeroOverhead := llamaEngineProjectorWeights
// Output layer handled at the end if we have space
gpuZeroOverhead := projectorWeights + projectorGraph
// Reduce set of GPUs to only those that have sufficient space to fit overhead and at least one layer
var layerCount int
@@ -218,8 +216,6 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
if len(gpusWithSpace) > 0 {
gpuZeroID = gpusWithSpace[0].i
gpuAllocations[gpuZeroID] += gpuZeroOverhead
} else {
overflow += gpuZeroOverhead
}
// For all the layers, find where they can fit on the GPU(s)
@@ -260,24 +256,21 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
}
// Determine if we need to consider output then find where it fits
memoryLastLayer := memoryLayerOutput + ollamaEngineProjectorWeights + ollamaEngineProjectorGraph
if memoryLastLayer > 0 {
if opts.NumGPU < 0 || layerCount < opts.NumGPU {
for j := len(gpusWithSpace); j > 0; j-- {
g := gpusWithSpace[layerCount%j]
used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
if g.g.FreeMemory > overhead+used+memoryLastLayer {
gpuAllocations[g.i] += memoryLastLayer
layerCounts[g.i]++
layerCount++
break
}
if memoryLayerOutput > 0 && (opts.NumGPU < 0 || layerCount < opts.NumGPU) {
for j := len(gpusWithSpace); j > 0; j-- {
g := gpusWithSpace[layerCount%j]
used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
if g.g.FreeMemory > overhead+used+memoryLayerOutput {
gpuAllocations[g.i] += memoryLayerOutput
layerCounts[g.i]++
layerCount++
break
}
}
if layerCount < int(f.KV().BlockCount())+1 {
fullyLoaded = false
overflow += memoryLastLayer
overflow += memoryLayerOutput
}
}
@@ -335,8 +328,8 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
memoryLayerOutput: memoryLayerOutput,
graphFullOffload: graphFullOffload,
graphPartialOffload: graphPartialOffload,
projectorWeights: llamaEngineProjectorWeights + ollamaEngineProjectorWeights,
projectorGraph: ollamaEngineProjectorGraph,
projectorWeights: projectorWeights,
projectorGraph: projectorGraph,
}
if gpus[0].Library == "cpu" {
@@ -422,7 +415,7 @@ func projectorMemoryRequirements(filename string) (weights uint64) {
}
defer file.Close()
ggml, err := ggml.Decode(file, 1024)
ggml, _, err := ggml.Decode(file, 1024)
if err != nil {
return 0
}

View File

@@ -121,7 +121,7 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
}
defer f.Close()
ggml, err := ggml.Decode(f, maxArraySize)
ggml, _, err := ggml.Decode(f, maxArraySize)
return ggml, err
}

View File

@@ -6,6 +6,7 @@ import (
"encoding/binary"
"fmt"
"math"
"os"
"slices"
"strconv"
"strings"
@@ -14,7 +15,6 @@ import (
)
type Backend interface {
Load(ctx context.Context, progress func(float32)) error
Config() fs.Config
Get(name string) Tensor
NewContext() Context
@@ -52,6 +52,10 @@ type CacheConfig struct {
// BackendParams controls how the backend loads and executes models
type BackendParams struct {
// Progress is a callback function that allows reporting percentage completion
// of model loading
Progress func(float32)
// NumThreads sets the number of threads to use if running on the CPU
NumThreads int
@@ -68,9 +72,9 @@ type BackendParams struct {
FlashAttention bool
}
var backends = make(map[string]func(string, BackendParams) (Backend, error))
var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok {
panic("backend: backend already registered")
}
@@ -78,9 +82,9 @@ func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)
backends[name] = f
}
func NewBackend(modelPath string, params BackendParams) (Backend, error) {
func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) {
if backend, ok := backends["ggml"]; ok {
return backend(modelPath, params)
return backend(ctx, f, params)
}
return nil, fmt.Errorf("unsupported backend")
@@ -128,8 +132,6 @@ type Tensor interface {
Neg(ctx Context) Tensor
Add(ctx Context, t2 Tensor) Tensor
Mul(ctx Context, t2 Tensor) Tensor
Div(ctx Context, t2 Tensor) Tensor
Mulmat(ctx Context, t2 Tensor) Tensor
MulmatFullPrec(ctx Context, t2 Tensor) Tensor
MulmatID(ctx Context, t2, ids Tensor) Tensor
@@ -138,11 +140,11 @@ type Tensor interface {
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor
SumRows(ctx Context) Tensor
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Sin(ctx Context) Tensor
@@ -170,7 +172,6 @@ type Tensor interface {
Duplicate(ctx Context) Tensor
TopK(ctx Context, k int) Tensor
Argsort(ctx Context) Tensor
}
// ScaledDotProductAttention implements a fused attention

View File

@@ -30,7 +30,6 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
"github.com/ollama/ollama/ml/nn/rope"
"golang.org/x/sync/errgroup"
)
@@ -45,15 +44,8 @@ func devices() []*C.struct_ggml_backend_device {
}
type Backend struct {
// modelPath is the location of the model data
modelPath string
meta *fsggml.GGML
// tensorLoadTargets maps from the name of the tensor in the file
// to the name that is used by the model definition
tensorLoadTargets map[string][]string
sched *C.struct_ggml_backend_sched
schedBackends []*C.struct_ggml_backend
schedBufts []*C.struct_ggml_backend_buffer_type
@@ -72,14 +64,8 @@ type Backend struct {
maxGraphNodes int
}
func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
r, err := os.Open(modelPath)
if err != nil {
return nil, err
}
defer r.Close()
meta, err := fsggml.Decode(r, -1)
func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
meta, n, err := fsggml.Decode(r, -1)
if err != nil {
return nil, err
}
@@ -321,6 +307,73 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
}
}
var doneBytes atomic.Uint64
totalBytes := uint64(n) - meta.Tensors().Offset
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(runtime.GOMAXPROCS(0))
for _, t := range meta.Tensors().Items() {
t := t
g.Go(func() error {
tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
for i := range tts {
target := targets[t.Name][i]
if target == "" {
target = t.Name
}
tt, ok := tensors[target]
if !ok {
return fmt.Errorf("unassigned tensor: %s", t.Name)
}
tts[i] = tt
}
// Create a new FD for each goroutine so that each FD is read sequentially, rather than
// seeking around within an FD shared between all goroutines.
file, err := os.Open(r.Name())
if err != nil {
slog.Warn("file open error", "file", r.Name(), "error", err)
return err
}
defer file.Close()
sr := io.NewSectionReader(file, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
bts := make([]byte, 128*format.KibiByte)
var s uint64
for s < t.Size() {
// Stop if either the parent context has been canceled or if any of the other tensors returned an error
if err := ctx.Err(); err != nil {
return err
}
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
if err != nil {
slog.Warn("file read error", "file", r.Name(), "error", err)
return err
}
for _, tt := range tts {
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
}
s += uint64(n)
if params.Progress != nil {
done := doneBytes.Add(uint64(n))
params.Progress(float32(done) / float32(totalBytes))
}
}
return nil
})
}
if err := g.Wait(); err != nil {
return nil, err
}
// map devices to backend buffer types so new tensors can be assigned to the correct device
deviceBufferTypes := make(map[*C.struct_ggml_backend_device]*C.struct_ggml_backend_buffer_type)
@@ -344,11 +397,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
maxGraphNodes := max(8192, len(meta.Tensors().Items())*5)
return &Backend{
modelPath: modelPath,
flashAttention: params.FlashAttention,
meta: meta,
tensorLoadTargets: targets,
tensors: tensors,
flashAttention: params.FlashAttention,
meta: meta,
tensors: tensors,
sched: C.ggml_backend_sched_new(
(*C.ggml_backend_t)(unsafe.Pointer(&schedBackends[0])),
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
@@ -375,77 +426,6 @@ func init() {
ml.RegisterBackend("ggml", New)
}
func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
var doneBytes atomic.Uint64
totalBytes := uint64(b.meta.Length) - b.meta.Tensors().Offset
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(runtime.GOMAXPROCS(0))
for _, t := range b.meta.Tensors().Items() {
t := t
g.Go(func() error {
tts := make([]*C.struct_ggml_tensor, max(1, len(b.tensorLoadTargets[t.Name])))
for i := range tts {
target := b.tensorLoadTargets[t.Name][i]
if target == "" {
target = t.Name
}
tt, ok := b.tensors[target]
if !ok {
return fmt.Errorf("unassigned tensor: %s", t.Name)
}
tts[i] = tt
}
// Create a new FD for each goroutine so that each FD is read sequentially, rather than
// seeking around within an FD shared between all goroutines.
file, err := os.Open(b.modelPath)
if err != nil {
slog.Warn("file open error", "file", b.modelPath, "error", err)
return err
}
defer file.Close()
sr := io.NewSectionReader(file, int64(b.meta.Tensors().Offset+t.Offset), int64(t.Size()))
bts := make([]byte, 128*format.KibiByte)
var s uint64
for s < t.Size() {
// Stop if either the parent context has been canceled or if any of the other tensors returned an error
if err := ctx.Err(); err != nil {
return err
}
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
if err != nil {
slog.Warn("file read error", "file", b.modelPath, "error", err)
return err
}
for _, tt := range tts {
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
}
s += uint64(n)
if progress != nil {
done := doneBytes.Add(uint64(n))
progress(float32(done) / float32(totalBytes))
}
}
return nil
})
}
if err := g.Wait(); err != nil {
return err
}
return nil
}
func (b *Backend) Config() fs.Config {
return b.meta.KV()
}
@@ -887,13 +867,6 @@ func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) Div(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_div(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
@@ -942,8 +915,6 @@ func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
if len(shape) != 4 {
panic("expected 4 dimensions")
} else if shape[3] != 0 {
panic("cuda does not support 4d tensors")
}
return &Tensor{
@@ -1011,13 +982,6 @@ func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
}
}
func (t *Tensor) SumRows(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_sum_rows(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
@@ -1089,13 +1053,16 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
}
}
func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase, ropeScale float32, options ...func(*rope.Options)) ml.Tensor {
// Default options
opts := &rope.Options{OriginalContextLength: 131072, Factors: &Tensor{}}
const (
ropeTypeNorm C.int = 0
ropeTypeNeox C.int = 2
ropeTypeMrope C.int = 8
ropeTypeVision C.int = 24
)
// Apply any provided options
for _, option := range options {
option(opts)
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b}
}
dequant := t.t
@@ -1106,19 +1073,16 @@ func (t *Tensor) RoPE(ctx ml.Context, positions ml.Tensor, ropeDim int, ropeBase
return &Tensor{
b: t.b,
t: C.ggml_rope_ext(
ctx.(*Context).ctx,
dequant,
positions.(*Tensor).t,
opts.Factors.(*Tensor).t,
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(ropeDim),
C.int(opts.Type),
C.int(opts.OriginalContextLength),
C.int(ropeType),
131072, // YaRN n_ctx_train
C.float(ropeBase),
C.float(ropeScale),
C.float(0.0),
C.float(1.0),
C.float(32.0),
C.float(1.0),
0., // YaRN ext_factor
1., // YaRN attn_factor
32., // YaRN beta_fast
1., // YaRN beta_slow
),
}
}
@@ -1212,10 +1176,3 @@ func (t *Tensor) TopK(ctx ml.Context, k int) ml.Tensor {
t: C.ggml_top_k(ctx.(*Context).ctx, t.t, C.int(k)),
}
}
func (t *Tensor) Argsort(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_argsort(ctx.(*Context).ctx, t.t, C.GGML_SORT_ORDER_ASC),
}
}

View File

@@ -3,7 +3,7 @@ package cpu
// #cgo CFLAGS: -O3 -Wno-implicit-function-declaration
// #cgo CXXFLAGS: -std=c++17
// #cgo CPPFLAGS: -I${SRCDIR}/amx -I${SRCDIR}/llamafile -I${SRCDIR}/.. -I${SRCDIR}/../../include
// #cgo CPPFLAGS: -DNDEBUG -DGGML_USE_LLAMAFILE
// #cgo CPPFLAGS: -DGGML_USE_LLAMAFILE
// #cgo linux CPPFLAGS: -D_GNU_SOURCE
// #cgo darwin,arm64 CPPFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
// #cgo darwin,arm64 LDFLAGS: -framework Accelerate

View File

@@ -6822,45 +6822,6 @@ static void ggml_compute_forward_argsort_f32(
}
}
static void ggml_compute_forward_argsort_i32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_TENSOR_UNARY_OP_LOCALS
GGML_ASSERT(nb0 == sizeof(int32_t));
const int ith = params->ith;
const int nth = params->nth;
const int64_t nr = ggml_nrows(src0);
ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
for (int64_t i = ith; i < nr; i += nth) {
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
const int32_t * src_data = (int32_t *)((char *) src0->data + i*nb01);
for (int64_t j = 0; j < ne0; j++) {
dst_data[j] = j;
}
// C doesn't have a functional sort, so we do a bubble sort instead
for (int64_t j = 0; j < ne0; j++) {
for (int64_t k = j + 1; k < ne0; k++) {
if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
(order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
int32_t tmp = dst_data[j];
dst_data[j] = dst_data[k];
dst_data[k] = tmp;
}
}
}
}
}
void ggml_compute_forward_argsort(
const ggml_compute_params * params,
ggml_tensor * dst) {
@@ -6872,10 +6833,6 @@ void ggml_compute_forward_argsort(
{
ggml_compute_forward_argsort_f32(params, dst);
} break;
case GGML_TYPE_I32:
{
ggml_compute_forward_argsort_i32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");

View File

@@ -85,107 +85,13 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
}
}
template<ggml_sort_order order>
static __global__ void k_argsort_i32_i32(const int32_t * x, int * dst, const int ncols, const int ncols_pad) {
extern __shared__ int shared_mem[];
int * indices = shared_mem;
const int tid = threadIdx.x;
const int row = blockIdx.y;
// Initialize all indices, handling the case where threads < ncols_pad
for (int i = tid; i < ncols_pad; i += blockDim.x) {
indices[i] = i < ncols ? i : 0; // Use 0 for padding indices
}
__syncthreads();
// Bitonic sort
for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k/2; j > 0; j /= 2) {
for (int i = tid; i < ncols_pad; i += blockDim.x) {
const int ij = i ^ j;
if (ij > i) {
// Only compare values within the actual data range
if (i < ncols && ij < ncols) {
if ((i & k) == 0) {
if (order == GGML_SORT_ORDER_ASC) {
if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) {
int tmp = indices[i];
indices[i] = indices[ij];
indices[ij] = tmp;
}
} else {
if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) {
int tmp = indices[i];
indices[i] = indices[ij];
indices[ij] = tmp;
}
}
} else {
if (order == GGML_SORT_ORDER_ASC) {
if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) {
int tmp = indices[i];
indices[i] = indices[ij];
indices[ij] = tmp;
}
} else {
if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) {
int tmp = indices[i];
indices[i] = indices[ij];
indices[ij] = tmp;
}
}
}
}
}
}
__syncthreads();
}
}
// Write sorted indices to output, only threads handling valid data
for (int i = tid; i < ncols; i += blockDim.x) {
dst[row * ncols + i] = indices[i];
}
}
static void argsort_i32_i32_cuda(const int32_t * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
// Bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols);
// Ensure thread count doesn't exceed maximum (typically 1024)
const int max_threads = 1024; // This is the typical max for most GPUs
const int threads_per_block = ncols_pad > max_threads ? max_threads : ncols_pad;
const dim3 block_dims(threads_per_block, 1, 1);
const dim3 block_nums(1, nrows, 1);
const size_t shared_mem = ncols_pad * sizeof(int);
// Check if shared memory size is within limits
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
// Instead of logging an error, use GGML_ASSERT with a descriptive message
GGML_ASSERT(shared_mem <= max_shared_mem && "argsort: required shared memory exceeds device limit");
// Launch kernels with the updated thread configuration
if (order == GGML_SORT_ORDER_ASC) {
k_argsort_i32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else if (order == GGML_SORT_ORDER_DESC) {
k_argsort_i32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else {
GGML_ABORT("fatal error");
}
}
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(src0));
@@ -194,9 +100,5 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
if (src0->type == GGML_TYPE_I32) {
argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream);
} else {
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
}
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
}

View File

@@ -38,13 +38,6 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
*dsti = *xi;
}
static __device__ void cpy_1_i32_i32(const char * cxi, char * cdsti) {
const int32_t * xi = (const int32_t *) cxi;
int32_t * dsti = (int32_t *) cdsti;
*dsti = *xi;
}
template <cpy_kernel_t cpy_1>
static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -75,44 +68,6 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in
cpy_1(cx + x_offset, cdst + dst_offset);
}
// First, add this template function after the other template functions
template <cpy_kernel_t cpy_1>
static __global__ void cpy_i32_i32(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
}
const int64_t i03 = i/(ne00 * ne01 * ne02);
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int64_t i13 = i/(ne10 * ne11 * ne12);
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
cpy_1(cx + x_offset, cdst + dst_offset);
}
// Then modify the ggml_cpy_i32_i32_cuda function to use the new template
static void ggml_cpy_i32_i32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int graph_cpynode_index) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_i32_i32<cpy_1_i32_i32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q8_0 * dsti = (block_q8_0 *) cdsti;
@@ -678,8 +633,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -735,8 +688,6 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
return (void*) cpy_i32_i32<cpy_1_i32_i32>;
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));

View File

@@ -4,6 +4,6 @@ package metal
//go:generate sh -c "{ echo // Code generated by 'go generate'. DO NOT EDIT.; sed -e '/__embed_ggml-common.h__/r ../ggml-common.h' -e '/__embed_ggml-common.h__/d' -e '/#include \"ggml-metal-impl.h\"/r ggml-metal-impl.h' -e '/#include \"ggml-metal-impl.h\"/d' ggml-metal.metal; } >ggml-metal-embed.metal"
// #cgo CPPFLAGS: -DGGML_METAL_NDEBUG -DGGML_METAL_EMBED_LIBRARY -I.. -I../../include
// #cgo CPPFLAGS: -DGGML_METAL_EMBED_LIBRARY -I.. -I../../include
// #cgo LDFLAGS: -framework Metal -framework MetalKit
import "C"

View File

@@ -1,21 +0,0 @@
// fast provides implementations of fast (fused) operations for increased performance.
package fast
import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn/rope"
)
// fastRoPE is an interface for tensors that support fast rotary positional embedding.
type fastRoPE interface {
RoPE(ctx ml.Context, positionIDs ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor
}
// RoPE applies rotary positional embedding to tensor `t`.
func RoPE(ctx ml.Context, t, positions ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor {
if t, ok := t.(fastRoPE); ok {
return t.RoPE(ctx, positions, dim, base, scale, options...)
}
panic("RoPE not implemented for this tensor type")
}

View File

@@ -1,33 +0,0 @@
package rope
import "github.com/ollama/ollama/ml"
// Options contains optional parameters for RoPE function
type Options struct {
OriginalContextLength int
Type int
Factors ml.Tensor
}
// WithOriginalContextLength sets a custom context length
func WithOriginalContextLength(n int) func(*Options) {
return func(opts *Options) {
opts.OriginalContextLength = n
}
}
// WithType sets RoPE type to NeoX
func WithTypeNeoX() func(*Options) {
return func(opts *Options) {
opts.Type = 2
}
}
// WithFactors sets custom rope factors
func WithFactors(factors ml.Tensor) func(*Options) {
return func(opts *Options) {
if factors != nil {
opts.Factors = factors
}
}
}

View File

@@ -2,30 +2,16 @@ package input
import "github.com/ollama/ollama/ml"
// Multimodal is a multimodal embedding or a component of one.
// For example, it could be a row of an image that can be processed
// independently.
type Multimodal struct {
// Tensor is the embedding data. Implementations may chose what to
// store here or it may be nil if not needed. However, any ml.Tensor
// objects must be stored here and not in Data.
Tensor ml.Tensor
// Data is implementation-specific opaque data, such as metadata on how
// to layout Tensor. It may be nil if not needed. It may also store larger
// objects such as complete images if they are to be processed later.
Data any
}
// Input represents one token in the input stream
type Input struct {
// Token is a single element of text.
Token int32
// Multimodal is represents a non-text element such as an
// image (or part of one if the image can be processed in pieces).
// It may be used either together with Token or on its own.
Multimodal []Multimodal
// Multimodal is opaque data representing a non-text
// element such as an image (or part of one if the image
// can be processed in pieces). It may be either together
// with Token or on its own.
Multimodal any
// MultimodalHash is a unique representation of the data
// stored in Multimodal, used for caching and comparing
@@ -46,7 +32,7 @@ type Input struct {
// Positions slice.
type MultimodalIndex struct {
Index int
Multimodal []Multimodal
Multimodal any
}
// Batch contains the inputs for a model forward pass

View File

@@ -40,13 +40,12 @@ type MultimodalProcessor interface {
// EncodeMultimodal processes a single input (such as an image) and
// generates an output (typically an embedding) that can be used by the model.
//
// The return value is one or more tensors, each with optional model-specific
// opaque metadata. Typically, the tensors might be views into an embedding
// with each view representing a chunk of data that can be processed independently
// in different batches.
// The return value is most typically an ml.Tensor, however, different
// type are possible, such as an object containing a tensor plus
// additional metadata, a slice of tensors or even just the original input.
//
// The result may be cached by the runner.
EncodeMultimodal(ml.Context, []byte) ([]input.Multimodal, error)
EncodeMultimodal(ml.Context, []byte) (any, error)
// PostTokenize is called after tokenization to allow the model to edit the
// input stream to correctly arrange multimodal elements.
@@ -98,8 +97,14 @@ func Register(name string, f func(fs.Config) (Model, error)) {
}
// New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(modelPath string, params ml.BackendParams) (Model, error) {
b, err := ml.NewBackend(modelPath, params)
func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) {
r, err := os.Open(modelPath)
if err != nil {
return nil, err
}
defer r.Close()
b, err := ml.NewBackend(ctx, r, params)
if err != nil {
return nil, err
}
@@ -128,7 +133,7 @@ func NewTextProcessor(s string) (TextProcessor, error) {
return nil, err
}
defer r.Close()
meta, err := fsggml.Decode(r, -1)
meta, _, err := fsggml.Decode(r, -1)
if err != nil {
return nil, err
}

View File

@@ -7,8 +7,6 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
@@ -45,13 +43,10 @@ func New(c fs.Config) (model.Model, error) {
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
// TODO: set EOT to EOS otherwise 0 will stop generation
EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
),
Layers: make([]Layer, c.Uint("block_count")),
@@ -85,10 +80,11 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@@ -98,7 +94,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@@ -128,7 +124,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
}
type MLP struct {

View File

@@ -60,16 +60,12 @@ func New(c fs.Config) (model.Model, error) {
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
EOS: int32(1),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
EOT: int32(106),
AddEOT: c.Bool("tokenizer.ggml.add_eot_token", false),
},
),
ImageProcessor: newImageProcessor(c),
@@ -86,7 +82,7 @@ func New(c fs.Config) (model.Model, error) {
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
@@ -112,22 +108,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
return []input.Multimodal{{Tensor: visionOutputs}}, nil
return visionOutputs, nil
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {
if len(inp.Multimodal) == 0 {
if inp.Multimodal == nil {
result = append(result, inp)
} else {
inputMultimodal := inp.Multimodal[0].Tensor
inputMultimodal := inp.Multimodal.(ml.Tensor)
result = append(result,
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
input.Input{Token: 255999}, // "<start_of_image>""
input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
input.Input{Token: 255999}, // "<start_of_image>""
input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
)
// add image token placeholders

View File

@@ -7,8 +7,6 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
@@ -75,6 +73,7 @@ type TextSelfAttention struct {
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
ropeBase := opts.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
@@ -84,7 +83,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@@ -95,7 +94,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, opts.ropeScale, rope.WithTypeNeoX())
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@@ -113,7 +112,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
ropeBase = m.TextConfig.ropeGlobalBase
}
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil
}
type TextMLP struct {
@@ -166,7 +165,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
// set image embeddings
var except []int
for _, image := range batch.Multimodal {
visionOutputs := image.Multimodal[0].Tensor
visionOutputs := image.Multimodal.(ml.Tensor)
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
for i := range visionOutputs.Dim(1) {

View File

@@ -1,23 +1,22 @@
package llama
import (
"cmp"
"fmt"
"math"
"strings"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
hiddenSize, numHeads, numKVHeads int
headDim, ropeDim int
eps, ropeBase, ropeScale float32
ropeDim uint32
}
type Model struct {
@@ -33,6 +32,10 @@ type Model struct {
}
func New(c fs.Config) (model.Model, error) {
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
@@ -40,13 +43,13 @@ func New(c fs.Config) (model.Model, error) {
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
// TODO: set EOT to EOS otherwise 0 will stop generation
EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
Layers: make([]Layer, c.Uint("block_count")),
@@ -54,11 +57,10 @@ func New(c fs.Config) (model.Model, error) {
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
},
}
@@ -75,31 +77,31 @@ type SelfAttention struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
ropeDim := cmp.Or(opts.ropeDim, headDim)
headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, attention)
return sa.Output.Forward(ctx, kqv)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
}
type MLP struct {
@@ -120,11 +122,11 @@ type Layer struct {
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, opts)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
@@ -147,20 +149,22 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var outputs ml.Tensor
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"image"
"slices"
"sync"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@@ -40,13 +41,13 @@ func New(c fs.Config) (model.Model, error) {
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
// TODO: set EOT to EOS otherwise 0 will stop generation
EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
ImageProcessor: newImageProcessor(c),
@@ -62,7 +63,7 @@ func New(c fs.Config) (model.Model, error) {
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) < 1 {
return nil, model.ErrNoVisionModel
}
@@ -102,79 +103,70 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3))
projectedOutputs := m.Projector.Forward(ctx, visionOutputs)
var multimodal []input.Multimodal
aspectRatio := image.Point{ratioW, ratioH}
var offset int
patchesPerChunk := projectedOutputs.Dim(1)
if aspectRatio.Y*aspectRatio.X > 1 {
patchesPerChunk = projectedOutputs.Dim(1) / (aspectRatio.X*aspectRatio.Y + 1)
for range aspectRatio.Y {
for x := range aspectRatio.X {
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset,
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
patchesPerChunk)
var separator separator
if x < aspectRatio.X-1 {
separator.x = true // <|tile_x_separator|>
} else {
separator.y = true // <|tile_y_separator|>
}
multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator})
offset += patchesPerChunk
}
}
}
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset,
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
patchesPerChunk)
multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator{}})
return multimodal, nil
return &chunks{Model: m, Tensor: projectedOutputs, aspectRatio: image.Point{ratioW, ratioH}}, nil
}
type separator struct {
x bool
y bool
type chunks struct {
*Model
ml.Tensor
aspectRatio image.Point
dataOnce sync.Once
data []float32
}
type chunk struct {
*chunks
s, n int
}
func (r *chunk) floats() []float32 {
r.dataOnce.Do(func() {
temp := r.Backend().NewContext()
defer temp.Close()
temp.Forward(r.Tensor).Compute(r.Tensor)
r.data = r.Floats()
})
return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)]
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {
if len(inp.Multimodal) == 0 {
if inp.Multimodal == nil {
result = append(result, inp)
continue
}
t := inp.Multimodal.(*chunks)
var imageInputs []input.Input
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|>
for i, mm := range inp.Multimodal {
patchesPerChunk := mm.Tensor.Dim(1)
var offset int
patchesPerChunk := t.Dim(1)
if t.aspectRatio.Y*t.aspectRatio.X > 1 {
patchesPerChunk = t.Dim(1) / (t.aspectRatio.X*t.aspectRatio.Y + 1)
if i < len(inp.Multimodal)-1 {
separator := mm.Data.(*separator)
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
if separator.x {
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
for range t.aspectRatio.Y {
for x := range t.aspectRatio.X {
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
if x < t.aspectRatio.X-1 {
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
}
offset += patchesPerChunk
}
if separator.y {
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
}
} else {
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
}
}
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
result = append(result, imageInputs...)
}

View File

@@ -8,8 +8,6 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
@@ -33,8 +31,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
if useRope {
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
query = query.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale)
key = key.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale)
}
if opts.useQKNorm {
@@ -82,7 +80,7 @@ func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tens
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates = nextStates.Add(ctx, downStates.View(ctx, i*downStates.Stride(1), hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)))
nextStates.Add(ctx, downStates.View(ctx, i*downStates.Stride(1), hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)))
}
return nextStates
@@ -212,7 +210,12 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
for _, mi := range batch.Multimodal {
img := mi.Multimodal[0].Tensor
f32s := mi.Multimodal.(*chunk).floats()
img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize)
if err != nil {
panic(err)
}
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
}
@@ -252,5 +255,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
return key.RoPE(ctx, shift, m.Layers[layer].Attention.RopeFactors, uint32(0), uint32(m.ropeDim), m.ropeBase, m.ropeScale), nil
}

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"image"
"slices"
"sync"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@@ -31,26 +32,31 @@ var _ model.MultimodalProcessor = (*Model)(nil)
var _ model.TextProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
textModel, err := NewTextModel(c)
if err != nil {
return nil, err
}
m := &Model{
TextModel: textModel,
VisionModel: newVisionModel(c),
ImageProcessor: newImageProcessor(c),
MultiModalProjector: newMultiModalProjector(c),
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id", 2)),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
// TODO: set EOT to EOS otherwise 0 will stop generation
EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
TextModel: newTextModel(c),
VisionModel: newVisionModel(c),
ImageProcessor: newImageProcessor(c),
MultiModalProjector: newMultiModalProjector(c),
}
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
@@ -99,7 +105,7 @@ func newMultiModalProjector(c fs.Config) *MultiModalProjector {
}
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
@@ -123,14 +129,37 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
// split into patches to be sent to the text transformer
rows := make([]input.Multimodal, size.Y)
parent := imageFeatures{tensor: features}
rows := make([]*imageRow, size.Y)
for i := range rows {
rows[i].Tensor = features.View(ctx, features.Stride(1)*size.X*i, features.Dim(0), features.Stride(1), size.X)
rows[i] = &imageRow{parent: &parent, s: i, shape: []int{features.Dim(0), size.X}}
}
return rows, nil
}
type imageFeatures struct {
tensor ml.Tensor
dataOnce sync.Once
data []float32
}
type imageRow struct {
parent *imageFeatures
s int
shape []int
}
func (r *imageRow) data() []float32 {
n := 1
for _, s := range r.shape {
n *= s
}
return r.parent.data[r.s*n : (r.s+1)*n]
}
// PostTokenize arranges Mistral 3's inputs for the forward pass
// In Mistral 3 and Pixtral, the input patches are arranged as follows:
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
@@ -139,14 +168,15 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {
if len(inp.Multimodal) == 0 {
if inp.Multimodal == nil {
result = append(result, inp)
} else {
for i, row := range inp.Multimodal {
inputMultimodal := inp.Multimodal.([]*imageRow)
for i, row := range inputMultimodal {
// [IMG]
result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
if i == len(inp.Multimodal)-1 {
result = append(result, input.Input{Token: 10, Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.shape[1]})
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.shape[1]-1)...)
if i == len(inputMultimodal)-1 {
// [IMG_END]
result = append(result, input.Input{Token: 13})
} else {

View File

@@ -1,24 +1,27 @@
package mistral3
import (
"cmp"
"fmt"
"math"
"strings"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type TextOptions struct {
hiddenSize, numHeads, numKVHeads int
headDim, ropeDim int
eps, ropeBase, ropeScale float32
hiddenSize, numHeads, numKVHeads, headDim int
eps, ropeBase, ropeScale float32
ropeDim uint32
}
type TextModel struct {
model.Base
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -36,15 +39,19 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
ropeType := uint32(0)
headDim := opts.headDim
if headDim == 0 {
headDim = opts.hiddenSize / opts.numHeads
}
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale)
q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale)
k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -55,7 +62,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale), nil
return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
}
type MLP struct {
@@ -102,7 +109,20 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
// image embeddings
for _, image := range batch.Multimodal {
imageFeature := image.Multimodal[0].Tensor
row := image.Multimodal.(*imageRow)
row.parent.dataOnce.Do(func() {
// use a new, throwaway context so the image tensor is not added to the graph
temp := m.Backend().NewContext()
temp.Forward(row.parent.tensor).Compute(row.parent.tensor)
row.parent.data = row.parent.tensor.Floats()
temp.Close()
})
imageFeature, err := ctx.Input().FromFloatSlice(row.data(), row.shape...)
if err != nil {
panic(err)
}
ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1))))
}
@@ -121,18 +141,24 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
return m.Output.Forward(ctx, hiddenState)
}
func newTextModel(c fs.Config) *TextModel {
return &TextModel{
func NewTextModel(c fs.Config) (*TextModel, error) {
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
}
textModel := &TextModel{
Layers: make([]Layer, c.Uint("block_count")),
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
},
}
return textModel, nil
}

View File

@@ -170,7 +170,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count")),
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)),
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length", 1024)),
numHeads: int(c.Uint("vision.attention.head_count", 16)),

View File

@@ -3,7 +3,6 @@ package mllama
import (
"bytes"
"image"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@@ -38,13 +37,13 @@ func New(c fs.Config) (model.Model, error) {
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
// TODO: set EOT to EOS otherwise 0 will stop generation
EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
ImageProcessor: newImageProcessor(c),
@@ -59,7 +58,7 @@ func New(c fs.Config) (model.Model, error) {
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
@@ -74,17 +73,13 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return nil, err
}
if ratio.numTiles() < m.maxNumTiles {
// Pad tiles to maxNumTiles
f32s = slices.Grow(f32s, m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles)
f32s = f32s[:m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles]
}
pixelValues, err := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles)
pixelValues, err := ctx.Input().FromFloatSlice(f32s, m.imageSize, m.imageSize, m.numChannels, ratio.numTiles())
if err != nil {
return nil, err
}
pixelValues = pixelValues.Pad(ctx, 0, 0, 0, m.ImageProcessor.maxNumTiles-ratio.numTiles())
aspectRatio, err := ctx.Input().FromIntSlice([]int32{int32(ratio.rank)}, 1)
if err != nil {
return nil, err
@@ -92,9 +87,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32)
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
projectedOutputs := m.Projector.Forward(ctx, crossAttentionStates)
return []input.Multimodal{{Tensor: projectedOutputs}}, nil
return m.Projector.Forward(ctx, crossAttentionStates), nil
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
@@ -110,7 +103,7 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor
if len(batch.Multimodal) > 0 {
crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal[0].Tensor
crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal.(ml.Tensor)
}
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))

View File

@@ -8,8 +8,6 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
)
type TextSelfAttention struct {
@@ -23,14 +21,15 @@ type TextSelfAttention struct {
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -45,7 +44,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
}
return key, nil
@@ -200,8 +199,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
type TextModelOptions struct {
hiddenSize, numHeads, numKVHeads int
ropeDim int
eps, ropeBase, ropeScale float32
ropeDim uint32
crossAttentionLayers []int32
}
@@ -241,10 +240,10 @@ func newTextModel(c fs.Config) *TextModel {
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
ropeDim: int(c.Uint("rope.dimension_count")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
crossAttentionLayers: c.Ints("attention.cross_attention_layers"),
},
}

View File

@@ -16,6 +16,8 @@ type VisionSelfAttention struct {
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
Gate ml.Tensor `gguf:"attn_gate"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
@@ -23,16 +25,27 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState ml.Tensor, op
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), batchSize)
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), batchSize)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), batchSize)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
scores := key.Mulmat(ctx, query)
scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
scores = scores.Softmax(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Reshape(ctx, headDim, attention.Dim(1), opts.numHeads, batchSize)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
hiddenState = sa.Output.Forward(ctx, attention)
return hiddenState
}
type VisionMLP struct {
@@ -63,18 +76,21 @@ func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState ml.Tensor, opts
// self attention
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, opts)
if e.AttentionGate != nil {
hiddenState = hiddenState.Mul(ctx, e.AttentionGate)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
// feed forward
hiddenState = e.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
hiddenState = hiddenState.Add(ctx, residual)
if e.MLPGate != nil {
hiddenState = hiddenState.Mul(ctx, e.MLPGate)
}
hiddenState = hiddenState.Add(ctx, residual)
return hiddenState
}

View File

@@ -7,7 +7,4 @@ import (
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama"
_ "github.com/ollama/ollama/model/models/qwen2"
_ "github.com/ollama/ollama/model/models/qwen25vl"
_ "github.com/ollama/ollama/model/models/qwen3"
)

View File

@@ -1,170 +0,0 @@
package qwen2
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
hiddenSize, numHeads, numKVHeads int
headDim, ropeDim int
eps, ropeBase, ropeScale float32
}
type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenStates.Dim(1)
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
ropeDim := cmp.Or(opts.ropeDim, headDim)
query := attn.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
key := attn.Key.Forward(ctx, hiddenStates)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
value := attn.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
return attn.Output.Forward(ctx, attention)
}
type MLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type DecoderLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
Attention *Attention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (d DecoderLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenStates
hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = d.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = d.MLP.Forward(ctx, hiddenStates)
return hiddenStates.Add(ctx, residual)
}
type Model struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []DecoderLayer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Options
}
// Forward implements model.Model.
func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
hiddenStates = m.Output.Forward(ctx, hiddenStates)
return hiddenStates, nil
}
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
}
func New(c fs.Config) (model.Model, error) {
m := Model{
Layers: make([]DecoderLayer, c.Uint("block_count")),
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
ropeDim: int(c.Uint("rope.dimension_count")),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
eps: c.Float("attention.layer_norm_rms_epsilon"),
},
}
m.Cache = kvcache.NewCausalCache(m.Shift)
return &m, nil
}
func init() {
model.Register("qwen2", New)
}

View File

@@ -1,160 +0,0 @@
package qwen25vl
import (
"bytes"
"fmt"
"image"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.BytePairEncoding
*TextModel
*VisionModel `gguf:"v,vision"`
ImageProcessor
}
// Implement MultimodalProcessor interface
var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
m := &Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
TextModel: NewTextModel(c),
VisionModel: newVisionModel(c),
ImageProcessor: newImageProcessor(c),
}
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
return m, nil
}
func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *Grid, error) {
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, nil, err
}
f32s, grid, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, nil, err
}
// Calculate tensor dimensions
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
numPatches := grid.Temporal * grid.Height * grid.Width
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
if err != nil {
return nil, nil, fmt.Errorf("failed to create tensor from image: %w", err)
}
return pixelValues, grid, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
pixels, grid, err := m.PixelValues(ctx, multimodalData)
if err != nil {
return nil, err
}
visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
return []input.Multimodal{{Tensor: visionOutputs}}, nil
}
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
var (
imageToken int32 = 151655
visionStartToken int32 = 151652
visionEndToken int32 = 151653
)
nImg := 0
for _, inp := range inputs {
if inp.Multimodal == nil {
// If not a multimodal input, add it to the result unchanged
result = append(result, inp)
} else {
// Adding the 'Picture' prefix is a hack, at the time of writing there is no way to prefix
// the image tokens with a prompt, so we add a prefix here
nImg++
pre, err := m.Encode(fmt.Sprintf(" Picture %d: ", nImg), true)
if err != nil {
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
}
for i := range pre {
result = append(result, input.Input{Token: pre[i]})
}
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
// First add the vision start token
result = append(result, input.Input{Token: visionStartToken})
// Add the image token with the multimodal tensor data at the first position
result = append(result, input.Input{
Token: imageToken,
Multimodal: inp.Multimodal,
MultimodalHash: inp.MultimodalHash,
SameBatch: patchesPerChunk,
})
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
result = append(result, input.Input{Token: visionEndToken})
}
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
}
func init() {
model.Register("qwen25vl", New)
}

View File

@@ -1,151 +0,0 @@
package qwen25vl
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
type TextOptions struct {
hiddenSize, numHeads, numKVHeads int
ropeDim, originalContextLength int
eps, ropeBase, ropeScale float32
}
type TextModel struct {
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
}
func NewTextModel(c fs.Config) *TextModel {
m := TextModel{
Layers: make([]Layer, c.Uint("block_count")),
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
ropeDim: int(c.Uint("rope.dimension_count", 128)),
originalContextLength: int(c.Uint("context_length", 128000)),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
},
}
return &m
}
// SelfAttention implements the multi-head self-attention mechanism
// with separate projections for query, key, value and output transformations
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv)
}
// Shift applies rotary position embeddings to the key tensor for causal attention caching
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
}
// MLP implements the feed-forward network component with SwiGLU activation
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
// Apply SwiGLU activation gating
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
// Project back to hidden dimension
return mlp.Down.Forward(ctx, hiddenState)
}
// Layer represents a single transformer layer combining self-attention and feed-forward components
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
// Self-attention branch with residual connection
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
// Feed-forward branch with residual connection
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
// Initial token embedding
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
for _, mi := range batch.Multimodal {
img := mi.Multimodal[0].Tensor
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
}
// Process through transformer layers
for i, layer := range m.Layers {
cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, cache, m.TextOptions)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}

View File

@@ -1,391 +0,0 @@
package qwen25vl
import (
"fmt"
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
// We only support batch size of 1
var batchSize int = 1
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
return x2.Neg(ctx).Concat(ctx, x1, 0)
}
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
}
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor {
// Create a flat slice for the mask (all -inf initially to block all attention)
flat := make([]float32, seqLength*seqLength)
for i := range flat {
flat[i] = float32(math.Inf(-1)) // Negative infinity to block attention
}
// Fill in the mask with zeros for tokens that CAN attend to each other
for i := 1; i < len(bounds); i++ {
start := bounds[i-1]
end := bounds[i]
// Enable attention within this sequence block by setting values to 0
for row := start; row < end; row++ {
for col := start; col < end; col++ {
idx := row*seqLength + col
flat[idx] = 0.0 // 0 allows attention, -inf blocks it
}
}
}
mask, err := ctx.Input().FromFloatSlice(flat, seqLength, seqLength)
if err != nil {
panic(err)
}
// Reshape to match [seqLength, seqLength, 1] for broadcasting
mask = mask.Reshape(ctx, seqLength, seqLength, 1)
return mask
}
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_out"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
// Scale factor for scaled dot-product attention
scale := 1.0 / math.Sqrt(float64(opts.headDim))
// Scaled dot-product attention
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale)
if mask != nil {
kq = kq.Add(ctx, mask)
}
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}
// VisionMLP implements the multi-layer perceptron
type VisionMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
// Using activation as specified in config (likely GELU or SiLU/Swish)
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
upOutput := mlp.Up.Forward(ctx, hiddenStates)
hiddenStates = gateOutput.SILU(ctx).Mul(ctx, upOutput)
return mlp.Down.Forward(ctx, hiddenStates)
}
type VisionEncoderLayer struct {
Norm1 *nn.RMSNorm `gguf:"ln1"`
SelfAttention *VisionSelfAttention
Norm2 *nn.RMSNorm `gguf:"ln2"`
MLP *VisionMLP
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, mask, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = e.Norm2.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
// VisionModelOptions contains configuration options
type VisionModelOptions struct {
hiddenSize int
numHeads int
headDim int
patchSize int
numChannels int
eps float32
ropeTheta float32
spatialMergeSize int
windowSize int
fullAttnBlocks []int32
temporalPatchSize int
}
type PatchEmbedding struct {
PatchConv0 *nn.Conv2D `gguf:"patch_embd_0"`
PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"`
}
func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, opts *VisionModelOptions) ml.Tensor {
numPatches := pixelValues.Shape()[1]
// Reshape the input tensor to match the expected dimensions
pixelValues = pixelValues.Reshape(ctx, opts.patchSize*opts.patchSize, opts.temporalPatchSize, opts.numChannels, numPatches)
// Permute the tensor to bring the temporal dimension to the front
pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Split the tensor into parts for the temporal convolutions
in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in0 = in0.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
in1 = in1.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
s0, s1 := opts.patchSize, opts.patchSize // Use full stride
p0, p1 := 0, 0 // padding
d0, d1 := 1, 1 // dilation
out0 := pe.PatchConv0.Forward(ctx, in0, s0, s1, p0, p1, d0, d1)
out1 := pe.PatchConv1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1)
// Add the outputs from the two temporal convolutions
out := out0.Add(ctx, out1)
// Reshape the output tensor to match the expected dimensions
return out.Reshape(ctx, opts.hiddenSize, numPatches)
}
// VisionPatchMerger implements patch merging for the Qwen vision model
type VisionPatchMerger struct {
LNQ *nn.RMSNorm `gguf:"ln_q"`
MLP0 *nn.Linear `gguf:"mlp.0"`
MLP2 *nn.Linear `gguf:"mlp.2"`
}
// Forward computes patch merging for the vision model
func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
normalized := pm.LNQ.Forward(ctx, visionOutputs, opts.eps)
hiddenSize := visionOutputs.Dim(0) * (opts.spatialMergeSize * opts.spatialMergeSize)
// Reshape the normalized output to view the hidden size dimension
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize), batchSize)
hidden := pm.MLP0.Forward(ctx, reshaped)
activated := hidden.GELU(ctx)
output := pm.MLP2.Forward(ctx, activated)
return output
}
// VisionModel implements the Qwen vision model
type VisionModel struct {
PatchEmbedding *PatchEmbedding
Layers []VisionEncoderLayer `gguf:"blk"`
PatchMerger *VisionPatchMerger `gguf:"merger"`
*VisionModelOptions
}
// Forward computes the vision model for an input tensor
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
// Extract patch embeddings
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions)
positionEmbedding := m.PositionalEmbedding(ctx, grid)
windowIndex, bounds := m.WindowIndex(ctx, grid)
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*spatialMergeUnit, hiddenStates.Dim(1)/spatialMergeUnit)
hiddenStates = hiddenStates.Rows(ctx, windowIndex)
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)/spatialMergeUnit, hiddenStates.Dim(1)*spatialMergeUnit)
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)*spatialMergeUnit, positionEmbedding.Dim(1)/spatialMergeUnit)
positionEmbedding = positionEmbedding.Rows(ctx, windowIndex)
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)/spatialMergeUnit, positionEmbedding.Dim(1)*spatialMergeUnit)
positionEmbedding = positionEmbedding.Concat(ctx, positionEmbedding, 0)
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads)
// Apply encoder layers
for i, layer := range m.Layers {
if slices.Contains(m.fullAttnBlocks, int32(i)) {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, nil, m.VisionModelOptions)
} else {
hiddenStates = layer.Forward(
ctx,
hiddenStates,
cos,
sin,
mask,
m.VisionModelOptions,
)
}
}
hiddenStates = m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions)
reverseWindowIndex := windowIndex.Argsort(ctx)
return hiddenStates.Rows(ctx, reverseWindowIndex)
}
// WindowIndex divides the grid into windows and returns:
// 1. A tensor containing flattened indices of all grid points organized by windows
// 2. A slice of boundaries that mark where each window's data begins and ends
// in the flattened representation, scaled by spatialMergeSize squared
//
// The boundaries slice always starts with 0 and contains cumulative ending
// positions for each window, allowing downstream processing to identify
// window boundaries in the tensor data.
func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) {
vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize
llmGridH := grid.Height / m.spatialMergeSize
llmGridW := grid.Width / m.spatialMergeSize
// Calculate window parameters
numWindowsH := int(math.Ceil(float64(llmGridH) / float64(vitMergerWindowSize)))
numWindowsW := int(math.Ceil(float64(llmGridW) / float64(vitMergerWindowSize)))
// Initialize index_new slice
var index []int32
// Initialize bounds with the first element as 0
bounds := []int{0}
totalSeqLen := 0
// Process each window without padding
for wh := range numWindowsH {
for ww := range numWindowsW {
// Calculate window boundaries
hStart := wh * vitMergerWindowSize
wStart := ww * vitMergerWindowSize
hEnd := min(hStart+vitMergerWindowSize, llmGridH)
wEnd := min(wStart+vitMergerWindowSize, llmGridW)
// Calculate sequence length for this window
seqLen := (hEnd - hStart) * (wEnd - wStart)
// Collect indices for this window
for h := hStart; h < hEnd; h++ {
for w := wStart; w < wEnd; w++ {
index = append(index, int32(h*llmGridW+w))
}
}
totalSeqLen += seqLen
bounds = append(bounds, totalSeqLen*(m.spatialMergeSize*m.spatialMergeSize)+bounds[0])
}
}
t, err := ctx.Input().FromIntSlice(index, len(index))
if err != nil {
panic(err)
}
return t, bounds
}
// PositionalEmbedding generates rotary position embeddings for attention mechanisms
func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor {
dim := m.headDim / 2
freq := dim / 2
theta := float64(m.ropeTheta)
merge := m.spatialMergeSize
// Create frequency patterns for position encoding
maxGridSize := max(grid.Height, grid.Width)
freqVals := make([]float32, freq*maxGridSize)
for i := range maxGridSize {
for j := range freq {
freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim)))
}
}
freqs, err := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize)
if err != nil {
panic(fmt.Errorf("failed to create tensor from frequencies: %w", err))
}
// Create position coordinates (y,x pairs) for the grid
// In PyTorch: Equivalent to generating position ids with torch.arange()
coords := make([]int32, 0, grid.Height*grid.Width*2)
for y := range grid.Height {
for x := range grid.Width {
coords = append(coords, int32(y), int32(x))
}
}
pos, err := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height)
if err != nil {
panic(fmt.Errorf("failed to create tensor from positions: %w", err))
}
// Reshape and permute positions to match spatial merging pattern
pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge)
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
pos = pos.Reshape(ctx, 2, merge, merge, grid.Width/merge*grid.Height/merge)
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
pos = pos.Reshape(ctx, 2*merge*merge*grid.Width/merge*grid.Height/merge)
// Use position indices to look up corresponding frequency values
positionalEmbedding := freqs.Rows(ctx, pos)
positionalEmbedding = positionalEmbedding.Reshape(ctx, positionalEmbedding.Dim(0)*2, positionalEmbedding.Dim(1)/2)
return positionalEmbedding
}
// newVisionModel creates a new instance of the Qwen vision model
func newVisionModel(c fs.Config) *VisionModel {
patchSize := int(c.Uint("vision.patch_size", 14))
hiddenSize := int(c.Uint("vision.embedding_length", 1280))
numHeads := int(c.Uint("vision.attention.head_count", 16))
numChannels := int(c.Uint("vision.num_channels", 3))
eps := c.Float("vision.attention.layer_norm_epsilon", 1e-6)
ropeTheta := c.Float("vision.rope.freq_base", 10000.0)
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
windowSize := int(c.Uint("vision.window_size", 112))
fullAttnBlocks := c.Ints("qwen25vl.vision.fullatt_block_indexes", []int32{7, 15, 23, 31})
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
model := &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)),
VisionModelOptions: &VisionModelOptions{
hiddenSize: hiddenSize,
numHeads: numHeads,
headDim: hiddenSize / numHeads,
patchSize: patchSize,
numChannels: numChannels,
eps: eps,
ropeTheta: ropeTheta,
spatialMergeSize: spatialMergeSize,
windowSize: windowSize,
temporalPatchSize: temporalPatchSize,
fullAttnBlocks: fullAttnBlocks,
},
}
return model
}

View File

@@ -1,184 +0,0 @@
package qwen25vl
import (
"fmt"
"image"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
// ImageProcessor contains configuration for the Qwen 2.5 VL image processing
type ImageProcessor struct {
numChannels int
patchSize int
temporalPatchSize int
mergeSize int
minPixels int
maxPixels int
factor int
rescaleFactor float32
imageMean []float32
imageStd []float32
}
// newImageProcessor creates a new image processor with default values
func newImageProcessor(c fs.Config) ImageProcessor {
patchSize := int(c.Uint("vision.patch_size", 14))
mergeSize := int(c.Uint("vision.spatial_merge_size", 2))
return ImageProcessor{
numChannels: int(c.Uint("vision.num_channels", 3)), // not set
patchSize: patchSize,
temporalPatchSize: 2,
mergeSize: mergeSize,
minPixels: 56 * 56,
maxPixels: int(c.Uint("vision.max_pixels", 28*28*1280)), // 1MP limit
factor: patchSize * mergeSize,
rescaleFactor: 1.0 / 255.0,
imageMean: imageproc.ClipDefaultMean[:],
imageStd: imageproc.ClipDefaultSTD[:],
}
}
// SmartResize implements the smart resize algorithm
func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
factor := p.factor
if height < factor || width < factor {
panic(fmt.Sprintf("height:%d or width:%d must be larger than factor:%d", height, width, factor))
} else if aspectRatio := max(height, width) / min(height, width); aspectRatio > 200 {
panic(fmt.Sprintf("absolute aspect ratio must be smaller than 200, got %v", aspectRatio))
}
round := func(x float64) int { return int(math.RoundToEven(x)) }
hBar := round(float64(height)/float64(factor)) * factor
wBar := round(float64(width)/float64(factor)) * factor
if hBar*wBar > p.maxPixels {
beta := math.Sqrt(float64(height*width) / float64(p.maxPixels))
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
} else if hBar*wBar < p.minPixels {
beta := math.Sqrt(float64(p.minPixels) / float64(height*width))
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
}
return hBar, wBar
}
type Grid struct {
Height int
Width int
Temporal int
}
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error) {
origWidth := img.Bounds().Dx()
origHeight := img.Bounds().Dy()
// Calculate smart resize dimensions
resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth)
// Resize image using existing functions
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
normalizedPixels := imageproc.Normalize(
resizedImg,
[3]float32{p.imageMean[0], p.imageMean[1], p.imageMean[2]},
[3]float32{p.imageStd[0], p.imageStd[1], p.imageStd[2]},
true, // rescale
true, // channelFirst
)
// Calculate grid dimensions
grid := &Grid{
Height: resizedHeight / p.patchSize,
Width: resizedWidth / p.patchSize,
Temporal: 1, // For single images, temporal dimension is 1
}
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
if err != nil {
return nil, nil, fmt.Errorf("failed to create patches: %v", err)
}
// Return patches and grid dimensions
return patches, grid, nil
}
func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) {
channels := p.numChannels
patchSize := p.patchSize
mergeSize := p.mergeSize
temporalPatchSize := p.temporalPatchSize
// Calculate output dimensions
numPatches := grid.Temporal * grid.Height * grid.Width
patchDim := channels * temporalPatchSize * patchSize * patchSize
result := make([]float32, numPatches*patchDim)
patchIndex := 0
// Single temporal frame handling (copies to all frames)
for range grid.Temporal {
for h := 0; h < grid.Height; h += mergeSize {
for w := 0; w < grid.Width; w += mergeSize {
// Handle the 2x2 merged patches
for mh := range mergeSize {
for mw := range mergeSize {
baseOffset := patchIndex * patchDim
// Extract patch data for first temporal frame
for c := range channels {
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
for py := range patchSize {
for px := range patchSize {
// Calculate source pixel coordinates
y := (h+mh)*patchSize + py
x := (w+mw)*patchSize + px
// Source index in input tensor (CHW format)
srcIdx := c*height*width + y*width + x
// Destination index in first temporal frame
dstIdx := channelOffset + (py * patchSize) + px
if srcIdx < len(pixels) && dstIdx < len(result) {
result[dstIdx] = pixels[srcIdx]
}
}
}
}
// Copy first temporal frame to all other frames
if temporalPatchSize > 1 {
for c := range channels {
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
firstFrameOffset := channelOffset
frameSize := patchSize * patchSize
// Copy first frame to all other frames
for tp := 1; tp < temporalPatchSize; tp++ {
currentFrameOffset := channelOffset + (tp * frameSize)
copy(result[currentFrameOffset:currentFrameOffset+frameSize],
result[firstFrameOffset:firstFrameOffset+frameSize])
}
}
}
patchIndex++
}
}
}
}
}
return result, nil
}

View File

@@ -1,239 +0,0 @@
package qwen3
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
hiddenSize, numHeads, numKVHeads int
eps float32
ropeBase, ropeScale float32
keyLength, valueLength int
numExperts, numExpertsUsed int
normTopKProb bool
}
func (o Options) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
}
type Attention struct {
QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"`
Query *nn.Linear `gguf:"attn_q"`
KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenStates.Dim(1)
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, opts.ropeScale, rope.WithTypeNeoX())
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
return sa.Output.Forward(ctx, attention)
}
type MLP interface {
Forward(ml.Context, ml.Tensor, *Options) ml.Tensor
}
type sparse struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
Up ml.Tensor `gguf:"ffn_up_exps.weight"`
Down ml.Tensor `gguf:"ffn_down_exps.weight"`
}
func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
routerLogits := mlp.Router.Forward(ctx, hiddenStates)
routingWeights := routerLogits.Softmax(ctx)
selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed)
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, selectedExperts)
if opts.normTopKProb {
routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1))
routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx))
routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1))
}
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
upStates := mlp.Up.MulmatID(ctx, hiddenStates, selectedExperts)
hiddenStates = mlp.Gate.MulmatID(ctx, hiddenStates, selectedExperts)
hiddenStates = hiddenStates.SILU(ctx)
hiddenStates = hiddenStates.Mul(ctx, upStates)
experts := mlp.Down.MulmatID(ctx, hiddenStates, selectedExperts)
experts = experts.Mul(ctx, routingWeights)
nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2)))
}
return nextStates
}
type dense struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *Options) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
*Attention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP
}
func (d *Layer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenStates
hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, cache, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = d.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = d.MLP.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
type Model struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Layers []Layer `gguf:"blk"`
*Options
}
// Forward implements model.Model.
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs, err = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, m.ropeScale, rope.WithTypeNeoX()), nil
}
var _ model.Model = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
layers := make([]Layer, c.Uint("block_count"))
for i := range layers {
if c.String("general.architecture") == "qwen3moe" {
layers[i].MLP = &sparse{}
} else {
layers[i].MLP = &dense{}
}
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
Layers: layers,
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
keyLength: int(c.Uint("attention.key_length")),
valueLength: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
normTopKProb: c.Bool("norm_top_k_prob", true),
},
}
m.Cache = kvcache.NewCausalCache(m.Shift)
return &m, nil
}
func init() {
model.Register("qwen3", New)
model.Register("qwen3moe", New)
}

View File

@@ -5,13 +5,116 @@ import (
"context"
"iter"
"log/slog"
"slices"
"strings"
"sync"
"github.com/dlclark/regexp2"
heap "github.com/emirpasic/gods/v2/trees/binaryheap"
"github.com/ollama/ollama/logutil"
)
type Special int32
const (
SpecialBOS Special = iota
SpecialEOS
)
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
TOKEN_TYPE_CONTROL
TOKEN_TYPE_USER_DEFINED
TOKEN_TYPE_UNUSED
TOKEN_TYPE_BYTE
)
type TextProcessor interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
Vocabulary() *Vocabulary
}
type Vocabulary struct {
Values []string
Types []int32
Scores []float32
Merges []string
BOS, EOS, EOT int32
AddBOS, AddEOS, AddEOT bool
specialOnce sync.Once
special []string
valuesOnce sync.Once
values map[string]int32
mergeOnce sync.Once
merge map[string]int32
}
func (v *Vocabulary) Is(id int32, special Special) bool {
switch special {
case SpecialBOS:
return id == v.BOS
case SpecialEOS:
return id == v.EOS || id == v.EOT
default:
return false
}
}
func (v *Vocabulary) Encode(s string) int32 {
v.valuesOnce.Do(func() {
v.values = make(map[string]int32, len(v.Values))
for i, value := range v.Values {
v.values[value] = int32(i)
}
})
if id, ok := v.values[s]; ok {
return id
}
return -1
}
func (v *Vocabulary) Decode(id int32) string {
return v.Values[id]
}
func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() {
for i := range v.Values {
if slices.Contains([]int{105, 106}, i) {
v.special = append(v.special, v.Values[i])
} else if v.Types[i] == TOKEN_TYPE_CONTROL {
v.special = append(v.special, v.Values[i])
}
}
})
return v.special
}
func (v *Vocabulary) Merge(left, right string) int {
v.mergeOnce.Do(func() {
v.merge = make(map[string]int32, len(v.Merges))
for i, merge := range v.Merges {
v.merge[merge] = int32(i)
}
})
if id, ok := v.merge[left+" "+right]; ok {
return int(id)
}
return -1
}
type BytePairEncoding struct {
pre *regexp2.Regexp
vocab *Vocabulary
@@ -201,12 +304,27 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
}
}
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids)
if addSpecial && len(ids) > 0 {
ids = bpe.vocab.addSpecials(ids)
if bpe.vocab.AddBOS {
if ids[0] == bpe.vocab.BOS {
slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)
}
slog.Debug("adding bos token to prompt", "id", bpe.vocab.BOS)
ids = append([]int32{bpe.vocab.BOS}, ids...)
}
if bpe.vocab.AddEOS {
if ids[len(ids)-1] == bpe.vocab.EOS {
slog.Warn("adding eos token to prompt which already has it", "id", bpe.vocab.EOS)
}
slog.Debug("adding eos token to prompt", "id", bpe.vocab.EOS)
ids = append(ids, bpe.vocab.EOS)
}
}
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "ids", ids)
return ids, nil
}
@@ -234,6 +352,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
}
}
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String())
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String())
return sb.String(), nil
}

View File

@@ -182,12 +182,27 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
}
}
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids)
if addSpecial && len(ids) > 0 {
ids = spm.vocab.addSpecials(ids)
if spm.vocab.AddBOS {
if ids[0] == spm.vocab.BOS {
slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
}
slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
ids = append([]int32{spm.vocab.BOS}, ids...)
}
if spm.vocab.AddEOS {
if ids[len(ids)-1] == spm.vocab.EOS {
slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
}
slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
ids = append(ids, spm.vocab.EOS)
}
}
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "ids", ids)
return ids, nil
}
@@ -246,6 +261,6 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
}
}
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String())
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String())
return sb.String(), nil
}

View File

@@ -1,17 +0,0 @@
package model
const (
TOKEN_TYPE_NORMAL = iota + 1
TOKEN_TYPE_UNKNOWN
TOKEN_TYPE_CONTROL
TOKEN_TYPE_USER_DEFINED
TOKEN_TYPE_UNUSED
TOKEN_TYPE_BYTE
)
type TextProcessor interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
Vocabulary() *Vocabulary
}

View File

@@ -1,112 +0,0 @@
package model
import (
"log/slog"
"slices"
"sync"
)
type Special int32
const (
SpecialBOS Special = iota
SpecialEOS
)
type Vocabulary struct {
Values []string
Types []int32
Scores []float32
Merges []string
BOS, EOS []int32
AddBOS, AddEOS bool
specialOnce sync.Once
special []string
valuesOnce sync.Once
values map[string]int32
mergeOnce sync.Once
merge map[string]int32
}
func (v *Vocabulary) Is(id int32, special Special) bool {
switch special {
case SpecialBOS:
return slices.Contains(v.BOS, id)
case SpecialEOS:
return slices.Contains(v.EOS, id)
default:
return false
}
}
func (v *Vocabulary) addSpecials(ids []int32) []int32 {
if v.AddBOS && len(v.BOS) > 0 {
if slices.Contains(v.BOS, ids[0]) {
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
}
slog.Debug("adding bos token to prompt", "id", v.BOS)
ids = append([]int32{v.BOS[0]}, ids...)
}
if v.AddEOS && len(v.EOS) > 0 {
if slices.Contains(v.BOS, ids[len(ids)-1]) {
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
}
slog.Debug("adding eos token to prompt", "id", v.EOS)
ids = append(ids, v.EOS[0])
}
return ids
}
func (v *Vocabulary) Encode(s string) int32 {
v.valuesOnce.Do(func() {
v.values = make(map[string]int32, len(v.Values))
for i, value := range v.Values {
v.values[value] = int32(i)
}
})
if id, ok := v.values[s]; ok {
return id
}
return -1
}
func (v *Vocabulary) Decode(id int32) string {
return v.Values[id]
}
func (v *Vocabulary) SpecialVocabulary() []string {
v.specialOnce.Do(func() {
for i := range v.Values {
if v.Types[i] == TOKEN_TYPE_CONTROL {
v.special = append(v.special, v.Values[i])
}
}
})
return v.special
}
func (v *Vocabulary) Merge(left, right string) int {
v.mergeOnce.Do(func() {
v.merge = make(map[string]int32, len(v.Merges))
for i, merge := range v.Merges {
v.merge[merge] = int32(i)
}
})
if id, ok := v.merge[left+" "+right]; ok {
return int(id)
}
return -1
}

View File

@@ -104,8 +104,8 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach
slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
"used", numPast, "remaining", len(prompt)-numPast)
slot.Inputs = prompt[:numPast]
prompt = prompt[numPast:]
slot.Inputs = slot.Inputs[:numPast]
return slot, prompt, nil
}

View File

@@ -136,8 +136,8 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
"used", numPast, "remaining", int32(len(prompt))-numPast)
slot.Inputs = prompt[:numPast]
prompt = prompt[numPast:]
slot.Inputs = slot.Inputs[:numPast]
return slot, prompt, nil
}

View File

@@ -3,6 +3,7 @@ package ollamarunner
import (
"errors"
"fmt"
"image"
"testing"
"time"
@@ -11,6 +12,10 @@ import (
)
func TestCountCommon(t *testing.T) {
imgA := image.NewRGBA(image.Rect(0, 0, 100, 100))
imgB := image.NewRGBA(image.Rect(0, 0, 50, 50))
imgC := image.NewRGBA(image.Rect(50, 50, 100, 100))
tests := []struct {
name string
t1 []input.Input
@@ -31,20 +36,20 @@ func TestCountCommon(t *testing.T) {
},
{
name: "Image Prefix",
t1: []input.Input{{MultimodalHash: 1}},
t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}},
t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}},
expected: 1,
},
{
name: "Mixed",
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}},
expected: 2,
},
{
name: "Mixed, Same Length",
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {MultimodalHash: 2}},
t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}},
expected: 1,
},
{

View File

@@ -1,116 +0,0 @@
package ollamarunner
import (
"errors"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
// Tensors can't be used across multiple compute graphs. This is a problem
// if a single embedding is split across batches using views since all of
// the views will have the same source tensor. We also don't want to
// recompute the entire embedding for each batch.
//
// To avoid this, we compute all of the tensors for the embedding on the
// first use and then store the result in system memory. When we need
// additional tensors, we recreate them from the stored data.
// multimodalEntry represents the embeddings of a single object (such
// as an image).
type multimodalEntry struct {
// mm is the original set of tensors created by EncodeMultimodal
mm []input.Multimodal
// data is the computed result of mm. Nil if not yet computed
data [][]float32
}
// multimodalStore maps from an individual tensor (of which there
// may be many in a single multimodal object) to its parent embedding
type multimodalStore map[ml.Tensor]*multimodalEntry
func newMultimodalStore() multimodalStore {
return make(multimodalStore)
}
// addMultimodal stores an embedding for later use in a compute graph
func (m multimodalStore) addMultimodal(embedding []input.Multimodal) {
entry := &multimodalEntry{mm: embedding}
for _, e := range embedding {
if e.Tensor != nil {
m[e.Tensor] = entry
}
}
}
// getMultimodal takes a source set of tensors (which may contain a whole or
// parts of one or more images) and returns the equivalent that can be used in
// the current context
func (m multimodalStore) getMultimodal(backend ml.Backend, ctx ml.Context, in []input.Multimodal, reserve bool) ([]input.Multimodal, error) {
out := make([]input.Multimodal, len(in))
for i := range out {
if in[i].Tensor != nil {
var err error
out[i].Tensor, err = m.getTensor(backend, ctx, in[i].Tensor, reserve)
if err != nil {
return nil, err
}
}
out[i].Data = in[i].Data
}
return out, nil
}
func (m multimodalStore) getTensor(backend ml.Backend, ctx ml.Context, in ml.Tensor, reserve bool) (ml.Tensor, error) {
entry := m[in]
if entry.data == nil {
computeCtx := backend.NewContext()
defer computeCtx.Close()
var tensors []ml.Tensor
for _, t := range entry.mm {
if t.Tensor != nil {
tensors = append(tensors, t.Tensor)
}
}
if len(tensors) == 0 {
return nil, nil
}
computeCtx.Forward(tensors...)
entry.data = make([][]float32, len(entry.mm))
if !reserve {
computeCtx.Compute(tensors...)
for i, t := range entry.mm {
if t.Tensor != nil {
entry.data[i] = t.Tensor.Floats()
}
}
} else {
err := computeCtx.Reserve()
if err != nil {
return nil, err
}
}
}
for i, t := range entry.mm {
if in == t.Tensor {
if !reserve {
return ctx.Input().FromFloatSlice(entry.data[i], t.Tensor.Shape()...)
} else {
return ctx.Input().Empty(t.Tensor.DType(), t.Tensor.Shape()...), nil
}
}
}
return nil, errors.New("multimodal tensor not found")
}

View File

@@ -1,14 +1,12 @@
package ollamarunner
import (
"bytes"
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"hash/maphash"
"image"
"log"
"log/slog"
"net"
@@ -22,7 +20,6 @@ import (
"time"
"unicode/utf8"
"golang.org/x/image/bmp"
"golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api"
@@ -43,9 +40,6 @@ type Sequence struct {
// multimodal embeddings
ctxs []ml.Context
// mmStore holds multimodal embeddings to mange memory and enable splitting across batches
mmStore multimodalStore
// batch index
iBatch int
@@ -107,7 +101,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
startTime := time.Now()
inputs, ctxs, mmStore, err := s.inputs(prompt, images)
inputs, ctxs, err := s.inputs(prompt, images)
if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err)
} else if len(inputs) == 0 {
@@ -162,7 +156,6 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
return &Sequence{
ctxs: ctxs,
mmStore: mmStore,
inputs: inputs,
numPromptInputs: len(inputs),
startProcessingTime: startTime,
@@ -181,10 +174,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) {
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, error) {
var inputs []input.Input
var ctxs []ml.Context
var mmStore multimodalStore
var parts []string
var matches [][]string
@@ -195,7 +187,6 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
re := regexp.MustCompile(`\[img-(\d+)\]`)
parts = re.Split(prompt, -1)
matches = re.FindAllStringSubmatch(prompt, -1)
mmStore = newMultimodalStore()
} else {
parts = []string{prompt}
}
@@ -205,7 +196,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
// text - tokenize
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
if err != nil {
return nil, nil, nil, err
return nil, nil, err
}
for _, t := range tokens {
@@ -225,7 +216,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
}
if imageIndex < 0 {
return nil, nil, nil, fmt.Errorf("invalid image index: %d", n)
return nil, nil, fmt.Errorf("invalid image index: %d", n)
}
ctx := s.model.Backend().NewContext()
@@ -233,15 +224,13 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
ctxs = append(ctxs, ctx)
imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
if err != nil {
return nil, nil, nil, err
return nil, nil, err
}
s.multimodalHash.Reset()
_, _ = s.multimodalHash.Write(images[imageIndex].Data)
imageHash := s.multimodalHash.Sum64()
mmStore.addMultimodal(imageEmbeddings)
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
postTokenize = true
}
@@ -251,11 +240,11 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
var err error
inputs, err = multimodalProcessor.PostTokenize(inputs)
if err != nil {
return nil, nil, nil, err
return nil, nil, err
}
}
return inputs, ctxs, mmStore, nil
return inputs, ctxs, nil
}
type Server struct {
@@ -374,9 +363,6 @@ func (s *Server) processBatch() error {
}
defer s.mu.Unlock()
ctx := s.model.Backend().NewContext()
defer ctx.Close()
var batchInputs []int32
var batch input.Batch
@@ -447,11 +433,7 @@ func (s *Server) processBatch() error {
batchInputs = append(batchInputs, inp.Token)
if inp.Multimodal != nil {
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
if err != nil {
return err
}
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal})
}
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
@@ -477,6 +459,9 @@ func (s *Server) processBatch() error {
return nil
}
ctx := s.model.Backend().NewContext()
defer ctx.Close()
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
if err != nil {
return fmt.Errorf("failed to decode batch: %w", err)
@@ -735,71 +720,12 @@ func (s *Server) reserveWorstCaseGraph() error {
ctx := s.model.Backend().NewContext()
defer ctx.Close()
var err error
inputs := make([]input.Input, s.batchSize)
mmStore := newMultimodalStore()
// Multimodal strategy:
// - Encode a 2048x2048 image. This assumes that a single image of this
// size is sufficient to trigger the worst case. This is currently true
// because for existing models, only a single image fits in a batch.
// - Add the embedding to a full batch of tokens - this is necessary because
// the model may be looking for non-image data, such as <image> tags.
// - Run PostTokenize to execute any transformations between generated
// embeddings and what the forward pass expects.
// - The result may now be larger than a batch (images may not fit in a
// single batch), so trim based on what will fit and must be grouped together.
// - Fill out the rest of the space with text tokens.
if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); ok {
mmCtx := s.model.Backend().NewContext()
defer mmCtx.Close()
img := image.NewGray(image.Rect(0, 0, 2048, 2048))
var buf bytes.Buffer
bmp.Encode(&buf, img)
if inputs[0].Multimodal, err = multimodalProcessor.EncodeMultimodal(mmCtx, buf.Bytes()); err == nil {
mmStore.addMultimodal(inputs[0].Multimodal)
inputs, err = multimodalProcessor.PostTokenize(inputs)
if err != nil {
return err
}
for i, inp := range inputs {
minBatch := 1 + inp.SameBatch
if minBatch > s.batchSize {
inputs = inputs[i:min(i+minBatch, len(inputs))]
break
} else if i+minBatch > s.batchSize {
inputs = inputs[:i]
break
}
}
if len(inputs) < s.batchSize {
newInputs := make([]input.Input, s.batchSize)
copy(newInputs, inputs)
inputs = newInputs
}
}
}
var batch input.Batch
batchInputs := make([]int32, len(inputs))
inputs := make([]int32, s.batchSize)
batch.Positions = make([]int32, len(inputs))
batch.Sequences = make([]int, len(inputs))
for i, inp := range inputs {
batchInputs[i] = inp.Token
if inp.Multimodal != nil {
mm, err := mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, true)
if err != nil {
return err
}
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: i, Multimodal: mm})
}
for i := range inputs {
batch.Positions[i] = int32(i)
}
@@ -808,7 +734,8 @@ func (s *Server) reserveWorstCaseGraph() error {
batch.Outputs[i] = int32(i)
}
batch.Inputs, err = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
var err error
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
if err != nil {
return err
}
@@ -845,7 +772,7 @@ func (s *Server) loadModel(
multiUserCache bool,
) {
var err error
s.model, err = model.New(mpath, params)
s.model, err = model.New(ctx, mpath, params)
if err != nil {
panic(err)
}
@@ -874,14 +801,6 @@ func (s *Server) loadModel(
panic(err)
}
err = s.model.Backend().Load(ctx,
func(progress float32) {
s.progress = progress
})
if err != nil {
panic(err)
}
s.status = llm.ServerStatusReady
s.ready.Done()
}
@@ -936,6 +855,9 @@ func Execute(args []string) error {
}
params := ml.BackendParams{
Progress: func(progress float32) {
server.progress = progress
},
NumThreads: *threads,
NumGPULayers: *numGPULayers,
MainGPU: *mainGPU,

View File

@@ -176,7 +176,7 @@ func NewGrammarSampler(model model.TextProcessor, grammarStr string) (*GrammarSa
vocabIds[i] = uint32(i)
}
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, model.Vocabulary().EOS)
grammar := llama.NewGrammar(grammarStr, vocabIds, pieces, []uint32{uint32(model.Vocabulary().EOS), uint32(model.Vocabulary().EOT)})
if grammar == nil {
return nil, errors.New("sample: failed to initialize grammar")
}

View File

@@ -295,7 +295,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
}
defer bin.Close()
f, err := ggml.Decode(bin, -1)
f, _, err := ggml.Decode(bin, -1)
if err != nil {
return nil, err
}
@@ -430,7 +430,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
fnWrap := func(n uint64) {
done := doneBytes.Add(n)
progress := float32(done) / float32(totalBytes)
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantizeType), Digest: "0000000000000000000", Total: layer.Size, Completed: int64(progress * float32(layer.Size))})
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantizeType), Digest: "0", Total: layer.Size, Completed: int64(progress * float32(layer.Size))})
}
ftype, err := ggml.ParseFileType(quantizeType)
if err != nil {
@@ -467,7 +467,7 @@ func quantizeLayer(layer *layerGGML, quantizeType string, fn func(resp api.Progr
return nil, err
}
f, err := ggml.Decode(temp, 1024)
f, _, err := ggml.Decode(temp, 1024)
if err != nil {
slog.Error(fmt.Sprintf("error decoding ggml: %s\n", err))
return nil, err
@@ -501,26 +501,47 @@ func ggufLayers(digest string, fn func(resp api.ProgressResponse)) ([]*layerGGML
return nil, errOnlyGGUFSupported
}
f, err := ggml.Decode(blob, -1)
stat, err := blob.Stat()
if err != nil {
return nil, err
}
mediatype := "application/vnd.ollama.image.model"
if f.KV().Kind() == "adapter" {
mediatype = "application/vnd.ollama.image.adapter"
} else if (f.KV().Uint("block_count") == 0 && f.KV().Uint("vision.block_count") > 0) || f.KV().Kind() == "projector" {
// if a model has vision.block_count but not block_count, it is a standalone vision model
mediatype = "application/vnd.ollama.image.projector"
}
var offset int64
for offset < stat.Size() {
f, n, err := ggml.Decode(blob, 1024)
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return nil, err
}
layer, err := NewLayerFromLayer(digest, mediatype, blob.Name())
if err != nil {
slog.Debug("could not create new layer from layer", "error", err)
return nil, err
}
mediatype := "application/vnd.ollama.image.model"
if f.KV().Kind() == "adapter" {
mediatype = "application/vnd.ollama.image.adapter"
} else if _, ok := f.KV()[fmt.Sprintf("%s.vision.block_count", f.KV().Architecture())]; ok || f.KV().Kind() == "projector" {
mediatype = "application/vnd.ollama.image.projector"
}
layers = append(layers, &layerGGML{layer, f})
var layer Layer
if digest != "" && n == stat.Size() && offset == 0 {
layer, err = NewLayerFromLayer(digest, mediatype, blob.Name())
if err != nil {
slog.Debug("could not create new layer from layer", "error", err)
return nil, err
}
}
// Fallback to creating layer from file copy (either NewLayerFromLayer failed, or digest empty/n != stat.Size())
if layer.Digest == "" {
layer, err = NewLayer(io.NewSectionReader(blob, offset, n), mediatype)
if err != nil {
return nil, err
}
}
layers = append(layers, &layerGGML{layer, f})
offset = n
}
return detectChatTemplate(layers)
}

View File

@@ -75,7 +75,7 @@ func (m *Model) Capabilities() []model.Capability {
if err == nil {
defer r.Close()
f, err := ggml.Decode(r, 1024)
f, _, err := ggml.Decode(r, 1024)
if err == nil {
if _, ok := f.KV()[fmt.Sprintf("%s.pooling_type", f.KV().Architecture())]; ok {
capabilities = append(capabilities, model.CapabilityEmbedding)

View File

@@ -10,9 +10,6 @@ import (
"log/slog"
"net/http"
"os"
"slices"
"strings"
"text/template/parse"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
@@ -64,7 +61,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
}
defer blob.Close()
f, err := ggml.Decode(blob, -1)
f, _, err := ggml.Decode(blob, -1)
if err != nil {
return nil, err
}
@@ -128,124 +125,3 @@ func detectContentType(r io.Reader) (string, error) {
return "unknown", nil
}
func parseObjects(s string) []map[string]any {
var objs []map[string]any
for offset := 0; offset < len(s); {
var obj map[string]any
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
break
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
// skip over any syntax errors
offset += int(syntax.Offset)
} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
// skip over any unmarshalable types
offset += int(unmarshalType.Offset)
} else if err != nil {
return nil
} else {
offset += int(decoder.InputOffset())
objs = append(objs, obj)
}
}
return objs
}
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// mxyng: this only really works if the input contains tool calls in some JSON format
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// create a subtree from the node that ranges over .ToolCalls
tmpl := m.Template.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
}
return false
})
if tmpl == nil {
return nil, false
}
var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": {
{
Function: api.ToolCallFunction{
Name: "@@name@@",
Arguments: api.ToolCallFunctionArguments{
"@@argument@@": 1,
},
},
},
},
}); err != nil {
return nil, false
}
templateObjects := parseObjects(b.String())
if len(templateObjects) == 0 {
return nil, false
}
// find the keys that correspond to the name and arguments fields
var name, arguments string
for k, v := range templateObjects[0] {
switch v.(type) {
case string:
name = k
case map[string]any:
arguments = k
}
}
if name == "" || arguments == "" {
return nil, false
}
responseObjects := parseObjects(s)
if len(responseObjects) == 0 {
return nil, false
}
// collect all nested objects
var collect func(any) []map[string]any
collect = func(obj any) (all []map[string]any) {
switch o := obj.(type) {
case map[string]any:
all = append(all, o)
for _, v := range o {
all = append(all, collect(v)...)
}
case []any:
for _, v := range o {
all = append(all, collect(v)...)
}
}
return all
}
var objs []map[string]any
for _, p := range responseObjects {
objs = append(objs, collect(p)...)
}
var toolCalls []api.ToolCall
for _, kv := range objs {
n, nok := kv[name].(string)
a, aok := kv[arguments].(map[string]any)
if nok && aok {
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: n,
Arguments: a,
},
})
}
}
return toolCalls, len(toolCalls) > 0
}

View File

@@ -1,179 +0,0 @@
package server
import (
"bytes"
"encoding/json"
"fmt"
"os"
"path/filepath"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
)
func readFile(t *testing.T, base, name string) *bytes.Buffer {
t.Helper()
bts, err := os.ReadFile(filepath.Join(base, name))
if err != nil {
t.Fatal(err)
}
return bytes.NewBuffer(bts)
}
func TestExecuteWithTools(t *testing.T) {
p := filepath.Join("testdata", "tools")
cases := []struct {
model string
output string
ok bool
}{
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"command-r-plus", "Action: ```json" + `
[
{
"tool_name": "get_current_weather",
"parameters": {
"format": "fahrenheit",
"location": "San Francisco, CA"
}
},
{
"tool_name": "get_current_weather",
"parameters": {
"format": "celsius",
"location": "Toronto, Canada"
}
}
]
` + "```", true},
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"llama3-groq-tool-use", `<tool_call>
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
</tool_call>`, true},
{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
{"nemotron", `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} </toolcall>`, true},
}
var tools []api.Tool
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
t.Fatal(err)
}
var messages []api.Message
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
t.Fatal(err)
}
calls := []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
},
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"format": "celsius",
"location": "Toronto, Canada",
},
},
},
}
for _, tt := range cases {
t.Run(tt.model, func(t *testing.T) {
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
if err != nil {
t.Fatal(err)
}
t.Run("template", func(t *testing.T) {
var actual bytes.Buffer
if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("parse", func(t *testing.T) {
m := &Model{Template: tmpl}
actual, ok := m.parseToolCalls(tt.output)
if ok != tt.ok {
t.Fatalf("expected %t, got %t", tt.ok, ok)
}
if tt.ok {
if diff := cmp.Diff(actual, calls); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
})
})
}
}
func TestParseObjects(t *testing.T) {
tests := []struct {
input string
want []map[string]any
}{
{
input: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
want: []map[string]any{
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, Canada"}},
},
},
{
input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall>`,
want: []map[string]any{
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
},
},
{
input: `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </toolcall> <toolcall>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, ON"}} </toolcall>`,
want: []map[string]any{
{"name": "get_current_weather", "arguments": map[string]any{"format": "fahrenheit", "location": "San Francisco, CA"}},
{"name": "get_current_weather", "arguments": map[string]any{"format": "celsius", "location": "Toronto, ON"}},
},
},
{
input: `{"name": "get_current_weather", "arguments": `,
want: nil,
},
}
for _, tc := range tests {
t.Run(tc.input, func(t *testing.T) {
got := parseObjects(tc.input)
if diff := cmp.Diff(got, tc.want); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
}

View File

@@ -120,30 +120,14 @@ func getTensorNewType(kv fsggml.KV, qs *quantizeState, newType fsggml.TensorType
if newType.IsQuantized() {
nx := shape[0]
ny := uint64(1)
if len(shape) > 1 {
ny = shape[1]
}
qk_k := newType.BlockSize()
// Check if first dimension is divisible by block size
if nx%qk_k != 0 {
// Store the original type for logging
originalType := newType
// Select appropriate fallback based on original type
switch newType {
case fsggml.TensorTypeQ4_K:
newType = fsggml.TensorTypeQ5_0
case fsggml.TensorTypeQ5_K:
newType = fsggml.TensorTypeQ5_1
case fsggml.TensorTypeQ6_K:
newType = fsggml.TensorTypeQ8_0
}
// Final check - if still incompatible, fall back to F16
if nx%newType.BlockSize() != 0 {
newType = fsggml.TensorTypeF16
}
slog.Warn(fmt.Sprintf("tensor cols %d are not divisible by %d, required for %s - using fallback quantization %s",
nx, qk_k, originalType.String(), newType.String()))
slog.Warn(fmt.Sprintf("tensor cols %d x %d are not divisible by %d, required for %s. Falling back to quantization %s", nx, ny, qk_k, newType.String(), fsggml.TensorTypeF16.String()))
newType = fsggml.TensorTypeF16
}
}
return newType

View File

@@ -271,7 +271,7 @@ func TestQuantizeModel(t *testing.T) {
t.Fatal(err.Error())
}
defer fp.Close()
meta, err := fsggml.Decode(fp, -1)
meta, _, err := fsggml.Decode(fp, -1)
if err != nil {
t.Fatal(err.Error())
}
@@ -303,7 +303,7 @@ func TestQuantizeModel(t *testing.T) {
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
}
defer fpNew.Close()
newMeta, err := fsggml.Decode(fpNew, -1)
newMeta, _, err := fsggml.Decode(fpNew, -1)
if err != nil {
t.Fatalf("failed to load the quantized model %s: %s", tmp.Name(), err)
}

View File

@@ -38,6 +38,7 @@ import (
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/tools"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
@@ -1482,11 +1483,20 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
var toolParser tools.Parser
if len(req.Tools) > 0 {
toolParser, err = tools.NewParser(m.Template.Template)
if err != nil {
slog.Error("failed to create tool parser", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
ch := make(chan any)
go func() {
defer close(ch)
var sb strings.Builder
var toolCallIndex int = 0
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
@@ -1512,37 +1522,26 @@ func (s *Server) ChatHandler(c *gin.Context) {
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
// TODO: tool call checking and filtering should be moved outside of this callback once streaming
// however this was a simple change for now without reworking streaming logic of this (and other)
// handlers
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
ch <- res
return
}
// Streaming tool calls:
// If tools are recognized, use a flag to track the sending of a tool downstream
// This ensures that content is cleared from the message on the last chunk sent
sb.WriteString(r.Content)
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
res.Message.ToolCalls = toolCalls
for i := range toolCalls {
toolCalls[i].Function.Index = toolCallIndex
toolCallIndex++
if len(req.Tools) > 0 && !toolParser.Done {
if r.Content == "" {
return
}
res.Message.Content = ""
sb.Reset()
ch <- res
return
}
if r.Done {
// Send any remaining content if no tool calls were detected
if toolCallIndex == 0 {
res.Message.Content = sb.String()
toolCalls, content, err := toolParser.Add(r.Content)
if err == nil {
if len(content) > 0 {
res.Message.Content = content
slog.Debug("tools: setting content to", "content", content)
} else if len(toolCalls) > 0 {
res.Message.ToolCalls = toolCalls
res.Message.Content = ""
}
} else if errors.Is(err, tools.ErrAccumulateMore) {
return
} else {
slog.Debug("tools: error", "error", err)
}
ch <- res
}
ch <- res
}); err != nil {
ch <- gin.H{"error": err.Error()}
}
@@ -1551,11 +1550,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
if req.Stream != nil && !*req.Stream {
var resp api.ChatResponse
var sb strings.Builder
var toolCalls []api.ToolCall
for rr := range ch {
switch t := rr.(type) {
case api.ChatResponse:
sb.WriteString(t.Message.Content)
resp = t
if len(req.Tools) > 0 {
toolCalls = append(toolCalls, t.Message.ToolCalls...)
}
case gin.H:
msg, ok := t["error"].(string)
if !ok {
@@ -1571,12 +1574,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
resp.Message.Content = sb.String()
if len(req.Tools) > 0 {
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
resp.Message.ToolCalls = toolCalls
resp.Message.Content = ""
}
if len(toolCalls) > 0 {
resp.Message.ToolCalls = toolCalls
}
c.JSON(http.StatusOK, resp)

44
tools/testdata/llama3.2.gotmpl vendored Normal file
View File

@@ -0,0 +1,44 @@
<|start_header_id|>system<|end_header_id|>
Cutting Knowledge Date: December 2023
{{ if .System }}{{ .System }}
{{- end }}
{{- if .Tools }}When you receive a tool call response, use the output to format an answer to the orginal user question.
You are a helpful assistant with tool calling capabilities.
{{- end }}<|eot_id|>
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 }}
{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|>
{{- if and $.Tools $last }}
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
{{ range $.Tools }}
{{- . }}
{{ end }}
{{ .Content }}<|eot_id|>
{{- else }}
{{ .Content }}<|eot_id|>
{{- end }}{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
{{ end }}
{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|>
{{- if .ToolCalls }}
{{ range .ToolCalls }}
{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}
{{- else }}
{{ .Content }}
{{- end }}{{ if not $last }}<|eot_id|>{{ end }}
{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|>
{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|>
{{ end }}
{{- end }}
{{- end }}

24
tools/testdata/llama3.2.out vendored Normal file
View File

@@ -0,0 +1,24 @@
<|start_header_id|>system<|end_header_id|>
Cutting Knowledge Date: December 2023
You are a knowledgeable assistant. You can answer questions and perform tasks.When you receive a tool call response, use the output to format an answer to the orginal user question.
You are a helpful assistant with tool calling capabilities.<|eot_id|><|start_header_id|>user<|end_header_id|>
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{"name": "get_current_weather", "parameters": {"format":"celsius","location":"Paris, France"}}<|eot_id|><|start_header_id|>ipython<|end_header_id|>
22<|eot_id|><|start_header_id|>assistant<|end_header_id|>
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt.
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables.
{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

51
tools/testdata/qwen2.5-coder.gotmpl vendored Normal file
View File

@@ -0,0 +1,51 @@
{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|>
{{- else if .Messages }}
{{- if or .System .Tools }}<|im_start|>system
{{- if .System }}
{{ .System }}
{{- end }}
{{- if .Tools }}
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{{- range .Tools }}
{"type": "function", "function": {{ .Function }}}
{{- end }}
</tools>
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
{{- end }}<|im_end|>
{{ end }}
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{- if eq .Role "user" }}<|im_start|>user
{{ .Content }}<|im_end|>
{{ else if eq .Role "assistant" }}<|im_start|>assistant
{{ if .Content }}{{ .Content }}
{{- else if .ToolCalls }}<tool_call>
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{ end }}</tool_call>
{{- end }}{{ if not $last }}<|im_end|>
{{ end }}
{{- else if eq .Role "tool" }}<|im_start|>user
<tool_response>
{{ .Content }}
</tool_response><|im_end|>
{{ end }}
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
{{ end }}
{{- end }}
{{- else }}
{{- if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}

31
tools/testdata/qwen2.5-coder.out vendored Normal file
View File

@@ -0,0 +1,31 @@
<|im_start|>system
You are a knowledgeable assistant. You can answer questions and perform tasks.
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
</tools>
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call><|im_end|>
<|im_start|>user
What's the weather like today in Paris?<|im_end|>
<|im_start|>assistant
<tool_call>
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
</tool_call><|im_end|>
<|im_start|>user
<tool_response>
22
</tool_response><|im_end|>
<|im_start|>assistant
The current temperature in Paris, France is 22 degrees Celsius.<|im_end|>
<|im_start|>user
What's the weather like today in San Francisco and Toronto?<|im_end|>
<|im_start|>assistant

50
tools/testdata/qwen3.gotmpl vendored Normal file
View File

@@ -0,0 +1,50 @@
{{- if .Messages }}
{{- if or .System .Tools }}<|im_start|>system
{{- if .System }}
{{ .System }}
{{- end }}
{{- if .Tools }}
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{{- range .Tools }}
{"type": "function", "function": {{ .Function }}}
{{- end }}
</tools>
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call>
{{- end }}<|im_end|>
{{ end }}
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 -}}
{{- if eq .Role "user" }}<|im_start|>user
{{ .Content }}<|im_end|>
{{ else if eq .Role "assistant" }}<|im_start|>assistant
{{ if .Content }}{{ .Content }}
{{- else if .ToolCalls }}<tool_call>
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{ end }}</tool_call>
{{- end }}{{ if not $last }}<|im_end|>
{{ end }}
{{- else if eq .Role "tool" }}<|im_start|>user
<tool_response>
{{ .Content }}
</tool_response><|im_end|>
{{ end }}
{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant
{{ end }}
{{- end }}
{{- else }}
{{- if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }}

31
tools/testdata/qwen3.out vendored Normal file
View File

@@ -0,0 +1,31 @@
<|im_start|>system
You are a knowledgeable assistant. You can answer questions and perform tasks.
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}
</tools>
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
<tool_call>
{"name": <function-name>, "arguments": <args-json-object>}
</tool_call><|im_end|>
<|im_start|>user
What's the weather like today in Paris?<|im_end|>
<|im_start|>assistant
<tool_call>
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
</tool_call><|im_end|>
<|im_start|>user
<tool_response>
22
</tool_response><|im_end|>
<|im_start|>assistant
The current temperature in Paris, France is 22 degrees Celsius.<|im_end|>
<|im_start|>user
What's the weather like today in San Francisco and Toronto?<|im_end|>
<|im_start|>assistant

228
tools/tools.go Normal file
View File

@@ -0,0 +1,228 @@
package tools
import (
"errors"
"io"
"log/slog"
"strings"
gotmpl "text/template"
jsonv2 "github.com/go-json-experiment/json"
jsontext "github.com/go-json-experiment/json/jsontext"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
)
// Sentinel errors for parsing states
var (
ErrPartialPrefix = errors.New("partial prefix detected")
ErrPrefixNotFound = errors.New("prefix not found")
ErrInvalidToolCall = errors.New("invalid tool call format")
ErrAccumulateMore = errors.New("need to accumulate more content")
)
type Parser struct {
greedyParse bool
prefixFound bool
tmpl gotmpl.Template
sb strings.Builder
prefix string
index int
name string
arguments string
Done bool
}
// parseJSONToolCalls attempts to parse a JSON string into a slice ToolCalls.
// It first tries to incrementally decode the JSON to handle partial inputs.
// Returns:
// - []api.ToolCall: The parsed tool calls if successful
// - error: ErrPartialJSON if JSON is incomplete, ErrInvalidToolCall if invalid, or nil if successful
func (p *Parser) parseJSONToolCalls(s string) ([]api.ToolCall, error) {
// First try incremental decoding to handle partial JSON
dec := jsontext.NewDecoder(strings.NewReader(s))
if got, err := dec.ReadValue(); err == nil {
s = got.String()
}
// Attempt full unmarshal of the JSON
var resp any
if err := jsonv2.Unmarshal([]byte(s), &resp); errors.Is(err, io.ErrUnexpectedEOF) {
slog.Debug("incomplete JSON detected", "input", s)
return nil, ErrAccumulateMore
} else if err != nil {
slog.Debug("failed to unmarshal response", "error", err)
return nil, ErrInvalidToolCall
}
// Collect all nested objects that could contain tool calls
objs := collect(resp)
if len(objs) == 0 {
return nil, ErrInvalidToolCall
}
var toolCalls []api.ToolCall
for _, kv := range objs {
n, nok := kv[p.name].(string)
a, aok := kv[p.arguments].(map[string]any)
if nok && aok {
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: n,
Arguments: a,
},
})
}
}
// Valid JSON, no tool calls found
if len(toolCalls) == 0 {
return nil, ErrInvalidToolCall
}
return toolCalls, nil
}
// checkPrefix processes a string to find and handle a prefix pattern.
//
// Returns:
// - The processed string with prefix removed if found
// - error: ErrPartialPrefix if prefix is incomplete, ErrPrefixNotFound if not found, or nil if successful
func (p *Parser) checkPrefix(s string) (string, error) {
// Keep original for overlap checks
original := s
s = strings.TrimSpace(s)
if s == "" {
return "", nil
}
// If no prefix defined, just return trimmed string
if p.prefix == "" {
return s, nil
}
// Check for prefix at start of string
if processedStr, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix {
// Found prefix at start - accumulate for potential tool
p.prefixFound = true
return processedStr, nil
}
// Check if prefix overlaps end of string
if overlap := suffixOverlap(original, p.prefix); overlap > 0 {
// Return everything except overlapping portion
p.sb.Reset()
p.sb.WriteString(original[len(original)-overlap:])
return original[0 : len(original)-overlap], ErrAccumulateMore
}
// Check if prefix appears in middle of string
if idx := strings.Index(original, p.prefix); idx != -1 {
// Save remainder starting at prefix for next pass
p.sb.Reset()
p.sb.WriteString(strings.TrimSpace(original[idx:]))
// Return everything before prefix
return original[:idx], ErrAccumulateMore
}
// No partial prefix found
return s, nil
}
// Add processes a string input to parse tool calls and content.
// It handles prefix detection and JSON parsing to extract tool calls.
//
// Returns:
// - tools: Any parsed tool calls
// - content: Non-tool call content
// - error: One of the sentinel errors or nil if successful
func (p *Parser) Add(s string) (tools []api.ToolCall, content string, err error) {
p.sb.WriteString(s)
s = p.sb.String()
// Check for prefix pattern in input
s, err = p.checkPrefix(s)
if err != nil {
if s != "" {
// Return content before prefix
return nil, s, nil
}
// Need more input to complete prefix
return nil, "", ErrAccumulateMore
}
// Exit if prefix exists in template, greedy parsing is off, and prefix not found
if !p.greedyParse && !p.prefixFound {
p.sb.Reset()
return nil, "", ErrPrefixNotFound
}
toolCalls, err := p.parseJSONToolCalls(s)
if err != nil {
if errors.Is(err, ErrAccumulateMore) {
return nil, "", err
} else {
p.sb.Reset()
// Do not try greedy parsing if JSON not found
p.greedyParse = false
if p.prefix == "" {
p.Done = true
}
if p.prefixFound {
// Drop tokens since prefix was found
return nil, "", ErrAccumulateMore
}
return nil, s, nil
}
}
for _, tc := range toolCalls {
tc.Function.Index = p.index
p.index++
}
// Mark as done if no prefix needed
if p.prefix == "" {
p.Done = true
}
p.sb.Reset()
return toolCalls, "", nil
}
// NewParser creates a new tool call parser from a template. It extracts the tool call format,
// prefix, and field names from the template to use for parsing tool calls from model output.
//
// Returns an error if the template does not contain valid tool call formatting.
func NewParser(templateToProcess *gotmpl.Template) (Parser, error) {
parsed, err := template.Parse(templateToProcess.Root.String())
if err != nil {
return Parser{}, err
}
tt, err := toolTemplate(parsed)
if err != nil {
return Parser{}, err
}
tp := toolPrefix(templateToProcess)
tp = strings.TrimSpace(tp)
name, arguments, err := extractToolArgs(tt)
if err != nil {
return Parser{}, err
}
return Parser{
tmpl: *tt,
sb: strings.Builder{},
prefix: tp,
greedyParse: true,
name: name,
arguments: arguments,
}, nil
}

491
tools/tools_test.go Normal file
View File

@@ -0,0 +1,491 @@
package tools
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
)
func readFile(t *testing.T, base, name string) *bytes.Buffer {
t.Helper()
bts, err := os.ReadFile(filepath.Join(base, name))
if err != nil {
t.Fatal(err)
}
return bytes.NewBuffer(bts)
}
func TestParseToolCalls(t *testing.T) {
p := filepath.Join("testdata")
t1 := api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
}
t2 := api.ToolCall{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"format": "celsius",
"location": "Toronto, Canada",
},
},
}
cases := []struct {
name string
model string
output string
expectedToolCall []api.ToolCall
expectedTokens string
}{
{
name: "mistral malformed json with tool calls prefix",
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`,
expectedToolCall: []api.ToolCall{},
expectedTokens: "",
},
{
name: "mistral multiple tool calls without prefix",
model: "mistral",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
name: "mistral tool calls with text between no prefix",
model: "mistral",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
},
{
name: "mistral valid json with tool calls prefix",
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
name: "mistral multiple tool calls with text between and prefix",
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2, t1, t2},
expectedTokens: "",
},
{
name: "mistral incomplete json with tool calls prefix",
model: "mistral",
output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `,
expectedToolCall: []api.ToolCall{},
expectedTokens: "",
},
{
name: "mistral invalid tool call with explanatory text no prefix",
model: "mistral",
output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
},
{
name: "mistral tool calls without prefix",
model: "mistral",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
name: "command r plus tool calls with json block format",
model: "command-r-plus",
output: "Action: ```json" + `
[
{
"tool_name": "get_current_weather",
"parameters": {
"format": "fahrenheit",
"location": "San Francisco, CA"
}
},
{
"tool_name": "get_current_weather",
"parameters": {
"format": "celsius",
"location": "Toronto, Canada"
}
}
]
` + "```",
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
name: "firefunction tool calls with functools prefix",
model: "firefunction",
output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
name: "llama3 groq single tool call with xml tags",
model: "llama3-groq-tool-use",
output: `<tool_call>
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
</tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "",
},
{
name: "xlam tool calls with wrapper object",
model: "xlam",
output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
name: "qwen2.5-coder single tool call with prefix",
model: "qwen2.5-coder",
output: `<tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "",
},
{
name: "qwen2.5-coder multiple tool calls with and without prefix",
model: "qwen2.5-coder",
output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} <tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call> <tool_call>{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}</tool_call>`,
expectedToolCall: []api.ToolCall{t1, t1, t2},
expectedTokens: "",
},
{
name: "qwen2.5-coder multiple tool calls without prefix",
model: "qwen2.5-coder",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
name: "qwen2.5-coder plain text response no tool calls",
model: "qwen2.5-coder",
output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
expectedToolCall: []api.ToolCall{},
expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.",
},
{
name: "qwen2.5-coder tool calls with trailing text",
model: "qwen2.5-coder",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "some tokens after call",
},
{
name: "qwen2.5-coder tool calls with initial text",
model: "qwen2.5-coder",
output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
},
{
name: "qwen2.5 tool calls with prefix and trailing text",
model: "qwen2.5-coder",
output: `<tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call> some tokens after call`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
name: "qwen2.5 tool calls with prefix and initial text",
model: "qwen2.5-coder",
output: `some tokens before call <tool_call> [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] </tool_call>`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "some tokens before call",
},
{
name: "qwen2.5 tool calls without prefix and valid tool call",
model: "qwen2.5-coder",
output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`,
expectedToolCall: []api.ToolCall{t1, t2},
expectedTokens: "",
},
{
name: "qwen2.5 tool calls without prefix and invalid tool call",
model: "qwen2.5-coder",
output: `[{"options": "foo"}]`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `[{"options": "foo"}]`,
},
{
name: "qwen2.5 tool calls with prefix and invalid tool call",
model: "qwen2.5-coder",
output: `<tool_call> [{"options": "foo"}] </tool_call> `,
expectedToolCall: []api.ToolCall{},
expectedTokens: ``,
},
{
name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)",
model: "qwen3",
output: `<think>Okay, let me think what tool we should use...</think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}</tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
},
{
name: "qwen3 tool call with think prefix, tool prefix, and whitespace (sent as separate tokens)",
model: "qwen3",
output: `<think>Okay, let me think what tool we should use...</think> <tool_call> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "<think>Okay, let me think what tool we should use...</think>",
},
{
name: "qwen3 empty think prefix without tool prefix and invalid tool call",
model: "qwen3",
output: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `<think></think>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
},
{
name: "qwen3 empty think prefix with tool prefix and valid tool call",
model: "qwen3",
output: `<think></think><tool_call>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: `<think></think>`,
},
{
name: "qwen3 invalid tool call with fake tool prefix (single rune suffix match)",
model: "qwen3",
output: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `<think></think>< fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
},
{
name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)",
model: "qwen3",
output: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `<think></think><tool_c fakeout{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
},
{
name: "qwen3 invalid tool call with malformed tool prefix",
model: "qwen3",
output: `<think></think><tool_cfakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `<think></think><tool_cfakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} </tool_call>`,
},
{
name: "llama3.2 valid tool call without prefix",
model: "llama3.2",
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{t1},
expectedTokens: "",
},
{
name: "llama3.2 incomplete tool call without prefix",
model: "llama3.2",
output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, `,
expectedToolCall: []api.ToolCall{},
expectedTokens: "",
},
{
name: "llama3.2 tool call with leading text",
model: "llama3.2",
output: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `some non json text{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
},
{
name: "llama3.2 tool call with invalid tool prefix (no prefix in template)",
model: "llama3.2",
output: `<tool_call>{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
expectedToolCall: []api.ToolCall{},
expectedTokens: `<tool_call>{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`,
},
}
var tools []api.Tool
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
t.Fatal(err)
}
var messages []api.Message
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
t.Fatal(err)
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
if err != nil {
t.Fatal(err)
}
t.Run("template", func(t *testing.T) {
actual := &bytes.Buffer{} // Create new buffer for each test
if err := tmpl.Execute(actual, template.Values{Tools: tools, Messages: messages}); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("parse", func(t *testing.T) {
// fmt.Printf("tmpl: %s\n", tmpl.Root.String())
tp, err := NewParser(tmpl.Template)
if err != nil {
t.Fatal(err)
}
got := []api.ToolCall{}
var gotTokens strings.Builder
var add bool
tokens := strings.Fields(tt.output)
for _, tok := range tokens {
s := " " + tok
add = true
if !tp.Done {
toolCalls, content, err := tp.Add(s)
if err == nil {
if content != "" {
fmt.Printf("content: %q\n", content)
gotTokens.WriteString(content)
add = false
} else if len(toolCalls) > 0 {
got = append(got, toolCalls...)
add = false
}
} else if errors.Is(err, ErrAccumulateMore) {
add = false
}
}
if add {
gotTokens.WriteString(s)
}
}
// Compare tool calls if we expect any
if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" {
t.Errorf("tool calls mismatch (-got +want):\n%s", diff)
}
// Compare tokens if we expect any
stripped := strings.TrimSpace(gotTokens.String())
if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" {
t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens)
t.Errorf("tokens mismatch (-got +want):\n%s", diff)
}
})
})
}
}
func TestParseJSONToolCalls(t *testing.T) {
tests := []struct {
name string
input string
parser *Parser
wantToolCalls []api.ToolCall
wantErr error
}{
{
name: "valid single tool call",
input: `{"name": "test_tool", "arguments": {"arg1": "value1"}}`,
parser: &Parser{name: "name", arguments: "arguments"},
wantToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "test_tool",
Arguments: map[string]any{
"arg1": "value1",
},
},
},
},
wantErr: nil,
},
{
name: "incomplete JSON",
input: `{"name": "test_tool", "arguments": {"arg1": `,
parser: &Parser{name: "name", arguments: "arguments"},
wantToolCalls: nil,
wantErr: ErrAccumulateMore,
},
{
name: "invalid JSON",
input: `not json at all`,
parser: &Parser{name: "name", arguments: "arguments"},
wantToolCalls: nil,
wantErr: ErrInvalidToolCall,
},
{
name: "missing required fields",
input: `{"other": "field"}`,
parser: &Parser{name: "name", arguments: "arguments"},
wantToolCalls: nil,
wantErr: ErrInvalidToolCall,
},
{
name: "multiple tool calls in array",
input: `[
{"name": "tool1", "arguments": {"arg1": 1}},
{"name": "tool2", "arguments": {"arg2": "value"}}
]`,
parser: &Parser{name: "name", arguments: "arguments"},
wantToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "tool1",
Arguments: map[string]any{
"arg1": float64(1),
},
},
},
{
Function: api.ToolCallFunction{
Name: "tool2",
Arguments: map[string]any{
"arg2": "value",
},
},
},
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotCalls, err := tt.parser.parseJSONToolCalls(tt.input)
if err != tt.wantErr {
t.Errorf("parseJSONToolCalls() error = %v, want %v", err, tt.wantErr)
}
if len(gotCalls) != 0 && tt.wantErr != nil {
t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantErr == nil)
}
if diff := cmp.Diff(gotCalls, tt.wantToolCalls); diff != "" {
t.Errorf("parseJSONToolCalls() tool calls mismatch (-got +want):\n%s", diff)
}
})
}
}

257
tools/utils.go Normal file
View File

@@ -0,0 +1,257 @@
package tools
import (
"bytes"
"errors"
"log/slog"
"slices"
"strings"
gotmpl "text/template"
"text/template/parse"
jsonv2 "github.com/go-json-experiment/json"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
)
// extractToolCallsFormat traverses a template AST to find text that follows a ".ToolCalls" condition.
// It walks the template nodes looking for if-statements containing ".ToolCalls" and extracts any
// immediate text nodes that follow. This is used to identify tool call prefixes and formatting.
//
// Returns:
// - string: The extracted text following the first ".ToolCalls" condition found
// - bool: Whether a ".ToolCalls" condition was found in the template
func extractToolCallsFormat(tmpl *gotmpl.Template) (string, bool) {
if tmpl == nil || tmpl.Tree == nil {
slog.Debug("TextAfterToolCalls: template or tree is nil")
return "", false
}
var result string
var found bool
var walk func(nodes []parse.Node)
walk = func(nodes []parse.Node) {
for _, node := range nodes {
if found {
return
}
switch n := node.(type) {
case *parse.IfNode:
if isToolCallsNode(n) {
// Collect immediate TextNode(s) at start of IfNode's list
var sb strings.Builder
for _, innerNode := range n.List.Nodes {
if tn, ok := innerNode.(*parse.TextNode); ok {
sb.Write(tn.Text)
} else {
// Stop at first non-text node
break
}
}
result = sb.String()
found = true
return
}
// Recurse into child nodes
walk(n.List.Nodes)
if n.ElseList != nil {
walk(n.ElseList.Nodes)
}
case *parse.ListNode:
walk(n.Nodes)
case *parse.RangeNode:
walk(n.List.Nodes)
if n.ElseList != nil {
walk(n.ElseList.Nodes)
}
case *parse.WithNode:
walk(n.List.Nodes)
if n.ElseList != nil {
walk(n.ElseList.Nodes)
}
default:
// Continue to next node
continue
}
if found {
return
}
}
}
walk(tmpl.Tree.Root.Nodes)
return result, found
}
// isToolCallsNode detects if a node's condition includes ".ToolCalls"
func isToolCallsNode(n *parse.IfNode) bool {
for _, cmd := range n.Pipe.Cmds {
for _, arg := range cmd.Args {
if field, ok := arg.(*parse.FieldNode); ok {
if slices.Contains(field.Ident, "ToolCalls") {
return true
}
}
}
}
return false
}
// TODO(parthsareen): get full prefix from the template instead of just the first token
// toolPrefix returns the prefix for the tool call if it exists from a template
func toolPrefix(tmpl *gotmpl.Template) string {
tokenText, ok := extractToolCallsFormat(tmpl)
if !ok {
return ""
}
tokenText = strings.TrimSpace(tokenText)
if tokenText == "" {
return ""
}
first := strings.Fields(tokenText)[0]
start := -1
end := -1
for i, r := range tokenText {
if r == '<' || r == '[' {
start = i
}
if (r == '>' || r == ']') && start != -1 {
end = i
break
}
}
if start != -1 && end != -1 {
// return the token including the [ or < and the ] or >
return tokenText[start : end+1]
} else if start != -1 {
// get until the [ or < - in the case tag was not closed
return tokenText[:start]
} else if end != -1 {
// get after the ] or > - in the case tag was not opened
return tokenText[end+1:]
}
return first
}
// toolTemplate creates a subtree from the node that ranges over .ToolCalls
//
// Returns:
// - *gotmpl.Template: The subtree containing the .ToolCalls range
// - error: Error if parsing failed
func toolTemplate(t *template.Template) (*gotmpl.Template, error) {
tmpl := t.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
}
return false
})
if tmpl == nil {
return nil, errors.New("failed to find tool template")
}
return tmpl, nil
}
// suffixOverlap returns the length of the longest suffix overlap between two strings
//
// Returns:
// - int: The length of the longest suffix overlap
func suffixOverlap(s, prefix string) int {
max := min(len(prefix), len(s))
for i := max; i > 0; i-- {
if strings.HasSuffix(s, prefix[:i]) {
return i
}
}
return 0
}
// extractToolArgs executes a template with a known tool call format to extract the name and arguments
//
// Returns:
// - string: The name of the tool call
// - string: The arguments of the tool call
// - error: Error if parsing failed
func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) {
var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": {
{
Function: api.ToolCallFunction{
Name: "@@name@@",
Arguments: api.ToolCallFunctionArguments{
"@@argument@@": 1,
},
},
},
},
}); err != nil {
return "", "", err
}
var obj any
err = jsonv2.Unmarshal(b.Bytes(), &obj)
if err != nil {
return "", "", err
}
var objs []map[string]any
switch v := obj.(type) {
case map[string]any:
objs = []map[string]any{v}
case []map[string]any:
objs = v
case []any:
objs = collect(v)
}
if len(objs) == 0 {
return "", "", errors.New("no template objects found")
}
// find the keys that correspond to the name and arguments fields
for k, v := range objs[0] {
switch v.(type) {
case string:
name = k
case map[string]any:
arguments = k
}
}
if name == "" || arguments == "" {
slog.Debug("missing required fields in tool call template", "name", name, "arguments", arguments)
return "", "", errors.New("missing required fields in tool call template")
}
return name, arguments, nil
}
// collect recursively traverses an object to collect all nested maps
//
// Returns:
// - []map[string]any: A slice of all nested maps found in the object
func collect(obj any) []map[string]any {
var all []map[string]any
switch o := obj.(type) {
case map[string]any:
all = append(all, o)
for _, v := range o {
all = append(all, collect(v)...)
}
case []any:
for _, v := range o {
all = append(all, collect(v)...)
}
default:
return nil
}
return all
}

464
tools/utils_test.go Normal file
View File

@@ -0,0 +1,464 @@
package tools
import (
"testing"
gotmpl "text/template"
"github.com/ollama/ollama/template"
)
func TestExtractToolCallsFormat(t *testing.T) {
cases := []struct {
name string
template string
want string
found bool
}{
{
name: "nil template",
template: "",
want: "",
found: false,
},
{
name: "basic tool call with text",
template: "{{if .ToolCalls}}Hello world{{end}}",
want: "Hello world",
found: true,
},
{
name: "tool call with json format",
template: "{{if .ToolCalls}}```json\n{{end}}",
want: "```json\n",
found: true,
},
{
name: "tool call in range",
template: "{{range .ToolCalls}}tool: {{.}}{{end}}",
want: "",
found: false,
},
{
name: "tool call with multiple text nodes",
template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}",
want: "First text",
found: true,
},
{
name: "nested if without tool calls",
template: "{{if .Something}}{{if .OtherThing}}text{{end}}{{end}}",
want: "",
found: false,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
tmpl, err := gotmpl.New("test").Parse(tc.template)
if err != nil && tc.template != "" {
t.Fatalf("failed to parse template: %v", err)
}
got, found := extractToolCallsFormat(tmpl)
if got != tc.want {
t.Errorf("got text %q, want %q", got, tc.want)
}
if found != tc.found {
t.Errorf("got found %v, want %v", found, tc.found)
}
})
}
}
func TestToolPrefix(t *testing.T) {
cases := []struct {
name string
template string
want string
}{
{
name: "basic tool call with action prefix",
template: "{{if .ToolCalls}}Action: ```json{{end}}",
want: "Action:",
},
{
name: "incomplete functools bracket",
template: "{{if .ToolCalls}}functools[{{end}}",
want: "functools",
},
{
name: "tool call with angle brackets",
template: "{{if .ToolCalls}}Hello, world! <tool_call>{{end}}",
want: "<tool_call>",
},
{
name: "multiple tool call formats",
template: "{{if .ToolCalls}}[tool_call] <tool_call>{{end}}",
want: "[tool_call]",
},
{
name: "single angle bracket tool call",
template: "{{if .ToolCalls}}<tool_call>{{end}}",
want: "<tool_call>",
},
{
name: "incomplete angle bracket after tool call",
template: "{{if .ToolCalls}}[tool_call] <{{end}}",
want: "[tool_call]",
},
{
name: "angle bracket prefix with tool call",
template: "{{if .ToolCalls}}> <tool_call>{{end}}",
want: "<tool_call>",
},
{
name: "uppercase tool call with incomplete bracket",
template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}",
want: "[TOOL_CALL]",
},
{
name: "uppercase tool call with adjacent bracket",
template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}",
want: "[TOOL_CALL]",
},
{
name: "tool call with pipe delimiters",
template: "{{if .ToolCalls}}<|tool_call|>{{end}}",
want: "<|tool_call|>",
},
{
name: "tool with no prefix",
template: "{{if .ToolCalls}}{{end}}",
want: "",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tmpl, err := gotmpl.New("test").Parse(tt.template)
if err != nil {
t.Fatalf("failed to parse template: %v", err)
}
got := toolPrefix(tmpl)
if got != tt.want {
t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want)
}
})
}
}
func TestToolTemplate(t *testing.T) {
cases := []struct {
name string
template string
want bool
}{
{
name: "basic tool call range",
template: "{{range .ToolCalls}}test{{end}}",
want: true,
},
{
name: "no tool calls",
template: "{{range .Other}}test{{end}}",
want: false,
},
{
name: "nested tool calls",
template: "{{range .Outer}}{{range .ToolCalls}}test{{end}}{{end}}",
want: true,
},
{
name: "empty template",
template: "",
want: false,
},
{
name: "tool calls in if statement",
template: "{{if .ToolCalls}}test{{end}}",
want: false,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tmpl, err := gotmpl.New("test").Parse(tt.template)
if err != nil {
t.Fatalf("failed to parse template: %v", err)
}
parsed, err := template.Parse(tmpl.Root.String())
if err != nil {
t.Fatalf("failed to parse template: %v", err)
}
_, err = toolTemplate(parsed)
if err != nil && tt.want {
t.Errorf("toolTemplate() = %v; want %v", err, tt.want)
}
})
}
}
func TestSuffixOverlap(t *testing.T) {
cases := []struct {
name string
s string
d string
want int
}{
{
name: "no overlap",
s: "hello world",
d: "",
want: 0,
},
{
name: "full overlap",
s: "<tool_call>",
d: "<tool_call>",
want: 11,
},
{
name: "partial overlap",
s: "text <tool_call>",
d: "<tool_call>",
want: 11,
},
{
name: "delimiter longer than string",
s: "<tool>",
d: "<tool_call>",
want: 0,
},
{
name: "empty string",
s: "",
d: "<tool_call>",
want: 0,
},
{
name: "empty delimiter",
s: "<tool_call>",
d: "",
want: 0,
},
{
name: "single char overlap",
s: "test<",
d: "<tool_call>",
want: 1,
},
{
name: "partial tool call",
s: "hello <tool_",
d: "<tool_call>",
want: 6,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got := suffixOverlap(tt.s, tt.d)
if got != tt.want {
t.Errorf("suffixOverlap(%q, %q) = %d; want %d", tt.s, tt.d, got, tt.want)
}
})
}
}
func TestExtractToolArgs(t *testing.T) {
cases := []struct {
name string
template string
want string
ok bool
}{
{
name: "basic tool call with text after",
template: `{{if .ToolCalls}}tool response{{end}}`,
want: "tool response",
ok: true,
},
{
name: "tool call with mixed content after",
template: `{{if .ToolCalls}}<tool_call>{{.Something}}{{end}}`,
want: "<tool_call>",
ok: true,
},
{
name: "tool call with no text after",
template: `{{if .ToolCalls}}{{.Something}}{{end}}`,
want: "",
ok: true,
},
{
name: "nested tool call",
template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`,
want: "[TOOL_CALL]",
ok: true,
},
{
name: "no tool calls",
template: `{{if .Something}}no tools here{{end}}`,
want: "",
ok: false,
},
{
name: "empty template",
template: ``,
want: "",
ok: false,
},
{
name: "multiple tool calls sections",
template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`,
want: "first",
ok: true,
},
{
name: "range over tool calls",
template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`,
want: "",
ok: true,
},
{
name: "tool calls with pipe delimiters",
template: `{{if .ToolCalls}}<|tool|>{{end}}`,
want: "<|tool|>",
ok: true,
},
{
name: "tool calls with nested template",
template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`,
want: "",
ok: true,
},
{
name: "tool calls with whitespace variations",
template: `{{if .ToolCalls}} tool {{end}}`,
want: " tool ",
ok: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
tmpl, err := gotmpl.New("test").Parse(tt.template)
if err != nil {
t.Fatalf("failed to parse template: %v", err)
}
got, ok := extractToolCallsFormat(tmpl)
if got != tt.want {
t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want)
}
if ok != tt.ok {
t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok)
}
})
}
}
func TestCollect(t *testing.T) {
cases := []struct {
name string
obj any
want []map[string]any
}{
{
name: "simple map",
obj: map[string]any{
"key": "value",
},
want: []map[string]any{
{"key": "value"},
},
},
{
name: "nested map",
obj: map[string]any{
"outer": map[string]any{
"inner": "value",
},
},
want: []map[string]any{
{"outer": map[string]any{"inner": "value"}},
{"inner": "value"},
},
},
{
name: "array of maps",
obj: []any{
map[string]any{"key1": "val1"},
map[string]any{"key2": "val2"},
},
want: []map[string]any{
{"key1": "val1"},
{"key2": "val2"},
},
},
{
name: "deeply nested",
obj: map[string]any{
"l1": map[string]any{
"l2": map[string]any{
"l3": "value",
},
},
},
want: []map[string]any{
{"l1": map[string]any{"l2": map[string]any{"l3": "value"}}},
{"l2": map[string]any{"l3": "value"}},
{"l3": "value"},
},
},
{
name: "non-map value",
obj: "string",
want: nil,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
got := collect(tt.obj)
if len(got) != len(tt.want) {
t.Errorf("collect() got %d maps, want %d", len(got), len(tt.want))
return
}
// Compare each map in the result
for i := range tt.want {
if !mapsEqual(got[i], tt.want[i]) {
t.Errorf("collect() map[%d] = %v, want %v", i, got[i], tt.want[i])
}
}
})
}
}
// mapsEqual compares two maps for deep equality
func mapsEqual(m1, m2 map[string]any) bool {
if len(m1) != len(m2) {
return false
}
for k, v1 := range m1 {
v2, ok := m2[k]
if !ok {
return false
}
switch val1 := v1.(type) {
case map[string]any:
val2, ok := v2.(map[string]any)
if !ok || !mapsEqual(val1, val2) {
return false
}
default:
if v1 != v2 {
return false
}
}
}
return true
}