Merge remote-tracking branch 'upstream/main' into vulkanV3
This commit is contained in:
commit
d97c2ab8b9
|
|
@ -22,7 +22,7 @@
|
|||
"name": "CUDA 12",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50-virtual;60-virtual;61-virtual;70-virtual;75-virtual;80-virtual;86-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;120-virtual",
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120",
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
|
||||
}
|
||||
},
|
||||
|
|
@ -30,14 +30,14 @@
|
|||
"name": "JetPack 5",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "72-virtual;87-virtual"
|
||||
"CMAKE_CUDA_ARCHITECTURES": "72;87"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "JetPack 6",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "87-virtual"
|
||||
"CMAKE_CUDA_ARCHITECTURES": "87"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
|
|
|||
|
|
@ -104,6 +104,8 @@ RUN go mod download
|
|||
COPY . .
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
ARG CGO_CFLAGS
|
||||
ARG CGO_CXXFLAGS
|
||||
RUN --mount=type=cache,target=/root/.cache/go-build \
|
||||
go build -trimpath -buildmode=pie -o /bin/ollama .
|
||||
|
||||
|
|
|
|||
|
|
@ -411,6 +411,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
|
||||
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
|
||||
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
|
||||
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
|
||||
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
|
||||
|
||||
### Cloud
|
||||
|
||||
|
|
@ -537,6 +539,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
|
||||
- [Ollama for D](https://github.com/kassane/ollama-d)
|
||||
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
|
||||
- [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/))
|
||||
- [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/))
|
||||
- [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) (Data-driven multi-agent orchestration framework) with [example](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama)
|
||||
|
||||
### Mobile
|
||||
|
||||
|
|
@ -597,6 +602,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
|||
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
|
||||
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
|
||||
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
|
||||
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
|
||||
|
||||
### Supported backends
|
||||
|
||||
|
|
|
|||
48
api/types.go
48
api/types.go
|
|
@ -90,6 +90,10 @@ type GenerateRequest struct {
|
|||
// (request that thinking _not_ be used) and unset (use the old behavior
|
||||
// before this option was introduced)
|
||||
Think *ThinkValue `json:"think,omitempty"`
|
||||
|
||||
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
||||
// template instead of calling the model.
|
||||
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
|
|
@ -120,6 +124,10 @@ type ChatRequest struct {
|
|||
// responding. Can be a boolean (true/false) or a string ("high", "medium", "low")
|
||||
// for supported models.
|
||||
Think *ThinkValue `json:"think,omitempty"`
|
||||
|
||||
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
||||
// template instead of calling the model.
|
||||
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
|
|
@ -278,16 +286,23 @@ func mapToTypeScriptType(jsonType string) string {
|
|||
}
|
||||
}
|
||||
|
||||
type ToolFunctionParameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]ToolProperty `json:"properties"`
|
||||
}
|
||||
|
||||
func (t *ToolFunctionParameters) String() string {
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
type ToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]ToolProperty `json:"properties"`
|
||||
} `json:"parameters"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters ToolFunctionParameters `json:"parameters"`
|
||||
}
|
||||
|
||||
func (t *ToolFunction) String() string {
|
||||
|
|
@ -308,6 +323,19 @@ type ChatResponse struct {
|
|||
Metrics
|
||||
}
|
||||
|
||||
// DebugInfo contains debug information for template rendering
|
||||
type DebugInfo struct {
|
||||
RenderedTemplate string `json:"rendered_template"`
|
||||
ImageCount int `json:"image_count,omitempty"`
|
||||
}
|
||||
|
||||
// DebugTemplateResponse is returned when _debug_render_only is set to true
|
||||
type DebugTemplateResponse struct {
|
||||
Model string `json:"model"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
DebugInfo DebugInfo `json:"_debug_info"`
|
||||
}
|
||||
|
||||
type Metrics struct {
|
||||
TotalDuration time.Duration `json:"total_duration,omitempty"`
|
||||
LoadDuration time.Duration `json:"load_duration,omitempty"`
|
||||
|
|
@ -860,7 +888,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
|||
if t < 0 {
|
||||
d.Duration = time.Duration(math.MaxInt64)
|
||||
} else {
|
||||
d.Duration = time.Duration(int(t) * int(time.Second))
|
||||
d.Duration = time.Duration(t * float64(time.Second))
|
||||
}
|
||||
case string:
|
||||
d.Duration, err = time.ParseDuration(t)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
|||
req string
|
||||
exp *Duration
|
||||
}{
|
||||
{
|
||||
name: "Unset",
|
||||
req: `{ }`,
|
||||
exp: nil,
|
||||
},
|
||||
{
|
||||
name: "Positive Integer",
|
||||
req: `{ "keep_alive": 42 }`,
|
||||
|
|
@ -25,7 +30,7 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
|||
{
|
||||
name: "Positive Float",
|
||||
req: `{ "keep_alive": 42.5 }`,
|
||||
exp: &Duration{42 * time.Second},
|
||||
exp: &Duration{42500 * time.Millisecond},
|
||||
},
|
||||
{
|
||||
name: "Positive Integer String",
|
||||
|
|
@ -436,3 +441,50 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFunctionParameters_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
params ToolFunctionParameters
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple object with string property",
|
||||
params: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"name"},
|
||||
Properties: map[string]ToolProperty{
|
||||
"name": {
|
||||
Type: PropertyType{"string"},
|
||||
Description: "The name of the person",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
|
||||
},
|
||||
{
|
||||
name: "marshal failure returns empty string",
|
||||
params: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Defs: func() any {
|
||||
// Create a cycle that will cause json.Marshal to fail
|
||||
type selfRef struct {
|
||||
Self *selfRef
|
||||
}
|
||||
s := &selfRef{}
|
||||
s.Self = s
|
||||
return s
|
||||
}(),
|
||||
Properties: map[string]ToolProperty{},
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := test.params.String()
|
||||
assert.Equal(t, test.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1612,6 +1612,7 @@ func NewCLI() *cobra.Command {
|
|||
appendEnvDocs(cmd, []envconfig.EnvVar{
|
||||
envVars["OLLAMA_DEBUG"],
|
||||
envVars["OLLAMA_HOST"],
|
||||
envVars["OLLAMA_CONTEXT_LENGTH"],
|
||||
envVars["OLLAMA_KEEP_ALIVE"],
|
||||
envVars["OLLAMA_MAX_LOADED_MODELS"],
|
||||
envVars["OLLAMA_MAX_QUEUE"],
|
||||
|
|
|
|||
|
|
@ -15,19 +15,24 @@ import (
|
|||
|
||||
type gptossModel struct {
|
||||
ModelParameters
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
AttentionHeads uint32 `json:"num_attention_heads"`
|
||||
KeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
Experts uint32 `json:"num_experts"`
|
||||
ExpertsPerToken uint32 `json:"experts_per_token"`
|
||||
RMSNormEpsilon float32 `json:"rms_norm_eps"`
|
||||
InitialContextLength uint32 `json:"initial_context_length"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeScalingFactor float32 `json:"rope_scaling_factor"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
HiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
AttentionHeads uint32 `json:"num_attention_heads"`
|
||||
KeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
Experts uint32 `json:"num_experts"`
|
||||
LocalExperts uint32 `json:"num_local_experts"`
|
||||
ExpertsPerToken uint32 `json:"experts_per_token"`
|
||||
RMSNormEpsilon float32 `json:"rms_norm_eps"`
|
||||
InitialContextLength uint32 `json:"initial_context_length"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeScalingFactor float32 `json:"rope_scaling_factor"`
|
||||
RopeScaling struct {
|
||||
Factor float32 `json:"factor"`
|
||||
} `json:"rope_scaling"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*gptossModel)(nil)
|
||||
|
|
@ -36,11 +41,11 @@ func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
|
|||
kv := m.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "gptoss"
|
||||
kv["general.file_type"] = uint32(4)
|
||||
kv["gptoss.context_length"] = uint32(m.RopeScalingFactor * float32(m.InitialContextLength))
|
||||
kv["gptoss.context_length"] = cmp.Or(m.MaxPositionEmbeddings, uint32(m.RopeScalingFactor*float32(m.InitialContextLength)))
|
||||
kv["gptoss.block_count"] = m.HiddenLayers
|
||||
kv["gptoss.embedding_length"] = m.HiddenSize
|
||||
kv["gptoss.feed_forward_length"] = m.IntermediateSize
|
||||
kv["gptoss.expert_count"] = m.Experts
|
||||
kv["gptoss.expert_count"] = cmp.Or(m.Experts, m.LocalExperts)
|
||||
kv["gptoss.expert_used_count"] = m.ExpertsPerToken
|
||||
kv["gptoss.attention.head_count"] = m.AttentionHeads
|
||||
kv["gptoss.attention.head_count_kv"] = m.KeyValueHeads
|
||||
|
|
@ -49,7 +54,7 @@ func (m *gptossModel) KV(t *Tokenizer) ggml.KV {
|
|||
kv["gptoss.attention.layer_norm_rms_epsilon"] = cmp.Or(m.RMSNormEpsilon, 1e-5)
|
||||
kv["gptoss.attention.sliding_window"] = m.SlidingWindow
|
||||
kv["gptoss.rope.freq_base"] = m.RopeTheta
|
||||
kv["gptoss.rope.scaling.factor"] = m.RopeScalingFactor
|
||||
kv["gptoss.rope.scaling.factor"] = cmp.Or(m.RopeScalingFactor, m.RopeScaling.Factor)
|
||||
kv["gptoss.rope.scaling.original_context_length"] = m.InitialContextLength
|
||||
kv["tokenizer.ggml.bos_token_id"] = uint32(199998) // <|startoftext|>
|
||||
kv["tokenizer.ggml.add_bos_token"] = false
|
||||
|
|
@ -92,6 +97,11 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
|||
|
||||
for name, mxfp4 := range mxfp4s {
|
||||
dims := mxfp4.blocks.Shape()
|
||||
|
||||
if !strings.HasSuffix(name, ".weight") {
|
||||
name += ".weight"
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
|
|
@ -104,25 +114,47 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
|||
}
|
||||
|
||||
func (m *gptossModel) Replacements() []string {
|
||||
return []string{
|
||||
// noop replacements so other replacements will not be applied
|
||||
".blocks", ".blocks",
|
||||
".scales", ".scales",
|
||||
// real replacements
|
||||
"block", "blk",
|
||||
"attn.norm", "attn_norm",
|
||||
"attn.qkv", "attn_qkv",
|
||||
"attn.sinks", "attn_sinks",
|
||||
"attn.out", "attn_out",
|
||||
"mlp.norm", "ffn_norm",
|
||||
"mlp.gate", "ffn_gate_inp",
|
||||
"mlp.mlp1_", "ffn_gate_up_exps.",
|
||||
"mlp.mlp2_", "ffn_down_exps.",
|
||||
"embedding", "token_embd",
|
||||
"norm", "output_norm",
|
||||
"unembedding", "output",
|
||||
"scale", "weight",
|
||||
var replacements []string
|
||||
if m.MaxPositionEmbeddings > 0 {
|
||||
// hf flavored model
|
||||
replacements = []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.layers", "blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_out",
|
||||
"self_attn.sinks", "attn_sinks",
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"mlp.router", "ffn_gate_inp",
|
||||
"mlp.experts.gate_up_proj_", "ffn_gate_up_exps.",
|
||||
"mlp.experts.down_proj_", "ffn_down_exps.",
|
||||
"model.norm", "output_norm",
|
||||
}
|
||||
} else {
|
||||
replacements = []string{
|
||||
// noop replacements so other replacements will not be applied
|
||||
".blocks", ".blocks",
|
||||
".scales", ".scales",
|
||||
// real replacements
|
||||
"block", "blk",
|
||||
"attn.norm", "attn_norm",
|
||||
"attn.qkv", "attn_qkv",
|
||||
"attn.sinks", "attn_sinks",
|
||||
"attn.out", "attn_out",
|
||||
"mlp.norm", "ffn_norm",
|
||||
"mlp.gate", "ffn_gate_inp",
|
||||
"mlp.mlp1_", "ffn_gate_up_exps.",
|
||||
"mlp.mlp2_", "ffn_down_exps.",
|
||||
"embedding", "token_embd",
|
||||
"norm", "output_norm",
|
||||
"unembedding", "output",
|
||||
"scale", "weight",
|
||||
}
|
||||
}
|
||||
return replacements
|
||||
}
|
||||
|
||||
type mxfp4 struct {
|
||||
|
|
@ -140,7 +172,20 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
|
|||
blocksDims[i] = int(d)
|
||||
}
|
||||
|
||||
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(b.Bytes()))
|
||||
bts := b.Bytes()
|
||||
var tmp [16]byte
|
||||
for i := 0; i < b.Len(); i += 16 {
|
||||
for j := range 8 {
|
||||
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
|
||||
a, b := bts[i+j], bts[i+j+8]
|
||||
tmp[2*j+0] = (a & 0x0F) | (b << 4)
|
||||
tmp[2*j+1] = (a >> 4) | (b & 0xF0)
|
||||
}
|
||||
|
||||
copy(bts[i:i+16], tmp[:])
|
||||
}
|
||||
|
||||
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(bts))
|
||||
|
||||
var s bytes.Buffer
|
||||
if _, err := m.scales.WriteTo(&s); err != nil {
|
||||
|
|
@ -174,5 +219,5 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
|
|||
return 0, err
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
return int64(len(u8s)), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,8 +33,8 @@ func (t tensorBase) Shape() []uint64 {
|
|||
const (
|
||||
tensorKindFP32 uint32 = iota
|
||||
tensorKindFP16
|
||||
tensorKindMXFP4 = 4
|
||||
tensorKindBF16 = 30
|
||||
tensorKindMXFP4 = 39
|
||||
)
|
||||
|
||||
func (t tensorBase) Kind() uint32 {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package convert
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
|
|
@ -124,26 +125,41 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
|||
}
|
||||
defer f.Close()
|
||||
|
||||
if seeker, ok := f.(io.Seeker); ok {
|
||||
if _, err := seeker.Seek(st.offset, io.SeekStart); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
if _, err := io.CopyN(io.Discard, f, st.offset); err != nil {
|
||||
return 0, err
|
||||
r, err := func() (io.Reader, error) {
|
||||
if readerAt, ok := f.(io.ReaderAt); ok {
|
||||
return io.NewSectionReader(readerAt, st.offset, st.size), nil
|
||||
} else if seeker, ok := f.(io.Seeker); ok {
|
||||
_, err := seeker.Seek(st.offset, io.SeekStart)
|
||||
return f, err
|
||||
} else {
|
||||
_, err := io.CopyN(io.Discard, f, st.offset)
|
||||
return f, err
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
br := bufio.NewReaderSize(r, min(32<<10, int(st.size)))
|
||||
// special case when input and output are same type and the
|
||||
// tensor doesn't need repacking
|
||||
if (st.repacker == nil) &&
|
||||
((st.dtype == "F32" && st.Kind() == tensorKindFP32) ||
|
||||
(st.dtype == "F16" && st.Kind() == tensorKindFP16) ||
|
||||
(st.dtype == "U8")) {
|
||||
return io.CopyN(w, br, st.size)
|
||||
}
|
||||
|
||||
var f32s []float32
|
||||
switch st.dtype {
|
||||
case "F32":
|
||||
f32s = make([]float32, st.size/4)
|
||||
if err = binary.Read(f, binary.LittleEndian, f32s); err != nil {
|
||||
if err = binary.Read(br, binary.LittleEndian, f32s); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case "F16":
|
||||
u16s := make([]uint16, st.size/2)
|
||||
if err = binary.Read(f, binary.LittleEndian, u16s); err != nil {
|
||||
if err = binary.Read(br, binary.LittleEndian, u16s); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
|
|
@ -154,14 +170,11 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
|||
|
||||
case "BF16":
|
||||
u8s := make([]uint8, st.size)
|
||||
if err = binary.Read(f, binary.LittleEndian, u8s); err != nil {
|
||||
if err = binary.Read(br, binary.LittleEndian, u8s); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
f32s = bfloat16.DecodeFloat32(u8s)
|
||||
case "U8":
|
||||
// U8 tensors do not support repacking or type conversion.
|
||||
return io.CopyN(w, f, st.size)
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown data type: %s", st.dtype)
|
||||
}
|
||||
|
|
@ -175,17 +188,17 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
|||
|
||||
switch st.Kind() {
|
||||
case tensorKindFP32:
|
||||
return 0, binary.Write(w, binary.LittleEndian, f32s)
|
||||
return int64(len(f32s) * 4), binary.Write(w, binary.LittleEndian, f32s)
|
||||
case tensorKindFP16:
|
||||
f16s := make([]uint16, len(f32s))
|
||||
for i := range f32s {
|
||||
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
|
||||
}
|
||||
|
||||
return 0, binary.Write(w, binary.LittleEndian, f16s)
|
||||
return int64(len(f16s) * 2), binary.Write(w, binary.LittleEndian, f16s)
|
||||
case tensorKindBF16:
|
||||
u8s := bfloat16.EncodeFloat32(f32s)
|
||||
return 0, binary.Write(w, binary.LittleEndian, u8s)
|
||||
return int64(len(u8s)), binary.Write(w, binary.LittleEndian, u8s)
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,232 @@
|
|||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/x448/float16"
|
||||
)
|
||||
|
||||
func TestSafetensors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
root, err := os.OpenRoot(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer root.Close()
|
||||
|
||||
cases := []struct {
|
||||
name,
|
||||
dtype string
|
||||
offset,
|
||||
size int64
|
||||
shape []uint64
|
||||
setup func(*testing.T, *os.File)
|
||||
want []byte
|
||||
}{
|
||||
{
|
||||
name: "fp32-fp32",
|
||||
dtype: "F32",
|
||||
size: 32 * 4, // 32 floats, each 4 bytes
|
||||
shape: []uint64{32},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
f32s := make([]float32, 32)
|
||||
for i := range f32s {
|
||||
f32s[i] = float32(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40,
|
||||
0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40,
|
||||
0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
|
||||
0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41,
|
||||
0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41,
|
||||
0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41,
|
||||
0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41,
|
||||
0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fp32-fp16",
|
||||
dtype: "F32",
|
||||
size: 32 * 4, // 32 floats, each 4 bytes
|
||||
shape: []uint64{16, 2},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
f32s := make([]float32, 32)
|
||||
for i := range f32s {
|
||||
f32s[i] = float32(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, f32s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x00, 0x00, 0x3c, 0x00, 0x40, 0x00, 0x42, 0x00, 0x44, 0x00, 0x45, 0x00, 0x46, 0x00, 0x47,
|
||||
0x00, 0x48, 0x80, 0x48, 0x00, 0x49, 0x80, 0x49, 0x00, 0x4a, 0x80, 0x4a, 0x00, 0x4b, 0x80, 0x4b,
|
||||
0x00, 0x4c, 0x40, 0x4c, 0x80, 0x4c, 0xc0, 0x4c, 0x00, 0x4d, 0x40, 0x4d, 0x80, 0x4d, 0xc0, 0x4d,
|
||||
0x00, 0x4e, 0x40, 0x4e, 0x80, 0x4e, 0xc0, 0x4e, 0x00, 0x4f, 0x40, 0x4f, 0x80, 0x4f, 0xc0, 0x4f,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fp16-fp16",
|
||||
dtype: "F16",
|
||||
size: 32 * 2, // 32 floats, each 2 bytes
|
||||
shape: []uint64{16, 2},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
u16s := make([]uint16, 32)
|
||||
for i := range u16s {
|
||||
u16s[i] = float16.Fromfloat32(float32(i)).Bits()
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x00, 0x00, 0x3c, 0x00, 0x40, 0x00, 0x42, 0x00, 0x44, 0x00, 0x45, 0x00, 0x46, 0x00, 0x47,
|
||||
0x00, 0x48, 0x80, 0x48, 0x00, 0x49, 0x80, 0x49, 0x00, 0x4a, 0x80, 0x4a, 0x00, 0x4b, 0x80, 0x4b,
|
||||
0x00, 0x4c, 0x40, 0x4c, 0x80, 0x4c, 0xc0, 0x4c, 0x00, 0x4d, 0x40, 0x4d, 0x80, 0x4d, 0xc0, 0x4d,
|
||||
0x00, 0x4e, 0x40, 0x4e, 0x80, 0x4e, 0xc0, 0x4e, 0x00, 0x4f, 0x40, 0x4f, 0x80, 0x4f, 0xc0, 0x4f,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "fp16-fp32",
|
||||
dtype: "F16",
|
||||
size: 32 * 2, // 32 floats, each 2 bytes
|
||||
shape: []uint64{32},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
u16s := make([]uint16, 32)
|
||||
for i := range u16s {
|
||||
u16s[i] = float16.Fromfloat32(float32(i)).Bits()
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, u16s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40,
|
||||
0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40,
|
||||
0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
|
||||
0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41,
|
||||
0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41,
|
||||
0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41,
|
||||
0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41,
|
||||
0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bf16-bf16",
|
||||
dtype: "BF16",
|
||||
size: 32 * 2, // 32 brain floats, each 2 bytes
|
||||
shape: []uint64{16, 2},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
f32s := make([]float32, 32)
|
||||
for i := range f32s {
|
||||
f32s[i] = float32(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x00, 0x80, 0x3f, 0x00, 0x40, 0x40, 0x40, 0x80, 0x40, 0xa0, 0x40, 0xc0, 0x40, 0xe0, 0x40,
|
||||
0x00, 0x41, 0x10, 0x41, 0x20, 0x41, 0x30, 0x41, 0x40, 0x41, 0x50, 0x41, 0x60, 0x41, 0x70, 0x41,
|
||||
0x80, 0x41, 0x88, 0x41, 0x90, 0x41, 0x98, 0x41, 0xa0, 0x41, 0xa8, 0x41, 0xb0, 0x41, 0xb8, 0x41,
|
||||
0xc0, 0x41, 0xc8, 0x41, 0xd0, 0x41, 0xd8, 0x41, 0xe0, 0x41, 0xe8, 0x41, 0xf0, 0x41, 0xf8, 0x41,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "bf16-fp32",
|
||||
dtype: "BF16",
|
||||
size: 32 * 2, // 32 brain floats, each 2 bytes
|
||||
shape: []uint64{32},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
f32s := make([]float32, 32)
|
||||
for i := range f32s {
|
||||
f32s[i] = float32(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, bfloat16.EncodeFloat32(f32s)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40,
|
||||
0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40,
|
||||
0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41,
|
||||
0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41,
|
||||
0x00, 0x00, 0x80, 0x41, 0x00, 0x00, 0x88, 0x41, 0x00, 0x00, 0x90, 0x41, 0x00, 0x00, 0x98, 0x41,
|
||||
0x00, 0x00, 0xa0, 0x41, 0x00, 0x00, 0xa8, 0x41, 0x00, 0x00, 0xb0, 0x41, 0x00, 0x00, 0xb8, 0x41,
|
||||
0x00, 0x00, 0xc0, 0x41, 0x00, 0x00, 0xc8, 0x41, 0x00, 0x00, 0xd0, 0x41, 0x00, 0x00, 0xd8, 0x41,
|
||||
0x00, 0x00, 0xe0, 0x41, 0x00, 0x00, 0xe8, 0x41, 0x00, 0x00, 0xf0, 0x41, 0x00, 0x00, 0xf8, 0x41,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "u8-u8",
|
||||
dtype: "U8",
|
||||
size: 32, // 32 brain floats, each 1 bytes
|
||||
shape: []uint64{32},
|
||||
setup: func(t *testing.T, f *os.File) {
|
||||
u8s := make([]uint8, 32)
|
||||
for i := range u8s {
|
||||
u8s[i] = uint8(i)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, u8s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
want: []byte{
|
||||
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
|
||||
0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
path := filepath.Base(t.Name())
|
||||
st := safetensor{
|
||||
fs: root.FS(),
|
||||
path: path,
|
||||
dtype: tt.dtype,
|
||||
offset: tt.offset,
|
||||
size: tt.size,
|
||||
tensorBase: &tensorBase{
|
||||
name: tt.name,
|
||||
shape: tt.shape,
|
||||
},
|
||||
}
|
||||
|
||||
f, err := root.Create(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
tt.setup(t, f)
|
||||
|
||||
var b bytes.Buffer
|
||||
if _, err := st.WriteTo(&b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, b.Bytes()); diff != "" {
|
||||
t.Errorf("safetensor.WriteTo() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -97,6 +97,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||
return a < b
|
||||
})
|
||||
gpuCount := 0
|
||||
gpuOrdinalID := 0
|
||||
for _, match := range matches {
|
||||
slog.Debug("evaluating amdgpu node " + match)
|
||||
fp, err := os.Open(match)
|
||||
|
|
@ -187,10 +188,6 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||
continue
|
||||
}
|
||||
|
||||
// Keep track of numeric IDs based on valid GPUs
|
||||
gpuID := gpuCount
|
||||
gpuCount += 1
|
||||
|
||||
// Look up the memory for the current node
|
||||
totalMemory := uint64(0)
|
||||
usedMemory := uint64(0)
|
||||
|
|
@ -269,7 +266,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||
if uniqueID != 0 {
|
||||
ID = fmt.Sprintf("GPU-%016x", uniqueID)
|
||||
} else {
|
||||
ID = strconv.Itoa(gpuID)
|
||||
ID = strconv.Itoa(gpuOrdinalID)
|
||||
}
|
||||
|
||||
gpuInfo := RocmGPUInfo{
|
||||
|
|
@ -280,6 +277,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||
FreeMemory: (totalMemory - usedMemory),
|
||||
},
|
||||
ID: ID,
|
||||
filterID: gpuOrdinalID,
|
||||
Name: name,
|
||||
Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
|
||||
MinimumMemory: rocmMinimumMemory,
|
||||
|
|
@ -287,13 +285,40 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||
DriverMinor: driverMinor,
|
||||
},
|
||||
usedFilepath: usedFile,
|
||||
index: gpuID,
|
||||
index: gpuCount,
|
||||
}
|
||||
|
||||
// Keep track of numeric IDs based on valid GPUs
|
||||
gpuCount += 1
|
||||
|
||||
// If the user wants to filter to a subset of devices, filter out if we aren't a match
|
||||
if len(visibleDevices) > 0 {
|
||||
include := false
|
||||
for _, visible := range visibleDevices {
|
||||
if (uniqueID != 0 && visible == gpuInfo.ID) || visible == strconv.Itoa(gpuInfo.index) {
|
||||
include = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !include {
|
||||
reason := "filtering out device per user request"
|
||||
slog.Info(reason, "id", gpuInfo.ID, "index", gpuInfo.index, "visible_devices", visibleDevices)
|
||||
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
||||
GpuInfo: gpuInfo.GpuInfo,
|
||||
Reason: reason,
|
||||
})
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Ordinal IDs are based on the visible GPUs
|
||||
gpuOrdinalID += 1
|
||||
|
||||
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
|
||||
if totalMemory < IGPUMemLimit {
|
||||
reason := "unsupported Radeon iGPU detected skipping"
|
||||
slog.Info(reason, "id", gpuID, "total", format.HumanBytes2(totalMemory))
|
||||
slog.Info(reason, "id", gpuInfo.ID, "total", format.HumanBytes2(totalMemory))
|
||||
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
||||
GpuInfo: gpuInfo.GpuInfo,
|
||||
Reason: reason,
|
||||
|
|
@ -306,7 +331,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||
}
|
||||
if int(major) < minVer {
|
||||
reason := fmt.Sprintf("amdgpu too old gfx%d%x%x", major, minor, patch)
|
||||
slog.Warn(reason, "gpu", gpuID)
|
||||
slog.Warn(reason, "gpu", gpuInfo.ID)
|
||||
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
||||
GpuInfo: gpuInfo.GpuInfo,
|
||||
Reason: reason,
|
||||
|
|
@ -315,29 +340,8 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||
continue
|
||||
}
|
||||
|
||||
slog.Debug("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
|
||||
slog.Debug("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
|
||||
|
||||
// If the user wants to filter to a subset of devices, filter out if we aren't a match
|
||||
if len(visibleDevices) > 0 {
|
||||
include := false
|
||||
for _, visible := range visibleDevices {
|
||||
if visible == gpuInfo.ID || visible == strconv.Itoa(gpuInfo.index) {
|
||||
include = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !include {
|
||||
reason := "filtering out device per user request"
|
||||
slog.Info(reason, "id", gpuInfo.ID, "visible_devices", visibleDevices)
|
||||
unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{
|
||||
GpuInfo: gpuInfo.GpuInfo,
|
||||
Reason: reason,
|
||||
})
|
||||
|
||||
continue
|
||||
}
|
||||
}
|
||||
slog.Debug("amdgpu memory", "gpu", gpuInfo.ID, "total", format.HumanBytes2(totalMemory))
|
||||
slog.Debug("amdgpu memory", "gpu", gpuInfo.ID, "available", format.HumanBytes2(totalMemory-usedMemory))
|
||||
|
||||
// Final validation is gfx compatibility - load the library if we haven't already loaded it
|
||||
// even if the user overrides, we still need to validate the library
|
||||
|
|
@ -391,7 +395,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||
|
||||
// Check for env var workarounds
|
||||
if name == "1002:687f" { // Vega RX 56
|
||||
gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, [2]string{"HSA_ENABLE_SDMA", "0"})
|
||||
gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, "HSA_ENABLE_SDMA=0")
|
||||
}
|
||||
|
||||
// The GPU has passed all the verification steps and is supported
|
||||
|
|
@ -520,19 +524,26 @@ func verifyKFDDriverAccess() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "rocm" {
|
||||
// TODO shouldn't happen if things are wired correctly...
|
||||
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
|
||||
continue
|
||||
}
|
||||
ids = append(ids, info.ID)
|
||||
// If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number
|
||||
if _, err := strconv.Atoi(info.ID); err == nil {
|
||||
ids = append(ids, fmt.Sprintf("%d", info.filterID))
|
||||
} else {
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// There are 3 potential env vars to use to select GPUs.
|
||||
// ROCR_VISIBLE_DEVICES supports UUID or numeric so is our preferred on linux
|
||||
// GPU_DEVICE_ORDINAL supports numeric IDs only
|
||||
// HIP_VISIBLE_DEVICES supports numeric IDs only
|
||||
return "ROCR_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||
return "ROCR_VISIBLE_DEVICES=" + strings.Join(ids, ",")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -111,6 +111,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) {
|
|||
UnreliableFreeMemory: true,
|
||||
|
||||
ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
|
||||
filterID: i,
|
||||
DependencyPath: []string{libDir},
|
||||
MinimumMemory: rocmMinimumMemory,
|
||||
Name: name,
|
||||
|
|
@ -200,19 +201,26 @@ func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "rocm" {
|
||||
// TODO shouldn't happen if things are wired correctly...
|
||||
slog.Debug("rocmGetVisibleDevicesEnv skipping over non-rocm device", "library", info.Library)
|
||||
continue
|
||||
}
|
||||
ids = append(ids, info.ID)
|
||||
// If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number
|
||||
if _, err := strconv.Atoi(info.ID); err == nil {
|
||||
ids = append(ids, fmt.Sprintf("%d", info.filterID))
|
||||
} else {
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// There are 3 potential env vars to use to select GPUs.
|
||||
// ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows
|
||||
// HIP_VISIBLE_DEVICES supports numeric IDs only
|
||||
// GPU_DEVICE_ORDINAL supports numeric IDs only
|
||||
return "HIP_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||
return "HIP_VISIBLE_DEVICES=" + strings.Join(ids, ",")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,19 +16,6 @@ import (
|
|||
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
||||
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
||||
|
||||
func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "cuda" {
|
||||
// TODO shouldn't happen if things are wired correctly...
|
||||
slog.Debug("cudaGetVisibleDevicesEnv skipping over non-cuda device", "library", info.Library)
|
||||
continue
|
||||
}
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||
}
|
||||
|
||||
func cudaVariant(gpuInfo CudaGPUInfo) string {
|
||||
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
|
||||
if CudaTegra != "" {
|
||||
|
|
|
|||
|
|
@ -439,6 +439,15 @@ func GetGPUInfo() GpuInfoList {
|
|||
}
|
||||
|
||||
rocmGPUs, err = AMDGetGPUInfo()
|
||||
|
||||
// The ID field is used in context of the filtered set of GPUS
|
||||
// so we have to replace any of these numeric IDs with their
|
||||
// placement in this set of GPUs
|
||||
for i := range rocmGPUs {
|
||||
if _, err := strconv.Atoi(rocmGPUs[i].ID); err == nil {
|
||||
rocmGPUs[i].ID = strconv.Itoa(i)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
bootstrapErrors = append(bootstrapErrors, err)
|
||||
}
|
||||
|
|
@ -790,25 +799,16 @@ func getVerboseState() C.uint16_t {
|
|||
|
||||
// Given the list of GPUs this instantiation is targeted for,
|
||||
// figure out the visible devices environment variable
|
||||
//
|
||||
// If different libraries are detected, the first one is what we use
|
||||
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
||||
func (l GpuInfoList) GetVisibleDevicesEnv() []string {
|
||||
if len(l) == 0 {
|
||||
return "", ""
|
||||
return nil
|
||||
}
|
||||
switch l[0].Library {
|
||||
case "cuda":
|
||||
return cudaGetVisibleDevicesEnv(l)
|
||||
case "rocm":
|
||||
return rocmGetVisibleDevicesEnv(l)
|
||||
case "oneapi":
|
||||
return oneapiGetVisibleDevicesEnv(l)
|
||||
case "vulkan":
|
||||
return vkGetVisibleDevicesEnv(l)
|
||||
default:
|
||||
slog.Debug("no filter required for library " + l[0].Library)
|
||||
return "", ""
|
||||
vd := []string{}
|
||||
// Only filter the AMD GPUs at this level, let all NVIDIA devices through
|
||||
if tmp := rocmGetVisibleDevicesEnv(l); tmp != "" {
|
||||
vd = append(vd, tmp)
|
||||
}
|
||||
return vd
|
||||
}
|
||||
|
||||
func GetSystemInfo() SystemInfo {
|
||||
|
|
|
|||
|
|
@ -62,9 +62,9 @@ func GetCPUMem() (memInfo, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (l GpuInfoList) GetVisibleDevicesEnv() (string, string) {
|
||||
func (l GpuInfoList) GetVisibleDevicesEnv() []string {
|
||||
// No-op on darwin
|
||||
return "", ""
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetSystemInfo() SystemInfo {
|
||||
|
|
|
|||
|
|
@ -1,21 +0,0 @@
|
|||
//go:build linux || windows
|
||||
|
||||
package discover
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func oneapiGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "oneapi" {
|
||||
// TODO shouldn't happen if things are wired correctly...
|
||||
slog.Debug("oneapiGetVisibleDevicesEnv skipping over non-sycl device", "library", info.Library)
|
||||
continue
|
||||
}
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
return "ONEAPI_DEVICE_SELECTOR", "level_zero:" + strings.Join(ids, ",")
|
||||
}
|
||||
|
|
@ -27,8 +27,8 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
|
|||
// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
|
||||
DependencyPath []string `json:"lib_path,omitempty"`
|
||||
|
||||
// Extra environment variables specific to the GPU as list of [key,value]
|
||||
EnvWorkarounds [][2]string `json:"envs,omitempty"`
|
||||
// Extra environment variables specific to the GPU as list of [key=value]
|
||||
EnvWorkarounds []string `json:"envs,omitempty"`
|
||||
|
||||
// Set to true if we can NOT reliably discover FreeMemory. A value of true indicates
|
||||
// the FreeMemory is best effort, and may over or under report actual memory usage
|
||||
|
|
@ -36,10 +36,10 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
|
|||
UnreliableFreeMemory bool
|
||||
|
||||
// GPU information
|
||||
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
|
||||
Name string `json:"name"` // user friendly name if available
|
||||
Compute string `json:"compute"` // Compute Capability or gfx
|
||||
FlashAttention bool `json:"flash_attention"` // is flash attention supported
|
||||
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
|
||||
filterID int //nolint:unused,nolintlint // AMD Workaround: The numeric ID of the device used to filter out other devices
|
||||
Name string `json:"name"` // user friendly name if available
|
||||
Compute string `json:"compute"` // Compute Capability or gfx
|
||||
|
||||
// Driver Information - TODO no need to put this on each GPU
|
||||
DriverMajor int `json:"driver_major,omitempty"`
|
||||
|
|
|
|||
|
|
@ -75,7 +75,7 @@ for part in client.chat('gpt-oss:120b', messages=messages, stream=True):
|
|||
import { Ollama } from 'ollama';
|
||||
|
||||
const ollama = new Ollama({
|
||||
host: 'https://ollama.com'
|
||||
host: 'https://ollama.com',
|
||||
headers: {
|
||||
Authorization: "Bearer <api key>"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -185,6 +185,8 @@ var (
|
|||
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
|
||||
// Auth enables authentication between the Ollama client and server
|
||||
UseAuth = Bool("OLLAMA_AUTH")
|
||||
// Enable the new memory estimation logic
|
||||
NewMemoryEstimates = Bool("OLLAMA_NEW_ESTIMATES")
|
||||
)
|
||||
|
||||
func String(s string) func() string {
|
||||
|
|
@ -271,6 +273,7 @@ func AsMap() map[string]EnvVar {
|
|||
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
|
||||
"OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"},
|
||||
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},
|
||||
"OLLAMA_NEW_ESTIMATES": {"OLLAMA_NEW_ESTIMATES", NewMemoryEstimates(), "Enable the new memory estimation logic"},
|
||||
|
||||
// Informational
|
||||
"HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"},
|
||||
|
|
|
|||
|
|
@ -7,9 +7,11 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/util/bufioutil"
|
||||
)
|
||||
|
||||
|
|
@ -275,7 +277,7 @@ type Tensor struct {
|
|||
|
||||
func (t Tensor) block() (n int) {
|
||||
if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
|
||||
return -1
|
||||
return math.MaxInt
|
||||
}
|
||||
|
||||
return
|
||||
|
|
@ -288,24 +290,24 @@ func (t Tensor) blockSize() uint64 {
|
|||
func (t TensorType) BlockSize() uint64 {
|
||||
switch t {
|
||||
case
|
||||
0, // F32
|
||||
1, // F16
|
||||
24, // I8
|
||||
25, // I16
|
||||
26, // I32
|
||||
27, // I64
|
||||
28, // F64
|
||||
30: // BF16
|
||||
TensorTypeF32,
|
||||
TensorTypeF16,
|
||||
TensorTypeI8,
|
||||
TensorTypeI16,
|
||||
TensorTypeI32,
|
||||
TensorTypeI64,
|
||||
TensorTypeF64,
|
||||
TensorTypeBF16:
|
||||
return 1
|
||||
case
|
||||
2, // Q4_0
|
||||
3, // Q4_1
|
||||
4, // MXFP4
|
||||
6, // Q5_0
|
||||
7, // Q5_1
|
||||
8, // Q8_0
|
||||
9, // Q8_1
|
||||
20: // IQ4_NL
|
||||
TensorTypeQ4_0,
|
||||
TensorTypeQ4_1,
|
||||
TensorTypeQ5_0,
|
||||
TensorTypeQ5_1,
|
||||
TensorTypeQ8_0,
|
||||
TensorTypeQ8_1,
|
||||
tensorTypeIQ4_NL,
|
||||
4, TensorTypeMXFP4:
|
||||
return 32
|
||||
default:
|
||||
return 256
|
||||
|
|
@ -328,8 +330,6 @@ func (t TensorType) TypeSize() uint64 {
|
|||
return 2 + blockSize/2
|
||||
case TensorTypeQ4_1:
|
||||
return 2 + 2 + blockSize/2
|
||||
case TensorTypeMXFP4, 39:
|
||||
return 1 + blockSize/2
|
||||
case TensorTypeQ5_0:
|
||||
return 2 + 4 + blockSize/2
|
||||
case TensorTypeQ5_1:
|
||||
|
|
@ -380,6 +380,8 @@ func (t TensorType) TypeSize() uint64 {
|
|||
return blockSize/8 + blockSize/16 + blockSize/32
|
||||
case TensorTypeBF16:
|
||||
return 2
|
||||
case 4, TensorTypeMXFP4:
|
||||
return 1 + blockSize/2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
|
|
@ -479,7 +481,9 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
context *= uint64(numParallel)
|
||||
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
heads := f.KV().HeadCountMax()
|
||||
headsKV := f.KV().HeadCountKVMax()
|
||||
|
|
@ -675,7 +679,12 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
|||
kv[i] *= context
|
||||
}
|
||||
}
|
||||
|
||||
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
||||
if useFlashAttention {
|
||||
// rough estimate of graph size with flash attention on
|
||||
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
|
|
@ -750,6 +759,11 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
|||
|
||||
// SupportsKVCacheType checks if the requested cache type is supported
|
||||
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
if arch := f.KV().Architecture(); slices.Contains([]string{"gptoss", "gpt-oss"}, arch) {
|
||||
// gpt-oss uses attention with sinks which does not support quantized cache types
|
||||
slog.Warn("model only supports non-quantized cache types ", "mode", arch)
|
||||
return cacheType == "f16"
|
||||
}
|
||||
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
|
||||
}
|
||||
|
||||
|
|
@ -766,6 +780,13 @@ func (f GGML) SupportsFlashAttention() bool {
|
|||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||
}
|
||||
|
||||
// FlashAttention checks if the model should enable flash attention
|
||||
func (f GGML) FlashAttention() bool {
|
||||
return slices.Contains([]string{
|
||||
"gptoss", "gpt-oss",
|
||||
}, f.KV().String("general.architecture"))
|
||||
}
|
||||
|
||||
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
||||
func kvCacheBytesPerElement(cacheType string) float64 {
|
||||
switch cacheType {
|
||||
|
|
|
|||
|
|
@ -533,12 +533,15 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
|||
}
|
||||
}
|
||||
|
||||
slices.SortStableFunc(ts, func(a, b *Tensor) int {
|
||||
if i, j := a.block(), b.block(); i > 0 && j > 0 {
|
||||
return cmp.Compare(i, j)
|
||||
}
|
||||
return cmp.Compare(a.Name, b.Name)
|
||||
})
|
||||
slices.SortStableFunc(
|
||||
ts,
|
||||
func(a, b *Tensor) int {
|
||||
return cmp.Or(
|
||||
cmp.Compare(a.block(), b.block()),
|
||||
cmp.Compare(a.Name, b.Name),
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
var s uint64
|
||||
for i := range ts {
|
||||
|
|
|
|||
|
|
@ -11,24 +11,24 @@ import (
|
|||
)
|
||||
|
||||
func TestWriteGGUF(t *testing.T) {
|
||||
r := rand.New(rand.NewPCG(0, 0))
|
||||
b := bytes.NewBuffer(make([]byte, 2*3))
|
||||
for range 8 {
|
||||
t.Run("shuffle", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ts := []*Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
|
||||
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
|
||||
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
||||
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
||||
}
|
||||
|
||||
r.Shuffle(len(ts), func(i, j int) {
|
||||
rand.Shuffle(len(ts), func(i, j int) {
|
||||
ts[i], ts[j] = ts[j], ts[i]
|
||||
})
|
||||
|
||||
|
|
@ -63,14 +63,14 @@ func TestWriteGGUF(t *testing.T) {
|
|||
}
|
||||
|
||||
if diff := cmp.Diff(Tensors{
|
||||
Offset: 608,
|
||||
Offset: 592,
|
||||
items: []*Tensor{
|
||||
{Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.0.ffn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.1.ffn_down.weight", Offset: 96, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.1.ffn_up.weight", Offset: 128, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.2.ffn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
|
||||
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
|
||||
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
|
||||
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
|
||||
|
|
|
|||
|
|
@ -146,8 +146,6 @@ func (ftype FileType) ToTensorType() TensorType {
|
|||
return TensorTypeQ4_0
|
||||
case fileTypeQ4_1:
|
||||
return TensorTypeQ4_1
|
||||
case fileTypeMXFP4:
|
||||
return TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
|
||||
case FileTypeQ8_0:
|
||||
return TensorTypeQ8_0
|
||||
case fileTypeQ5_0:
|
||||
|
|
@ -176,6 +174,8 @@ func (ftype FileType) ToTensorType() TensorType {
|
|||
return TensorTypeQ2_K
|
||||
case FileTypeBF16:
|
||||
return TensorTypeBF16
|
||||
case fileTypeMXFP4:
|
||||
return TensorTypeMXFP4
|
||||
default:
|
||||
slog.Warn("unsupported file type", "type", ftype)
|
||||
return 0 // F32
|
||||
|
|
@ -191,8 +191,8 @@ const (
|
|||
TensorTypeF16
|
||||
TensorTypeQ4_0
|
||||
TensorTypeQ4_1
|
||||
TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
|
||||
tensorTypeQ4_3 // unused by GGML
|
||||
tensorTypeQ4_2
|
||||
tensorTypeQ4_3 // unused by GGML
|
||||
TensorTypeQ5_0
|
||||
TensorTypeQ5_1
|
||||
TensorTypeQ8_0
|
||||
|
|
@ -226,6 +226,7 @@ const (
|
|||
tensorTypeIQ4_NL_4_4 // unused by GGML
|
||||
tensorTypeIQ4_NL_4_8 // unused by GGML
|
||||
tensorTypeIQ4_NL_8_8 // unused by GGML
|
||||
TensorTypeMXFP4
|
||||
)
|
||||
|
||||
// ParseFileType parses the provided GGUF file type
|
||||
|
|
@ -318,7 +319,7 @@ func (t TensorType) String() string {
|
|||
return "F64"
|
||||
case TensorTypeBF16:
|
||||
return "BF16"
|
||||
case TensorTypeMXFP4:
|
||||
case 4, TensorTypeMXFP4:
|
||||
return "MXFP4"
|
||||
default:
|
||||
return "unknown"
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
package server
|
||||
package harmony
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
|
|
@ -19,18 +18,6 @@ const (
|
|||
harmonyParserState_ParsingContent
|
||||
)
|
||||
|
||||
func shouldUseHarmony(model Model) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
// search for harmony tags that are nearly always used
|
||||
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s harmonyParserState) String() string {
|
||||
switch s {
|
||||
// we're looking for the message start tag
|
||||
|
|
@ -275,19 +262,21 @@ const (
|
|||
// HarmonyMessageHandler processes harmony events and accumulates content appropriately.
|
||||
// This is a higher level interface that maps harmony concepts into ollama concepts
|
||||
type HarmonyMessageHandler struct {
|
||||
state harmonyMessageState
|
||||
harmonyParser *HarmonyParser
|
||||
state harmonyMessageState
|
||||
HarmonyParser *HarmonyParser
|
||||
FunctionNameMap *FunctionNameMap
|
||||
}
|
||||
|
||||
// NewHarmonyMessageHandler creates a new message handler
|
||||
func NewHarmonyMessageHandler() *HarmonyMessageHandler {
|
||||
return &HarmonyMessageHandler{
|
||||
state: harmonyMessageState_Normal,
|
||||
harmonyParser: &HarmonyParser{
|
||||
HarmonyParser: &HarmonyParser{
|
||||
MessageStartTag: "<|start|>",
|
||||
MessageEndTag: "<|end|>",
|
||||
HeaderEndTag: "<|message|>",
|
||||
},
|
||||
FunctionNameMap: NewFunctionNameMap(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -298,11 +287,11 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
|
|||
thinkingSb := strings.Builder{}
|
||||
toolContentSb := strings.Builder{}
|
||||
|
||||
events := h.harmonyParser.AddContent(content)
|
||||
events := h.HarmonyParser.AddContent(content)
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case HarmonyEventHeaderComplete:
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "harmony event header complete", "header", event.Header)
|
||||
logutil.Trace("harmony event header complete", "header", event.Header)
|
||||
switch event.Header.Channel {
|
||||
case "analysis":
|
||||
if event.Header.Recipient != "" {
|
||||
|
|
@ -325,7 +314,7 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
|
|||
h.state = harmonyMessageState_Normal
|
||||
}
|
||||
case HarmonyEventContentEmitted:
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "harmony event content", "content", event.Content, "state", h.state)
|
||||
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
|
||||
if h.state == harmonyMessageState_Normal {
|
||||
contentSb.WriteString(event.Content)
|
||||
} else if h.state == harmonyMessageState_Thinking {
|
||||
|
|
@ -378,3 +367,97 @@ func (a *HarmonyToolCallAccumulator) Drain() (*string, string) {
|
|||
func (a *HarmonyToolCallAccumulator) Content() string {
|
||||
return a.acc.String()
|
||||
}
|
||||
|
||||
// FunctionNameMap maps a user-specified function name to a valid function
|
||||
// name for harmony (which look like TypeScript identifiers). This is needed to
|
||||
// transform user-specified function names, which might contain characters that
|
||||
// are not allowed in TypeScript identifiers
|
||||
type FunctionNameMap struct {
|
||||
userToHarmony map[string]string
|
||||
harmonyToUser map[string]string
|
||||
}
|
||||
|
||||
func NewFunctionNameMap() *FunctionNameMap {
|
||||
return &FunctionNameMap{
|
||||
userToHarmony: make(map[string]string),
|
||||
harmonyToUser: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string {
|
||||
harmonyFunctionName := m.deriveName(userFunctionName)
|
||||
m.userToHarmony[userFunctionName] = harmonyFunctionName
|
||||
m.harmonyToUser[harmonyFunctionName] = userFunctionName
|
||||
return harmonyFunctionName
|
||||
}
|
||||
|
||||
// OriginalFromConverted looks up the reverse-mapping of a previously-converted
|
||||
// user->harmony function name. To unmap reliably, the mapping must exist, as
|
||||
// the conversion process is not reversible without the appropriate state
|
||||
func (m *FunctionNameMap) OriginalFromConverted(harmonyFunctionName string) string {
|
||||
if userFunctionName, ok := m.harmonyToUser[harmonyFunctionName]; ok {
|
||||
return userFunctionName
|
||||
}
|
||||
slog.Warn("harmony parser: no reverse mapping found for function name", "harmonyFunctionName", harmonyFunctionName)
|
||||
// fallback to the original function name if we can't find a mapping
|
||||
return harmonyFunctionName
|
||||
}
|
||||
|
||||
// convertToValidChars converts a user-specified function name to a valid
|
||||
// TypeScript identifier.
|
||||
//
|
||||
// Limitations:
|
||||
//
|
||||
// - This doesn't restrict reserved TypeScript keywords.
|
||||
// - We don't perform a real ID_Start/ID_Continue check, and instead use the more
|
||||
// restrictive unicode.IsLetter/unicode.IsDigit check. Unclear what kind of
|
||||
// identifiers these models were trained on, so in the end we might want to
|
||||
// convert unicode-heavy identifiers to their closest ASCII equivalents.
|
||||
func (m *FunctionNameMap) convertToValidChars(userFunctionName string) string {
|
||||
mapper := func(r rune) rune {
|
||||
// first, replace certain characters with underscores
|
||||
if r == ' ' || r == '-' || r == '.' {
|
||||
return '_'
|
||||
}
|
||||
|
||||
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' {
|
||||
return r
|
||||
}
|
||||
|
||||
// finally, remove any other characters
|
||||
return -1
|
||||
}
|
||||
candidate := strings.Map(mapper, userFunctionName)
|
||||
|
||||
// set a default name if we end up with nothing left
|
||||
if candidate == "" {
|
||||
return "unnamed"
|
||||
}
|
||||
|
||||
// if the candidate starts with a number, prepend an underscore to make it a
|
||||
// valid identifier
|
||||
if unicode.IsDigit(rune(candidate[0])) {
|
||||
candidate = "_" + candidate
|
||||
}
|
||||
|
||||
return candidate
|
||||
}
|
||||
|
||||
func (m *FunctionNameMap) deriveName(userFunctionName string) string {
|
||||
originalCandidate := m.convertToValidChars(userFunctionName)
|
||||
candidate := originalCandidate
|
||||
|
||||
// Check for dupes, and if so, add a number to the end.
|
||||
// We start at 2 because if we have dupes and the first is never renamed, it
|
||||
// makes sense for them to be named, say, `f`, `f_2`, `f_3`
|
||||
count := 2
|
||||
for {
|
||||
if _, exists := m.harmonyToUser[candidate]; !exists {
|
||||
break
|
||||
}
|
||||
candidate = fmt.Sprintf("%s_%d", originalCandidate, count)
|
||||
count++
|
||||
}
|
||||
|
||||
return candidate
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
package server
|
||||
package harmony
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
|
@ -467,3 +467,71 @@ func TestHarmonyParserStreaming(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestFunctionConvertToValidChars tests only FunctionNameMap.convert(), which doesn't
|
||||
// handle any saving (and therefore no dupe handling)
|
||||
func TestFunctionConvertToValidChars(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{name: "replace spaces with underscores", in: "get weather", want: "get_weather"},
|
||||
{name: "replace hyphens with underscores", in: "get-weather", want: "get_weather"},
|
||||
{name: "replace periods with underscores", in: "get.weather", want: "get_weather"},
|
||||
{name: "disallow non-word characters", in: "get weather!", want: "get_weather"},
|
||||
{name: "strip out invalid non-alphanumeric unicode characters", in: "a🫠bc", want: "abc"},
|
||||
{name: "names that only contain invalid characters", in: "🫠", want: "unnamed"},
|
||||
{name: "leading number", in: "123", want: "_123"},
|
||||
{name: "$ allowed", in: "$", want: "$"},
|
||||
// show that we allow weird unicode letter characters, though we might want
|
||||
// to convert them to their closest ASCII equivalents in the future
|
||||
{name: "allow weird unicode letter characters", in: "𝓸𝓵𝓵𝓪𝓶𝓪", want: "𝓸𝓵𝓵𝓪𝓶𝓪"},
|
||||
// names that look like words but are invalid (i.e., not ID_Start/ID_Continue)
|
||||
{name: "disallow non-word characters that look like words", in: "ⓞⓛⓛⓐⓜⓐ123", want: "_123"},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := NewFunctionNameMap()
|
||||
got := parser.convertToValidChars(tt.in)
|
||||
if got != tt.want {
|
||||
t.Errorf("case %d: got %q, want %q", i, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFunctionConvertAndAdd(t *testing.T) {
|
||||
// make a fresh map for each test, but within a test use the same map so we can test for dupe handling
|
||||
tests := []struct {
|
||||
name string
|
||||
in []string
|
||||
want []string
|
||||
}{
|
||||
{name: "basic dupe handling", in: []string{"get weather", "get weather"}, want: []string{"get_weather", "get_weather_2"}},
|
||||
{name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}},
|
||||
{name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}},
|
||||
{name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
parser := NewFunctionNameMap()
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
for j, in := range tt.in {
|
||||
got := parser.ConvertAndAdd(in)
|
||||
want := tt.want[j]
|
||||
if got != want {
|
||||
t.Errorf("case %d: got %q, want %q", i, got, want)
|
||||
}
|
||||
// check that the maps are correct
|
||||
if parser.userToHarmony[in] != want {
|
||||
t.Errorf("case %d: userToHarmony[%q] = %q, want %q", i, in, parser.userToHarmony[in], want)
|
||||
}
|
||||
if parser.harmonyToUser[want] != in {
|
||||
t.Errorf("case %d: harmonyToUser[%q] = %q, want %q", i, want, parser.harmonyToUser[want], in)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -2,10 +2,13 @@
|
|||
|
||||
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
|
||||
|
||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...`
|
||||
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` Some tests require additional tags to enable to allow scoped testing to keep the duration reasonable. For example, testing a broad set of models requires `-tags=integration,models` and a longer timeout (~60m or more depending on the speed of your GPU.). To view the current set of tag combinations use `find integration -type f | xargs grep "go:build"`
|
||||
|
||||
|
||||
The integration tests have 2 modes of operating.
|
||||
|
||||
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
|
||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote
|
||||
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote based on your `OLLAMA_HOST` environment variable
|
||||
|
||||
> [!IMPORTANT]
|
||||
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
|
||||
|
|
|
|||
|
|
@ -390,7 +390,7 @@ func TestAPIEmbeddings(t *testing.T) {
|
|||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
req := api.EmbeddingRequest{
|
||||
Model: "orca-mini",
|
||||
Model: libraryEmbedModels[0],
|
||||
Prompt: "why is the sky blue?",
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBlueSky(t *testing.T) {
|
||||
|
|
@ -37,8 +36,8 @@ func TestUnicode(t *testing.T) {
|
|||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
// DeepSeek has a Unicode tokenizer regex, making it a unicode torture test
|
||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K",
|
||||
Prompt: "天空为什么是蓝色的?",
|
||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage?
|
||||
Prompt: "天空为什么是蓝色的?", // Why is the sky blue?
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
|
|
@ -50,8 +49,20 @@ func TestUnicode(t *testing.T) {
|
|||
}
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
DoGenerate(ctx, t, client, req, []string{"散射", "频率"}, 120*time.Second, 120*time.Second)
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
slog.Info("loading", "model", req.Model)
|
||||
err := client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", req.Model, err)
|
||||
}
|
||||
skipIfNotGPULoaded(ctx, t, client, req.Model, 100)
|
||||
|
||||
DoGenerate(ctx, t, client, req, []string{
|
||||
"散射", // scattering
|
||||
"频率", // frequency
|
||||
}, 120*time.Second, 120*time.Second)
|
||||
}
|
||||
|
||||
func TestExtendedUnicodeOutput(t *testing.T) {
|
||||
|
|
@ -69,7 +80,9 @@ func TestExtendedUnicodeOutput(t *testing.T) {
|
|||
}
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
|
||||
}
|
||||
|
||||
|
|
@ -84,7 +97,9 @@ func TestUnicodeModelDir(t *testing.T) {
|
|||
}
|
||||
|
||||
modelDir, err := os.MkdirTemp("", "ollama_埃")
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(modelDir)
|
||||
slog.Info("unicode", "OLLAMA_MODELS", modelDir)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,254 +7,167 @@ import (
|
|||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
func TestMultiModelConcurrency(t *testing.T) {
|
||||
var (
|
||||
req = [2]api.GenerateRequest{
|
||||
{
|
||||
Model: smol,
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "qwen3:0.6b",
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
resp = [2][]string{
|
||||
{"sunlight"},
|
||||
{"england", "english", "massachusetts", "pilgrims", "british", "festival"},
|
||||
}
|
||||
)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(req))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*240)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for i := 0; i < len(req); i++ {
|
||||
require.NoError(t, PullIfMissing(ctx, client, req[i].Model))
|
||||
}
|
||||
|
||||
for i := 0; i < len(req); i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
// Note: CPU based inference can crawl so don't give up too quickly
|
||||
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 30*time.Second)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestIntegrationConcurrentPredict(t *testing.T) {
|
||||
// Send multiple requests in parallel (concurrently) to a single model and ensure responses are expected
|
||||
func TestConcurrentGenerate(t *testing.T) {
|
||||
// Assumes all requests have the same model
|
||||
req, resp := GenerateRequests()
|
||||
reqLimit := len(req)
|
||||
iterLimit := 5
|
||||
numParallel := int(envconfig.NumParallel() + 1)
|
||||
iterLimit := 3
|
||||
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
maxVram, err := strconv.ParseUint(s, 10, 64)
|
||||
require.NoError(t, err)
|
||||
// Don't hammer on small VRAM cards...
|
||||
if maxVram < 4*format.GibiByte {
|
||||
reqLimit = min(reqLimit, 2)
|
||||
iterLimit = 2
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 9*time.Minute)
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial request
|
||||
DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 10*time.Second)
|
||||
slog.Info("loading", "model", req[0].Model)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: req[0].Model, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", req[0].Model, err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(reqLimit)
|
||||
for i := 0; i < reqLimit; i++ {
|
||||
r := rand.New(rand.NewSource(0))
|
||||
wg.Add(numParallel)
|
||||
for i := range numParallel {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterLimit; j++ {
|
||||
slog.Info("Starting", "req", i, "iter", j)
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
k := r.Int() % len(req)
|
||||
slog.Info("Starting", "thread", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 20*time.Second)
|
||||
DoGenerate(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
|
||||
// Stress the scheduler and attempt to load more models than will fit to cause thrashing
|
||||
// This test will always load at least 2 models even on CPU based systems
|
||||
func TestMultiModelStress(t *testing.T) {
|
||||
s := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
|
||||
s := os.Getenv("OLLAMA_MAX_VRAM")
|
||||
if s == "" {
|
||||
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
|
||||
s = "0"
|
||||
}
|
||||
|
||||
maxVram, err := strconv.ParseUint(s, 10, 64)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if maxVram < 2*format.GibiByte {
|
||||
t.Skip("VRAM less than 2G, skipping model stress tests")
|
||||
|
||||
// All models compatible with ollama-engine
|
||||
smallModels := []string{
|
||||
"llama3.2:1b",
|
||||
"qwen3:0.6b",
|
||||
"gemma2:2b",
|
||||
"deepseek-r1:1.5b", // qwen2 arch
|
||||
"gemma3:270m",
|
||||
}
|
||||
mediumModels := []string{
|
||||
"llama3.2:3b", // ~3.4G
|
||||
"qwen3:8b", // ~6.6G
|
||||
"gpt-oss:20b", // ~15G
|
||||
"deepseek-r1:7b", // ~5.6G
|
||||
"gemma3:4b", // ~5.8G
|
||||
"gemma2:9b", // ~8.1G
|
||||
}
|
||||
|
||||
type model struct {
|
||||
name string
|
||||
size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
|
||||
}
|
||||
|
||||
smallModels := []model{
|
||||
{
|
||||
name: "llama3.2:1b",
|
||||
size: 2876 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "qwen3:0.6b",
|
||||
size: 1600 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "gemma:2b",
|
||||
size: 2364 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "deepseek-r1:1.5b",
|
||||
size: 2048 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "starcoder2:3b",
|
||||
size: 2166 * format.MebiByte,
|
||||
},
|
||||
}
|
||||
mediumModels := []model{
|
||||
{
|
||||
name: "qwen3:8b",
|
||||
size: 6600 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "llama2",
|
||||
size: 5118 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "deepseek-r1:7b",
|
||||
size: 5600 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "mistral",
|
||||
size: 4620 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "dolphin-mistral",
|
||||
size: 4620 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "gemma:7b",
|
||||
size: 5000 * format.MebiByte,
|
||||
},
|
||||
{
|
||||
name: "codellama:7b",
|
||||
size: 5118 * format.MebiByte,
|
||||
},
|
||||
}
|
||||
|
||||
// These seem to be too slow to be useful...
|
||||
// largeModels := []model{
|
||||
// {
|
||||
// name: "llama2:13b",
|
||||
// size: 7400 * format.MebiByte,
|
||||
// },
|
||||
// {
|
||||
// name: "codellama:13b",
|
||||
// size: 7400 * format.MebiByte,
|
||||
// },
|
||||
// {
|
||||
// name: "orca-mini:13b",
|
||||
// size: 7400 * format.MebiByte,
|
||||
// },
|
||||
// {
|
||||
// name: "gemma:7b",
|
||||
// size: 5000 * format.MebiByte,
|
||||
// },
|
||||
// {
|
||||
// name: "starcoder2:15b",
|
||||
// size: 9100 * format.MebiByte,
|
||||
// },
|
||||
// }
|
||||
|
||||
var chosenModels []model
|
||||
var chosenModels []string
|
||||
switch {
|
||||
case maxVram < 10000*format.MebiByte:
|
||||
slog.Info("selecting small models")
|
||||
chosenModels = smallModels
|
||||
// case maxVram < 30000*format.MebiByte:
|
||||
default:
|
||||
slog.Info("selecting medium models")
|
||||
chosenModels = mediumModels
|
||||
// default:
|
||||
// slog.Info("selecting large models")
|
||||
// chosenModels = largeModels
|
||||
}
|
||||
|
||||
req, resp := GenerateRequests()
|
||||
|
||||
for i := range req {
|
||||
if i > len(chosenModels) {
|
||||
break
|
||||
}
|
||||
req[i].Model = chosenModels[i].name
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Make sure all the models are pulled before we get started
|
||||
for _, r := range req {
|
||||
require.NoError(t, PullIfMissing(ctx, client, r.Model))
|
||||
for _, model := range chosenModels {
|
||||
if err := PullIfMissing(ctx, client, model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
consumed := uint64(256 * format.MebiByte) // Assume some baseline usage
|
||||
for i := 0; i < len(req); i++ {
|
||||
// Always get at least 2 models, but don't overshoot VRAM too much or we'll take too long
|
||||
if i > 1 && consumed > maxVram {
|
||||
slog.Info("achieved target vram exhaustion", "count", i, "vram", format.HumanBytes2(maxVram), "models", format.HumanBytes2(consumed))
|
||||
break
|
||||
// Determine how many models we can load in parallel before we exceed VRAM
|
||||
// The intent is to go 1 over what can fit so we force the scheduler to thrash
|
||||
targetLoadCount := 0
|
||||
slog.Info("Loading models to find how many can fit in VRAM before overflowing")
|
||||
for i, model := range chosenModels {
|
||||
req := &api.GenerateRequest{Model: model}
|
||||
slog.Info("loading", "model", model)
|
||||
err = client.Generate(ctx, req, func(response api.GenerateResponse) error { return nil })
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", model, err)
|
||||
}
|
||||
consumed += chosenModels[i].size
|
||||
slog.Info("target vram", "count", i, "vram", format.HumanBytes2(maxVram), "models", format.HumanBytes2(consumed))
|
||||
targetLoadCount++
|
||||
if i > 0 {
|
||||
models, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to list running models: %s", err)
|
||||
}
|
||||
if len(models.Models) < targetLoadCount {
|
||||
loaded := []string{}
|
||||
for _, m := range models.Models {
|
||||
loaded = append(loaded, m.Name)
|
||||
}
|
||||
slog.Info("found model load capacity", "target", targetLoadCount, "current", loaded, "chosen", chosenModels[:targetLoadCount])
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if targetLoadCount == len(chosenModels) {
|
||||
// TODO consider retrying the medium models
|
||||
slog.Warn("all models being used without exceeding VRAM, set OLLAMA_MAX_VRAM so test can pick larger models")
|
||||
}
|
||||
|
||||
r := rand.New(rand.NewSource(0))
|
||||
var wg sync.WaitGroup
|
||||
for i := range targetLoadCount {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
reqs, resps := GenerateRequests()
|
||||
for j := 0; j < 3; j++ {
|
||||
slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
|
||||
DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 5*time.Second)
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
k := r.Int() % len(reqs)
|
||||
reqs[k].Model = chosenModels[i]
|
||||
slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Prompt)
|
||||
DoGenerate(ctx, t, client, reqs[k], resps[k],
|
||||
120*time.Second, // Be extra patient for the model to load initially
|
||||
10*time.Second, // Once results start streaming, fail if they stall
|
||||
)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ package integration
|
|||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
|
@ -20,7 +22,7 @@ func TestLongInputContext(t *testing.T) {
|
|||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: "llama2",
|
||||
Model: smol,
|
||||
Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
|
|
@ -34,7 +36,7 @@ func TestLongInputContext(t *testing.T) {
|
|||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("PullIfMissing failed: %v", err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia"}, 120*time.Second, 10*time.Second)
|
||||
DoGenerate(ctx, t, client, req, []string{"russia", "germany", "france", "england", "austria", "prussia", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
func TestContextExhaustion(t *testing.T) {
|
||||
|
|
@ -47,7 +49,7 @@ func TestContextExhaustion(t *testing.T) {
|
|||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: "llama2",
|
||||
Model: smol,
|
||||
Prompt: "Write me a story with a ton of emojis?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
|
|
@ -61,5 +63,104 @@ func TestContextExhaustion(t *testing.T) {
|
|||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("PullIfMissing failed: %v", err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
|
||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
||||
func TestGenerateWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
req, resp := GenerateRequests()
|
||||
numParallel := 2
|
||||
iterLimit := 2
|
||||
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial request
|
||||
slog.Info("loading", "model", modelOverride)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", modelOverride, err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numParallel)
|
||||
for i := range numParallel {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
k := i % len(req)
|
||||
req[k].Model = modelOverride
|
||||
for j := 0; j < iterLimit; j++ {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
slog.Info("Starting", "thread", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
c := DoGenerate(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||
req[k].Context = c
|
||||
req[k].Prompt = "tell me more!"
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
||||
func TestChatWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
req, resp := ChatRequests()
|
||||
numParallel := 2
|
||||
iterLimit := 2
|
||||
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial empty request
|
||||
slog.Info("loading", "model", modelOverride)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", modelOverride, err)
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numParallel)
|
||||
for i := range numParallel {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
k := i % len(req)
|
||||
req[k].Model = modelOverride
|
||||
for j := 0; j < iterLimit; j++ {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
return
|
||||
}
|
||||
slog.Info("Starting", "thread", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||
if assistant == nil {
|
||||
t.Fatalf("didn't get an assistant response for context")
|
||||
}
|
||||
req[k].Messages = append(req[k].Messages,
|
||||
*assistant,
|
||||
api.Message{Role: "user", Content: "tell me more!"},
|
||||
)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVisionModels(t *testing.T) {
|
||||
|
|
@ -32,7 +31,9 @@ func TestVisionModels(t *testing.T) {
|
|||
for _, v := range testCases {
|
||||
t.Run(v.model, func(t *testing.T) {
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := api.GenerateRequest{
|
||||
Model: v.model,
|
||||
Prompt: "what does the text in this image say?",
|
||||
|
|
@ -52,7 +53,9 @@ func TestVisionModels(t *testing.T) {
|
|||
// Note: sometimes it returns "the ollamas" sometimes "the ollams"
|
||||
resp := "the ollam"
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// llava models on CPU can be quite slow to start
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||
})
|
||||
|
|
@ -62,7 +65,9 @@ func TestVisionModels(t *testing.T) {
|
|||
func TestIntegrationSplitBatch(t *testing.T) {
|
||||
skipUnderMinVRAM(t, 6)
|
||||
image, err := base64.StdEncoding.DecodeString(imageEncoding)
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := api.GenerateRequest{
|
||||
Model: "gemma3:4b",
|
||||
// Fill up a chunk of the batch so the image will partially spill over into the next one
|
||||
|
|
@ -84,7 +89,9 @@ func TestIntegrationSplitBatch(t *testing.T) {
|
|||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// llava models on CPU can be quite slow to start,
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,47 +0,0 @@
|
|||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
|
||||
// package to avoid circular dependencies
|
||||
|
||||
var (
|
||||
stream = false
|
||||
req = [2]api.GenerateRequest{
|
||||
{
|
||||
Model: smol,
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
},
|
||||
}
|
||||
resp = [2][]string{
|
||||
{"sunlight", "scattering", "interact"},
|
||||
{"england", "english", "massachusetts", "pilgrims"},
|
||||
}
|
||||
)
|
||||
|
||||
func TestIntegrationSimple(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||
defer cancel()
|
||||
GenerateTestHelper(ctx, t, req[0], resp[0])
|
||||
}
|
||||
|
|
@ -13,12 +13,12 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestMaxQueue(t *testing.T) {
|
||||
t.Skip("this test needs to be re-evaluated to use a proper embedding model")
|
||||
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
|
||||
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
|
||||
return
|
||||
|
|
@ -45,7 +45,9 @@ func TestMaxQueue(t *testing.T) {
|
|||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Context for the worker threads so we can shut them down
|
||||
// embedCtx, embedCancel := context.WithCancel(ctx)
|
||||
|
|
@ -89,7 +91,9 @@ func TestMaxQueue(t *testing.T) {
|
|||
switch {
|
||||
case genErr == nil:
|
||||
successCount++
|
||||
require.Greater(t, len(resp.Embedding), 5) // somewhat arbitrary, but sufficient to be reasonable
|
||||
if len(resp.Embedding) < 5 { // somewhat arbitrary, but sufficient to be reasonable
|
||||
t.Fatalf("embeddings shorter than expected: %d", len(resp.Embedding))
|
||||
}
|
||||
case errors.Is(genErr, context.Canceled):
|
||||
canceledCount++
|
||||
case strings.Contains(genErr.Error(), "busy"):
|
||||
|
|
@ -97,7 +101,9 @@ func TestMaxQueue(t *testing.T) {
|
|||
case strings.Contains(genErr.Error(), "connection reset by peer"):
|
||||
resetByPeerCount++
|
||||
default:
|
||||
require.NoError(t, genErr, "%d request failed", i)
|
||||
if genErr != nil {
|
||||
t.Fatalf("%d request failed", i)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("embed finished", "id", i)
|
||||
|
|
@ -108,8 +114,13 @@ func TestMaxQueue(t *testing.T) {
|
|||
embedwg.Wait()
|
||||
|
||||
slog.Info("embeds completed", "success", successCount, "busy", busyCount, "reset", resetByPeerCount, "canceled", canceledCount)
|
||||
require.Equal(t, resetByPeerCount, 0, "Connections reset by peer, have you updated your fd and socket limits?")
|
||||
require.True(t, busyCount > 0, "no requests hit busy error but some should have")
|
||||
require.True(t, canceledCount == 0, "no requests should have been canceled due to timeout")
|
||||
|
||||
if resetByPeerCount != 0 {
|
||||
t.Fatalf("Connections reset by peer, have you updated your fd and socket limits? %d", resetByPeerCount)
|
||||
}
|
||||
if busyCount == 0 {
|
||||
t.Fatalf("no requests hit busy error but some should have")
|
||||
}
|
||||
if canceledCount > 0 {
|
||||
t.Fatalf("no requests should have been canceled due to timeout %d", canceledCount)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -9,6 +9,7 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
|
|
@ -25,11 +26,11 @@ import (
|
|||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/app/lifecycle"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
smol = "llama3.2:1b"
|
||||
smol = "llama3.2:1b"
|
||||
stream = false
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -435,7 +436,9 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
|||
}
|
||||
lifecycle.ServerLogFile = fp.Name()
|
||||
fp.Close()
|
||||
require.NoError(t, startServer(t, ctx, testEndpoint))
|
||||
if err := startServer(t, ctx, testEndpoint); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
return client, testEndpoint, func() {
|
||||
|
|
@ -468,19 +471,25 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
|||
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, genReq.Model))
|
||||
if err := PullIfMissing(ctx, client, genReq.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) {
|
||||
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) []int {
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var buf bytes.Buffer
|
||||
var context []int
|
||||
fn := func(response api.GenerateResponse) error {
|
||||
// fmt.Print(".")
|
||||
buf.Write([]byte(response.Response))
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return errors.New("stall was detected while streaming response, aborting")
|
||||
}
|
||||
if len(response.Context) > 0 {
|
||||
context = response.Context
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -503,9 +512,11 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
|||
case <-done:
|
||||
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
|
||||
slog.Warn("model is too large for the target test system", "model", genReq.Model, "error", genErr)
|
||||
return
|
||||
return context
|
||||
}
|
||||
if genErr != nil {
|
||||
t.Fatalf("%s failed with %s request prompt %s", genErr, genReq.Model, genReq.Prompt)
|
||||
}
|
||||
require.NoError(t, genErr, "failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
|
|
@ -515,11 +526,14 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
|||
break
|
||||
}
|
||||
}
|
||||
require.True(t, atLeastOne, "%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in %s", genReq.Model, anyResp, response)
|
||||
}
|
||||
slog.Info("test pass", "model", genReq.Model, "prompt", genReq.Prompt, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
}
|
||||
return context
|
||||
}
|
||||
|
||||
// Generate a set of requests
|
||||
|
|
@ -528,65 +542,125 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||
return []api.GenerateRequest{
|
||||
{
|
||||
Model: smol,
|
||||
Prompt: "why is the ocean blue?",
|
||||
Prompt: "why is the ocean blue? Be brief but factual in your reply",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "why is the color of dirt brown?",
|
||||
Prompt: "why is the color of dirt brown? Be brief but factual in your reply",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Prompt: "what is the origin of the US thanksgiving holiday? Be brief but factual in your reply",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "what is the origin of independence day?",
|
||||
Prompt: "what is the origin of independence day? Be brief but factual in your reply",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: smol,
|
||||
Prompt: "what is the composition of air?",
|
||||
Prompt: "what is the composition of air? Be brief but factual in your reply",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
[][]string{
|
||||
{"sunlight", "scattering", "interact"},
|
||||
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles"},
|
||||
{"england", "english", "massachusetts", "pilgrims", "british"},
|
||||
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
|
||||
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
|
||||
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states", "cultural", "hardship", "autumn", "festival"},
|
||||
{"fourth", "july", "declaration", "independence"},
|
||||
{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||
}
|
||||
}
|
||||
|
||||
func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) *api.Message {
|
||||
stallTimer := time.NewTimer(initialTimeout)
|
||||
var buf bytes.Buffer
|
||||
role := "assistant"
|
||||
fn := func(response api.ChatResponse) error {
|
||||
// fmt.Print(".")
|
||||
role = response.Message.Role
|
||||
buf.Write([]byte(response.Message.Content))
|
||||
if !stallTimer.Reset(streamTimeout) {
|
||||
return errors.New("stall was detected while streaming response, aborting")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
stream := true
|
||||
req.Stream = &stream
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Chat(ctx, &req, fn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stallTimer.C:
|
||||
if buf.Len() == 0 {
|
||||
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
|
||||
} else {
|
||||
t.Errorf("generate stalled. Response so far:%s", buf.String())
|
||||
}
|
||||
case <-done:
|
||||
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
|
||||
slog.Warn("model is too large for the target test system", "model", req.Model, "error", genErr)
|
||||
return nil
|
||||
}
|
||||
if genErr != nil {
|
||||
t.Fatalf("%s failed with %s request prompt %v", genErr, req.Model, req.Messages)
|
||||
}
|
||||
|
||||
// Verify the response contains the expected data
|
||||
response := buf.String()
|
||||
atLeastOne := false
|
||||
for _, resp := range anyResp {
|
||||
if strings.Contains(strings.ToLower(response), resp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
|
||||
}
|
||||
|
||||
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
}
|
||||
return &api.Message{Role: role, Content: buf.String()}
|
||||
}
|
||||
|
||||
func ChatRequests() ([]api.ChatRequest, [][]string) {
|
||||
genReqs, results := GenerateRequests()
|
||||
reqs := make([]api.ChatRequest, len(genReqs))
|
||||
// think := api.ThinkValue{Value: "low"}
|
||||
for i := range reqs {
|
||||
reqs[i].Model = genReqs[i].Model
|
||||
reqs[i].Stream = genReqs[i].Stream
|
||||
reqs[i].KeepAlive = genReqs[i].KeepAlive
|
||||
// reqs[i].Think = &think
|
||||
reqs[i].Messages = []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: genReqs[i].Prompt,
|
||||
},
|
||||
}
|
||||
}
|
||||
return reqs, results
|
||||
}
|
||||
|
||||
func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||
// TODO use info API in the future
|
||||
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
|
||||
maxVram, err := strconv.ParseUint(s, 10, 64)
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Don't hammer on small VRAM cards...
|
||||
if maxVram < gb*format.GibiByte {
|
||||
t.Skip("skipping with small VRAM to avoid timeouts")
|
||||
|
|
@ -594,6 +668,39 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
|||
}
|
||||
}
|
||||
|
||||
// Skip if the target model isn't X% GPU loaded to avoid excessive runtime
|
||||
func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, model string, minPercent int) {
|
||||
models, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to list running models: %s", err)
|
||||
}
|
||||
loaded := []string{}
|
||||
for _, m := range models.Models {
|
||||
loaded = append(loaded, m.Name)
|
||||
if m.Name != model {
|
||||
continue
|
||||
}
|
||||
gpuPercent := 0
|
||||
switch {
|
||||
case m.SizeVRAM == 0:
|
||||
gpuPercent = 0
|
||||
case m.SizeVRAM == m.Size:
|
||||
gpuPercent = 100
|
||||
case m.SizeVRAM > m.Size || m.Size == 0:
|
||||
t.Logf("unexpected size detected: %d", m.SizeVRAM)
|
||||
default:
|
||||
sizeCPU := m.Size - m.SizeVRAM
|
||||
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 110)
|
||||
gpuPercent = int(100 - cpuPercent)
|
||||
}
|
||||
if gpuPercent < minPercent {
|
||||
t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent))
|
||||
}
|
||||
return
|
||||
}
|
||||
t.Skip(fmt.Sprintf("model %s not loaded - actually loaded: %v", model, loaded))
|
||||
}
|
||||
|
||||
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
|
||||
deadline, hasDeadline := t.Deadline()
|
||||
if !hasDeadline {
|
||||
|
|
|
|||
|
|
@ -378,9 +378,7 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
|||
maskTensor := ctx.Input().FromFloatSlice(mask, length, batchSize)
|
||||
|
||||
if c.config.MaskDType != ml.DTypeF32 {
|
||||
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
|
||||
ctx.Forward(maskTensor.Copy(ctx, out))
|
||||
maskTensor = out
|
||||
maskTensor = maskTensor.Cast(ctx, c.config.MaskDType)
|
||||
}
|
||||
|
||||
return maskTensor
|
||||
|
|
|
|||
|
|
@ -962,8 +962,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
const int64_t n_vocab = vocab.n_tokens();
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
// when computing embeddings, all tokens are output
|
||||
const bool output_all = cparams.embeddings;
|
||||
const bool output_all = false;
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
|
|
|
|||
|
|
@ -62,6 +62,22 @@ func BackendInit() {
|
|||
C.llama_backend_init()
|
||||
}
|
||||
|
||||
func EnumerateGPUs() []string {
|
||||
var ids []string
|
||||
|
||||
for i := range C.ggml_backend_dev_count() {
|
||||
device := C.ggml_backend_dev_get(i)
|
||||
|
||||
if C.ggml_backend_dev_type(device) == C.GGML_BACKEND_DEVICE_TYPE_GPU {
|
||||
var props C.struct_ggml_backend_dev_props
|
||||
C.ggml_backend_dev_get_props(device, &props)
|
||||
ids = append(ids, C.GoString(props.id))
|
||||
}
|
||||
}
|
||||
|
||||
return ids
|
||||
}
|
||||
|
||||
func GetModelArch(modelPath string) (string, error) {
|
||||
mp := C.CString(modelPath)
|
||||
defer C.free(unsafe.Pointer(mp))
|
||||
|
|
|
|||
|
|
@ -1,32 +0,0 @@
|
|||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Daniel Hiltgen <daniel@ollama.com>
|
||||
Date: Sun, 22 Jun 2025 09:22:05 -0700
|
||||
Subject: [PATCH] temporary prevent rocm+cuda mixed loading
|
||||
|
||||
---
|
||||
ggml/src/ggml-backend-reg.cpp | 12 ++++++++++--
|
||||
1 file changed, 10 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
||||
index 3040b2aa..f1e9c180 100644
|
||||
--- a/ggml/src/ggml-backend-reg.cpp
|
||||
+++ b/ggml/src/ggml-backend-reg.cpp
|
||||
@@ -581,8 +581,16 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||
|
||||
ggml_backend_load_best("blas", silent, dir_path);
|
||||
ggml_backend_load_best("cann", silent, dir_path);
|
||||
- ggml_backend_load_best("cuda", silent, dir_path);
|
||||
- ggml_backend_load_best("hip", silent, dir_path);
|
||||
+
|
||||
+ // Avoid mixed hip+cuda configurations
|
||||
+ const char * hip_devices = std::getenv("HIP_VISIBLE_DEVICES");
|
||||
+ const char * rocr_devices = std::getenv("ROCR_VISIBLE_DEVICES");
|
||||
+ if (!hip_devices && !rocr_devices) {
|
||||
+ ggml_backend_load_best("cuda", silent, dir_path);
|
||||
+ } else {
|
||||
+ ggml_backend_load_best("hip", silent, dir_path);
|
||||
+ }
|
||||
+
|
||||
ggml_backend_load_best("metal", silent, dir_path);
|
||||
ggml_backend_load_best("rpc", silent, dir_path);
|
||||
ggml_backend_load_best("sycl", silent, dir_path);
|
||||
|
|
@ -13,7 +13,7 @@ checks.
|
|||
1 file changed, 18 insertions(+)
|
||||
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 57eae461..9db0c8b5 100644
|
||||
index 57eae461..c7f9dc3a 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -2671,12 +2671,24 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Yang <git@mxy.ng>
|
||||
Date: Mon, 18 Aug 2025 16:58:39 -0700
|
||||
Subject: [PATCH] decode: disable output_all
|
||||
|
||||
---
|
||||
src/llama-context.cpp | 3 +--
|
||||
1 file changed, 1 insertion(+), 2 deletions(-)
|
||||
|
||||
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
|
||||
index 26a5cf9c..6ece5263 100644
|
||||
--- a/src/llama-context.cpp
|
||||
+++ b/src/llama-context.cpp
|
||||
@@ -962,8 +962,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||
const int64_t n_vocab = vocab.n_tokens();
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
|
||||
- // when computing embeddings, all tokens are output
|
||||
- const bool output_all = cparams.embeddings;
|
||||
+ const bool output_all = false;
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Jesse Gross <jesse@ollama.com>
|
||||
Date: Wed, 27 Aug 2025 14:39:48 -0700
|
||||
Subject: [PATCH] ggml: Enable resetting backend devices
|
||||
|
||||
Touching a CUDA device causes the allocation of a primary context
|
||||
with CUDA data structures (~300 MB of VRAM). If a device is
|
||||
unused then it can be reset to free these data structures.
|
||||
---
|
||||
ggml/include/ggml-backend.h | 1 +
|
||||
ggml/src/ggml-backend-impl.h | 4 ++++
|
||||
ggml/src/ggml-backend.cpp | 8 ++++++++
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 17 +++++++++++++++--
|
||||
ggml/src/ggml-cuda/vendors/hip.h | 1 +
|
||||
5 files changed, 29 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
|
||||
index b602a7c78..fda5ceb24 100644
|
||||
--- a/ggml/include/ggml-backend.h
|
||||
+++ b/ggml/include/ggml-backend.h
|
||||
@@ -167,6 +167,7 @@ extern "C" {
|
||||
GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
|
||||
GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
|
||||
+ GGML_API void ggml_backend_dev_reset(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
|
||||
diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
|
||||
index 81749a5a3..6f10c353b 100644
|
||||
--- a/ggml/src/ggml-backend-impl.h
|
||||
+++ b/ggml/src/ggml-backend-impl.h
|
||||
@@ -178,6 +178,10 @@ extern "C" {
|
||||
ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev);
|
||||
void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event);
|
||||
void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event);
|
||||
+
|
||||
+ // (optional) reset device, clearing existing allocations and context
|
||||
+ // the caller must ensure that there are no outstanding buffers, as these will become invalid
|
||||
+ void (*reset)(ggml_backend_dev_t dev);
|
||||
};
|
||||
|
||||
struct ggml_backend_device {
|
||||
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
|
||||
index 05a842ed5..6556943b0 100644
|
||||
--- a/ggml/src/ggml-backend.cpp
|
||||
+++ b/ggml/src/ggml-backend.cpp
|
||||
@@ -477,6 +477,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
|
||||
return device->iface.init_backend(device, params);
|
||||
}
|
||||
|
||||
+void ggml_backend_dev_reset(ggml_backend_dev_t device) {
|
||||
+ if (device->iface.reset == NULL) {
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
+ device->iface.reset(device);
|
||||
+}
|
||||
+
|
||||
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
|
||||
return device->iface.get_buffer_type(device);
|
||||
}
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index c7f9dc3a5..e43fde523 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -103,6 +103,11 @@ int ggml_cuda_get_device() {
|
||||
return id;
|
||||
}
|
||||
|
||||
+void ggml_cuda_reset_device(int device) {
|
||||
+ ggml_cuda_set_device(device);
|
||||
+ CUDA_CHECK(cudaDeviceReset());
|
||||
+}
|
||||
+
|
||||
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
||||
ggml_cuda_set_device(device);
|
||||
cudaError_t err;
|
||||
@@ -3243,7 +3248,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
|
||||
props->description = ggml_backend_cuda_device_get_description(dev);
|
||||
props->id = ggml_backend_cuda_device_get_id(dev);
|
||||
props->type = ggml_backend_cuda_device_get_type(dev);
|
||||
- ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
+
|
||||
+ // Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device).
|
||||
+ // If you need the memory data, call ggml_backend_dev_memory() explicitly.
|
||||
+ props->memory_total = props->memory_free = 0;
|
||||
|
||||
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
|
||||
#ifdef GGML_CUDA_NO_PEER_COPY
|
||||
@@ -3700,6 +3708,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
|
||||
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
|
||||
}
|
||||
|
||||
+static void ggml_backend_cuda_device_reset(ggml_backend_dev_t dev) {
|
||||
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||
+ ggml_cuda_reset_device(ctx->device);
|
||||
+}
|
||||
+
|
||||
static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
||||
/* .get_name = */ ggml_backend_cuda_device_get_name,
|
||||
/* .get_description = */ ggml_backend_cuda_device_get_description,
|
||||
@@ -3716,6 +3729,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
||||
/* .event_new = */ ggml_backend_cuda_device_event_new,
|
||||
/* .event_free = */ ggml_backend_cuda_device_event_free,
|
||||
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
|
||||
+ /* .reset = */ ggml_backend_cuda_device_reset,
|
||||
};
|
||||
|
||||
// backend reg
|
||||
@@ -3835,7 +3849,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
dev_ctx->device = i;
|
||||
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
|
||||
|
||||
- ggml_cuda_set_device(i);
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||
dev_ctx->description = prop.name;
|
||||
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
index c31f31923..cf22e60d2 100644
|
||||
--- a/ggml/src/ggml-cuda/vendors/hip.h
|
||||
+++ b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
@@ -40,6 +40,7 @@
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
#define cudaDeviceProp hipDeviceProp_t
|
||||
+#define cudaDeviceReset hipDeviceReset
|
||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||
#define cudaError_t hipError_t
|
||||
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Daniel Hiltgen <daniel@ollama.com>
|
||||
Date: Fri, 29 Aug 2025 16:53:08 -0700
|
||||
Subject: [PATCH] harden uncaught exception registration
|
||||
|
||||
---
|
||||
ggml/src/ggml.cpp | 8 ++++++--
|
||||
1 file changed, 6 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml.cpp b/ggml/src/ggml.cpp
|
||||
index 0d388d45..f5bcb446 100644
|
||||
--- a/ggml/src/ggml.cpp
|
||||
+++ b/ggml/src/ggml.cpp
|
||||
@@ -19,8 +19,12 @@ static bool ggml_uncaught_exception_init = []{
|
||||
return false;
|
||||
}
|
||||
const auto prev{std::get_terminate()};
|
||||
- GGML_ASSERT(prev != ggml_uncaught_exception);
|
||||
- previous_terminate_handler = prev;
|
||||
+ // GGML_ASSERT(prev != ggml_uncaught_exception);
|
||||
+ if (prev != ggml_uncaught_exception) {
|
||||
+ previous_terminate_handler = prev;
|
||||
+ } else {
|
||||
+ GGML_LOG_WARN("%s double registration of ggml_uncaught_exception\n", __func__);
|
||||
+ }
|
||||
std::set_terminate(ggml_uncaught_exception);
|
||||
return true;
|
||||
}();
|
||||
110
llm/memory.go
110
llm/memory.go
|
|
@ -4,7 +4,7 @@ import (
|
|||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strconv"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
|
|
@ -14,13 +14,79 @@ import (
|
|||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
// pickBestFullFitByLibrary will try to find the optimal placement of the model in the available GPUs where the model fully fits
|
||||
// The list of GPUs returned will always be the same brand (library)
|
||||
// If the model can not be fit fully within the available GPU(s) nil is returned
|
||||
func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []string, adapters []string, opts api.Options, gpus discover.GpuInfoList, numParallel int) discover.GpuInfoList {
|
||||
for _, gl := range gpus.ByLibrary() {
|
||||
sgl := append(make(discover.GpuInfoList, 0, len(gl)), gl...)
|
||||
|
||||
// TODO - potentially sort by performance capability, existing models loaded, etc.
|
||||
// TODO - Eliminate any GPUs that already have envconfig.MaxRunners loaded on them
|
||||
// Note: at present, this will favor most current available VRAM descending and ignoring faster GPU speed in mixed setups
|
||||
sort.Sort(sort.Reverse(discover.ByFreeMemory(sgl)))
|
||||
|
||||
if !envconfig.SchedSpread() {
|
||||
// Try to pack into as few GPUs as possible, starting from 1 GPU
|
||||
for numGPUs := 1; numGPUs <= len(sgl); numGPUs++ {
|
||||
gpuSubset := sgl[:numGPUs]
|
||||
ok, estimatedVRAM := predictServerFit(gpuSubset, f, adapters, projectors, opts, numParallel)
|
||||
|
||||
if ok {
|
||||
slog.Info("new model will fit in available VRAM across minimum required GPUs, loading",
|
||||
"model", modelPath,
|
||||
"library", sgl[0].Library,
|
||||
"parallel", numParallel,
|
||||
"required", format.HumanBytes2(estimatedVRAM),
|
||||
"gpus", numGPUs)
|
||||
return gpuSubset
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// TODO future refinements
|
||||
// - if multiple Libraries, see if any single GPU in any Library will fit
|
||||
// - try subsets of GPUs instead of just falling back to 1 or all in a family
|
||||
|
||||
// Now try all the GPUS (OLLAMA_SCHED_SPREAD is set)
|
||||
if ok, estimatedVRAM := predictServerFit(sgl, f, adapters, projectors, opts, numParallel); ok {
|
||||
slog.Info("new model will fit in available VRAM, loading",
|
||||
"model", modelPath,
|
||||
"library", sgl[0].Library,
|
||||
"parallel", numParallel,
|
||||
"required", format.HumanBytes2(estimatedVRAM),
|
||||
"gpus", len(sgl))
|
||||
return sgl
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If multiple Libraries are detected, pick the Library which loads the most layers for the model
|
||||
func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []string, opts api.Options, gpus discover.GpuInfoList, numParallel int) discover.GpuInfoList {
|
||||
byLibrary := gpus.ByLibrary()
|
||||
if len(byLibrary) <= 1 {
|
||||
return gpus
|
||||
}
|
||||
var bestEstimate uint64
|
||||
var bestFit int
|
||||
for i, gl := range byLibrary {
|
||||
_, estimatedVRAM := predictServerFit(gl, f, adapters, projectors, opts, numParallel)
|
||||
if estimatedVRAM > bestEstimate {
|
||||
bestEstimate = estimatedVRAM
|
||||
bestFit = i
|
||||
}
|
||||
}
|
||||
return byLibrary[bestFit]
|
||||
}
|
||||
|
||||
// This algorithm looks for a complete fit to determine if we need to unload other models
|
||||
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
|
||||
func predictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
|
||||
// Split up the GPUs by type and try them
|
||||
var estimatedVRAM uint64
|
||||
for _, gpus := range allGpus.ByLibrary() {
|
||||
var layerCount int
|
||||
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||
estimate := estimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
|
||||
if opts.NumGPU < 0 {
|
||||
if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) {
|
||||
|
|
@ -31,6 +97,10 @@ func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, proj
|
|||
return true, estimatedVRAM
|
||||
}
|
||||
}
|
||||
|
||||
if len(gpus) == 1 && gpus[0].Library == "cpu" && estimate.TotalSize <= gpus[0].FreeMemory {
|
||||
return true, estimatedVRAM
|
||||
}
|
||||
}
|
||||
return false, estimatedVRAM
|
||||
}
|
||||
|
|
@ -49,7 +119,7 @@ type MemoryEstimate struct {
|
|||
TotalSize uint64
|
||||
|
||||
// For multi-GPU scenarios, this provides the tensor split parameter
|
||||
TensorSplit string
|
||||
TensorSplit []int
|
||||
|
||||
// For multi-GPU scenarios, this is the size in bytes per GPU
|
||||
GPUSizes []uint64
|
||||
|
|
@ -71,7 +141,7 @@ type MemoryEstimate struct {
|
|||
|
||||
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
|
||||
// The GPUs provided must all be the same Library
|
||||
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate {
|
||||
func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate {
|
||||
// Graph size for a partial offload, applies to all GPUs
|
||||
var graphPartialOffload uint64
|
||||
|
||||
|
|
@ -112,13 +182,9 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||
|
||||
for _, projector := range projectors {
|
||||
llamaEngineProjectorWeights += projectorMemoryRequirements(projector)
|
||||
|
||||
// multimodal models require at least 2048 context
|
||||
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||
}
|
||||
if llamaEngineProjectorWeights == 0 {
|
||||
ollamaEngineProjectorWeights, ollamaEngineProjectorGraph = f.VisionGraphSize()
|
||||
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||
}
|
||||
|
||||
layers := f.Tensors().GroupLayers()
|
||||
|
|
@ -129,17 +195,19 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||
slog.Warn("model missing blk.0 layer size")
|
||||
}
|
||||
|
||||
var kvct string
|
||||
if envconfig.FlashAttention() &&
|
||||
useFlashAttention := (envconfig.FlashAttention() || f.FlashAttention()) &&
|
||||
discover.GetGPUInfo().FlashAttentionSupported() &&
|
||||
f.SupportsFlashAttention() {
|
||||
f.SupportsFlashAttention()
|
||||
|
||||
var kvct string
|
||||
if useFlashAttention {
|
||||
requested := strings.ToLower(envconfig.KvCacheType())
|
||||
if requested != "" && f.SupportsKVCacheType(requested) {
|
||||
kvct = requested
|
||||
}
|
||||
}
|
||||
|
||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct)
|
||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct, useFlashAttention)
|
||||
|
||||
if len(kv) > 0 {
|
||||
layerSize += kv[0]
|
||||
|
|
@ -184,7 +252,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||
|
||||
// Reduce set of GPUs to only those that have sufficient space to fit overhead and at least one layer
|
||||
var layerCount int
|
||||
layerCounts := make([]int, len(gpus))
|
||||
tensorSplit := make([]int, len(gpus))
|
||||
gpuAllocations := make([]uint64, len(gpus))
|
||||
type gs struct {
|
||||
i int
|
||||
|
|
@ -248,7 +316,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||
used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
|
||||
if g.g.FreeMemory > overhead+used+layerSize {
|
||||
gpuAllocations[g.i] += layerSize
|
||||
layerCounts[g.i]++
|
||||
tensorSplit[g.i]++
|
||||
layerCount++
|
||||
break
|
||||
} else {
|
||||
|
|
@ -273,7 +341,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||
used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
|
||||
if g.g.FreeMemory > overhead+used+memoryLastLayer {
|
||||
gpuAllocations[g.i] += memoryLastLayer
|
||||
layerCounts[g.i]++
|
||||
tensorSplit[g.i]++
|
||||
layerCount++
|
||||
break
|
||||
}
|
||||
|
|
@ -288,7 +356,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||
|
||||
// Add the applicable (full or partial) graph allocations
|
||||
for i := range gpus {
|
||||
if layerCounts[i] <= 0 {
|
||||
if tensorSplit[i] <= 0 {
|
||||
continue
|
||||
}
|
||||
if fullyLoaded {
|
||||
|
|
@ -310,14 +378,6 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
|||
}
|
||||
memoryRequiredTotal = memoryRequiredPartial + overflow
|
||||
|
||||
tensorSplit := ""
|
||||
if len(gpus) > 1 {
|
||||
splits := make([]string, len(gpus))
|
||||
for i, count := range layerCounts {
|
||||
splits[i] = strconv.Itoa(count)
|
||||
}
|
||||
tensorSplit = strings.Join(splits, ",")
|
||||
}
|
||||
allocationsList := []string{}
|
||||
for _, a := range gpuAllocations {
|
||||
allocationsList = append(allocationsList, format.HumanBytes2(a))
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ func TestEstimateGPULayers(t *testing.T) {
|
|||
projectors := []string{}
|
||||
opts := api.DefaultOptions()
|
||||
t.Run("cpu", func(t *testing.T) {
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
|
||||
estimate := estimateGPULayers(gpus, ggml, projectors, opts, 1)
|
||||
assert.Equal(t, 0, estimate.Layers)
|
||||
assert.Equal(t, uint64(0), estimate.Graph)
|
||||
})
|
||||
|
|
@ -88,7 +88,7 @@ func TestEstimateGPULayers(t *testing.T) {
|
|||
// Nested array: GPU0 layer space, GPU1 layer space, expected gpu0, expected gpu1
|
||||
for i, s := range []struct {
|
||||
layer0, layer1 uint64
|
||||
expect0, expect1 uint64
|
||||
expect0, expect1 int
|
||||
}{
|
||||
{1, 1, 1, 1},
|
||||
{2, 1, 2, 1},
|
||||
|
|
@ -112,9 +112,9 @@ func TestEstimateGPULayers(t *testing.T) {
|
|||
gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
|
||||
gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
||||
gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
|
||||
assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s)
|
||||
assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s)
|
||||
estimate := estimateGPULayers(gpus, ggml, projectors, opts, 1)
|
||||
assert.Equal(t, s.expect0+s.expect1, estimate.Layers, "scenario %d: %v", i, s)
|
||||
assert.Equal(t, []int{s.expect0, s.expect1}, estimate.TensorSplit, "scenario %d: %v", i, s)
|
||||
var layerSums uint64
|
||||
for _, b := range estimate.GPUSizes {
|
||||
layerSums += b
|
||||
|
|
|
|||
1165
llm/server.go
1165
llm/server.go
File diff suppressed because it is too large
Load Diff
|
|
@ -8,9 +8,178 @@ import (
|
|||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"golang.org/x/sync/semaphore"
|
||||
)
|
||||
|
||||
func TestLLMServerFitGPU(t *testing.T) {
|
||||
type gpu struct {
|
||||
library string
|
||||
free int
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
gpus []gpu
|
||||
layers []int
|
||||
numGPU int
|
||||
requireFull bool
|
||||
expected ml.GPULayersList
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "No GPU",
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{},
|
||||
},
|
||||
{
|
||||
name: "Full single GPU",
|
||||
gpus: []gpu{{free: 256 * format.MebiByte}},
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{ID: "gpu0", Layers: []int{0, 1, 2}}},
|
||||
},
|
||||
{
|
||||
name: "Partial single GPU",
|
||||
gpus: []gpu{{free: 256 * format.MebiByte}},
|
||||
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{ID: "gpu0", Layers: []int{1, 2}}},
|
||||
},
|
||||
{
|
||||
name: "Single GPU with numGPU 1",
|
||||
gpus: []gpu{{free: 256 * format.MebiByte}},
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: 1,
|
||||
expected: ml.GPULayersList{{ID: "gpu0", Layers: []int{1}}},
|
||||
},
|
||||
{
|
||||
name: "Single GPU with numGPU 0",
|
||||
gpus: []gpu{{free: 256 * format.MebiByte}},
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: 0,
|
||||
expected: ml.GPULayersList{},
|
||||
},
|
||||
{
|
||||
name: "Single GPU with numGPU 999",
|
||||
gpus: []gpu{{free: 256 * format.MebiByte}},
|
||||
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||
numGPU: 999,
|
||||
expected: ml.GPULayersList{{ID: "gpu0", Layers: []int{0, 1, 2, 3}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU fits on one",
|
||||
gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}},
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0, 1, 2}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU split",
|
||||
gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}},
|
||||
layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0}}, {ID: "gpu0", Layers: []int{1, 2}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU partial",
|
||||
gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}},
|
||||
layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{1}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU numGPU 1",
|
||||
gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}},
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: 1,
|
||||
expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{1}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU numGPU 2",
|
||||
gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}},
|
||||
layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: 2,
|
||||
expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0}}, {ID: "gpu0", Layers: []int{1}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU numGPU 999",
|
||||
gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}},
|
||||
layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: 999,
|
||||
expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0, 1}}, {ID: "gpu0", Layers: []int{2}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU different libraries",
|
||||
gpus: []gpu{{library: "cuda", free: 128 * format.MebiByte}, {library: "rocm", free: 256 * format.MebiByte}},
|
||||
layers: []int{128 * format.MebiByte, 128 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0, 1}}},
|
||||
},
|
||||
{
|
||||
name: "requireFull",
|
||||
gpus: []gpu{{free: 256 * format.MebiByte}},
|
||||
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
requireFull: true,
|
||||
expectedErr: ErrLoadRequiredFull,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var systemInfo discover.SystemInfo
|
||||
systemInfo.System.TotalMemory = format.GibiByte
|
||||
systemInfo.System.FreeMemory = 512 * format.MebiByte
|
||||
systemInfo.System.FreeSwap = 256 * format.MebiByte
|
||||
|
||||
gpus := make(discover.GpuInfoList, len(tt.gpus))
|
||||
for i := range tt.gpus {
|
||||
gpus[i].ID = fmt.Sprintf("gpu%d", i)
|
||||
gpus[i].Library = tt.gpus[i].library
|
||||
gpus[i].FreeMemory = uint64(tt.gpus[i].free)
|
||||
}
|
||||
|
||||
s := &ollamaServer{
|
||||
llmServer: llmServer{
|
||||
totalLayers: uint64(len(tt.layers)),
|
||||
options: api.Options{
|
||||
Runner: api.Runner{
|
||||
NumGPU: tt.numGPU,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
s.mem = &ml.BackendMemory{CPU: ml.DeviceMemory{
|
||||
Weights: make([]ml.Memory, s.totalLayers),
|
||||
Cache: make([]ml.Memory, s.totalLayers),
|
||||
}, GPUs: make([]ml.DeviceMemory, len(gpus))}
|
||||
|
||||
for i := range tt.layers {
|
||||
s.mem.CPU.Weights[i].Size = uint64(tt.layers[i])
|
||||
}
|
||||
|
||||
for i := range s.mem.GPUs {
|
||||
s.mem.GPUs[i].ID = fmt.Sprintf("gpu%d", i)
|
||||
s.mem.GPUs[i].Weights = make([]ml.Memory, s.totalLayers)
|
||||
s.mem.GPUs[i].Cache = make([]ml.Memory, s.totalLayers)
|
||||
}
|
||||
|
||||
gpuLayers, err := s.createLayout(systemInfo, gpus, s.mem, tt.requireFull, 0)
|
||||
if err != tt.expectedErr {
|
||||
t.Fatalf("fitGPU returned error: %v", err)
|
||||
}
|
||||
if gpuLayers.Hash() != tt.expected.Hash() {
|
||||
t.Errorf("fitGPU assigned %v, want %v", gpuLayers, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLLMServerCompletionFormat(t *testing.T) {
|
||||
// This test was written to fix an already deployed issue. It is a bit
|
||||
// of a mess, and but it's good enough, until we can refactoring the
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package logutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
|
|
@ -27,3 +28,11 @@ func NewLogger(w io.Writer, level slog.Level) *slog.Logger {
|
|||
},
|
||||
}))
|
||||
}
|
||||
|
||||
func Trace(msg string, args ...any) {
|
||||
slog.Log(context.TODO(), LevelTrace, msg, args...)
|
||||
}
|
||||
|
||||
func TraceContext(ctx context.Context, msg string, args ...any) {
|
||||
slog.Log(ctx, LevelTrace, msg, args...)
|
||||
}
|
||||
|
|
|
|||
168
ml/backend.go
168
ml/backend.go
|
|
@ -5,12 +5,14 @@ import (
|
|||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"log/slog"
|
||||
"math"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs"
|
||||
)
|
||||
|
||||
|
|
@ -58,19 +60,89 @@ type CacheConfig struct {
|
|||
MaskBatchPadding int
|
||||
}
|
||||
|
||||
// GPULayers is a set of layers to be allocated on a single GPU
|
||||
type GPULayers struct {
|
||||
// ID is the identifier of the GPU, as reported in DeviceMemory
|
||||
ID string
|
||||
|
||||
// Layers is a set of layer indicies to load
|
||||
Layers []int
|
||||
}
|
||||
|
||||
func (g GPULayers) String() string {
|
||||
if len(g.Layers) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
slices.Sort(g.Layers)
|
||||
|
||||
contiguous := true
|
||||
base := g.Layers[0]
|
||||
for i := range g.Layers {
|
||||
if g.Layers[i] != base+i {
|
||||
contiguous = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if contiguous {
|
||||
return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1])
|
||||
} else {
|
||||
return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers)
|
||||
}
|
||||
}
|
||||
|
||||
// GPULayersList is a set of layer allocations across multiple GPUs
|
||||
type GPULayersList []GPULayers
|
||||
|
||||
func (l GPULayersList) String() string {
|
||||
if l.Sum() > 0 {
|
||||
return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l))
|
||||
} else {
|
||||
return fmt.Sprintf("%v", []GPULayers(l))
|
||||
}
|
||||
}
|
||||
|
||||
// Sum is the total number of layers assigned across all GPUs
|
||||
func (l GPULayersList) Sum() int {
|
||||
var sum int
|
||||
|
||||
for _, g := range l {
|
||||
sum += len(g.Layers)
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
var h maphash.Hash
|
||||
|
||||
// Hash is an identifier of this layer assignment
|
||||
func (l GPULayersList) Hash() uint64 {
|
||||
h.Reset()
|
||||
for _, g := range l {
|
||||
if len(g.Layers) > 0 {
|
||||
h.WriteString(g.ID)
|
||||
for _, l := range g.Layers {
|
||||
binary.Write(&h, binary.NativeEndian, int64(l))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
// BackendParams controls how the backend loads and executes models
|
||||
type BackendParams struct {
|
||||
// AllocMemory causes the backend to allocate memory for the model. If
|
||||
// false, this is only being used for discovering the required amount of
|
||||
// memory and cannot load the model for running.
|
||||
AllocMemory bool
|
||||
|
||||
// NumThreads sets the number of threads to use if running on the CPU
|
||||
NumThreads int
|
||||
|
||||
// MainGPU is the index of the primary GPU to use
|
||||
MainGPU int
|
||||
|
||||
// NumGPULayers is the number of layers to offload to GPUs
|
||||
NumGPULayers int
|
||||
|
||||
// TensorSplit is the fraction of the model to offload to each GPU
|
||||
TensorSplit []float32
|
||||
// GPULayers is the set of layers to offload to GPUs
|
||||
GPULayers GPULayersList
|
||||
|
||||
// FlashAttention indicates that we should use a fused flash attention kernel
|
||||
FlashAttention bool
|
||||
|
|
@ -141,6 +213,28 @@ type DeviceMemory struct {
|
|||
Graph Memory
|
||||
}
|
||||
|
||||
// Allocated returns the total size of the memory that has been successfully
|
||||
// allocated on this device
|
||||
func (m DeviceMemory) Allocated() uint64 {
|
||||
var mem uint64
|
||||
|
||||
for _, w := range m.Weights {
|
||||
if w.Status == Allocated {
|
||||
mem += w.Size
|
||||
}
|
||||
}
|
||||
for _, c := range m.Cache {
|
||||
if c.Status == Allocated {
|
||||
mem += c.Size
|
||||
}
|
||||
}
|
||||
if m.Graph.Status == Allocated {
|
||||
mem += m.Graph.Size
|
||||
}
|
||||
|
||||
return mem
|
||||
}
|
||||
|
||||
func memoryPresent(mem []Memory) bool {
|
||||
return slices.ContainsFunc(mem, func(m Memory) bool { return m.Size != 0 })
|
||||
}
|
||||
|
|
@ -172,7 +266,7 @@ func (m DeviceMemory) LogValue() slog.Value {
|
|||
// allocation is guaranteed to be provided so that if it failed, the caller can
|
||||
// accommodate that to make forward progress.
|
||||
type BackendMemory struct {
|
||||
// InputsWeights are always located on the CPU and cannot be moved
|
||||
// InputWeights are always located on the CPU and cannot be moved
|
||||
InputWeights Memory
|
||||
|
||||
// CPU model components are located in system memory. This does not
|
||||
|
|
@ -197,6 +291,58 @@ func (m BackendMemory) LogValue() slog.Value {
|
|||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
func sumMemory(mem []Memory) uint64 {
|
||||
var sum uint64
|
||||
|
||||
for _, m := range mem {
|
||||
sum += m.Size
|
||||
}
|
||||
|
||||
return sum
|
||||
}
|
||||
|
||||
// Log prints a high level summary of the memory (allocated or not)
|
||||
func (m BackendMemory) Log(level slog.Level) {
|
||||
var total uint64
|
||||
|
||||
for _, gpu := range m.GPUs {
|
||||
if sum := sumMemory(gpu.Weights); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
}
|
||||
if sum := m.InputWeights.Size + sumMemory(m.CPU.Weights); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
|
||||
for _, gpu := range m.GPUs {
|
||||
if sum := sumMemory(gpu.Cache); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
}
|
||||
if sum := sumMemory(m.CPU.Cache); sum > 0 {
|
||||
slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
|
||||
for _, gpu := range m.GPUs {
|
||||
if sum := gpu.Graph.Size; sum > 0 {
|
||||
slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
}
|
||||
if sum := m.CPU.Graph.Size; sum > 0 {
|
||||
slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum))
|
||||
total += sum
|
||||
}
|
||||
|
||||
if total > 0 {
|
||||
slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total))
|
||||
}
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(string, BackendParams) (Backend, error))
|
||||
|
||||
func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) {
|
||||
|
|
@ -226,6 +372,7 @@ type Context interface {
|
|||
|
||||
Forward(...Tensor) Context
|
||||
Compute(...Tensor)
|
||||
ComputeWithNotify(func(), ...Tensor) // notify callback once compute has begun
|
||||
|
||||
// Reserve is analogous to Compute but rather than executing a
|
||||
// graph, simply preallocates memory. Typically called with a
|
||||
|
|
@ -250,10 +397,13 @@ type Tensor interface {
|
|||
|
||||
Shape() []int
|
||||
DType() DType
|
||||
Cast(ctx Context, dtype DType) Tensor
|
||||
|
||||
Bytes() []byte
|
||||
Floats() []float32
|
||||
|
||||
SetValueFromIntSlice(s []int32)
|
||||
|
||||
Neg(ctx Context) Tensor
|
||||
Add(ctx Context, t2 Tensor) Tensor
|
||||
Sub(ctx Context, t2 Tensor) Tensor
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import "C"
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
|
|
@ -62,27 +63,40 @@ var initDevices = sync.OnceFunc(func() {
|
|||
}
|
||||
})
|
||||
|
||||
type layerDevice struct {
|
||||
d C.ggml_backend_dev_t
|
||||
bt C.ggml_backend_buffer_type_t
|
||||
}
|
||||
|
||||
type Backend struct {
|
||||
// modelPath is the location of the model data
|
||||
modelPath string
|
||||
|
||||
meta *fsggml.GGML
|
||||
|
||||
// allocMemory means that memory should be allocated for tensors and not
|
||||
// just a dry run
|
||||
allocMemory bool
|
||||
|
||||
// tensorLoadTargets maps from the name of the tensor in the file
|
||||
// to the name that is used by the model definition
|
||||
tensorLoadTargets map[string][]string
|
||||
|
||||
schedMu sync.Mutex // Only one Compute can run at a time
|
||||
sched C.ggml_backend_sched_t
|
||||
schedBackends []C.ggml_backend_t
|
||||
schedBufts []C.ggml_backend_buffer_type_t
|
||||
|
||||
tensors map[string]*C.struct_ggml_tensor
|
||||
|
||||
// input is the backend used for inputs
|
||||
// input is the backend buffer type used for inputs
|
||||
input C.ggml_backend_buffer_type_t
|
||||
|
||||
// output is the backend device used for outputs
|
||||
output C.ggml_backend_dev_t
|
||||
|
||||
// layers is the backend used for repeating layers
|
||||
layers map[int]C.ggml_backend_buffer_type_t
|
||||
layers map[int]layerDevice
|
||||
|
||||
// requiredMemory is the cumulative memory allocations needed by the backend
|
||||
requiredMemory *ml.BackendMemory
|
||||
|
|
@ -99,6 +113,8 @@ type Backend struct {
|
|||
weightBuffers map[*C.struct_ggml_context]C.ggml_backend_buffer_t
|
||||
}
|
||||
|
||||
var once sync.Once
|
||||
|
||||
func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
||||
r, err := os.Open(modelPath)
|
||||
if err != nil {
|
||||
|
|
@ -111,15 +127,17 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
slog.Info(
|
||||
"",
|
||||
"architecture", meta.KV().Architecture(),
|
||||
"file_type", meta.KV().FileType(),
|
||||
"name", meta.KV().String("general.name"),
|
||||
"description", meta.KV().String("general.description"),
|
||||
"num_tensors", len(meta.Tensors().Items()),
|
||||
"num_key_values", len(meta.KV()),
|
||||
)
|
||||
once.Do(func() {
|
||||
slog.Info(
|
||||
"",
|
||||
"architecture", meta.KV().Architecture(),
|
||||
"file_type", meta.KV().FileType(),
|
||||
"name", meta.KV().String("general.name"),
|
||||
"description", meta.KV().String("general.description"),
|
||||
"num_tensors", len(meta.Tensors().Items()),
|
||||
"num_key_values", len(meta.KV()),
|
||||
)
|
||||
})
|
||||
|
||||
initDevices()
|
||||
|
||||
|
|
@ -139,7 +157,10 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
|||
switch C.ggml_backend_dev_type(d) {
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_CPU,
|
||||
C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
|
||||
cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d))
|
||||
bt := C.ggml_backend_dev_buffer_type(d)
|
||||
cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, bt)
|
||||
C.ggml_backend_buft_set_alloc(bt, C.bool(params.AllocMemory))
|
||||
|
||||
btDeviceMemory[C.ggml_backend_dev_buffer_type(d)] = &requiredMemory.CPU
|
||||
}
|
||||
}
|
||||
|
|
@ -160,6 +181,8 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
|||
d: d,
|
||||
bts: append([]C.ggml_backend_buffer_type_t{bt}, cpuDeviceBufferType.bts...),
|
||||
})
|
||||
C.ggml_backend_buft_set_alloc(bt, C.bool(params.AllocMemory))
|
||||
|
||||
btDeviceMemory[bt] = &requiredMemory.GPUs[i]
|
||||
requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d))
|
||||
var props C.struct_ggml_backend_dev_props
|
||||
|
|
@ -169,56 +192,25 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
|||
requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1)
|
||||
}
|
||||
|
||||
useDefaultSplit := true
|
||||
for _, s := range params.TensorSplit {
|
||||
if s != 0 {
|
||||
useDefaultSplit = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// calculate splits
|
||||
splits := make([]float32, len(gpus))
|
||||
if useDefaultSplit {
|
||||
// default: split on free memory
|
||||
for i := range splits {
|
||||
var free, total C.size_t
|
||||
C.ggml_backend_dev_memory(gpus[i], &free, &total)
|
||||
splits[i] = float32(free)
|
||||
}
|
||||
} else {
|
||||
splits = params.TensorSplit
|
||||
}
|
||||
|
||||
var sum float32
|
||||
// cumulative sum of all splits
|
||||
for i := range splits {
|
||||
sum += splits[i]
|
||||
splits[i] = sum
|
||||
}
|
||||
|
||||
// normalize splits
|
||||
for i := range splits {
|
||||
splits[i] /= sum
|
||||
}
|
||||
|
||||
// inputs always use cpu
|
||||
input := cpuDeviceBufferType
|
||||
|
||||
// define a range of gpu layers. anything outside of this range is assigned to the cpu
|
||||
gpuRangeStart := max(0, blocks-params.NumGPULayers)
|
||||
gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1)
|
||||
assignLayer := func(i int) deviceBufferType {
|
||||
if i < gpuRangeStart || i >= gpuRangeStop {
|
||||
return cpuDeviceBufferType
|
||||
assignLayer := func(layer int) deviceBufferType {
|
||||
for _, p := range params.GPULayers {
|
||||
for _, l := range p.Layers {
|
||||
if l == layer {
|
||||
for i := range requiredMemory.GPUs {
|
||||
if requiredMemory.GPUs[i].ID == p.ID {
|
||||
return gpuDeviceBufferTypes[i]
|
||||
}
|
||||
}
|
||||
|
||||
return cpuDeviceBufferType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
index := slices.IndexFunc(splits, func(f float32) bool { return float32(i-gpuRangeStart)/float32(gpuRangeStop-gpuRangeStart) < f })
|
||||
if index < 0 || index >= len(gpuDeviceBufferTypes) {
|
||||
return cpuDeviceBufferType
|
||||
}
|
||||
|
||||
return gpuDeviceBufferTypes[index]
|
||||
return cpuDeviceBufferType
|
||||
}
|
||||
|
||||
// repeating layers are assigned based on their index in reverse order, e.g. i / (block_count + 1)
|
||||
|
|
@ -279,12 +271,14 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
|||
tt := C.ggml_new_tensor(ctxs[bt], kind, C.int(len(t.source.Shape)), (*C.int64_t)(unsafe.Pointer(&t.source.Shape[0])))
|
||||
C.ggml_set_name(tt, cname)
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
|
||||
logutil.Trace("created tensor", "name", name, "shape", t.source.Shape, "dtype", t.source.Kind, "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
|
||||
|
||||
size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt))
|
||||
if layer == -1 {
|
||||
// Assume that InputWeights can be allocated - they're always in system memory and can't be moved in any case
|
||||
requiredMemory.InputWeights.Status = ml.Allocated
|
||||
if params.AllocMemory {
|
||||
requiredMemory.InputWeights.Status = ml.Allocated
|
||||
}
|
||||
requiredMemory.InputWeights.Size += uint64(size)
|
||||
} else {
|
||||
btDeviceMemory[bt].Weights[layer].Size += uint64(size)
|
||||
|
|
@ -355,12 +349,14 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
|||
}
|
||||
|
||||
b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
|
||||
for i := range btDeviceMemory[bt].Weights {
|
||||
if btDeviceMemory[bt].Weights[i].Size != 0 {
|
||||
if b != nil {
|
||||
btDeviceMemory[bt].Weights[i].Status = ml.Allocated
|
||||
} else {
|
||||
btDeviceMemory[bt].Weights[i].Status = ml.Failed
|
||||
if params.AllocMemory {
|
||||
for i := range btDeviceMemory[bt].Weights {
|
||||
if btDeviceMemory[bt].Weights[i].Size != 0 {
|
||||
if b != nil {
|
||||
btDeviceMemory[bt].Weights[i].Status = ml.Allocated
|
||||
} else {
|
||||
btDeviceMemory[bt].Weights[i].Status = ml.Failed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -381,28 +377,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
|||
bbs[c] = b
|
||||
}
|
||||
|
||||
// Mimic llama runner logs summarizing layers and memory
|
||||
gpuLayers := 0
|
||||
for _, layer := range layers {
|
||||
if C.ggml_backend_dev_type(layer.d) == C.GGML_BACKEND_DEVICE_TYPE_GPU {
|
||||
gpuLayers++
|
||||
}
|
||||
}
|
||||
slog.Info(fmt.Sprintf("offloading %d repeating layers to GPU", gpuLayers))
|
||||
|
||||
switch C.ggml_backend_dev_type(output.d) {
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
|
||||
slog.Info("offloading output layer to CPU")
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
|
||||
slog.Info("offloading output layer to GPU")
|
||||
gpuLayers++
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
|
||||
slog.Info("offloading output layer to ACCEL")
|
||||
}
|
||||
slog.Info(fmt.Sprintf("offloaded %d/%d layers to GPU", gpuLayers, len(layers)+1))
|
||||
|
||||
for bs := range maps.Values(bbs) {
|
||||
slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
|
||||
logutil.Trace("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)),
|
||||
"size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs))))
|
||||
}
|
||||
|
||||
// map tensor names to tensors for easy lookup later
|
||||
|
|
@ -423,6 +400,13 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
|||
b := backends[d]
|
||||
bt := C.ggml_backend_get_default_buffer_type(b)
|
||||
|
||||
// Always include CPU as a fallback but otherwise, just use the devices where we assigned layers
|
||||
if !slices.Contains(cpuDeviceBufferType.bts, bt) {
|
||||
if c, ok := ctxs[bt]; !ok || C.ggml_get_first_tensor(c) == nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
deviceBufferTypes[d] = bt
|
||||
|
||||
schedBackends = append(schedBackends, b)
|
||||
|
|
@ -437,6 +421,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
|||
maxGraphNodes := max(8192, len(meta.Tensors().Items())*5)
|
||||
return &Backend{
|
||||
modelPath: modelPath,
|
||||
allocMemory: params.AllocMemory,
|
||||
flashAttention: params.FlashAttention,
|
||||
meta: meta,
|
||||
tensorLoadTargets: targets,
|
||||
|
|
@ -452,10 +437,14 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) {
|
|||
schedBackends: schedBackends,
|
||||
schedBufts: schedBufts,
|
||||
input: deviceBufferTypes[input.d],
|
||||
layers: func() map[int]C.ggml_backend_buffer_type_t {
|
||||
m := make(map[int]C.ggml_backend_buffer_type_t)
|
||||
output: output.d,
|
||||
layers: func() map[int]layerDevice {
|
||||
m := make(map[int]layerDevice)
|
||||
for i, layer := range layers {
|
||||
m[i] = deviceBufferTypes[layer.d]
|
||||
m[i] = layerDevice{
|
||||
d: layer.d,
|
||||
bt: deviceBufferTypes[layer.d],
|
||||
}
|
||||
}
|
||||
return m
|
||||
}(),
|
||||
|
|
@ -484,6 +473,30 @@ func (b *Backend) Close() {
|
|||
}
|
||||
|
||||
func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
||||
if !b.allocMemory {
|
||||
return errors.New("cannot load model without memory allocation")
|
||||
}
|
||||
|
||||
// Mimic llama runner logs summarizing layers and memory
|
||||
gpuLayers := 0
|
||||
for layer := range maps.Values(b.layers) {
|
||||
if C.ggml_backend_dev_type(layer.d) == C.GGML_BACKEND_DEVICE_TYPE_GPU {
|
||||
gpuLayers++
|
||||
}
|
||||
}
|
||||
slog.Info(fmt.Sprintf("offloading %d repeating layers to GPU", gpuLayers))
|
||||
|
||||
switch C.ggml_backend_dev_type(b.output) {
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_CPU:
|
||||
slog.Info("offloading output layer to CPU")
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_GPU:
|
||||
slog.Info("offloading output layer to GPU")
|
||||
gpuLayers++
|
||||
case C.GGML_BACKEND_DEVICE_TYPE_ACCEL:
|
||||
slog.Info("offloading output layer to ACCEL")
|
||||
}
|
||||
slog.Info(fmt.Sprintf("offloaded %d/%d layers to GPU", gpuLayers, len(b.layers)+1))
|
||||
|
||||
var doneBytes atomic.Uint64
|
||||
totalBytes := uint64(b.meta.Length) - b.meta.Tensors().Offset
|
||||
|
||||
|
|
@ -523,6 +536,7 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
|||
const BS = 17 // MXFP4 block size
|
||||
bts := make([]byte, 8*BS*format.KibiByte) // ~128k block aligned
|
||||
var s uint64
|
||||
var tmp [16]byte
|
||||
for s < t.Size() {
|
||||
// Stop if either the parent context has been canceled or if any of the other tensors returned an error
|
||||
if err := ctx.Err(); err != nil {
|
||||
|
|
@ -534,37 +548,13 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
|||
return err
|
||||
}
|
||||
for j := range n / BS {
|
||||
for i := 1; i < BS; i++ {
|
||||
// swap nibbles
|
||||
t_lo := bts[j*BS+i] & 0x0F
|
||||
t_hi := bts[j*BS+i] & 0xF0
|
||||
bts[j*BS+i] = (t_lo << 4) | (t_hi >> 4)
|
||||
}
|
||||
// transform aaaa...bbbb... to abababab...
|
||||
oi := 0
|
||||
tmp := [16]byte{}
|
||||
for i := 1; i < 9; i++ {
|
||||
blk_a0 := bts[j*BS+i] & 0xF0
|
||||
blk_a1 := bts[j*BS+i] << 4
|
||||
blk_b0 := bts[j*BS+i+8] >> 4
|
||||
blk_b1 := bts[j*BS+i+8] & 0x0F
|
||||
// swap once more
|
||||
out0 := blk_a0 | blk_b0
|
||||
out1 := blk_a1 | blk_b1
|
||||
out_h0 := out0 & 0xF0
|
||||
out_l0 := out0 & 0x0F
|
||||
out_h1 := out1 & 0xF0
|
||||
out_l1 := out1 & 0x0F
|
||||
out0 = (out_h0 >> 4) | (out_l0 << 4)
|
||||
out1 = (out_h1 >> 4) | (out_l1 << 4)
|
||||
tmp[oi] = out0
|
||||
oi++
|
||||
tmp[oi] = out1
|
||||
oi++
|
||||
}
|
||||
for i := range tmp {
|
||||
bts[j*BS+i+1] = tmp[i]
|
||||
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
|
||||
a, b := bts[j*BS+i], bts[j*BS+i+8]
|
||||
tmp[2*(i-1)] = (a & 0x0F) | (b << 4)
|
||||
tmp[2*(i-1)+1] = (a >> 4) | (b & 0xF0)
|
||||
}
|
||||
copy(bts[j*BS+1:j*BS+17], tmp[:])
|
||||
}
|
||||
|
||||
for _, tt := range tts {
|
||||
|
|
@ -640,6 +630,18 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
|||
})
|
||||
}
|
||||
|
||||
// Cleanup any backend state from devices that we didn't end up using
|
||||
nextDevice:
|
||||
for _, d := range append(gpus, append(accels, cpus...)...) {
|
||||
for _, backend := range b.schedBackends {
|
||||
if d == C.ggml_backend_get_device(backend) {
|
||||
continue nextDevice
|
||||
}
|
||||
}
|
||||
|
||||
C.ggml_backend_dev_reset(d)
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -730,11 +732,11 @@ func (c *Context) Input() ml.Context {
|
|||
}
|
||||
|
||||
func (c *Context) Layer(i int) ml.Context {
|
||||
if buft, ok := c.b.layers[i]; ok {
|
||||
if layer, ok := c.b.layers[i]; ok {
|
||||
return &Context{
|
||||
b: c.b,
|
||||
ctx: c.ctx,
|
||||
buft: buft,
|
||||
buft: layer.bt,
|
||||
allocatedBuffers: c.allocatedBuffers,
|
||||
maxGraphNodes: c.maxGraphNodes,
|
||||
layer: i,
|
||||
|
|
@ -757,6 +759,15 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
|||
}
|
||||
|
||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||
c.ComputeWithNotify(nil, tensors...)
|
||||
}
|
||||
|
||||
func (c *Context) ComputeWithNotify(cb func(), tensors ...ml.Tensor) {
|
||||
c.b.schedMu.Lock()
|
||||
defer c.b.schedMu.Unlock()
|
||||
if cb != nil {
|
||||
go cb()
|
||||
}
|
||||
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
|
||||
panic(fmt.Errorf("error computing ggml graph: %v", status))
|
||||
}
|
||||
|
|
@ -792,14 +803,16 @@ func (c *Context) Reserve() {
|
|||
|
||||
graph := &c.b.btDeviceMemory[c.b.schedBufts[i]].Graph
|
||||
graph.Size += uint64(bufferStatus.size)
|
||||
if bufferStatus.allocated && graph.Status != ml.Failed {
|
||||
graph.Status = ml.Allocated
|
||||
} else {
|
||||
graph.Status = ml.Failed
|
||||
if c.b.allocMemory {
|
||||
if bufferStatus.allocated && graph.Status != ml.Failed {
|
||||
graph.Status = ml.Allocated
|
||||
} else {
|
||||
graph.Status = ml.Failed
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])),
|
||||
"size", format.HumanBytes2(uint64(bufferStatus.size)))
|
||||
logutil.Trace("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])),
|
||||
"buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), "size", format.HumanBytes2(uint64(bufferStatus.size)))
|
||||
}
|
||||
|
||||
if !reserved {
|
||||
|
|
@ -829,23 +842,7 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
|
|||
panic("set Input or Layer before creating tensors")
|
||||
}
|
||||
|
||||
var cdtype uint32
|
||||
switch dtype {
|
||||
case ml.DTypeF32:
|
||||
cdtype = C.GGML_TYPE_F32
|
||||
case ml.DTypeF16:
|
||||
cdtype = C.GGML_TYPE_F16
|
||||
case ml.DTypeQ80:
|
||||
cdtype = C.GGML_TYPE_Q8_0
|
||||
case ml.DTypeQ40:
|
||||
cdtype = C.GGML_TYPE_Q4_0
|
||||
case ml.DTypeI32:
|
||||
cdtype = C.GGML_TYPE_I32
|
||||
case ml.DTypeMXFP4:
|
||||
cdtype = C.GGML_TYPE_MXFP4
|
||||
default:
|
||||
panic("unsupported dtype")
|
||||
}
|
||||
cdtype := ggmlDType(dtype)
|
||||
|
||||
if len(shape) < 1 || shape[0] == 0 {
|
||||
var shape C.int64_t = 0
|
||||
|
|
@ -868,10 +865,12 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
|
|||
cache := &c.b.btDeviceMemory[c.buft].Cache[c.layer]
|
||||
|
||||
cache.Size += uint64(size)
|
||||
if b != nil {
|
||||
cache.Status = ml.Allocated
|
||||
} else {
|
||||
cache.Status = ml.Failed
|
||||
if c.b.allocMemory {
|
||||
if b != nil {
|
||||
cache.Status = ml.Allocated
|
||||
} else {
|
||||
cache.Status = ml.Failed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -890,7 +889,9 @@ func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
|||
|
||||
func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
t := c.newTensor(dtype, shape)
|
||||
C.ggml_set_zero(t.(*Tensor).t)
|
||||
if c.b.allocMemory {
|
||||
C.ggml_set_zero(t.(*Tensor).t)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
|
|
@ -915,7 +916,7 @@ func (c *Context) FromFloatSlice(s []float32, shape ...int) ml.Tensor {
|
|||
|
||||
t := c.newTensor(ml.DTypeF32, shape)
|
||||
|
||||
if len(s) > 0 {
|
||||
if c.b.allocMemory && len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
}
|
||||
|
||||
|
|
@ -927,7 +928,7 @@ func (c *Context) FromIntSlice(s []int32, shape ...int) ml.Tensor {
|
|||
|
||||
t := c.newTensor(ml.DTypeI32, shape)
|
||||
|
||||
if len(s) > 0 {
|
||||
if c.b.allocMemory && len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
}
|
||||
|
||||
|
|
@ -1019,6 +1020,12 @@ func (t *Tensor) Floats() (data []float32) {
|
|||
return
|
||||
}
|
||||
|
||||
func (t *Tensor) SetValueFromIntSlice(s []int32) {
|
||||
if len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.t))
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) DType() ml.DType {
|
||||
switch t.t._type {
|
||||
case C.GGML_TYPE_F32:
|
||||
|
|
@ -1038,6 +1045,32 @@ func (t *Tensor) DType() ml.DType {
|
|||
}
|
||||
}
|
||||
|
||||
func ggmlDType(dtype ml.DType) uint32 {
|
||||
switch dtype {
|
||||
case ml.DTypeF32:
|
||||
return C.GGML_TYPE_F32
|
||||
case ml.DTypeF16:
|
||||
return C.GGML_TYPE_F16
|
||||
case ml.DTypeQ80:
|
||||
return C.GGML_TYPE_Q8_0
|
||||
case ml.DTypeQ40:
|
||||
return C.GGML_TYPE_Q4_0
|
||||
case ml.DTypeI32:
|
||||
return C.GGML_TYPE_I32
|
||||
case ml.DTypeMXFP4:
|
||||
return C.GGML_TYPE_MXFP4
|
||||
default:
|
||||
panic("unsupported dtype")
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Cast(ctx ml.Context, dtype ml.DType) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cast(ctx.(*Context).ctx, t.t, ggmlDType(dtype)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Neg(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
|
|
@ -1550,7 +1583,7 @@ func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor {
|
|||
func (c Context) FromBytes(dtype ml.DType, s []uint8, shape ...int) ml.Tensor {
|
||||
// Unchecked to handle quantized types
|
||||
t := c.newTensor(dtype, shape)
|
||||
if len(s) > 0 {
|
||||
if c.b.allocMemory && len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -167,6 +167,7 @@ extern "C" {
|
|||
GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
|
||||
GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
|
||||
GGML_API void ggml_backend_dev_reset(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
|
||||
|
|
|
|||
|
|
@ -178,6 +178,10 @@ extern "C" {
|
|||
ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev);
|
||||
void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event);
|
||||
void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event);
|
||||
|
||||
// (optional) reset device, clearing existing allocations and context
|
||||
// the caller must ensure that there are no outstanding buffers, as these will become invalid
|
||||
void (*reset)(ggml_backend_dev_t dev);
|
||||
};
|
||||
|
||||
struct ggml_backend_device {
|
||||
|
|
|
|||
|
|
@ -581,16 +581,8 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
|||
|
||||
ggml_backend_load_best("blas", silent, dir_path);
|
||||
ggml_backend_load_best("cann", silent, dir_path);
|
||||
|
||||
// Avoid mixed hip+cuda configurations
|
||||
const char * hip_devices = std::getenv("HIP_VISIBLE_DEVICES");
|
||||
const char * rocr_devices = std::getenv("ROCR_VISIBLE_DEVICES");
|
||||
if (!hip_devices && !rocr_devices) {
|
||||
ggml_backend_load_best("cuda", silent, dir_path);
|
||||
} else {
|
||||
ggml_backend_load_best("hip", silent, dir_path);
|
||||
}
|
||||
|
||||
ggml_backend_load_best("cuda", silent, dir_path);
|
||||
ggml_backend_load_best("hip", silent, dir_path);
|
||||
ggml_backend_load_best("metal", silent, dir_path);
|
||||
ggml_backend_load_best("rpc", silent, dir_path);
|
||||
ggml_backend_load_best("sycl", silent, dir_path);
|
||||
|
|
|
|||
|
|
@ -477,6 +477,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
|
|||
return device->iface.init_backend(device, params);
|
||||
}
|
||||
|
||||
void ggml_backend_dev_reset(ggml_backend_dev_t device) {
|
||||
if (device->iface.reset == NULL) {
|
||||
return;
|
||||
}
|
||||
|
||||
device->iface.reset(device);
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
|
||||
return device->iface.get_buffer_type(device);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
package arm
|
||||
|
||||
// #cgo CXXFLAGS: -std=c++17
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/../.. -I${SRCDIR}/../../.. -I${SRCDIR}/../../../../include
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/../.. -I${SRCDIR}/../../.. -I${SRCDIR}/../../../../include -DHWCAP2_SVE2="2"
|
||||
import "C"
|
||||
|
|
|
|||
|
|
@ -103,6 +103,11 @@ int ggml_cuda_get_device() {
|
|||
return id;
|
||||
}
|
||||
|
||||
void ggml_cuda_reset_device(int device) {
|
||||
ggml_cuda_set_device(device);
|
||||
CUDA_CHECK(cudaDeviceReset());
|
||||
}
|
||||
|
||||
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
||||
ggml_cuda_set_device(device);
|
||||
cudaError_t err;
|
||||
|
|
@ -3243,7 +3248,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
|
|||
props->description = ggml_backend_cuda_device_get_description(dev);
|
||||
props->id = ggml_backend_cuda_device_get_id(dev);
|
||||
props->type = ggml_backend_cuda_device_get_type(dev);
|
||||
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
|
||||
// Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device).
|
||||
// If you need the memory data, call ggml_backend_dev_memory() explicitly.
|
||||
props->memory_total = props->memory_free = 0;
|
||||
|
||||
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
|
||||
#ifdef GGML_CUDA_NO_PEER_COPY
|
||||
|
|
@ -3700,6 +3708,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
|
|||
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_device_reset(ggml_backend_dev_t dev) {
|
||||
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||
ggml_cuda_reset_device(ctx->device);
|
||||
}
|
||||
|
||||
static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
||||
/* .get_name = */ ggml_backend_cuda_device_get_name,
|
||||
/* .get_description = */ ggml_backend_cuda_device_get_description,
|
||||
|
|
@ -3716,6 +3729,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
|||
/* .event_new = */ ggml_backend_cuda_device_event_new,
|
||||
/* .event_free = */ ggml_backend_cuda_device_event_free,
|
||||
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
|
||||
/* .reset = */ ggml_backend_cuda_device_reset,
|
||||
};
|
||||
|
||||
// backend reg
|
||||
|
|
@ -3835,7 +3849,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
|||
dev_ctx->device = i;
|
||||
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
|
||||
|
||||
ggml_cuda_set_device(i);
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||
dev_ctx->description = prop.name;
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@
|
|||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
#define cudaDeviceProp hipDeviceProp_t
|
||||
#define cudaDeviceReset hipDeviceReset
|
||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||
#define cudaError_t hipError_t
|
||||
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||
|
|
|
|||
|
|
@ -19,8 +19,12 @@ static bool ggml_uncaught_exception_init = []{
|
|||
return false;
|
||||
}
|
||||
const auto prev{std::get_terminate()};
|
||||
GGML_ASSERT(prev != ggml_uncaught_exception);
|
||||
previous_terminate_handler = prev;
|
||||
// GGML_ASSERT(prev != ggml_uncaught_exception);
|
||||
if (prev != ggml_uncaught_exception) {
|
||||
previous_terminate_handler = prev;
|
||||
} else {
|
||||
GGML_LOG_WARN("%s double registration of ggml_uncaught_exception\n", __func__);
|
||||
}
|
||||
std::set_terminate(ggml_uncaught_exception);
|
||||
return true;
|
||||
}();
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package model
|
|||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"fmt"
|
||||
"iter"
|
||||
"log/slog"
|
||||
|
|
@ -109,7 +108,7 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
|||
r = 0x0143
|
||||
case r <= 0x0020:
|
||||
r = r + 0x0100
|
||||
case r >= 0x007e && r <= 0x00a0:
|
||||
case r >= 0x007f && r <= 0x00a0:
|
||||
r = r + 0x00a2
|
||||
}
|
||||
|
||||
|
|
@ -202,12 +201,11 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
|||
}
|
||||
}
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids)
|
||||
|
||||
if addSpecial && len(ids) > 0 {
|
||||
ids = bpe.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
|
|
@ -243,6 +241,6 @@ func (bpe BytePairEncoding) Decode(ids []int32) (string, error) {
|
|||
}
|
||||
}
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
|
||||
logutil.Trace("decoded", "string", sb.String(), "from", lazyIdsString{ids: ids})
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -207,6 +207,36 @@ func TestLlama(t *testing.T) {
|
|||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("roundtriping 0x00-0xFF", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for b := 0x00; b <= 0xFF; b++ {
|
||||
input := string(rune(b))
|
||||
ids, err := tokenizer.Encode(input, false)
|
||||
if err != nil {
|
||||
t.Errorf("failed to encode rune 0x%02X: %v", b, err)
|
||||
continue
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("failed to decode rune 0x%02X: %v", b, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if b == 0x00 {
|
||||
if len(decoded) != 0 {
|
||||
t.Errorf("Decode(Encode(0x00)) should be empty, got %v", ids)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if decoded != input {
|
||||
t.Errorf("rune 0x%02X failed roundtrip: got %q, want %q", b, decoded, input)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
package model
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
_ "image/jpeg"
|
||||
_ "image/png"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
|
|
@ -64,7 +63,7 @@ type MultimodalProcessor interface {
|
|||
// This function is also responsible for updating MultimodalHash for any Multimodal
|
||||
// that is modified to ensure that there is a unique hash value that accurately
|
||||
// represents the contents.
|
||||
PostTokenize([]input.Input) ([]input.Input, error)
|
||||
PostTokenize([]*input.Input) ([]*input.Input, error)
|
||||
}
|
||||
|
||||
// Base implements the common fields and methods for all models
|
||||
|
|
@ -105,6 +104,10 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
|||
}
|
||||
|
||||
arch := b.Config().Architecture()
|
||||
if b.Config().Uint("pooling_type", math.MaxUint32) != math.MaxUint32 {
|
||||
arch = arch + "_embed"
|
||||
}
|
||||
|
||||
f, ok := models[arch]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
||||
|
|
@ -198,7 +201,7 @@ func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
|||
names := fn(tagsCopy)
|
||||
for _, name := range names {
|
||||
if tensor := base.Backend().Get(strings.Join(name, ".")); tensor != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "found tensor", "", tensor)
|
||||
logutil.Trace("found tensor", "", tensor)
|
||||
vv.Set(reflect.ValueOf(tensor))
|
||||
break
|
||||
}
|
||||
|
|
@ -278,7 +281,7 @@ func canNil(t reflect.Type) bool {
|
|||
t.Kind() == reflect.Slice
|
||||
}
|
||||
|
||||
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
|
||||
func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
|
||||
if len(batch.Positions) != len(batch.Sequences) {
|
||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||
}
|
||||
|
|
@ -287,8 +290,6 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
|
|||
return nil, errors.New("batch size cannot be less than 1")
|
||||
}
|
||||
|
||||
batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||
|
||||
cache := m.Config().Cache
|
||||
if cache != nil {
|
||||
err := cache.StartForward(ctx, batch, false)
|
||||
|
|
@ -302,7 +303,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
|
|||
return nil, err
|
||||
}
|
||||
|
||||
ctx.Forward(t).Compute(t)
|
||||
ctx.Forward(t)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,73 @@
|
|||
package gemma3
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type embedModel struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
|
||||
*TextModel
|
||||
PoolingType uint32
|
||||
|
||||
Dense [2]*nn.Linear `gguf:"dense"`
|
||||
}
|
||||
|
||||
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
batch.Outputs = batch.Positions // return all positions
|
||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
|
||||
switch m.PoolingType {
|
||||
case 0: // None
|
||||
case 1: // Mean
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
|
||||
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
default:
|
||||
return nil, errors.New("unsupported pooling type")
|
||||
}
|
||||
|
||||
for _, dense := range m.Dense {
|
||||
hiddenStates = dense.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
return hiddenStates, nil
|
||||
}
|
||||
|
||||
func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||
m := &embedModel{
|
||||
SentencePieceModel: model.NewSentencePieceModel(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{
|
||||
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
|
||||
},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
),
|
||||
TextModel: newTextModel(c),
|
||||
PoolingType: c.Uint("pooling_type", 0),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
|
@ -18,7 +18,7 @@ type Model struct {
|
|||
model.Base
|
||||
model.SentencePieceModel
|
||||
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
|
||||
*MultiModalProjector `gguf:"mm"`
|
||||
|
|
@ -112,8 +112,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
|||
return []input.Multimodal{{Tensor: visionOutputs}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
|
|
@ -122,17 +122,17 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||
inputMultimodal := inp.Multimodal[0].Tensor
|
||||
|
||||
result = append(result,
|
||||
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||
input.Input{Token: 255999}, // "<start_of_image>""
|
||||
input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||
&input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
|
||||
&input.Input{Token: 255999}, // "<start_of_image>""
|
||||
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
|
||||
)
|
||||
|
||||
// add image token placeholders
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
|
||||
|
||||
result = append(result,
|
||||
input.Input{Token: 256000}, // <end_of_image>
|
||||
input.Input{Token: 108}, // "\n\n"
|
||||
&input.Input{Token: 256000}, // <end_of_image>
|
||||
&input.Input{Token: 108}, // "\n\n"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
@ -141,12 +141,11 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
|
||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("gemma3", New)
|
||||
model.Register("gemma3_embed", newEmbedModel)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -159,8 +159,11 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
|
|||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
|
||||
|
||||
// set image embeddings
|
||||
|
|
@ -198,5 +201,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
|||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState)
|
||||
return hiddenState
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ type Model struct {
|
|||
model.BytePairEncoding
|
||||
ImageProcessor
|
||||
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*VisionModel `gguf:"v"`
|
||||
*Projector `gguf:"mm"`
|
||||
*TextModel
|
||||
}
|
||||
|
|
@ -134,16 +134,16 @@ type separator struct {
|
|||
y bool
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
continue
|
||||
}
|
||||
|
||||
var imageInputs []input.Input
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|>
|
||||
var imageInputs []*input.Input
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_start|>
|
||||
|
||||
for i, mm := range inp.Multimodal {
|
||||
patchesPerChunk := mm.Tensor.Dim(1)
|
||||
|
|
@ -151,20 +151,20 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||
if i < len(inp.Multimodal)-1 {
|
||||
separator := mm.Data.(*separator)
|
||||
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
|
||||
if separator.x {
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200084}) // <|tile_x_separator|>
|
||||
}
|
||||
if separator.y {
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200085}) // <|tile_y_separator|>
|
||||
}
|
||||
} else {
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200090}) // <|image|>
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
|
||||
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
|
||||
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_end|>
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ type Model struct {
|
|||
model.BytePairEncoding
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*VisionModel `gguf:"v"`
|
||||
*MultiModalProjector `gguf:"mm"`
|
||||
|
||||
ImageProcessor
|
||||
|
|
@ -133,22 +133,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
|||
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
|
||||
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
|
||||
// that can be processed together.
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
for _, inp := range inputs {
|
||||
if len(inp.Multimodal) == 0 {
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
for i, row := range inp.Multimodal {
|
||||
// [IMG]
|
||||
result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
|
||||
result = append(result, &input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
|
||||
if i == len(inp.Multimodal)-1 {
|
||||
// [IMG_END]
|
||||
result = append(result, input.Input{Token: 13})
|
||||
result = append(result, &input.Input{Token: 13})
|
||||
} else {
|
||||
// [IMG_BREAK]
|
||||
result = append(result, input.Input{Token: 12})
|
||||
result = append(result, &input.Input{Token: 12})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ type Model struct {
|
|||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
|
||||
Projector *nn.Linear `gguf:"mm.0"`
|
||||
|
|
@ -90,7 +90,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
|||
return []input.Multimodal{{Tensor: projectedOutputs}}, nil
|
||||
}
|
||||
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
for i := range inputs {
|
||||
if inputs[i].Multimodal != nil {
|
||||
inputs[i].Token = 128256 // <|image|>
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ type Model struct {
|
|||
model.BytePairEncoding
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*VisionModel `gguf:"v"`
|
||||
|
||||
ImageProcessor
|
||||
}
|
||||
|
|
@ -89,8 +89,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
|
|||
}
|
||||
|
||||
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
var result []*input.Input
|
||||
|
||||
var (
|
||||
imageToken int32 = 151655
|
||||
|
|
@ -112,16 +112,16 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
|
||||
}
|
||||
for i := range pre {
|
||||
result = append(result, input.Input{Token: pre[i]})
|
||||
result = append(result, &input.Input{Token: pre[i]})
|
||||
}
|
||||
|
||||
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
|
||||
|
||||
// First add the vision start token
|
||||
result = append(result, input.Input{Token: visionStartToken})
|
||||
result = append(result, &input.Input{Token: visionStartToken})
|
||||
|
||||
// Add the image token with the multimodal tensor data at the first position
|
||||
result = append(result, input.Input{
|
||||
result = append(result, &input.Input{
|
||||
Token: imageToken,
|
||||
Multimodal: inp.Multimodal,
|
||||
MultimodalHash: inp.MultimodalHash,
|
||||
|
|
@ -129,9 +129,9 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
|||
})
|
||||
|
||||
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
||||
result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
||||
|
||||
result = append(result, input.Input{Token: visionEndToken})
|
||||
result = append(result, &input.Input{Token: visionEndToken})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package model
|
|||
|
||||
import (
|
||||
"container/heap"
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
|
|
@ -25,7 +24,7 @@ func (spm SentencePieceModel) Vocabulary() *Vocabulary {
|
|||
}
|
||||
|
||||
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
||||
|
||||
counter := map[int]int{}
|
||||
var maxTokenLen int
|
||||
|
|
@ -39,7 +38,7 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
|
|||
}
|
||||
}
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
||||
logutil.Trace("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
||||
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
||||
"max token len", maxTokenLen)
|
||||
|
||||
|
|
@ -182,12 +181,11 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
|
|||
}
|
||||
}
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "encoded", "string", s, "ids", ids)
|
||||
|
||||
if addSpecial && len(ids) > 0 {
|
||||
ids = spm.vocab.addSpecials(ids)
|
||||
}
|
||||
|
||||
logutil.Trace("encoded", "string", s, "ids", ids)
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
|
|
@ -246,6 +244,6 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
|||
}
|
||||
}
|
||||
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "decoded", "ids", ids, "string", sb.String())
|
||||
logutil.Trace("decoded", "ids", ids, "string", sb.String())
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
|
|||
slog.Warn("adding bos token to prompt which already has it", "id", v.BOS)
|
||||
}
|
||||
|
||||
slog.Debug("adding bos token to prompt", "id", v.BOS)
|
||||
slog.Debug("adding bos token to prompt", "id", v.BOS[0])
|
||||
ids = append([]int32{v.BOS[0]}, ids...)
|
||||
}
|
||||
|
||||
|
|
@ -58,7 +58,7 @@ func (v *Vocabulary) addSpecials(ids []int32) []int32 {
|
|||
slog.Warn("adding eos token to prompt which already has it", "id", v.EOS)
|
||||
}
|
||||
|
||||
slog.Debug("adding eos token to prompt", "id", v.EOS)
|
||||
slog.Debug("adding eos token to prompt", "id", v.EOS[0])
|
||||
ids = append(ids, v.EOS[0])
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -557,12 +557,10 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
|||
|
||||
var think *api.ThinkValue
|
||||
if r.Reasoning != nil {
|
||||
options["reasoning"] = *r.Reasoning.Effort
|
||||
think = &api.ThinkValue{
|
||||
Value: *r.Reasoning.Effort,
|
||||
}
|
||||
} else if r.ReasoningEffort != nil {
|
||||
options["reasoning"] = *r.ReasoningEffort
|
||||
think = &api.ThinkValue{
|
||||
Value: *r.ReasoningEffort,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache b
|
|||
}
|
||||
|
||||
// Locking: Operations on InputCacheSlot (including finding one
|
||||
// through LoadCacheSlot) require a lock to be be held that serializes
|
||||
// through LoadCacheSlot) require a lock to be held that serializes
|
||||
// these operations with each other and llama.Decode
|
||||
|
||||
type InputCacheSlot struct {
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ import (
|
|||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
|
@ -216,6 +215,12 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error)
|
|||
}
|
||||
|
||||
type Server struct {
|
||||
// modelPath is the location of the model to be loaded
|
||||
modelPath string
|
||||
|
||||
// loadMu prevents more than one load attempt from occurring at a time
|
||||
loadMu sync.Mutex
|
||||
|
||||
// is the server ready to process requests?
|
||||
// protects access to model and image
|
||||
ready sync.WaitGroup
|
||||
|
|
@ -723,21 +728,12 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
type multiLPath []string
|
||||
|
||||
func (m *multiLPath) Set(value string) error {
|
||||
*m = append(*m, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *multiLPath) String() string {
|
||||
return strings.Join(*m, ", ")
|
||||
}
|
||||
|
||||
// loadModel allocates memory based on the given parameters and loads the weights. The
|
||||
// memory allocated is worst case for text models but not for vision.
|
||||
func (s *Server) loadModel(
|
||||
params llama.ModelParams,
|
||||
mpath string,
|
||||
lpath multiLPath,
|
||||
lpath []string,
|
||||
ppath string,
|
||||
kvSize int,
|
||||
kvCacheType string,
|
||||
|
|
@ -757,12 +753,10 @@ func (s *Server) loadModel(
|
|||
panic(err)
|
||||
}
|
||||
|
||||
if lpath.String() != "" {
|
||||
for _, path := range lpath {
|
||||
err := s.model.ApplyLoraFromFile(s.lc, path, 1.0, threads)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for _, path := range lpath {
|
||||
err := s.model.ApplyLoraFromFile(s.lc, path, 1.0, threads)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -783,26 +777,81 @@ func (s *Server) loadModel(
|
|||
s.ready.Done()
|
||||
}
|
||||
|
||||
// load is the handler called by the Ollama server to process different
|
||||
// load operations
|
||||
func (s *Server) load(w http.ResponseWriter, r *http.Request) {
|
||||
s.loadMu.Lock()
|
||||
defer s.loadMu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if s.status != llm.ServerStatusLaunched {
|
||||
http.Error(w, "model already loaded", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var req llm.LoadRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("load", "request", req)
|
||||
|
||||
switch req.Operation {
|
||||
// LoadOperationFit and LoadOperationAlloc have no meaning here - just return a successful response
|
||||
|
||||
case llm.LoadOperationCommit:
|
||||
s.batchSize = req.BatchSize
|
||||
s.parallel = req.Parallel
|
||||
s.seqs = make([]*Sequence, s.parallel)
|
||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||
|
||||
gpuIDs := llama.EnumerateGPUs()
|
||||
tensorSplit := make([]float32, len(gpuIDs))
|
||||
numGPU := 0
|
||||
for i := range gpuIDs {
|
||||
for _, layers := range req.GPULayers {
|
||||
if gpuIDs[i] == layers.ID {
|
||||
tensorSplit[i] = float32(len(layers.Layers))
|
||||
numGPU += len(layers.Layers)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
params := llama.ModelParams{
|
||||
NumGpuLayers: numGPU,
|
||||
MainGpu: req.MainGPU,
|
||||
UseMmap: req.UseMmap && len(req.LoraPath) == 0,
|
||||
TensorSplit: tensorSplit,
|
||||
Progress: func(progress float32) {
|
||||
s.progress = progress
|
||||
},
|
||||
}
|
||||
|
||||
s.status = llm.ServerStatusLoadingModel
|
||||
go s.loadModel(params, s.modelPath, req.LoraPath, req.ProjectorPath, req.KvSize, req.KvCacheType, req.FlashAttention, req.NumThreads, req.MultiUserCache)
|
||||
|
||||
case llm.LoadOperationClose:
|
||||
// No-op for us
|
||||
if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
resp := llm.LoadResponse{Success: true}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func Execute(args []string) error {
|
||||
fs := flag.NewFlagSet("runner", flag.ExitOnError)
|
||||
mpath := fs.String("model", "", "Path to model binary file")
|
||||
ppath := fs.String("mmproj", "", "Path to projector binary file")
|
||||
parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously")
|
||||
batchSize := fs.Int("batch-size", 512, "Batch size")
|
||||
nGpuLayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
||||
mainGpu := fs.Int("main-gpu", 0, "Main GPU")
|
||||
flashAttention := fs.Bool("flash-attn", false, "Enable flash attention")
|
||||
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
||||
port := fs.Int("port", 8080, "Port to expose the server on")
|
||||
threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
||||
_ = fs.Bool("verbose", false, "verbose output (default: disabled)")
|
||||
noMmap := fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
|
||||
tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
|
||||
multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
|
||||
|
||||
var lpaths multiLPath
|
||||
fs.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)")
|
||||
|
||||
fs.Usage = func() {
|
||||
fmt.Fprintf(fs.Output(), "Runner usage\n")
|
||||
|
|
@ -817,35 +866,11 @@ func Execute(args []string) error {
|
|||
llama.BackendInit()
|
||||
|
||||
server := &Server{
|
||||
batchSize: *batchSize,
|
||||
parallel: *parallel,
|
||||
seqs: make([]*Sequence, *parallel),
|
||||
seqsSem: semaphore.NewWeighted(int64(*parallel)),
|
||||
status: llm.ServerStatusLoadingModel,
|
||||
}
|
||||
|
||||
var tensorSplitFloats []float32
|
||||
if *tensorSplit != "" {
|
||||
splits := strings.Split(*tensorSplit, ",")
|
||||
tensorSplitFloats = make([]float32, len(splits))
|
||||
for i, s := range splits {
|
||||
f, _ := strconv.ParseFloat(s, 32)
|
||||
tensorSplitFloats[i] = float32(f)
|
||||
}
|
||||
}
|
||||
|
||||
params := llama.ModelParams{
|
||||
NumGpuLayers: *nGpuLayers,
|
||||
MainGpu: *mainGpu,
|
||||
UseMmap: !*noMmap && lpaths.String() == "",
|
||||
TensorSplit: tensorSplitFloats,
|
||||
Progress: func(progress float32) {
|
||||
server.progress = progress
|
||||
},
|
||||
modelPath: *mpath,
|
||||
status: llm.ServerStatusLaunched,
|
||||
}
|
||||
|
||||
server.ready.Add(1)
|
||||
go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *kvCacheType, *flashAttention, *threads, *multiUserCache)
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
|
|
@ -863,6 +888,7 @@ func Execute(args []string) error {
|
|||
defer listener.Close()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("POST /load", server.load)
|
||||
mux.HandleFunc("/embedding", server.embeddings)
|
||||
mux.HandleFunc("/completion", server.completion)
|
||||
mux.HandleFunc("/health", server.health)
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ func (c *InputCache) Close() {
|
|||
}
|
||||
|
||||
// Locking: Operations on InputCacheSlot (including finding one
|
||||
// through LoadCacheSlot) require a lock to be be held that serializes
|
||||
// through LoadCacheSlot) require a lock to be held that serializes
|
||||
// these operations with each other and processBatch
|
||||
|
||||
type InputCacheSlot struct {
|
||||
|
|
@ -86,7 +86,7 @@ type InputCacheSlot struct {
|
|||
Id int
|
||||
|
||||
// Inputs that are stored in the KV cache
|
||||
Inputs []input.Input
|
||||
Inputs []*input.Input
|
||||
|
||||
// is this cache actively being processed as part of a sequence?
|
||||
InUse bool
|
||||
|
|
@ -95,7 +95,7 @@ type InputCacheSlot struct {
|
|||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
|
||||
func (c *InputCache) LoadCacheSlot(prompt []*input.Input, cachePrompt bool) (*InputCacheSlot, []*input.Input, error) {
|
||||
var slot *InputCacheSlot
|
||||
var numPast int32
|
||||
var err error
|
||||
|
|
@ -113,6 +113,10 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
if !cachePrompt {
|
||||
numPast = 0
|
||||
}
|
||||
|
||||
slot.InUse = true
|
||||
slot.lastUsed = time.Now()
|
||||
|
||||
|
|
@ -146,7 +150,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
|
|||
return slot, prompt, nil
|
||||
}
|
||||
|
||||
func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
||||
func (c *InputCache) findLongestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
|
||||
longest := int32(-1)
|
||||
var longestSlot *InputCacheSlot
|
||||
|
||||
|
|
@ -169,7 +173,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot
|
|||
return longestSlot, longest, nil
|
||||
}
|
||||
|
||||
func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
|
||||
func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
|
||||
oldest := time.Now()
|
||||
var oldestSlot *InputCacheSlot
|
||||
|
||||
|
|
@ -205,7 +209,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
|
|||
if longest > 0 && longestSlot != oldestSlot {
|
||||
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
||||
len(longestSlot.Inputs))
|
||||
oldestSlot.Inputs = make([]input.Input, longest)
|
||||
oldestSlot.Inputs = make([]*input.Input, longest)
|
||||
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
||||
if c.cache != nil {
|
||||
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
|
||||
|
|
@ -215,7 +219,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, i
|
|||
return oldestSlot, longest, nil
|
||||
}
|
||||
|
||||
func countCommonPrefix(a []input.Input, b []input.Input) int32 {
|
||||
func countCommonPrefix(a []*input.Input, b []*input.Input) int32 {
|
||||
var count int32
|
||||
|
||||
for i := range a {
|
||||
|
|
@ -250,7 +254,7 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
|
|||
}
|
||||
|
||||
type ErrReprocessInputs struct {
|
||||
Inputs []input.Input
|
||||
Inputs []*input.Input
|
||||
}
|
||||
|
||||
func (e *ErrReprocessInputs) Error() string {
|
||||
|
|
@ -283,13 +287,13 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
|
|||
"id", slot.Id, "error", err)
|
||||
|
||||
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
|
||||
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
|
||||
newInputs := make([]*input.Input, numKeep+inputLen-(numKeep+discard))
|
||||
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
|
||||
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
|
||||
|
||||
// Reset the cache
|
||||
_ = c.cache.Remove(slot.Id, 0, math.MaxInt32)
|
||||
slot.Inputs = []input.Input{}
|
||||
slot.Inputs = []*input.Input{}
|
||||
|
||||
// Return error with inputs that need to be reprocessed
|
||||
return &ErrReprocessInputs{Inputs: newInputs}
|
||||
|
|
|
|||
|
|
@ -13,50 +13,50 @@ import (
|
|||
func TestCountCommon(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
t1 []input.Input
|
||||
t2 []input.Input
|
||||
t1 []*input.Input
|
||||
t2 []*input.Input
|
||||
expected int32
|
||||
}{
|
||||
{
|
||||
name: "Equal",
|
||||
t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t1: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
name: "Prefix",
|
||||
t1: []input.Input{{Token: 1}},
|
||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t1: []*input.Input{{Token: 1}},
|
||||
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Image Prefix",
|
||||
t1: []input.Input{{MultimodalHash: 1}},
|
||||
t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
|
||||
t1: []*input.Input{{MultimodalHash: 1}},
|
||||
t2: []*input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Mixed",
|
||||
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||
t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
|
||||
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||
t2: []*input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "Mixed, Same Length",
|
||||
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||
t2: []input.Input{{Token: 1}, {MultimodalHash: 2}},
|
||||
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
|
||||
t2: []*input.Input{{Token: 1}, {MultimodalHash: 2}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Empty",
|
||||
t1: []input.Input{},
|
||||
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
t1: []*input.Input{},
|
||||
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "Both Empty",
|
||||
t1: []input.Input{},
|
||||
t2: []input.Input{},
|
||||
t1: []*input.Input{},
|
||||
t2: []*input.Input{},
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
|
@ -80,7 +80,7 @@ func TestFindCacheSlot(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
cache InputCache
|
||||
prompt []input.Input
|
||||
prompt []*input.Input
|
||||
longest expected
|
||||
best expected
|
||||
}{
|
||||
|
|
@ -89,18 +89,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{},
|
||||
Inputs: []*input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
Inputs: []*input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
}},
|
||||
prompt: []input.Input{{Token: 1}},
|
||||
prompt: []*input.Input{{Token: 1}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 0, len: 0},
|
||||
},
|
||||
|
|
@ -109,18 +109,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}},
|
||||
Inputs: []*input.Input{{Token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
longest: expected{result: 1, len: 2},
|
||||
best: expected{result: 1, len: 2},
|
||||
},
|
||||
|
|
@ -129,18 +129,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
Inputs: []*input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
}},
|
||||
prompt: []input.Input{{Token: 2}},
|
||||
prompt: []*input.Input{{Token: 2}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 1, len: 0},
|
||||
},
|
||||
|
|
@ -150,19 +150,19 @@ func TestFindCacheSlot(t *testing.T) {
|
|||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
Inputs: []*input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}},
|
||||
prompt: []*input.Input{{Token: 1}},
|
||||
longest: expected{result: 0, len: 1},
|
||||
best: expected{result: 1, len: 1},
|
||||
},
|
||||
|
|
@ -171,18 +171,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}},
|
||||
Inputs: []*input.Input{{Token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []input.Input{{Token: 2}, {Token: 3}},
|
||||
prompt: []*input.Input{{Token: 2}, {Token: 3}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 1, len: 0},
|
||||
},
|
||||
|
|
@ -191,18 +191,18 @@ func TestFindCacheSlot(t *testing.T) {
|
|||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: true,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{{Token: 1}},
|
||||
Inputs: []*input.Input{{Token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
longest: expected{result: 1, len: 1},
|
||||
best: expected{result: 1, len: 2},
|
||||
},
|
||||
|
|
@ -300,7 +300,7 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
cache InputCache
|
||||
prompt []input.Input
|
||||
prompt []*input.Input
|
||||
wantErr bool
|
||||
expectedSlotId int
|
||||
expectedPrompt int // expected length of remaining prompt
|
||||
|
|
@ -312,19 +312,19 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
Inputs: []*input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Only token 3 remains
|
||||
|
|
@ -336,19 +336,19 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input.Input{},
|
||||
Inputs: []*input.Input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Only token 3 remains
|
||||
|
|
@ -360,13 +360,13 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
wantErr: false,
|
||||
expectedSlotId: 0,
|
||||
expectedPrompt: 1, // Should leave 1 token for sampling
|
||||
|
|
@ -378,13 +378,13 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input.Input{{Token: 1}, {Token: 2}},
|
||||
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
|
||||
InUse: true,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
|
||||
wantErr: true,
|
||||
expectedSlotId: -1,
|
||||
expectedPrompt: -1,
|
||||
|
|
@ -393,7 +393,7 @@ func TestLoadCacheSlot(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
|
||||
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt, true)
|
||||
|
||||
// Check error state
|
||||
if (err != nil) != tt.wantErr {
|
||||
|
|
@ -452,7 +452,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
|||
tests := []struct {
|
||||
name string
|
||||
numCtx int32
|
||||
inputs []input.Input
|
||||
inputs []*input.Input
|
||||
numKeep int32
|
||||
cacheErr bool
|
||||
wantErr any
|
||||
|
|
@ -461,7 +461,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
|||
{
|
||||
name: "Normal shift",
|
||||
numCtx: 10,
|
||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
numKeep: 2,
|
||||
cacheErr: false, // No error
|
||||
wantErr: nil,
|
||||
|
|
@ -470,7 +470,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
|||
{
|
||||
name: "Cache removal fails",
|
||||
numCtx: 10,
|
||||
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
|
||||
numKeep: 2,
|
||||
cacheErr: true,
|
||||
wantErr: &ErrReprocessInputs{},
|
||||
|
|
@ -487,7 +487,7 @@ func TestShiftCacheSlot(t *testing.T) {
|
|||
}
|
||||
slot := &InputCacheSlot{
|
||||
Id: 123,
|
||||
Inputs: make([]input.Input, len(tt.inputs)),
|
||||
Inputs: make([]*input.Input, len(tt.inputs)),
|
||||
}
|
||||
copy(slot.Inputs, tt.inputs)
|
||||
|
||||
|
|
|
|||
|
|
@ -11,11 +11,14 @@ import (
|
|||
"image"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
|
@ -50,10 +53,10 @@ type Sequence struct {
|
|||
iBatch int
|
||||
|
||||
// prompt inputs left to evaluate
|
||||
inputs []input.Input
|
||||
inputs []*input.Input
|
||||
|
||||
// inputs that have been added to a batch but not yet submitted to Forward
|
||||
pendingInputs []input.Input
|
||||
pendingInputs []*input.Input
|
||||
|
||||
// tokens that have been generated but not returned yet (e.g. for stop sequences)
|
||||
pendingResponses []string
|
||||
|
|
@ -181,8 +184,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
|||
// inputs processes the prompt and images into a list of inputs
|
||||
// by splitting the prompt on [img-<n>] tags, tokenizing text and
|
||||
// decoding images
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) {
|
||||
var inputs []input.Input
|
||||
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, []ml.Context, multimodalStore, error) {
|
||||
var inputs []*input.Input
|
||||
var ctxs []ml.Context
|
||||
var mmStore multimodalStore
|
||||
|
||||
|
|
@ -209,7 +212,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
|||
}
|
||||
|
||||
for _, t := range tokens {
|
||||
inputs = append(inputs, input.Input{Token: t})
|
||||
inputs = append(inputs, &input.Input{Token: t})
|
||||
}
|
||||
|
||||
// image - decode and store
|
||||
|
|
@ -242,7 +245,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
|||
|
||||
mmStore.addMultimodal(imageEmbeddings)
|
||||
|
||||
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||
inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
|
||||
postTokenize = true
|
||||
}
|
||||
}
|
||||
|
|
@ -258,7 +261,48 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [
|
|||
return inputs, ctxs, mmStore, nil
|
||||
}
|
||||
|
||||
type batchState struct {
|
||||
// id provides a counter for trace logging batches
|
||||
id int
|
||||
|
||||
// ctx holds the backend context used for this batch
|
||||
ctx ml.Context
|
||||
|
||||
// modelOutput holds the outputs from this batch
|
||||
modelOutput ml.Tensor
|
||||
|
||||
// batchInputs holds the input token pointers which may start as
|
||||
// placeholders later filled in before calling ctx.Compute
|
||||
batchInputs []*input.Input
|
||||
|
||||
// batch contains the inputs for a model forward pass
|
||||
batch input.Batch
|
||||
|
||||
// full set of seqs at the time this batch was initiated
|
||||
seqs []*Sequence
|
||||
|
||||
// Signaled when this batches inputs are ready and compute can proceed
|
||||
inputsReadyCh chan struct{}
|
||||
|
||||
// Signaling when Compute is about to begin on this batch, and
|
||||
// seqs have been updated to prepare for the next batch
|
||||
computeStartedCh chan struct{}
|
||||
|
||||
// Signaled when this batches outputs are complete and the next batch can proceed
|
||||
outputsReadyCh chan struct{}
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
// modelPath is the location of the model to be loaded
|
||||
modelPath string
|
||||
|
||||
// loadMu prevents more than one load attempt from occurring at a time
|
||||
loadMu sync.Mutex
|
||||
|
||||
// lastLoad is the load request from the previous load attempt. Used to
|
||||
// detect if we can reuse an existing memory allocation.
|
||||
lastLoad llm.LoadRequest
|
||||
|
||||
// is the server ready to process requests?
|
||||
// protects access to model and image
|
||||
ready sync.WaitGroup
|
||||
|
|
@ -279,6 +323,12 @@ type Server struct {
|
|||
// TODO (jmorganca): make this n_batch
|
||||
batchSize int
|
||||
|
||||
// Used to signal a hard failure during async processing which will panic the runner
|
||||
hardErrCh chan error
|
||||
|
||||
// Simple counter used only for trace logging batches
|
||||
batchID int
|
||||
|
||||
// protects access to everything below this line
|
||||
// this is context state needed for decoding
|
||||
mu sync.Mutex
|
||||
|
|
@ -351,33 +401,73 @@ func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
|
|||
s.seqsSem.Release(1)
|
||||
}
|
||||
|
||||
// track batch state between forwardBatch, computeBatch and predictForwardBatch
|
||||
|
||||
func (s *Server) run(ctx context.Context) {
|
||||
s.ready.Wait()
|
||||
|
||||
supportsAsync := s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32
|
||||
|
||||
var activeBatch batchState
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case err := <-s.hardErrCh:
|
||||
panic(err)
|
||||
default:
|
||||
err := s.processBatch()
|
||||
var err error
|
||||
activeBatch, err = s.forwardBatch(activeBatch)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if supportsAsync {
|
||||
go s.computeBatch(activeBatch)
|
||||
} else {
|
||||
s.computeBatch(activeBatch)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) processBatch() error {
|
||||
// forwardBatch will calculate a batch.
|
||||
func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, err error) {
|
||||
// If we have a pending batch still processing, wait until Compute has started
|
||||
// before setting up the next batch so the seqs inputs are ready to receive their
|
||||
// token values and we get the correct input pointers for the batchInputs
|
||||
if pendingBatch.ctx != nil {
|
||||
logutil.Trace("forwardBatch waiting for compute to start", "pendingBatch.id", pendingBatch.id)
|
||||
<-pendingBatch.computeStartedCh
|
||||
logutil.Trace("forwardBatch compute started, setting up next batch", "pendingBatch.id", pendingBatch.id, "id", s.batchID)
|
||||
nextBatch.inputsReadyCh = pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
|
||||
} else {
|
||||
logutil.Trace("forwardBatch no pending batch detected", "batchID", s.batchID)
|
||||
// No pendingBatch, so the inputs will be ready in the seqs immediately
|
||||
nextBatch.inputsReadyCh = make(chan struct{}, 1)
|
||||
nextBatch.inputsReadyCh <- struct{}{}
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
for s.allNil() {
|
||||
s.cond.Wait() // Wait until an item is added
|
||||
}
|
||||
defer s.mu.Unlock()
|
||||
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
nextBatch.ctx = s.model.Backend().NewContext()
|
||||
defer func() {
|
||||
if err != nil {
|
||||
nextBatch.ctx.Close()
|
||||
nextBatch.ctx = nil
|
||||
}
|
||||
}()
|
||||
nextBatch.id = s.batchID
|
||||
nextBatch.seqs = append([]*Sequence{}, s.seqs...)
|
||||
nextBatch.computeStartedCh = make(chan struct{}, 1)
|
||||
nextBatch.outputsReadyCh = make(chan struct{}, 1)
|
||||
|
||||
var batchInputs []int32
|
||||
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
|
||||
var batchInputs []*input.Input
|
||||
var batch input.Batch
|
||||
|
||||
resumeSeq := -1
|
||||
|
|
@ -385,7 +475,6 @@ func (s *Server) processBatch() error {
|
|||
for range s.seqs {
|
||||
seqIdx = (seqIdx + 1) % len(s.seqs)
|
||||
seq := s.seqs[seqIdx]
|
||||
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
|
|
@ -393,12 +482,13 @@ func (s *Server) processBatch() error {
|
|||
// if past the num predict limit
|
||||
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||
nextBatch.seqs[seqIdx] = nil
|
||||
continue
|
||||
}
|
||||
|
||||
if !s.cache.enabled {
|
||||
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
|
||||
seq.cache.Inputs = []input.Input{}
|
||||
seq.cache.Inputs = []*input.Input{}
|
||||
}
|
||||
|
||||
batchSize := s.batchSize
|
||||
|
|
@ -431,25 +521,28 @@ func (s *Server) processBatch() error {
|
|||
break
|
||||
}
|
||||
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
var reprocess *ErrReprocessInputs
|
||||
if errors.As(err, &reprocess) {
|
||||
// Prepend these inputs to the sequence's inputs queue for reprocessing
|
||||
seq.inputs = append(reprocess.Inputs, seq.inputs...)
|
||||
// Skip this sequence but continue processing the rest
|
||||
nextBatch.seqs[seqIdx] = nil // clear this sequence for this batch
|
||||
err = nil
|
||||
continue
|
||||
} else {
|
||||
return err
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
batchInputs = append(batchInputs, inp.Token)
|
||||
batchInputs = append(batchInputs, seq.inputs[i])
|
||||
if inp.Multimodal != nil {
|
||||
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
|
||||
var mm []input.Multimodal
|
||||
mm, err = seq.mmStore.getMultimodal(s.model.Backend(), nextBatch.ctx, inp.Multimodal, false)
|
||||
if err != nil {
|
||||
return err
|
||||
return
|
||||
}
|
||||
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
|
||||
}
|
||||
|
|
@ -461,6 +554,7 @@ func (s *Server) processBatch() error {
|
|||
if i+1 == len(seq.inputs) {
|
||||
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
||||
}
|
||||
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
|
||||
seq.pendingInputs = append(seq.pendingInputs, inp)
|
||||
}
|
||||
|
||||
|
|
@ -474,73 +568,168 @@ func (s *Server) processBatch() error {
|
|||
}
|
||||
|
||||
if len(batchInputs) == 0 {
|
||||
return nil
|
||||
logutil.Trace("forwardBatch no batchInputs, going idle", "batchID", s.batchID)
|
||||
nextBatch.ctx.Close()
|
||||
nextBatch.ctx = nil
|
||||
return
|
||||
}
|
||||
s.batchID++
|
||||
|
||||
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
|
||||
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
|
||||
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
|
||||
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
err = fmt.Errorf("failed to build graph: %w", err)
|
||||
return
|
||||
}
|
||||
nextBatch.batchInputs = batchInputs
|
||||
nextBatch.batch = batch
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Async processing of the next batch
|
||||
func (s *Server) computeBatch(activeBatch batchState) {
|
||||
if activeBatch.ctx == nil {
|
||||
// Nothing to compute
|
||||
return
|
||||
}
|
||||
defer activeBatch.ctx.Close()
|
||||
|
||||
// Wait until inputs are ready
|
||||
logutil.Trace("computeBatch: waiting for inputs to be ready", "batchID", activeBatch.id)
|
||||
<-activeBatch.inputsReadyCh
|
||||
logutil.Trace("computeBatch: inputs are ready", "batchID", activeBatch.id)
|
||||
|
||||
// Once we complete, signal the next batch of inputs are ready
|
||||
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
|
||||
defer func() {
|
||||
logutil.Trace("computeBatch: outputs are ready", "batchID", activeBatch.id)
|
||||
activeBatch.outputsReadyCh <- struct{}{}
|
||||
}()
|
||||
|
||||
s.mu.Lock()
|
||||
|
||||
// Gather the actual input token values now that they're ready
|
||||
batchInputs := make([]int32, len(activeBatch.batchInputs))
|
||||
for i := range batchInputs {
|
||||
batchInputs[i] = activeBatch.batchInputs[i].Token
|
||||
}
|
||||
|
||||
logits := modelOutput.Floats()
|
||||
|
||||
// Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens
|
||||
// so that forwardBatch can build a batchInputs set which will eventually contain the actual
|
||||
// decoded tokens.
|
||||
nextBatchTokens := make([]*input.Input, len(s.seqs))
|
||||
iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock
|
||||
for i, seq := range s.seqs {
|
||||
iBatches[i] = -1
|
||||
if seq == nil {
|
||||
continue
|
||||
}
|
||||
// Skip over any newly added or skipped sequences
|
||||
if activeBatch.seqs[i] == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// After calling Forward, pending inputs are now in the cache
|
||||
// Detect if the sequence we're processing has already been completed and replaced
|
||||
// with a new sequence
|
||||
if seq != activeBatch.seqs[i] {
|
||||
logutil.Trace("computeBatch: sequence replaced, discarding its results", "batchID", activeBatch.id, "seqIdx", i)
|
||||
continue
|
||||
}
|
||||
|
||||
// Pending inputs will actually be in the cache after we call Compute.
|
||||
// However, we have already resolved any placeholder tokens.
|
||||
//
|
||||
// It's possible for incoming sequences to look at the values that we've
|
||||
// added to the cache here and start relying on them before we've done
|
||||
// the computation. This is OK as long as we ensure that this batch's
|
||||
// computation happens before any future batch's and we never fail
|
||||
// (unless we take down the whole runner).
|
||||
if len(seq.pendingInputs) > 0 {
|
||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
|
||||
seq.pendingInputs = []input.Input{}
|
||||
seq.pendingInputs = []*input.Input{}
|
||||
}
|
||||
|
||||
// don't sample prompt processing
|
||||
if len(seq.inputs) != 0 {
|
||||
if !s.cache.enabled {
|
||||
return errors.New("caching disabled but unable to fit entire input in a batch")
|
||||
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
seq.numPredicted++
|
||||
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
|
||||
seq.inputs = []*input.Input{nextToken}
|
||||
nextBatchTokens[i] = nextToken
|
||||
iBatches[i] = seq.iBatch
|
||||
}
|
||||
|
||||
// At this point the seqs are ready for forwardBatch to move forward so unblock
|
||||
s.mu.Unlock()
|
||||
|
||||
activeBatch.batch.Inputs.SetValueFromIntSlice(batchInputs)
|
||||
activeBatch.ctx.ComputeWithNotify(
|
||||
func() {
|
||||
logutil.Trace("computeBatch: signaling computeStartedCh", "batchID", activeBatch.id)
|
||||
activeBatch.computeStartedCh <- struct{}{}
|
||||
},
|
||||
activeBatch.modelOutput)
|
||||
|
||||
outputs := activeBatch.modelOutput.Floats()
|
||||
|
||||
logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
logutil.Trace("computeBatch: decoding", "batchID", activeBatch.id)
|
||||
for i, seq := range s.seqs {
|
||||
if seq == nil || nextBatchTokens[i] == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if seq.numPredicted == 1 {
|
||||
seq.startGenerationTime = time.Now()
|
||||
}
|
||||
|
||||
// if done processing the prompt, generate an embedding and return
|
||||
if seq.embeddingOnly {
|
||||
// TODO(jessegross): Embedding support
|
||||
slog.Warn("generation of embedding outputs not yet supported")
|
||||
seq.embedding <- outputs
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
// sample a token
|
||||
vocabSize := len(logits) / len(batch.Outputs)
|
||||
|
||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
||||
vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
|
||||
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sample token: %w", err)
|
||||
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
nextBatchTokens[i].Token = token
|
||||
|
||||
// if it's an end of sequence token, break
|
||||
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
||||
// TODO (jmorganca): we should send this back
|
||||
// as it's important for the /api/generate context
|
||||
// seq.responses <- piece
|
||||
|
||||
logutil.Trace("computeBatch: EOS", "batchID", activeBatch.id, "seqIdx", i)
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
continue
|
||||
}
|
||||
|
||||
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
||||
if err != nil {
|
||||
return err
|
||||
s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
seq.inputs = []input.Input{{Token: token}}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
sequence := strings.Join(seq.pendingResponses, "")
|
||||
|
||||
|
|
@ -564,6 +753,7 @@ func (s *Server) processBatch() error {
|
|||
if tokenTruncated || origLen == newLen {
|
||||
tokenLen--
|
||||
}
|
||||
|
||||
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
|
||||
|
||||
s.removeSequence(i, llm.DoneReasonStop)
|
||||
|
|
@ -582,8 +772,6 @@ func (s *Server) processBatch() error {
|
|||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
@ -654,7 +842,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
s.seqsSem.Release(1)
|
||||
|
|
@ -710,6 +898,67 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
if s.model.Backend().Config().Uint("pooling_type", math.MaxUint32) == math.MaxUint32 {
|
||||
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
var req llm.EmbeddingRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, fmt.Sprintf("bad request: %s", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
slog.Info("aborting embedding request due to client closing the connection")
|
||||
} else {
|
||||
http.Error(w, fmt.Sprintf("failed to acquire semaphore: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
found := false
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, fmt.Sprintf("failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.seqs[i] = seq
|
||||
s.cond.Signal()
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if !found {
|
||||
s.seqsSem.Release(1)
|
||||
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
||||
Embedding: <-seq.embedding,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err := json.NewEncoder(w).Encode(&llm.ServerStatusResponse{
|
||||
|
|
@ -720,23 +969,15 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
type multiLPath []string
|
||||
|
||||
func (m *multiLPath) Set(value string) error {
|
||||
*m = append(*m, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *multiLPath) String() string {
|
||||
return strings.Join(*m, ", ")
|
||||
}
|
||||
|
||||
func (s *Server) reserveWorstCaseGraph() error {
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
var err error
|
||||
inputs := make([]input.Input, s.batchSize)
|
||||
inputs := make([]*input.Input, s.batchSize)
|
||||
for i := range inputs {
|
||||
inputs[i] = &input.Input{}
|
||||
}
|
||||
mmStore := newMultimodalStore()
|
||||
|
||||
// Multimodal strategy:
|
||||
|
|
@ -778,8 +1019,11 @@ func (s *Server) reserveWorstCaseGraph() error {
|
|||
}
|
||||
|
||||
if len(inputs) < s.batchSize {
|
||||
newInputs := make([]input.Input, s.batchSize)
|
||||
newInputs := make([]*input.Input, s.batchSize)
|
||||
copy(newInputs, inputs)
|
||||
for i := len(inputs); i < s.batchSize; i++ {
|
||||
newInputs[i] = &input.Input{}
|
||||
}
|
||||
inputs = newInputs
|
||||
}
|
||||
}
|
||||
|
|
@ -828,15 +1072,29 @@ func (s *Server) reserveWorstCaseGraph() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) initModel(
|
||||
// allocModel pre-allocates the maximum needed memory for a model
|
||||
// based on the given parameters
|
||||
func (s *Server) allocModel(
|
||||
mpath string,
|
||||
params ml.BackendParams,
|
||||
lpath multiLPath,
|
||||
loraPath []string,
|
||||
parallel int,
|
||||
kvCacheType string,
|
||||
kvSize int,
|
||||
multiUserCache bool,
|
||||
) error {
|
||||
) (panicErr error) {
|
||||
// Convert memory allocation panics to errors
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
debug.PrintStack()
|
||||
if err, ok := r.(error); ok {
|
||||
panicErr = err
|
||||
} else {
|
||||
panic(r)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
var err error
|
||||
s.model, err = model.New(mpath, params)
|
||||
if err != nil {
|
||||
|
|
@ -844,7 +1102,7 @@ func (s *Server) initModel(
|
|||
}
|
||||
|
||||
// TODO(jessegross): LoRA loading
|
||||
if lpath.String() != "" {
|
||||
if len(loraPath) > 0 {
|
||||
return errors.New("loras are not yet implemented")
|
||||
}
|
||||
|
||||
|
|
@ -865,63 +1123,122 @@ func (s *Server) initModel(
|
|||
return s.reserveWorstCaseGraph()
|
||||
}
|
||||
|
||||
func (s *Server) load(
|
||||
ctx context.Context,
|
||||
mpath string,
|
||||
params ml.BackendParams,
|
||||
lpath multiLPath,
|
||||
parallel int,
|
||||
kvCacheType string,
|
||||
kvSize int,
|
||||
multiUserCache bool,
|
||||
) {
|
||||
err := s.initModel(mpath, params, lpath, parallel, kvCacheType, kvSize, multiUserCache)
|
||||
if err != nil {
|
||||
var noMem ml.ErrNoMem
|
||||
if errors.As(err, &noMem) {
|
||||
// We can't yet handle this but in the future we will
|
||||
s.cache.Close()
|
||||
if s.model != nil {
|
||||
s.model.Backend().Close()
|
||||
}
|
||||
}
|
||||
|
||||
panic(err)
|
||||
// closeModel frees all memory associated with a model
|
||||
func (s *Server) closeModel() {
|
||||
s.cache.Close()
|
||||
s.cache = nil
|
||||
if s.model != nil {
|
||||
s.model.Backend().Close()
|
||||
s.model = nil
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("memory", "allocated", s.model.Backend().BackendMemory())
|
||||
|
||||
err = s.model.Backend().Load(ctx,
|
||||
// loadModel loads the weights for a model. The memory must already
|
||||
// have been allocated with allocModel
|
||||
func (s *Server) loadModel() {
|
||||
err := s.model.Backend().Load(context.TODO(),
|
||||
func(progress float32) {
|
||||
s.progress = progress
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
panic(fmt.Errorf("failed to load model: %v", err))
|
||||
}
|
||||
|
||||
s.status = llm.ServerStatusReady
|
||||
s.ready.Done()
|
||||
}
|
||||
|
||||
// load is the handler called by the Ollama server to process different
|
||||
// load operations
|
||||
func (s *Server) load(w http.ResponseWriter, r *http.Request) {
|
||||
s.loadMu.Lock()
|
||||
defer s.loadMu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if s.status != llm.ServerStatusLaunched {
|
||||
http.Error(w, "model already loaded", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var req llm.LoadRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("load", "request", req)
|
||||
|
||||
if req.Operation == llm.LoadOperationClose {
|
||||
s.closeModel()
|
||||
if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
s.lastLoad.Operation = req.Operation
|
||||
loadModel := s.model == nil || !reflect.DeepEqual(req, s.lastLoad)
|
||||
|
||||
s.lastLoad = req
|
||||
|
||||
if loadModel {
|
||||
s.closeModel()
|
||||
|
||||
params := ml.BackendParams{
|
||||
AllocMemory: req.Operation != llm.LoadOperationFit,
|
||||
NumThreads: req.NumThreads,
|
||||
GPULayers: req.GPULayers,
|
||||
FlashAttention: req.FlashAttention,
|
||||
}
|
||||
|
||||
s.batchSize = req.BatchSize
|
||||
|
||||
err := s.allocModel(s.modelPath, params, req.LoraPath, req.Parallel, req.KvCacheType, req.KvSize, req.MultiUserCache)
|
||||
if err != nil {
|
||||
s.closeModel()
|
||||
|
||||
var noMem ml.ErrNoMem
|
||||
if errors.As(err, &noMem) {
|
||||
resp := llm.LoadResponse{Success: false, Memory: noMem.BackendMemory}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
http.Error(w, fmt.Sprintf("failed to initialize model: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
mem := s.model.Backend().BackendMemory()
|
||||
|
||||
switch req.Operation {
|
||||
case llm.LoadOperationFit:
|
||||
// LoadOperationFit can't be used for anything else, so just close it
|
||||
s.closeModel()
|
||||
|
||||
// LoadOperationAlloc should stay open for future operations
|
||||
|
||||
case llm.LoadOperationCommit:
|
||||
s.status = llm.ServerStatusLoadingModel
|
||||
go s.loadModel()
|
||||
}
|
||||
|
||||
resp := llm.LoadResponse{Success: true, Memory: mem}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func Execute(args []string) error {
|
||||
fs := flag.NewFlagSet("runner", flag.ExitOnError)
|
||||
mpath := fs.String("model", "", "Path to model binary file")
|
||||
parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously")
|
||||
batchSize := fs.Int("batch-size", 512, "Batch size")
|
||||
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
||||
mainGPU := fs.Int("main-gpu", 0, "Main GPU")
|
||||
flashAttention := fs.Bool("flash-attn", false, "Enable flash attention")
|
||||
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
||||
port := fs.Int("port", 8080, "Port to expose the server on")
|
||||
threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
||||
_ = fs.Bool("verbose", false, "verbose output (default: disabled)")
|
||||
_ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
|
||||
tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
|
||||
multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
|
||||
|
||||
var lpaths multiLPath
|
||||
fs.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)")
|
||||
|
||||
fs.Usage = func() {
|
||||
fmt.Fprintf(fs.Output(), "Runner usage\n")
|
||||
|
|
@ -933,39 +1250,18 @@ func Execute(args []string) error {
|
|||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||
slog.Info("starting ollama engine")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
server := &Server{
|
||||
batchSize: *batchSize,
|
||||
status: llm.ServerStatusLoadingModel,
|
||||
modelPath: *mpath,
|
||||
status: llm.ServerStatusLaunched,
|
||||
hardErrCh: make(chan error, 1),
|
||||
}
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
server.ready.Add(1)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// TODO(jessegross): Parameters that need to be implemented:
|
||||
// no-mmap
|
||||
|
||||
var tensorSplitFloats []float32
|
||||
if *tensorSplit != "" {
|
||||
splits := strings.Split(*tensorSplit, ",")
|
||||
tensorSplitFloats = make([]float32, len(splits))
|
||||
for i, s := range splits {
|
||||
f, _ := strconv.ParseFloat(s, 32)
|
||||
tensorSplitFloats[i] = float32(f)
|
||||
}
|
||||
}
|
||||
|
||||
params := ml.BackendParams{
|
||||
NumThreads: *threads,
|
||||
NumGPULayers: *numGPULayers,
|
||||
MainGPU: *mainGPU,
|
||||
TensorSplit: tensorSplitFloats,
|
||||
FlashAttention: *flashAttention,
|
||||
}
|
||||
|
||||
go server.load(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
|
||||
go server.run(ctx)
|
||||
|
||||
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
||||
|
|
@ -978,10 +1274,8 @@ func Execute(args []string) error {
|
|||
|
||||
mux := http.NewServeMux()
|
||||
// TODO: support embeddings
|
||||
mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "this model does not support embeddings", http.StatusNotImplemented)
|
||||
})
|
||||
|
||||
mux.HandleFunc("POST /load", server.load)
|
||||
mux.HandleFunc("POST /embedding", server.embeddings)
|
||||
mux.HandleFunc("POST /completion", server.completion)
|
||||
mux.HandleFunc("GET /health", server.health)
|
||||
|
||||
|
|
|
|||
103
server/routes.go
103
server/routes.go
|
|
@ -32,6 +32,7 @@ import (
|
|||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/openai"
|
||||
|
|
@ -45,6 +46,18 @@ import (
|
|||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
func shouldUseHarmony(model *Model) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
// search for harmony tags that are nearly always used
|
||||
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func experimentEnabled(name string) bool {
|
||||
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
|
||||
}
|
||||
|
|
@ -176,7 +189,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||
}
|
||||
|
||||
// expire the runner
|
||||
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
||||
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||
s.sched.expireRunner(m)
|
||||
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
|
|
@ -194,12 +207,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
useHarmony := shouldUseHarmony(*m) && !req.Raw
|
||||
var harmonyMessageHandler *HarmonyMessageHandler
|
||||
var harmonyToolParser *HarmonyToolCallAccumulator
|
||||
useHarmony := shouldUseHarmony(m) && !req.Raw
|
||||
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
||||
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
||||
if useHarmony {
|
||||
harmonyMessageHandler = NewHarmonyMessageHandler()
|
||||
harmonyMessageHandler.harmonyParser.AddImplicitStart()
|
||||
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
||||
harmonyMessageHandler.HarmonyParser.AddImplicitStart()
|
||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||
}
|
||||
|
||||
|
|
@ -314,6 +327,19 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||
prompt = b.String()
|
||||
}
|
||||
|
||||
// If debug mode is enabled, return the rendered template instead of calling the model
|
||||
if req.DebugRenderOnly {
|
||||
c.JSON(http.StatusOK, api.DebugTemplateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
DebugInfo: api.DebugInfo{
|
||||
RenderedTemplate: prompt,
|
||||
ImageCount: len(images),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var thinkingState *thinking.Parser
|
||||
if !useHarmony {
|
||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||
|
|
@ -1477,14 +1503,14 @@ func (s *Server) PsHandler(c *gin.Context) {
|
|||
mr := api.ProcessModelResponse{
|
||||
Model: model.ShortName,
|
||||
Name: model.ShortName,
|
||||
Size: int64(v.estimatedTotal),
|
||||
SizeVRAM: int64(v.estimatedVRAM),
|
||||
Size: int64(v.totalSize),
|
||||
SizeVRAM: int64(v.vramSize),
|
||||
Digest: model.Digest,
|
||||
Details: modelDetails,
|
||||
ExpiresAt: v.expiresAt,
|
||||
}
|
||||
if v.Options != nil {
|
||||
mr.ContextLength = v.Options.NumCtx / v.numParallel
|
||||
mr.ContextLength = v.Options.NumCtx
|
||||
}
|
||||
// The scheduler waits to set expiresAt, so if a model is loading it's
|
||||
// possible that it will be set to the unix epoch. For those cases, just
|
||||
|
|
@ -1518,7 +1544,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
}
|
||||
|
||||
// expire the runner
|
||||
if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
||||
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||
model, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
switch {
|
||||
|
|
@ -1590,14 +1616,49 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
}
|
||||
msgs = filterThinkTags(msgs, m)
|
||||
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools, req.Think)
|
||||
var harmonyMessageHandler *harmony.HarmonyMessageHandler
|
||||
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
|
||||
|
||||
useHarmony := shouldUseHarmony(m)
|
||||
|
||||
processedTools := req.Tools
|
||||
if useHarmony {
|
||||
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
|
||||
var lastMessage *api.Message
|
||||
if len(msgs) > 0 {
|
||||
lastMessage = &msgs[len(msgs)-1]
|
||||
}
|
||||
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||
|
||||
// make a copy of tools to pass to the chat prompt. Function names may be
|
||||
// renamed to be valid Harmony function names.
|
||||
processedTools = make([]api.Tool, len(req.Tools))
|
||||
copy(processedTools, req.Tools)
|
||||
for i, tool := range processedTools {
|
||||
processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
|
||||
}
|
||||
}
|
||||
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
useHarmony := shouldUseHarmony(*m)
|
||||
// If debug mode is enabled, return the rendered template instead of calling the model
|
||||
if req.DebugRenderOnly {
|
||||
c.JSON(http.StatusOK, api.DebugTemplateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
DebugInfo: api.DebugInfo{
|
||||
RenderedTemplate: prompt,
|
||||
ImageCount: len(images),
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate Think value: string values currently only allowed for gptoss models
|
||||
if req.Think != nil && req.Think.IsString() && !useHarmony {
|
||||
|
|
@ -1605,19 +1666,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
var harmonyMessageHandler *HarmonyMessageHandler
|
||||
var harmonyToolParser *HarmonyToolCallAccumulator
|
||||
|
||||
if useHarmony {
|
||||
harmonyMessageHandler = NewHarmonyMessageHandler()
|
||||
var lastMessage *api.Message
|
||||
if len(msgs) > 0 {
|
||||
lastMessage = &msgs[len(msgs)-1]
|
||||
}
|
||||
harmonyMessageHandler.harmonyParser.AddImplicitStartOrPrefill(lastMessage)
|
||||
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
|
||||
}
|
||||
|
||||
var thinkingState *thinking.Parser
|
||||
openingTag, closingTag := thinking.InferTags(m.Template.Template)
|
||||
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
|
||||
|
|
@ -1625,6 +1673,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
OpeningTag: openingTag,
|
||||
ClosingTag: closingTag,
|
||||
}
|
||||
|
||||
if strings.HasSuffix(strings.TrimSpace(prompt), openingTag) {
|
||||
thinkingState.AddContent(openingTag)
|
||||
}
|
||||
}
|
||||
|
||||
var toolParser *tools.Parser
|
||||
|
|
@ -1670,6 +1722,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||
toolName, toolContent := harmonyToolParser.Drain()
|
||||
if toolName != nil {
|
||||
*toolName = strings.TrimPrefix(*toolName, "functions.")
|
||||
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
|
||||
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
|
||||
|
|
|
|||
|
|
@ -0,0 +1,413 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a test model
|
||||
stream := false
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-model",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Template: "{{ .Prompt }}",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
request api.GenerateRequest
|
||||
expectDebug bool
|
||||
expectTemplate string
|
||||
expectNumImages int
|
||||
}{
|
||||
{
|
||||
name: "debug render only enabled",
|
||||
request: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello, world!",
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "Hello, world!",
|
||||
},
|
||||
{
|
||||
name: "debug render only disabled",
|
||||
request: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello, world!",
|
||||
DebugRenderOnly: false,
|
||||
},
|
||||
expectDebug: false,
|
||||
},
|
||||
{
|
||||
name: "debug render only with system prompt",
|
||||
request: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "User question",
|
||||
System: "You are a helpful assistant",
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "User question",
|
||||
},
|
||||
{
|
||||
name: "debug render only with template",
|
||||
request: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Template: "PROMPT: {{ .Prompt }}",
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "PROMPT: Hello",
|
||||
},
|
||||
{
|
||||
name: "debug render only with images",
|
||||
request: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Describe this image",
|
||||
Images: []api.ImageData{[]byte("fake-image-data")},
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "[img-0]\n\nDescribe this image",
|
||||
expectNumImages: 1,
|
||||
},
|
||||
{
|
||||
name: "debug render only with raw mode",
|
||||
request: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Raw prompt text",
|
||||
Raw: true,
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "Raw prompt text",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
// Test both with and without streaming
|
||||
streamValues := []bool{false, true}
|
||||
for _, stream := range streamValues {
|
||||
streamSuffix := ""
|
||||
if stream {
|
||||
streamSuffix = " (streaming)"
|
||||
}
|
||||
t.Run(tt.name+streamSuffix, func(t *testing.T) {
|
||||
req := tt.request
|
||||
req.Stream = &stream
|
||||
w := createRequest(t, s.GenerateHandler, req)
|
||||
|
||||
if tt.expectDebug {
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response api.DebugTemplateResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if response.Model != tt.request.Model {
|
||||
t.Errorf("expected model %s, got %s", tt.request.Model, response.Model)
|
||||
}
|
||||
|
||||
if tt.expectTemplate != "" && response.DebugInfo.RenderedTemplate != tt.expectTemplate {
|
||||
t.Errorf("expected template %q, got %q", tt.expectTemplate, response.DebugInfo.RenderedTemplate)
|
||||
}
|
||||
|
||||
if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages {
|
||||
t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount)
|
||||
}
|
||||
} else {
|
||||
// When debug is disabled, it should attempt normal processing
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatDebugRenderOnly(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a test model
|
||||
stream := false
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-model",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Template: "{{ if .Tools }}{{ .Tools }}{{ end }}{{ range .Messages }}{{ .Role }}: {{ .Content }}\n{{ end }}",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
request api.ChatRequest
|
||||
expectDebug bool
|
||||
expectTemplate string
|
||||
expectNumImages int
|
||||
}{
|
||||
{
|
||||
name: "chat debug render only enabled",
|
||||
request: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant"},
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "system: You are a helpful assistant\nuser: Hello\n",
|
||||
},
|
||||
{
|
||||
name: "chat debug render only disabled",
|
||||
request: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
},
|
||||
DebugRenderOnly: false,
|
||||
},
|
||||
expectDebug: false,
|
||||
},
|
||||
{
|
||||
name: "chat debug with assistant message",
|
||||
request: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi there!"},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
},
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "user: Hello\nassistant: Hi there!\nuser: How are you?\n",
|
||||
},
|
||||
{
|
||||
name: "chat debug with images",
|
||||
request: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's in this image?",
|
||||
Images: []api.ImageData{[]byte("fake-image-data")},
|
||||
},
|
||||
},
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "user: [img-0]What's in this image?\n",
|
||||
expectNumImages: 1,
|
||||
},
|
||||
{
|
||||
name: "chat debug with tools",
|
||||
request: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Get the weather"},
|
||||
},
|
||||
Tools: api.Tools{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get weather information",
|
||||
},
|
||||
},
|
||||
},
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "[{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather information\",\"parameters\":{\"type\":\"\",\"required\":null,\"properties\":null}}}]user: Get the weather\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
// Test both with and without streaming
|
||||
streamValues := []bool{false, true}
|
||||
for _, stream := range streamValues {
|
||||
streamSuffix := ""
|
||||
if stream {
|
||||
streamSuffix = " (streaming)"
|
||||
}
|
||||
t.Run(tt.name+streamSuffix, func(t *testing.T) {
|
||||
req := tt.request
|
||||
req.Stream = &stream
|
||||
w := createRequest(t, s.ChatHandler, req)
|
||||
|
||||
if tt.expectDebug {
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var response api.DebugTemplateResponse
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if response.Model != tt.request.Model {
|
||||
t.Errorf("expected model %s, got %s", tt.request.Model, response.Model)
|
||||
}
|
||||
|
||||
if tt.expectTemplate != "" && response.DebugInfo.RenderedTemplate != tt.expectTemplate {
|
||||
t.Errorf("expected template %q, got %q", tt.expectTemplate, response.DebugInfo.RenderedTemplate)
|
||||
}
|
||||
|
||||
if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages {
|
||||
t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount)
|
||||
}
|
||||
} else {
|
||||
// When debug is disabled, it should attempt normal processing
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -77,12 +77,13 @@ func TestGenerateChat(t *testing.T) {
|
|||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -620,12 +621,13 @@ func TestGenerate(t *testing.T) {
|
|||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -967,3 +969,233 @@ func TestGenerate(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Helper to create a standard thinking test setup
|
||||
setupThinkingTest := func(t *testing.T) (*mockRunner, *Server) {
|
||||
mock := &mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(mock),
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{llama: mock}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a model with thinking support
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
"llama.context_length": uint32(8192),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
// Create model with thinking template that adds <think> at the end
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-thinking",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Template: `{{- range .Messages }}
|
||||
{{- if eq .Role "user" }}user: {{ .Content }}
|
||||
{{ else if eq .Role "assistant" }}assistant: {{ if .Thinking }}<think>{{ .Thinking }}</think>{{ end }}{{ .Content }}
|
||||
{{ end }}{{ end }}<think>`,
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
return mock, s
|
||||
}
|
||||
|
||||
mock, s := setupThinkingTest(t)
|
||||
|
||||
// Helper to test chat responses
|
||||
testChatRequest := func(t *testing.T, name string, userContent string, modelResponse string, expectedThinking string, expectedContent string, think bool) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
mock.CompletionResponse = llm.CompletionResponse{
|
||||
Content: modelResponse,
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
}
|
||||
mock.CompletionFn = nil
|
||||
|
||||
streamRequest := false
|
||||
req := api.ChatRequest{
|
||||
Model: "test-thinking",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: userContent},
|
||||
},
|
||||
Stream: &streamRequest,
|
||||
}
|
||||
if think {
|
||||
req.Think = &api.ThinkValue{Value: think}
|
||||
}
|
||||
|
||||
w := createRequest(t, s.ChatHandler, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.ChatResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Message.Thinking != expectedThinking {
|
||||
t.Errorf("expected thinking %q, got %q", expectedThinking, resp.Message.Thinking)
|
||||
}
|
||||
|
||||
if resp.Message.Content != expectedContent {
|
||||
t.Errorf("expected content %q, got %q", expectedContent, resp.Message.Content)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Test cases - Note: Template adds <think> at the end, and leading whitespace after <think> is eaten by the parser
|
||||
testChatRequest(t, "basic thinking response",
|
||||
"Help me solve this problem",
|
||||
" Let me think about this step by step... </think> The answer is 42.",
|
||||
"Let me think about this step by step... ",
|
||||
"The answer is 42.",
|
||||
true)
|
||||
|
||||
testChatRequest(t, "thinking with multiple sentences",
|
||||
"Explain quantum computing",
|
||||
" First, I need to understand the basics. Quantum bits can be in superposition. </think> Quantum computing uses quantum mechanics principles.",
|
||||
"First, I need to understand the basics. Quantum bits can be in superposition. ",
|
||||
"Quantum computing uses quantum mechanics principles.",
|
||||
true)
|
||||
|
||||
testChatRequest(t, "no thinking content",
|
||||
"What is 2+2?",
|
||||
"</think> The answer is 4.",
|
||||
"",
|
||||
"The answer is 4.",
|
||||
true)
|
||||
|
||||
testChatRequest(t, "thinking disabled but template still adds think tag",
|
||||
"Simple question",
|
||||
" My thoughts </think> The answer.",
|
||||
"",
|
||||
" My thoughts </think> The answer.",
|
||||
false)
|
||||
|
||||
// Test streaming response with template-added <think>
|
||||
t.Run("streaming with thinking", func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
defer wg.Done()
|
||||
|
||||
// Verify the prompt ends with <think> due to template
|
||||
if !strings.HasSuffix(r.Prompt, "<think>") {
|
||||
t.Errorf("expected prompt to end with <think>, got: %q", r.Prompt)
|
||||
}
|
||||
|
||||
// Simulate streaming chunks
|
||||
responses := []llm.CompletionResponse{
|
||||
{Content: " I need to consider", Done: false, PromptEvalCount: 1, PromptEvalDuration: 1},
|
||||
{Content: " multiple factors here...", Done: false, PromptEvalCount: 1, PromptEvalDuration: 1},
|
||||
{Content: " </think> Based on my analysis,", Done: false, PromptEvalCount: 1, PromptEvalDuration: 1},
|
||||
{Content: " the solution is straightforward.", Done: true, DoneReason: llm.DoneReasonStop, PromptEvalCount: 1, PromptEvalDuration: 1, EvalCount: 1, EvalDuration: 1},
|
||||
}
|
||||
|
||||
for _, resp := range responses {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
fn(resp)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
think := true
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-thinking",
|
||||
Messages: []api.Message{{Role: "user", Content: "Analyze this complex problem"}},
|
||||
Think: &api.ThinkValue{Value: think},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Parse streaming responses
|
||||
decoder := json.NewDecoder(w.Body)
|
||||
var allThinking, allContent strings.Builder
|
||||
|
||||
for {
|
||||
var resp api.ChatResponse
|
||||
if err := decoder.Decode(&resp); err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
allThinking.WriteString(resp.Message.Thinking)
|
||||
allContent.WriteString(resp.Message.Content)
|
||||
}
|
||||
|
||||
// Note: Leading whitespace after <think> is eaten by the parser
|
||||
if got := allThinking.String(); got != "I need to consider multiple factors here... " {
|
||||
t.Errorf("expected thinking %q, got %q", "I need to consider multiple factors here... ", got)
|
||||
}
|
||||
|
||||
if got := allContent.String(); got != "Based on my analysis, the solution is straightforward." {
|
||||
t.Errorf("expected content %q, got %q", "Based on my analysis, the solution is straightforward.", got)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -277,10 +277,11 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
|||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 100 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -427,10 +428,11 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
|||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 100 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -608,10 +610,11 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
|
|||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) {
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
|
|||
369
server/sched.go
369
server/sched.go
|
|
@ -28,7 +28,6 @@ type LlmRequest struct {
|
|||
ctx context.Context //nolint:containedctx
|
||||
model *Model
|
||||
opts api.Options
|
||||
origNumCtx int // Track the initial ctx request
|
||||
sessionDuration *api.Duration
|
||||
successCh chan *runnerRef
|
||||
errCh chan error
|
||||
|
|
@ -41,10 +40,17 @@ type Scheduler struct {
|
|||
expiredCh chan *runnerRef
|
||||
unloadedCh chan any
|
||||
|
||||
loaded map[string]*runnerRef
|
||||
// loadedMu protects loaded and activeLoading
|
||||
loadedMu sync.Mutex
|
||||
|
||||
loadFn func(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, numParallel int)
|
||||
// activeLoading is the model that we are currently working on loading,
|
||||
// including by evicting one or more other models. We can only load
|
||||
// one model at a time but new requests to models that already loaded can
|
||||
// happen in parallel
|
||||
activeLoading llm.LlamaServer
|
||||
loaded map[string]*runnerRef
|
||||
|
||||
loadFn func(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool
|
||||
newServerFn func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
|
||||
getGpuFn func() discover.GpuInfoList
|
||||
getCpuFn func() discover.GpuInfoList
|
||||
|
|
@ -56,9 +62,6 @@ type Scheduler struct {
|
|||
// on a large GPU can cause stalling
|
||||
var defaultModelsPerGPU = 3
|
||||
|
||||
// Default automatic value for parallel setting
|
||||
var defaultParallel = 1
|
||||
|
||||
var ErrMaxQueue = errors.New("server busy, please try again. maximum pending requests exceeded")
|
||||
|
||||
func InitScheduler(ctx context.Context) *Scheduler {
|
||||
|
|
@ -79,24 +82,36 @@ func InitScheduler(ctx context.Context) *Scheduler {
|
|||
}
|
||||
|
||||
// context must be canceled to decrement ref count and release the runner
|
||||
func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
|
||||
func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) {
|
||||
if opts.NumCtx < 4 {
|
||||
opts.NumCtx = 4
|
||||
}
|
||||
|
||||
if m.CheckCapabilities(model.CapabilityVision) == nil {
|
||||
// multimodal models require at least 2048 context
|
||||
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||
}
|
||||
|
||||
req := &LlmRequest{
|
||||
ctx: c,
|
||||
model: model,
|
||||
model: m,
|
||||
opts: opts,
|
||||
sessionDuration: sessionDuration,
|
||||
successCh: make(chan *runnerRef),
|
||||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
}
|
||||
|
||||
select {
|
||||
case s.pendingReqCh <- req:
|
||||
default:
|
||||
req.errCh <- ErrMaxQueue
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[req.model.ModelPath]
|
||||
s.loadedMu.Unlock()
|
||||
if runner != nil && !runner.needsReload(c, req) {
|
||||
req.useLoadedRunner(runner, s.finishedReqCh)
|
||||
} else {
|
||||
select {
|
||||
case s.pendingReqCh <- req:
|
||||
default:
|
||||
req.errCh <- ErrMaxQueue
|
||||
}
|
||||
}
|
||||
return req.successCh, req.errCh
|
||||
}
|
||||
|
|
@ -122,21 +137,11 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||
case pending := <-s.pendingReqCh:
|
||||
// Block other requests until we get this pending request running
|
||||
pending.schedAttempts++
|
||||
if pending.origNumCtx == 0 {
|
||||
pending.origNumCtx = pending.opts.NumCtx
|
||||
}
|
||||
|
||||
if pending.ctx.Err() != nil {
|
||||
slog.Debug("pending request cancelled or timed out, skipping scheduling")
|
||||
continue
|
||||
}
|
||||
numParallel := int(envconfig.NumParallel())
|
||||
// `mllama` is a snowflake and uses an encoder cache which cannot be used with num_parallel > 1
|
||||
// ref: https://github.com/ollama/ollama/issues/4165
|
||||
if slices.Contains(pending.model.Config.ModelFamilies, "mllama") && numParallel != 1 {
|
||||
numParallel = 1
|
||||
slog.Warn("mllama does not currently support parallel requests")
|
||||
}
|
||||
|
||||
for {
|
||||
var runnerToExpire *runnerRef
|
||||
|
|
@ -195,84 +200,26 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||
break
|
||||
}
|
||||
|
||||
// Embedding models should always be loaded with parallel=1
|
||||
if pending.model.CheckCapabilities(model.CapabilityCompletion) != nil {
|
||||
numParallel = 1
|
||||
}
|
||||
// Update free memory from currently loaded models
|
||||
s.updateFreeSpace(gpus)
|
||||
|
||||
// Evaluate if the model will fit in the available system memory, or if we should unload a model first
|
||||
if len(gpus) == 1 && gpus[0].Library == "cpu" {
|
||||
// simplifying assumption of defaultParallel when in CPU mode
|
||||
if numParallel <= 0 {
|
||||
numParallel = defaultParallel
|
||||
}
|
||||
|
||||
pending.opts.NumCtx = pending.origNumCtx * numParallel
|
||||
|
||||
if loadedCount == 0 {
|
||||
slog.Debug("cpu mode with first model, loading")
|
||||
s.loadFn(pending, ggml, gpus, numParallel)
|
||||
break
|
||||
}
|
||||
runnerToExpire = s.maybeFindCPURunnerToUnload(pending, ggml, gpus)
|
||||
if runnerToExpire == nil {
|
||||
slog.Debug("cpu mode with available system memory or first model, loading")
|
||||
s.loadFn(pending, ggml, gpus, numParallel)
|
||||
break
|
||||
}
|
||||
// else we need to expire a runner
|
||||
} else if loadedCount == 0 {
|
||||
if loadedCount == 0 {
|
||||
// No models loaded. Load the model but prefer the best fit.
|
||||
slog.Debug("loading first model", "model", pending.model.ModelPath)
|
||||
g := pickBestFullFitByLibrary(pending, ggml, gpus, &numParallel)
|
||||
if g != nil {
|
||||
gpus = g
|
||||
} else {
|
||||
// Only allow partial loads when this is the first model
|
||||
gpus = pickBestPartialFitByLibrary(pending, ggml, gpus, &numParallel)
|
||||
}
|
||||
s.loadFn(pending, ggml, gpus, numParallel)
|
||||
s.loadFn(pending, ggml, gpus, false)
|
||||
break
|
||||
}
|
||||
|
||||
if runnerToExpire == nil {
|
||||
// More than one loaded model, so we have to see if the
|
||||
// new one fits
|
||||
//
|
||||
// We want to avoid loading on any GPUs that have other
|
||||
// models still loading on them to avoid potential races
|
||||
// with VRAM consumption ramping up during load
|
||||
availGpus := s.filterGPUsWithoutLoadingModels(gpus)
|
||||
// More than one loaded model, so we have to see if the
|
||||
// new one fits
|
||||
|
||||
// Update free memory from currently loaded models
|
||||
s.updateFreeSpace(availGpus)
|
||||
fitGpus := pickBestFullFitByLibrary(pending, ggml, availGpus, &numParallel)
|
||||
if fitGpus != nil {
|
||||
slog.Debug("new model fits with existing models, loading")
|
||||
s.loadFn(pending, ggml, fitGpus, numParallel)
|
||||
break
|
||||
}
|
||||
|
||||
// We couldn't find a set of GPUs to fully load the new
|
||||
// model. If no other models are loading (both GPU lists
|
||||
// are the same) then we need to unload another model to
|
||||
// make room
|
||||
if len(availGpus) < len(gpus) {
|
||||
// There are other requests pending, and this one
|
||||
// needs more time, so put it on the back of the
|
||||
// queue so that we might satisfy other pending
|
||||
// requests that aren't blocked
|
||||
go func() {
|
||||
// Process in a go routine to avoid deadlocking
|
||||
// the scheduler if our queue is full
|
||||
slog.Debug("delaying scheduling while other models finish loading", "attempts", pending.schedAttempts, "model", pending.model.ModelPath)
|
||||
time.Sleep(s.reschedDelay)
|
||||
s.pendingReqCh <- pending
|
||||
}()
|
||||
break
|
||||
}
|
||||
runnerToExpire = s.findRunnerToUnload()
|
||||
needEvict := s.loadFn(pending, ggml, gpus, true)
|
||||
if !needEvict {
|
||||
slog.Debug("new model fits with existing models, loading")
|
||||
break
|
||||
}
|
||||
|
||||
runnerToExpire = s.findRunnerToUnload()
|
||||
}
|
||||
|
||||
if runnerToExpire == nil {
|
||||
|
|
@ -293,8 +240,6 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
|||
}
|
||||
runnerToExpire.refMu.Unlock()
|
||||
// Wait for the unload to happen
|
||||
// Note: at this point we're queueing up all incoming requests, even if they were for
|
||||
// a different model that's loaded and not scheduled to be removed.
|
||||
slog.Debug("waiting for pending requests to complete and unload to occur", "runner", runnerToExpire)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
|
@ -434,26 +379,72 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
|
|||
}()
|
||||
}
|
||||
|
||||
func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, numParallel int) {
|
||||
// load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs
|
||||
// (if any). Returns whether the scheduler needs to evict a model to make this one fit.
|
||||
func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool {
|
||||
numParallel := int(envconfig.NumParallel())
|
||||
if numParallel < 1 {
|
||||
numParallel = 1
|
||||
}
|
||||
|
||||
// Embedding models should always be loaded with parallel=1
|
||||
if req.model.CheckCapabilities(model.CapabilityCompletion) != nil {
|
||||
numParallel = 1
|
||||
}
|
||||
|
||||
// `mllama` is a snowflake and uses an encoder cache which cannot be used with num_parallel > 1
|
||||
// ref: https://github.com/ollama/ollama/issues/4165
|
||||
if slices.Contains(req.model.Config.ModelFamilies, "mllama") && numParallel != 1 {
|
||||
numParallel = 1
|
||||
slog.Warn("mllama does not currently support parallel requests")
|
||||
}
|
||||
|
||||
sessionDuration := envconfig.KeepAlive()
|
||||
if req.sessionDuration != nil {
|
||||
sessionDuration = req.sessionDuration.Duration
|
||||
}
|
||||
llama, err := s.newServerFn(gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
|
||||
if err != nil {
|
||||
// some older models are not compatible with newer versions of llama.cpp
|
||||
// show a generalized compatibility error until there is a better way to
|
||||
// check for model compatibility
|
||||
if errors.Is(err, ggml.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") {
|
||||
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
|
||||
|
||||
s.loadedMu.Lock()
|
||||
llama := s.activeLoading
|
||||
|
||||
if llama == nil {
|
||||
var err error
|
||||
llama, err = s.newServerFn(gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
|
||||
if err != nil {
|
||||
// some older models are not compatible with newer versions of llama.cpp
|
||||
// show a generalized compatibility error until there is a better way to
|
||||
// check for model compatibility
|
||||
if errors.Is(err, ggml.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") {
|
||||
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName)
|
||||
}
|
||||
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
|
||||
req.errCh <- err
|
||||
s.loadedMu.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
s.activeLoading = llama
|
||||
} else {
|
||||
if s.activeLoading.ModelPath() != req.model.ModelPath {
|
||||
panic(fmt.Errorf("attempting to load different model after eviction (original %v new %v)", s.activeLoading.ModelPath(), req.model.ModelPath))
|
||||
}
|
||||
slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err)
|
||||
req.errCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
err := llama.Load(req.ctx, gpus, requireFull)
|
||||
if err != nil {
|
||||
if errors.Is(err, llm.ErrLoadRequiredFull) {
|
||||
return true
|
||||
}
|
||||
|
||||
slog.Info("Load failed", "model", req.model.ModelPath, "error", err)
|
||||
s.activeLoading.Close()
|
||||
s.activeLoading = nil
|
||||
req.errCh <- err
|
||||
return false
|
||||
}
|
||||
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
|
|
@ -461,8 +452,8 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
|
|||
Options: &req.opts,
|
||||
sessionDuration: sessionDuration,
|
||||
gpus: gpus,
|
||||
estimatedVRAM: llama.EstimatedVRAM(),
|
||||
estimatedTotal: llama.EstimatedTotal(),
|
||||
vramSize: llama.VRAMSize(),
|
||||
totalSize: llama.TotalSize(),
|
||||
loading: true,
|
||||
pid: llama.Pid(),
|
||||
}
|
||||
|
|
@ -477,6 +468,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
|
|||
oldRunner.unload()
|
||||
oldRunner.refMu.Unlock()
|
||||
}
|
||||
s.activeLoading = nil
|
||||
s.loaded[req.model.ModelPath] = runner
|
||||
slog.Info("loaded runners", "count", len(s.loaded))
|
||||
s.loadedMu.Unlock()
|
||||
|
|
@ -503,6 +495,8 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
|
|||
}()
|
||||
req.successCh <- runner
|
||||
}()
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
|
||||
|
|
@ -521,7 +515,7 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
|
|||
r.refMu.Lock()
|
||||
if r.llama != nil {
|
||||
for _, gpu := range allGpus {
|
||||
predMap[predKey{gpu.Library, gpu.ID}] += r.llama.EstimatedVRAMByGPU(gpu.ID)
|
||||
predMap[predKey{gpu.Library, gpu.ID}] += r.llama.VRAMByGPU(gpu.ID)
|
||||
}
|
||||
} else {
|
||||
slog.Warn("unexpected nil runner reference, memory prediction may be incorrect")
|
||||
|
|
@ -548,41 +542,17 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
|
|||
}
|
||||
}
|
||||
|
||||
// While models are loading the VRAM consumption numbers will be indeterminate, so we have
|
||||
// to avoid scheduling another model on the same GPU(s) that haven't stabilized.
|
||||
// This routine returns the set of GPUs that do not have an active loading model.
|
||||
// If all GPUs have loading models, an empty list will be returned (not a single CPU entry)
|
||||
func (s *Scheduler) filterGPUsWithoutLoadingModels(allGpus discover.GpuInfoList) discover.GpuInfoList {
|
||||
ret := append(discover.GpuInfoList{}, allGpus...)
|
||||
s.loadedMu.Lock()
|
||||
defer s.loadedMu.Unlock()
|
||||
for _, runner := range s.loaded {
|
||||
if runner.loading {
|
||||
slog.Debug("overlapping loads detected", "gpus", runner.gpus, "model", runner.modelPath)
|
||||
for _, busyGPU := range runner.gpus {
|
||||
for i := range ret {
|
||||
if ret[i].ID == busyGPU.ID {
|
||||
ret = append(ret[:i], ret[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// TODO consolidate sched_types.go
|
||||
type runnerRef struct {
|
||||
refMu sync.Mutex
|
||||
refCount uint // prevent unloading if > 0
|
||||
|
||||
llama llm.LlamaServer
|
||||
pid int
|
||||
loading bool // True only during initial load, then false forever
|
||||
gpus discover.GpuInfoList // Recorded at time of provisioning
|
||||
estimatedVRAM uint64
|
||||
estimatedTotal uint64
|
||||
llama llm.LlamaServer
|
||||
pid int
|
||||
loading bool // True only during initial load, then false forever
|
||||
gpus discover.GpuInfoList // Recorded at time of provisioning
|
||||
vramSize uint64
|
||||
totalSize uint64
|
||||
|
||||
sessionDuration time.Duration
|
||||
expireTimer *time.Timer
|
||||
|
|
@ -631,9 +601,6 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
|||
optsNew.NumGPU = -1
|
||||
}
|
||||
|
||||
// Normalize the NumCtx for parallelism
|
||||
optsExisting.NumCtx = optsExisting.NumCtx / runner.numParallel
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed?
|
||||
|
|
@ -694,7 +661,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any {
|
|||
freeMemoryNow += gpu.FreeMemory
|
||||
}
|
||||
// If we're within ~80% of the estimated memory usage recovered, bail out
|
||||
if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.estimatedVRAM)*0.8 {
|
||||
if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.vramSize)*0.8 {
|
||||
slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "runner", runner)
|
||||
finished <- struct{}{}
|
||||
return
|
||||
|
|
@ -719,8 +686,8 @@ func (runner *runnerRef) LogValue() slog.Value {
|
|||
)
|
||||
}
|
||||
attrs = append(attrs,
|
||||
slog.String("size", format.HumanBytes2(runner.estimatedTotal)),
|
||||
slog.String("vram", format.HumanBytes2(runner.estimatedVRAM)),
|
||||
slog.String("size", format.HumanBytes2(runner.totalSize)),
|
||||
slog.String("vram", format.HumanBytes2(runner.vramSize)),
|
||||
slog.Int("parallel", runner.numParallel),
|
||||
slog.Int("pid", runner.pid),
|
||||
slog.String("model", runner.modelPath),
|
||||
|
|
@ -750,95 +717,7 @@ func (a ByDurationAndName) Less(i, j int) bool {
|
|||
// type BySize []*runnerRef
|
||||
// func (a BySize) Len() int { return len(a) }
|
||||
// func (a BySize) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
// func (a BySize) Less(i, j int) bool { return a[i].estimatedVRAM < a[j].estimatedVRAM }
|
||||
|
||||
// pickBestFullFitByLibrary will try to find the optimal placement of the model in the available GPUs where the model fully fits
|
||||
// The list of GPUs returned will always be the same brand (library)
|
||||
// If the model can not be fit fully within the available GPU(s) nil is returned
|
||||
// If numParallel is <= 0, this will attempt try to optimize parallelism based on available VRAM, and adjust
|
||||
// opts.NumCtx accordingly
|
||||
func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, numParallel *int) discover.GpuInfoList {
|
||||
var numParallelToTry []int
|
||||
if *numParallel <= 0 {
|
||||
// If no specific parallel setting was provided, try larger then smaller, always end with 1
|
||||
numParallelToTry = append(numParallelToTry, defaultParallel, 1)
|
||||
} else {
|
||||
numParallelToTry = []int{*numParallel}
|
||||
}
|
||||
|
||||
for _, gl := range gpus.ByLibrary() {
|
||||
sgl := append(make(discover.GpuInfoList, 0, len(gl)), gl...)
|
||||
|
||||
// TODO - potentially sort by performance capability, existing models loaded, etc.
|
||||
// TODO - Eliminate any GPUs that already have envconfig.MaxRunners loaded on them
|
||||
// Note: at present, this will favor most current available VRAM descending and ignoring faster GPU speed in mixed setups
|
||||
sort.Sort(sort.Reverse(discover.ByFreeMemory(sgl)))
|
||||
|
||||
if !envconfig.SchedSpread() {
|
||||
for _, p := range numParallelToTry {
|
||||
req.opts.NumCtx = req.origNumCtx * p
|
||||
// Try to pack into as few GPUs as possible, starting from 1 GPU
|
||||
for numGPUs := 1; numGPUs <= len(sgl); numGPUs++ {
|
||||
gpuSubset := sgl[:numGPUs]
|
||||
ok, estimatedVRAM := llm.PredictServerFit(gpuSubset, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p)
|
||||
|
||||
if ok {
|
||||
slog.Info("new model will fit in available VRAM across minimum required GPUs, loading",
|
||||
"model", req.model.ModelPath,
|
||||
"library", sgl[0].Library,
|
||||
"parallel", p,
|
||||
"required", format.HumanBytes2(estimatedVRAM),
|
||||
"gpus", numGPUs)
|
||||
*numParallel = p
|
||||
return gpuSubset
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// TODO future refinements
|
||||
// - if multiple Libraries, see if any single GPU in any Library will fit
|
||||
// - try subsets of GPUs instead of just falling back to 1 or all in a family
|
||||
|
||||
// Now try all the GPUS (OLLAMA_SCHED_SPREAD is set)
|
||||
for _, p := range numParallelToTry {
|
||||
req.opts.NumCtx = req.origNumCtx * p
|
||||
if ok, estimatedVRAM := llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok {
|
||||
slog.Info("new model will fit in available VRAM, loading",
|
||||
"model", req.model.ModelPath,
|
||||
"library", sgl[0].Library,
|
||||
"parallel", p,
|
||||
"required", format.HumanBytes2(estimatedVRAM),
|
||||
"gpus", len(sgl))
|
||||
*numParallel = p
|
||||
return sgl
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If multiple Libraries are detected, pick the Library which loads the most layers for the model
|
||||
func pickBestPartialFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, numParallel *int) discover.GpuInfoList {
|
||||
if *numParallel <= 0 {
|
||||
*numParallel = 1
|
||||
req.opts.NumCtx = req.origNumCtx
|
||||
}
|
||||
byLibrary := gpus.ByLibrary()
|
||||
if len(byLibrary) <= 1 {
|
||||
return gpus
|
||||
}
|
||||
var bestEstimate uint64
|
||||
var bestFit int
|
||||
for i, gl := range byLibrary {
|
||||
_, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, *numParallel)
|
||||
if estimatedVRAM > bestEstimate {
|
||||
bestEstimate = estimatedVRAM
|
||||
bestFit = i
|
||||
}
|
||||
}
|
||||
return byLibrary[bestFit]
|
||||
}
|
||||
// func (a BySize) Less(i, j int) bool { return a[i].vramSize < a[j].vramSize }
|
||||
|
||||
// findRunnerToUnload finds a runner to unload to make room for a new model
|
||||
func (s *Scheduler) findRunnerToUnload() *runnerRef {
|
||||
|
|
@ -875,6 +754,13 @@ func (s *Scheduler) findRunnerToUnload() *runnerRef {
|
|||
func (s *Scheduler) unloadAllRunners() {
|
||||
s.loadedMu.Lock()
|
||||
defer s.loadedMu.Unlock()
|
||||
|
||||
if s.activeLoading != nil {
|
||||
slog.Debug("shutting down currently loading runner")
|
||||
s.activeLoading.Close()
|
||||
s.activeLoading = nil
|
||||
}
|
||||
|
||||
for model, runner := range s.loaded {
|
||||
if runner.llama != nil {
|
||||
slog.Debug("shutting down runner", "model", model)
|
||||
|
|
@ -901,18 +787,3 @@ func (s *Scheduler) expireRunner(model *Model) {
|
|||
runner.refMu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// If other runners are loaded, make sure the pending request will fit in system memory
|
||||
// If not, pick a runner to unload, else return nil and the request can be loaded
|
||||
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList) *runnerRef {
|
||||
slog.Debug("evaluating if CPU model load will fit in available system memory")
|
||||
estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts, req.opts.NumCtx/req.origNumCtx)
|
||||
if estimate.TotalSize <= gpus[0].FreeMemory {
|
||||
slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory))
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO - optimization: try to find CPU only runners first, or partial offloads with enough in system memory to make room
|
||||
|
||||
return s.findRunnerToUnload()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ func TestLoad(t *testing.T) {
|
|||
return nil, errors.New("something failed to load model blah")
|
||||
}
|
||||
gpus := discover.GpuInfoList{}
|
||||
s.load(req, f, gpus, 0)
|
||||
s.load(req, f, gpus, false)
|
||||
require.Empty(t, req.successCh)
|
||||
require.Len(t, req.errCh, 1)
|
||||
s.loadedMu.Lock()
|
||||
|
|
@ -61,16 +61,17 @@ func TestLoad(t *testing.T) {
|
|||
err := <-req.errCh
|
||||
require.Contains(t, err.Error(), "this model may be incompatible")
|
||||
|
||||
server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}}
|
||||
server := &mockLlm{vramSize: 10, vramByGPU: map[string]uint64{}}
|
||||
s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
server.modelPath = model
|
||||
return server, nil
|
||||
}
|
||||
s.load(req, f, gpus, 0)
|
||||
s.load(req, f, gpus, false)
|
||||
select {
|
||||
case err := <-req.errCh:
|
||||
require.NoError(t, err)
|
||||
case resp := <-req.successCh:
|
||||
require.Equal(t, uint64(10), resp.estimatedVRAM)
|
||||
require.Equal(t, uint64(10), resp.vramSize)
|
||||
require.Equal(t, uint(1), resp.refCount)
|
||||
s.loadedMu.Lock()
|
||||
require.Len(t, s.loaded, 1)
|
||||
|
|
@ -79,7 +80,7 @@ func TestLoad(t *testing.T) {
|
|||
|
||||
req.model.ModelPath = "dummy_model_path"
|
||||
server.waitResp = errors.New("wait failure")
|
||||
s.load(req, f, gpus, 0)
|
||||
s.load(req, f, gpus, false)
|
||||
select {
|
||||
case err := <-req.errCh:
|
||||
require.Contains(t, err.Error(), "wait failure")
|
||||
|
|
@ -104,10 +105,11 @@ type reqBundle struct {
|
|||
}
|
||||
|
||||
func (scenario *reqBundle) newServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
scenario.srv.modelPath = model
|
||||
return scenario.srv, nil
|
||||
}
|
||||
|
||||
func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64, duration *api.Duration) *reqBundle {
|
||||
func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vramSize uint64, duration *api.Duration) *reqBundle {
|
||||
b := &reqBundle{}
|
||||
b.ctx, b.ctxDone = context.WithCancel(ctx)
|
||||
t.Helper()
|
||||
|
|
@ -144,7 +146,7 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
|
|||
successCh: make(chan *runnerRef, 1),
|
||||
errCh: make(chan error, 1),
|
||||
}
|
||||
b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
|
||||
b.srv = &mockLlm{vramSize: vramSize, vramByGPU: map[string]uint64{"": vramSize}}
|
||||
return b
|
||||
}
|
||||
|
||||
|
|
@ -262,10 +264,10 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
|
|||
|
||||
// Multiple loaded models
|
||||
a := newScenarioRequest(t, ctx, "ollama-model-3a", 1*format.GigaByte, nil)
|
||||
b := newScenarioRequest(t, ctx, "ollama-model-3b", 24*format.GigaByte, nil)
|
||||
c := newScenarioRequest(t, ctx, "ollama-model-4a", 30, nil)
|
||||
c.req.opts.NumGPU = 0 // CPU load, will be allowed
|
||||
d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded
|
||||
b := newScenarioRequest(t, ctx, "ollama-model-3b", 10*format.GigaByte, nil)
|
||||
c := newScenarioRequest(t, ctx, "ollama-model-4a", 10*format.GigaByte, nil)
|
||||
c.req.opts.NumGPU = 0 // CPU load, will be allowed
|
||||
d := newScenarioRequest(t, ctx, "ollama-model-3c", 10*format.GigaByte, nil) // Needs prior unloaded
|
||||
|
||||
t.Setenv("OLLAMA_MAX_LOADED_MODELS", "1")
|
||||
s.newServerFn = a.newServer
|
||||
|
|
@ -418,11 +420,12 @@ func TestExpireRunner(t *testing.T) {
|
|||
|
||||
var f *ggml.GGML
|
||||
gpus := discover.GpuInfoList{}
|
||||
server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}}
|
||||
server := &mockLlm{vramSize: 10, vramByGPU: map[string]uint64{}}
|
||||
s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
server.modelPath = model
|
||||
return server, nil
|
||||
}
|
||||
s.load(req, f, gpus, 0)
|
||||
s.load(req, f, gpus, false)
|
||||
|
||||
select {
|
||||
case err := <-req.errCh:
|
||||
|
|
@ -506,7 +509,7 @@ func TestUseLoadedRunner(t *testing.T) {
|
|||
sessionDuration: &api.Duration{Duration: 2},
|
||||
}
|
||||
finished := make(chan *LlmRequest)
|
||||
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
|
||||
llm1 := &mockLlm{vramByGPU: map[string]uint64{}}
|
||||
r1 := &runnerRef{llama: llm1, sessionDuration: 1, numParallel: 1}
|
||||
req.useLoadedRunner(r1, finished)
|
||||
require.Equal(t, uint(1), r1.refCount)
|
||||
|
|
@ -541,8 +544,8 @@ func TestUpdateFreeSpace(t *testing.T) {
|
|||
gpus[0].FreeMemory = 900
|
||||
gpus[1].TotalMemory = 2000
|
||||
gpus[1].FreeMemory = 1900
|
||||
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 50, "2": 50}}
|
||||
llm2 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 125, "2": 75}}
|
||||
llm1 := &mockLlm{vramByGPU: map[string]uint64{"1": 50, "2": 50}}
|
||||
llm2 := &mockLlm{vramByGPU: map[string]uint64{"1": 125, "2": 75}}
|
||||
r1 := &runnerRef{llama: llm1, gpus: gpus, numParallel: 1}
|
||||
r2 := &runnerRef{llama: llm2, gpus: gpus, numParallel: 1}
|
||||
|
||||
|
|
@ -557,40 +560,6 @@ func TestUpdateFreeSpace(t *testing.T) {
|
|||
require.Equal(t, uint64(2000-50-75), gpus[1].FreeMemory)
|
||||
}
|
||||
|
||||
func TestFilterGPUsWithoutLoadingModels(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
gpus := discover.GpuInfoList{
|
||||
{
|
||||
Library: "cuda",
|
||||
ID: "0",
|
||||
},
|
||||
{
|
||||
Library: "cuda",
|
||||
ID: "1",
|
||||
},
|
||||
}
|
||||
r1 := &runnerRef{gpus: discover.GpuInfoList{gpus[0]}, loading: true}
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["a"] = r1
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
tmp := s.filterGPUsWithoutLoadingModels(gpus)
|
||||
require.Len(t, tmp, 1)
|
||||
require.Equal(t, "1", tmp[0].ID)
|
||||
|
||||
r1.gpus = discover.GpuInfoList{gpus[1]}
|
||||
tmp = s.filterGPUsWithoutLoadingModels(gpus)
|
||||
require.Len(t, tmp, 1)
|
||||
require.Equal(t, "0", tmp[0].ID)
|
||||
|
||||
r1.gpus = discover.GpuInfoList{}
|
||||
tmp = s.filterGPUsWithoutLoadingModels(gpus)
|
||||
require.Len(t, tmp, 2)
|
||||
}
|
||||
|
||||
func TestFindRunnerToUnload(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
|
@ -615,7 +584,7 @@ func TestNeedsReload(t *testing.T) {
|
|||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
llm := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
|
||||
llm := &mockLlm{vramByGPU: map[string]uint64{}}
|
||||
do := api.DefaultOptions()
|
||||
runner := &runnerRef{
|
||||
model: &Model{
|
||||
|
|
@ -662,8 +631,8 @@ func TestUnloadAllRunners(t *testing.T) {
|
|||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
|
||||
llm2 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
|
||||
llm1 := &mockLlm{vramByGPU: map[string]uint64{}}
|
||||
llm2 := &mockLlm{vramByGPU: map[string]uint64{}}
|
||||
s := InitScheduler(ctx)
|
||||
s.unloadAllRunners()
|
||||
|
||||
|
|
@ -681,7 +650,7 @@ func TestUnloadAllRunners(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestUnload(t *testing.T) {
|
||||
llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}}
|
||||
llm1 := &mockLlm{vramByGPU: map[string]uint64{}}
|
||||
r1 := &runnerRef{llama: llm1, numParallel: 1}
|
||||
r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}, numParallel: 1}
|
||||
r1.unload()
|
||||
|
|
@ -748,24 +717,40 @@ func TestHomogeneousGPUs(t *testing.T) {
|
|||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
|
||||
type mockLlm struct {
|
||||
pingResp error
|
||||
waitResp error
|
||||
completionResp error
|
||||
embeddingResp []float32
|
||||
embeddingRespErr error
|
||||
tokenizeResp []int
|
||||
tokenizeRespErr error
|
||||
detokenizeResp string
|
||||
detonekizeRespErr error
|
||||
closeResp error
|
||||
closeCalled bool
|
||||
estimatedVRAM uint64
|
||||
estimatedTotal uint64
|
||||
estimatedVRAMByGPU map[string]uint64
|
||||
modelPath string
|
||||
pingResp error
|
||||
waitResp error
|
||||
completionResp error
|
||||
embeddingResp []float32
|
||||
embeddingRespErr error
|
||||
tokenizeResp []int
|
||||
tokenizeRespErr error
|
||||
detokenizeResp string
|
||||
detonekizeRespErr error
|
||||
closeResp error
|
||||
closeCalled bool
|
||||
vramSize uint64
|
||||
totalSize uint64
|
||||
vramByGPU map[string]uint64
|
||||
}
|
||||
|
||||
func (s *mockLlm) ModelPath() string {
|
||||
return s.modelPath
|
||||
}
|
||||
|
||||
func (s *mockLlm) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) error {
|
||||
if requireFull {
|
||||
for _, g := range gpus {
|
||||
if g.FreeMemory >= s.vramSize {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return llm.ErrLoadRequiredFull
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (s *mockLlm) Ping(ctx context.Context) error { return s.pingResp }
|
||||
func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitResp }
|
||||
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||
|
|
@ -788,7 +773,7 @@ func (s *mockLlm) Close() error {
|
|||
s.closeCalled = true
|
||||
return s.closeResp
|
||||
}
|
||||
func (s *mockLlm) EstimatedVRAM() uint64 { return s.estimatedVRAM }
|
||||
func (s *mockLlm) EstimatedTotal() uint64 { return s.estimatedTotal }
|
||||
func (s *mockLlm) EstimatedVRAMByGPU(gpuid string) uint64 { return s.estimatedVRAMByGPU[gpuid] }
|
||||
func (s *mockLlm) Pid() int { return -1 }
|
||||
func (s *mockLlm) VRAMSize() uint64 { return s.vramSize }
|
||||
func (s *mockLlm) TotalSize() uint64 { return s.totalSize }
|
||||
func (s *mockLlm) VRAMByGPU(gpuid string) uint64 { return s.vramByGPU[gpuid] }
|
||||
func (s *mockLlm) Pid() int { return -1 }
|
||||
|
|
|
|||
|
|
@ -103,7 +103,9 @@ func eat(s *Parser) (string, string, bool) {
|
|||
// note that we use the original content, not the trimmed one because we
|
||||
// don't want to eat any whitespace in the real content if there were no
|
||||
// thinking tags
|
||||
return "", s.acc.String(), false
|
||||
untrimmed := s.acc.String()
|
||||
s.acc.Reset()
|
||||
return "", untrimmed, false
|
||||
}
|
||||
case thinkingState_ThinkingStartedEatingWhitespace:
|
||||
trimmed := strings.TrimLeftFunc(s.acc.String(), unicode.IsSpace)
|
||||
|
|
|
|||
|
|
@ -58,6 +58,15 @@ func TestThinkingStreaming(t *testing.T) {
|
|||
wantContent: " abc",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
// regression test for a bug where we were transitioning directly to
|
||||
// ThinkingDone without clearing the buffer. This would cuase the first
|
||||
// step to be outputted twice
|
||||
{
|
||||
input: "def",
|
||||
wantThinking: "",
|
||||
wantContent: "def",
|
||||
wantStateAfter: thinkingState_ThinkingDone,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
|
|
|||
|
|
@ -224,22 +224,45 @@ func findArguments(buffer []byte) (map[string]any, int) {
|
|||
return nil, 0
|
||||
}
|
||||
|
||||
start := -1
|
||||
var braces int
|
||||
var start int = -1
|
||||
var inString, escaped bool
|
||||
|
||||
for i := range buffer {
|
||||
c := buffer[i]
|
||||
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
|
||||
if c == '\\' {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
|
||||
if c == '"' {
|
||||
inString = !inString
|
||||
continue
|
||||
}
|
||||
|
||||
if inString {
|
||||
continue
|
||||
}
|
||||
|
||||
for i, c := range buffer {
|
||||
if c == '{' {
|
||||
if braces == 0 {
|
||||
start = i
|
||||
}
|
||||
braces++
|
||||
} else if c == '}' && braces > 0 {
|
||||
} else if c == '}' {
|
||||
braces--
|
||||
if braces == 0 && start != -1 {
|
||||
object := buffer[start : i+1]
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal(object, &data); err != nil {
|
||||
// not a valid object, keep looking
|
||||
start = -1
|
||||
continue
|
||||
}
|
||||
|
|
@ -282,6 +305,10 @@ func findArguments(buffer []byte) (map[string]any, int) {
|
|||
|
||||
return data, i
|
||||
}
|
||||
|
||||
if braces < 0 {
|
||||
braces = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package tools
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"text/template"
|
||||
|
||||
|
|
@ -40,13 +41,7 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolFunction{
|
||||
Name: "get_temperature",
|
||||
Description: "Retrieve the temperature for a given location",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]api.ToolProperty `json:"properties"`
|
||||
}{
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"city"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
|
|
@ -68,13 +63,7 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolFunction{
|
||||
Name: "get_conditions",
|
||||
Description: "Retrieve the current weather conditions for a given location",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]api.ToolProperty `json:"properties"`
|
||||
}{
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
|
|
@ -104,13 +93,7 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolFunction{
|
||||
Name: "get_address",
|
||||
Description: "Get the address of a given location",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]api.ToolProperty `json:"properties"`
|
||||
}{
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
|
|
@ -126,13 +109,7 @@ func TestParser(t *testing.T) {
|
|||
Function: api.ToolFunction{
|
||||
Name: "add",
|
||||
Description: "Add two numbers",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]api.ToolProperty `json:"properties"`
|
||||
}{
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"a": {
|
||||
|
|
@ -1140,11 +1117,163 @@ func TestFindArguments(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "deepseek",
|
||||
buffer: []byte(`", "arguments": {"location": "Tokyo"}}</tool_call>`),
|
||||
buffer: []byte(`"arguments": {"location": "Tokyo"}}</tool_call>`),
|
||||
want: map[string]any{
|
||||
"location": "Tokyo",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "string with braces",
|
||||
buffer: []byte(`{"name": "process_code", "arguments": {"code": "if (x > 0) { return true; }"}}`),
|
||||
want: map[string]any{
|
||||
"code": "if (x > 0) { return true; }",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "string with nested json",
|
||||
buffer: []byte(`{"name": "send_data", "arguments": {"payload": "{\"nested\": {\"key\": \"value\"}}"}}`),
|
||||
want: map[string]any{
|
||||
"payload": `{"nested": {"key": "value"}}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "string with escaped quotes and braces",
|
||||
buffer: []byte(`{"name": "analyze", "arguments": {"text": "The JSON is: {\"key\": \"val{ue}\"}"}}`),
|
||||
want: map[string]any{
|
||||
"text": `The JSON is: {"key": "val{ue}"}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple objects with string containing braces",
|
||||
buffer: []byte(`{"name": "test", "arguments": {"query": "find } in text"}} {"name": "other"}`),
|
||||
want: map[string]any{
|
||||
"query": "find } in text",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unmatched closing brace in string",
|
||||
buffer: []byte(`{"name": "search", "arguments": {"pattern": "regex: }"}}`),
|
||||
want: map[string]any{
|
||||
"pattern": "regex: }",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "complex nested with mixed braces",
|
||||
buffer: []byte(`{"name": "analyze", "arguments": {"data": "{\"items\": [{\"value\": \"}\"}, {\"code\": \"if (x) { return y; }\"}]}"}}`),
|
||||
want: map[string]any{
|
||||
"data": `{"items": [{"value": "}"}, {"code": "if (x) { return y; }"}]}`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "string with newline and braces",
|
||||
buffer: []byte(`{"name": "format", "arguments": {"template": "{\n \"key\": \"value\"\n}"}}`),
|
||||
want: map[string]any{
|
||||
"template": "{\n \"key\": \"value\"\n}",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "string with unicode escape",
|
||||
buffer: []byte(`{"name": "test", "arguments": {"text": "Unicode: \u007B and \u007D"}}`),
|
||||
want: map[string]any{
|
||||
"text": "Unicode: { and }",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "array arguments",
|
||||
buffer: []byte(`{"name": "batch", "arguments": ["item1", "item2", "{\"nested\": true}"]}`),
|
||||
want: nil, // This should return nil because arguments is not a map
|
||||
},
|
||||
{
|
||||
name: "escaped backslash before quote",
|
||||
buffer: []byte(`{"name": "path", "arguments": {"dir": "C:\\Program Files\\{App}\\"}}`),
|
||||
want: map[string]any{
|
||||
"dir": `C:\Program Files\{App}\`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single quotes not treated as string delimiters",
|
||||
buffer: []byte(`{"name": "query", "arguments": {"sql": "SELECT * FROM users WHERE name = '{admin}'"}}`),
|
||||
want: map[string]any{
|
||||
"sql": "SELECT * FROM users WHERE name = '{admin}'",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "incomplete json at buffer end",
|
||||
buffer: []byte(`{"name": "test", "arguments": {"data": "some {"`),
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "multiple escaped quotes",
|
||||
buffer: []byte(`{"name": "echo", "arguments": {"msg": "He said \"Hello {World}\" loudly"}}`),
|
||||
want: map[string]any{
|
||||
"msg": `He said "Hello {World}" loudly`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "json with comments style string",
|
||||
buffer: []byte(`{"name": "code", "arguments": {"snippet": "// This is a comment with { and }"}}`),
|
||||
want: map[string]any{
|
||||
"snippet": "// This is a comment with { and }",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "consecutive escaped backslashes",
|
||||
buffer: []byte(`{"name": "test", "arguments": {"path": "C:\\\\{folder}\\\\"}}`),
|
||||
want: map[string]any{
|
||||
"path": `C:\\{folder}\\`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty string with braces after",
|
||||
buffer: []byte(`{"name": "test", "arguments": {"a": "", "b": "{value}"}}`),
|
||||
want: map[string]any{
|
||||
"a": "",
|
||||
"b": "{value}",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "unicode in key names",
|
||||
buffer: []byte(`{"name": "test", "arguments": {"key{": "value", "key}": "value2"}}`),
|
||||
want: map[string]any{
|
||||
"key{": "value",
|
||||
"key}": "value2",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "very long string with braces",
|
||||
buffer: []byte(`{"name": "test", "arguments": {"data": "` + strings.Repeat("a{b}c", 100) + `"}}`),
|
||||
want: map[string]any{
|
||||
"data": strings.Repeat("a{b}c", 100),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tab characters and braces",
|
||||
buffer: []byte(`{"name": "test", "arguments": {"code": "\tif (true) {\n\t\treturn;\n\t}"}}`),
|
||||
want: map[string]any{
|
||||
"code": "\tif (true) {\n\t\treturn;\n\t}",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "null byte in string",
|
||||
buffer: []byte(`{"name": "test", "arguments": {"data": "before\u0000{after}"}}`),
|
||||
want: map[string]any{
|
||||
"data": "before\x00{after}",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "escaped quote at end of string",
|
||||
buffer: []byte(`{"name": "test", "arguments": {"data": "text with quote at end\\\""}}`),
|
||||
want: map[string]any{
|
||||
"data": `text with quote at end\"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mixed array and object in arguments",
|
||||
buffer: []byte(`{"name": "test", "arguments": {"items": ["{", "}", {"key": "value"}]}}`),
|
||||
want: map[string]any{
|
||||
"items": []any{"{", "}", map[string]any{"key": "value"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
|
|
|||
Loading…
Reference in New Issue